Skip to content

Commit

Permalink
fix(ai-transformers): incorrect return parameter used for parser erro…
Browse files Browse the repository at this point in the history
…r handling
  • Loading branch information
tysoekong committed Sep 30, 2024
1 parent 456cbfd commit 3ee4797
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 4 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
message: "**ai-transformers**: Fixed a bug where the correct LLM error message was not propagated to the caller."
type: bugfix
scope: Plugin
11 changes: 10 additions & 1 deletion kong/llm/drivers/gemini.lua
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,19 @@ local function from_gemini_chat_openai(response, model_info, route_type)
}
end

else -- probably a server fault or other unexpected response
elseif response.candidates
and #response.candidates > 0
and response.candidates[1].finishReason
and response.candidates[1].finishReason == "SAFETY" then
local err = "transformation generation candidate breached Gemini content safety"
ngx.log(ngx.ERR, err)
return nil, err

else-- probably a server fault or other unexpected response
local err = "no generation candidates received from Gemini, or max_tokens too short"
ngx.log(ngx.ERR, err)
return nil, err

end

return cjson.encode(messages)
Expand Down
6 changes: 3 additions & 3 deletions kong/llm/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -142,18 +142,18 @@ do
if err then
return nil, err
end

-- run the shared logging/analytics/auth function
ai_shared.pre_request(self.conf, ai_request)

-- send it to the ai service
local ai_response, _, err = self.driver.subrequest(ai_request, self.conf, http_opts, false, self.identity_interface)
if err then
return nil, "failed to introspect request with AI service: " .. err
end

-- parse and convert the response
local ai_response, _, err = self.driver.from_format(ai_response, self.conf.model, self.conf.route_type)
local ai_response, err, _ = self.driver.from_format(ai_response, self.conf.model, self.conf.route_type)
if err then
return nil, "failed to convert AI response to Kong format: " .. err
end
Expand Down
60 changes: 60 additions & 0 deletions spec/03-plugins/39-ai-request-transformer/02-integration_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,29 @@ local OPENAI_FLAT_RESPONSE = {
},
}

local GEMINI_GOOD = {
route_type = "llm/v1/chat",
logging = {
log_payloads = false,
log_statistics = true,
},
model = {
name = "gemini-1.5-flash",
provider = "gemini",
options = {
max_tokens = 512,
temperature = 0.5,
upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/failssafety",
input_cost = 10.0,
output_cost = 10.0,
},
},
auth = {
header_name = "x-goog-api-key",
header_value = "123",
},
}

local OPENAI_BAD_REQUEST = {
route_type = "llm/v1/chat",
model = {
Expand Down Expand Up @@ -177,6 +200,15 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
}
}
location = "/failssafety" {
content_by_lua_block {
local pl_file = require "pl.file"
ngx.status = 200
ngx.print(pl_file.read("spec/fixtures/ai-proxy/gemini/llm-v1-chat/responses/fails_safety.json"))
}
}
location = "/internalservererror" {
content_by_lua_block {
local pl_file = require "pl.file"
Expand Down Expand Up @@ -223,6 +255,18 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
},
}

local fails_safety = assert(bp.routes:insert {
paths = { "/echo-fails-safety" }
})
bp.plugins:insert {
name = PLUGIN_NAME,
route = { id = fails_safety.id },
config = {
prompt = SYSTEM_PROMPT,
llm = GEMINI_GOOD,
},
}

local internal_server_error = assert(bp.routes:insert {
paths = { "/echo-internal-server-error" }
})
Expand Down Expand Up @@ -327,6 +371,22 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
assert.same({ error = { message = "failed to introspect request with AI service: status code 400" }}, body_table)
end)

it("fails Gemini content-safety", function()
local r = client:get("/echo-fails-safety", {
headers = {
["content-type"] = "application/json",
["accept"] = "application/json",
},
body = REQUEST_BODY,
})

local body = assert.res_status(400 , r)
local body_table, err = cjson.decode(body)

assert.is_nil(err)
assert.match_re(body_table.error.message, ".*transformation generation candidate breached Gemini content safety.*")
end)

it("internal server error from LLM", function()
local r = client:get("/echo-internal-server-error", {
headers = {
Expand Down
61 changes: 61 additions & 0 deletions spec/03-plugins/40-ai-response-transformer/02-integration_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,29 @@ local OPENAI_FLAT_RESPONSE = {
},
}

local GEMINI_GOOD = {
route_type = "llm/v1/chat",
logging = {
log_payloads = false,
log_statistics = true,
},
model = {
name = "gemini-1.5-flash",
provider = "gemini",
options = {
max_tokens = 512,
temperature = 0.5,
upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/failssafety",
input_cost = 10.0,
output_cost = 10.0,
},
},
auth = {
header_name = "x-goog-api-key",
header_value = "123",
},
}

local OPENAI_BAD_INSTRUCTIONS = {
route_type = "llm/v1/chat",
model = {
Expand Down Expand Up @@ -250,6 +273,15 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
}
}
location = "/failssafety" {
content_by_lua_block {
local pl_file = require "pl.file"
ngx.status = 200
ngx.print(pl_file.read("spec/fixtures/ai-proxy/gemini/llm-v1-chat/responses/fails_safety.json"))
}
}
location = "/internalservererror" {
content_by_lua_block {
local pl_file = require "pl.file"
Expand Down Expand Up @@ -338,6 +370,19 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
},
}

local fails_safety = assert(bp.routes:insert {
paths = { "/echo-fails-safety" }
})
bp.plugins:insert {
name = PLUGIN_NAME,
route = { id = fails_safety.id },
config = {
prompt = SYSTEM_PROMPT,
parse_llm_response_json_instructions = false,
llm = GEMINI_GOOD,
},
}

local internal_server_error = assert(bp.routes:insert {
paths = { "/echo-internal-server-error" }
})
Expand Down Expand Up @@ -485,6 +530,22 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
assert.same({ error = { message = "failed to introspect request with AI service: status code 400" }}, body_table)
end)

it("fails Gemini content-safety", function()
local r = client:get("/echo-fails-safety", {
headers = {
["content-type"] = "application/json",
["accept"] = "application/json",
},
body = REQUEST_BODY,
})

local body = assert.res_status(400 , r)
local body_table, err = cjson.decode(body)

assert.is_nil(err)
assert.match_re(body_table.error.message, ".*transformation generation candidate breached Gemini content safety.*")
end)

it("internal server error from LLM", function()
local r = client:get("/echo-internal-server-error", {
headers = {
Expand Down

0 comments on commit 3ee4797

Please sign in to comment.