Skip to content

Commit

Permalink
refactor(plugins/ai-proxy): improve readability using early returns (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
chronolaw authored Apr 1, 2024
1 parent 4c37ce7 commit 6d3e3ab
Showing 1 changed file with 88 additions and 64 deletions.
152 changes: 88 additions & 64 deletions kong/plugins/ai-proxy/handler.lua
Original file line number Diff line number Diff line change
Expand Up @@ -5,91 +5,112 @@ local ai_shared = require("kong.llm.drivers.shared")
local llm = require("kong.llm")
local cjson = require("cjson.safe")
local kong_utils = require("kong.tools.gzip")
local kong_meta = require "kong.meta"
local kong_meta = require("kong.meta")
--


_M.PRIORITY = 770
_M.VERSION = kong_meta.version


-- reuse this table for error message response
local ERROR_MSG = { error = { message = "" } }


local function bad_request(msg)
kong.log.warn(msg)
return kong.response.exit(400, { error = { message = msg } })
ERROR_MSG.error.message = msg

return kong.response.exit(400, ERROR_MSG)
end


local function internal_server_error(msg)
kong.log.err(msg)
return kong.response.exit(500, { error = { message = msg } })
ERROR_MSG.error.message = msg

return kong.response.exit(500, ERROR_MSG)
end


function _M:header_filter(conf)
if not kong.ctx.shared.skip_response_transformer then
-- clear shared restricted headers
for i, v in ipairs(ai_shared.clear_response_headers.shared) do
kong.response.clear_header(v)
end
if kong.ctx.shared.skip_response_transformer then
return
end

-- only act on 200 in first release - pass the unmodifed response all the way through if any failure
if kong.response.get_status() == 200 then
local ai_driver = require("kong.llm.drivers." .. conf.model.provider)
local route_type = conf.route_type
local response_body = kong.service.response.get_raw_body()

if response_body then
local is_gzip = kong.response.get_header("Content-Encoding") == "gzip"

if is_gzip then
response_body = kong_utils.inflate_gzip(response_body)
end

local new_response_string, err = ai_driver.from_format(response_body, conf.model, route_type)
if err then
kong.ctx.plugin.ai_parser_error = true

ngx.status = 500
local message = {
error = {
message = err,
},
}

kong.ctx.plugin.parsed_response = cjson.encode(message)
elseif new_response_string then
-- preserve the same response content type; assume the from_format function
-- has returned the body in the appropriate response output format
kong.ctx.plugin.parsed_response = new_response_string
end

ai_driver.post_request(conf)
end
end
-- clear shared restricted headers
for _, v in ipairs(ai_shared.clear_response_headers.shared) do
kong.response.clear_header(v)
end

-- only act on 200 in first release - pass the unmodifed response all the way through if any failure
if kong.response.get_status() ~= 200 then
return
end

local response_body = kong.service.response.get_raw_body()
if not response_body then
return
end

local ai_driver = require("kong.llm.drivers." .. conf.model.provider)
local route_type = conf.route_type

local is_gzip = kong.response.get_header("Content-Encoding") == "gzip"
if is_gzip then
response_body = kong_utils.inflate_gzip(response_body)
end

local new_response_string, err = ai_driver.from_format(response_body, conf.model, route_type)
if err then
kong.ctx.plugin.ai_parser_error = true

ngx.status = 500
ERROR_MSG.error.message = err

kong.ctx.plugin.parsed_response = cjson.encode(ERROR_MSG)

elseif new_response_string then
-- preserve the same response content type; assume the from_format function
-- has returned the body in the appropriate response output format
kong.ctx.plugin.parsed_response = new_response_string
end

ai_driver.post_request(conf)
end


function _M:body_filter(conf)
if not kong.ctx.shared.skip_response_transformer then
if (kong.response.get_status() == 200) or (kong.ctx.plugin.ai_parser_error) then
-- all errors MUST be checked and returned in header_filter
-- we should receive a replacement response body from the same thread

local original_request = kong.ctx.plugin.parsed_response
local deflated_request = kong.ctx.plugin.parsed_response
if deflated_request then
local is_gzip = kong.response.get_header("Content-Encoding") == "gzip"
if is_gzip then
deflated_request = kong_utils.deflate_gzip(deflated_request)
end

kong.response.set_raw_body(deflated_request)
end

-- call with replacement body, or original body if nothing changed
ai_shared.post_request(conf, original_request)
if kong.ctx.shared.skip_response_transformer then
return
end

if (kong.response.get_status() ~= 200) and (not kong.ctx.plugin.ai_parser_error) then
return
end

-- (kong.response.get_status() == 200) or (kong.ctx.plugin.ai_parser_error)

-- all errors MUST be checked and returned in header_filter
-- we should receive a replacement response body from the same thread

local original_request = kong.ctx.plugin.parsed_response
local deflated_request = original_request

if deflated_request then
local is_gzip = kong.response.get_header("Content-Encoding") == "gzip"
if is_gzip then
deflated_request = kong_utils.deflate_gzip(deflated_request)
end

kong.response.set_raw_body(deflated_request)
end

-- call with replacement body, or original body if nothing changed
ai_shared.post_request(conf, original_request)
end


function _M:access(conf)
kong.service.request.enable_buffering()

Expand All @@ -100,10 +121,12 @@ function _M:access(conf)
local ai_driver = require("kong.llm.drivers." .. conf.model.provider)

local request_table

-- we may have received a replacement / decorated request body from another AI plugin
if kong.ctx.shared.replacement_request then
kong.log.debug("replacement request body received from another AI plugin")
request_table = kong.ctx.shared.replacement_request

else
-- first, calculate the coordinates of the request
local content_type = kong.request.get_header("Content-Type") or "application/json"
Expand All @@ -116,7 +139,7 @@ function _M:access(conf)
end

-- check the incoming format is the same as the configured LLM format
local compatible, err = llm.is_compatible(request_table, conf.route_type)
local compatible, err = llm.is_compatible(request_table, route_type)
if not compatible then
kong.ctx.shared.skip_response_transformer = true
return bad_request(err)
Expand Down Expand Up @@ -147,8 +170,9 @@ function _M:access(conf)
if not ok then
return internal_server_error(err)
end

-- lights out, and away we go
end


return _M

1 comment on commit 6d3e3ab

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bazel Build

Docker image available kong/kong:6d3e3abb11951e7b21b9a3f30804743c090bc10d
Artifacts available https://github.com/Kong/kong/actions/runs/8503493360

Please sign in to comment.