Skip to content

Commit

Permalink
fix(llm): fix gemini transformer plugins, incorrect URL and response …
Browse files Browse the repository at this point in the history
…error parsers (#13703) (#10447)

* fix(gemini-ai): incorrect URL parser for transformer plugins

* fix(ai-transformers): incorrect return parameter used for parser error handling

AG-113

(cherry picked from commit b49cf31)

Co-authored-by: Jack Tysoe <[email protected]>
  • Loading branch information
team-gateway-bot and tysoekong authored Oct 16, 2024
1 parent 5467cee commit 91c762f
Show file tree
Hide file tree
Showing 7 changed files with 199 additions and 12 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
message: "**ai-proxy**: Fixed an issue where AI Transformer plugins always returned a 404 error when using 'Google One' Gemini subscriptions."
type: bugfix
scope: Plugin
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
message: "**ai-transformers**: Fixed a bug where the correct LLM error message was not propagated to the caller."
type: bugfix
scope: Plugin
48 changes: 39 additions & 9 deletions kong/llm/drivers/gemini.lua
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,19 @@ local function from_gemini_chat_openai(response, model_info, route_type)
}
end

else -- probably a server fault or other unexpected response
elseif response.candidates
and #response.candidates > 0
and response.candidates[1].finishReason
and response.candidates[1].finishReason == "SAFETY" then
local err = "transformation generation candidate breached Gemini content safety"
ngx.log(ngx.ERR, err)
return nil, err

else-- probably a server fault or other unexpected response
local err = "no generation candidates received from Gemini, or max_tokens too short"
ngx.log(ngx.ERR, err)
return nil, err

end

return cjson.encode(messages)
Expand Down Expand Up @@ -284,13 +293,34 @@ function _M.subrequest(body, conf, http_opts, return_res_table, identity_interfa
return nil, nil, "body must be table or string"
end

-- may be overridden
local url = (conf.model.options and conf.model.options.upstream_url)
or fmt(
"%s%s",
ai_shared.upstream_url_format[DRIVER_NAME],
ai_shared.operation_map[DRIVER_NAME][conf.route_type].path
)
local operation = llm_state.is_streaming_mode() and "streamGenerateContent"
or "generateContent"
local f_url = conf.model.options and conf.model.options.upstream_url

if not f_url then -- upstream_url override is not set
-- check if this is "public" or "vertex" gemini deployment
if conf.model.options
and conf.model.options.gemini
and conf.model.options.gemini.api_endpoint
and conf.model.options.gemini.project_id
and conf.model.options.gemini.location_id
then
-- vertex mode
f_url = fmt(ai_shared.upstream_url_format["gemini_vertex"],
conf.model.options.gemini.api_endpoint) ..
fmt(ai_shared.operation_map["gemini_vertex"][conf.route_type].path,
conf.model.options.gemini.project_id,
conf.model.options.gemini.location_id,
conf.model.name,
operation)
else
-- public mode
f_url = ai_shared.upstream_url_format["gemini"] ..
fmt(ai_shared.operation_map["gemini"][conf.route_type].path,
conf.model.name,
operation)
end
end

local method = ai_shared.operation_map[DRIVER_NAME][conf.route_type].method

Expand Down Expand Up @@ -319,7 +349,7 @@ function _M.subrequest(body, conf, http_opts, return_res_table, identity_interfa
headers[conf.auth.header_name] = conf.auth.header_value
end

local res, err, httpc = ai_shared.http_request(url, body_string, method, headers, http_opts, return_res_table)
local res, err, httpc = ai_shared.http_request(f_url, body_string, method, headers, http_opts, return_res_table)
if err then
return nil, nil, "request to ai service failed: " .. err
end
Expand Down
6 changes: 3 additions & 3 deletions kong/llm/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -153,18 +153,18 @@ do
if err then
return nil, err
end

-- run the shared logging/analytics/auth function
ai_shared.pre_request(self.conf, ai_request)

-- send it to the ai service
local ai_response, _, err = self.driver.subrequest(ai_request, self.conf, http_opts, false, self.identity_interface)
if err then
return nil, "failed to introspect request with AI service: " .. err
end

-- parse and convert the response
local ai_response, _, err = self.driver.from_format(ai_response, self.conf.model, self.conf.route_type)
local ai_response, err, _ = self.driver.from_format(ai_response, self.conf.model, self.conf.route_type)
if err then
return nil, "failed to convert AI response to Kong format: " .. err
end
Expand Down
60 changes: 60 additions & 0 deletions spec/03-plugins/39-ai-request-transformer/02-integration_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,29 @@ local OPENAI_FLAT_RESPONSE = {
},
}

local GEMINI_GOOD = {
route_type = "llm/v1/chat",
logging = {
log_payloads = false,
log_statistics = true,
},
model = {
name = "gemini-1.5-flash",
provider = "gemini",
options = {
max_tokens = 512,
temperature = 0.5,
upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/failssafety",
input_cost = 10.0,
output_cost = 10.0,
},
},
auth = {
header_name = "x-goog-api-key",
header_value = "123",
},
}

local OPENAI_BAD_REQUEST = {
route_type = "llm/v1/chat",
model = {
Expand Down Expand Up @@ -183,6 +206,15 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
}
}
location = "/failssafety" {
content_by_lua_block {
local pl_file = require "pl.file"
ngx.status = 200
ngx.print(pl_file.read("spec/fixtures/ai-proxy/gemini/llm-v1-chat/responses/fails_safety.json"))
}
}
location = "/internalservererror" {
content_by_lua_block {
local pl_file = require "pl.file"
Expand Down Expand Up @@ -229,6 +261,18 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
},
}

local fails_safety = assert(bp.routes:insert {
paths = { "/echo-fails-safety" }
})
bp.plugins:insert {
name = PLUGIN_NAME,
route = { id = fails_safety.id },
config = {
prompt = SYSTEM_PROMPT,
llm = GEMINI_GOOD,
},
}

local internal_server_error = assert(bp.routes:insert {
paths = { "/echo-internal-server-error" }
})
Expand Down Expand Up @@ -333,6 +377,22 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
assert.same({ error = { message = "failed to introspect request with AI service: status code 400" }}, body_table)
end)

it("fails Gemini content-safety", function()
local r = client:get("/echo-fails-safety", {
headers = {
["content-type"] = "application/json",
["accept"] = "application/json",
},
body = REQUEST_BODY,
})

local body = assert.res_status(400 , r)
local body_table, err = cjson.decode(body)

assert.is_nil(err)
assert.match_re(body_table.error.message, ".*transformation generation candidate breached Gemini content safety.*")
end)

it("internal server error from LLM", function()
local r = client:get("/echo-internal-server-error", {
headers = {
Expand Down
61 changes: 61 additions & 0 deletions spec/03-plugins/40-ai-response-transformer/02-integration_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,29 @@ local OPENAI_FLAT_RESPONSE = {
},
}

local GEMINI_GOOD = {
route_type = "llm/v1/chat",
logging = {
log_payloads = false,
log_statistics = true,
},
model = {
name = "gemini-1.5-flash",
provider = "gemini",
options = {
max_tokens = 512,
temperature = 0.5,
upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/failssafety",
input_cost = 10.0,
output_cost = 10.0,
},
},
auth = {
header_name = "x-goog-api-key",
header_value = "123",
},
}

local OPENAI_BAD_INSTRUCTIONS = {
route_type = "llm/v1/chat",
model = {
Expand Down Expand Up @@ -256,6 +279,15 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
}
}
location = "/failssafety" {
content_by_lua_block {
local pl_file = require "pl.file"
ngx.status = 200
ngx.print(pl_file.read("spec/fixtures/ai-proxy/gemini/llm-v1-chat/responses/fails_safety.json"))
}
}
location = "/internalservererror" {
content_by_lua_block {
local pl_file = require "pl.file"
Expand Down Expand Up @@ -344,6 +376,19 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
},
}

local fails_safety = assert(bp.routes:insert {
paths = { "/echo-fails-safety" }
})
bp.plugins:insert {
name = PLUGIN_NAME,
route = { id = fails_safety.id },
config = {
prompt = SYSTEM_PROMPT,
parse_llm_response_json_instructions = false,
llm = GEMINI_GOOD,
},
}

local internal_server_error = assert(bp.routes:insert {
paths = { "/echo-internal-server-error" }
})
Expand Down Expand Up @@ -491,6 +536,22 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
assert.same({ error = { message = "failed to introspect request with AI service: status code 400" }}, body_table)
end)

it("fails Gemini content-safety", function()
local r = client:get("/echo-fails-safety", {
headers = {
["content-type"] = "application/json",
["accept"] = "application/json",
},
body = REQUEST_BODY,
})

local body = assert.res_status(400 , r)
local body_table, err = cjson.decode(body)

assert.is_nil(err)
assert.match_re(body_table.error.message, ".*transformation generation candidate breached Gemini content safety.*")
end)

it("internal server error from LLM", function()
local r = client:get("/echo-internal-server-error", {
headers = {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
{
"candidates": [
{
"finishReason": "SAFETY",
"index": 0,
"safetyRatings": [
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"probability": "NEGLIGIBLE"
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"probability": "NEGLIGIBLE"
},
{
"category": "HARM_CATEGORY_HARASSMENT",
"probability": "NEGLIGIBLE"
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"probability": "MEDIUM"
}
]
}
],
"usageMetadata": {
"promptTokenCount": 319,
"totalTokenCount": 319
}
}

0 comments on commit 91c762f

Please sign in to comment.