Skip to content

Commit

Permalink
refactor(ai-proxy): cleanup body_filter and header_filter
Browse files Browse the repository at this point in the history
  • Loading branch information
fffonion committed Aug 7, 2024
1 parent 5a41d3a commit 94d3ddc
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 129 deletions.
10 changes: 1 addition & 9 deletions kong/llm/drivers/shared.lua
Original file line number Diff line number Diff line change
@@ -1,21 +1,13 @@
local _M = {}

-- imports
<<<<<<< HEAD
local cjson = require("cjson.safe")
local http = require("resty.http")
local fmt = string.format
local os = os
local parse_url = require("socket.url").parse
local aws_stream = require("kong.tools.aws_stream")
=======
local cjson = require("cjson.safe")
local http = require("resty.http")
local fmt = string.format
local os = os
local parse_url = require("socket.url").parse
local llm_state = require("kong.llm.state")
>>>>>>> d9053432f9 (refactor(plugins): move shared ctx usage of ai plugins to use a proper API)
local aws_stream = require("kong.tools.aws_stream")
--

-- static
Expand Down
202 changes: 84 additions & 118 deletions kong/llm/proxy/handler.lua
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,7 @@ local function get_token_text(event_t)
return (type(token_text) == "string" and token_text) or ""
end


local function handle_streaming_frame(conf)
local function handle_streaming_frame(conf, chunk, finished)
-- make a re-usable framebuffer
local framebuffer = buffer.new()
local is_gzip = kong.response.get_header("Content-Encoding") == "gzip"
Expand All @@ -136,9 +135,6 @@ local function handle_streaming_frame(conf)
kong_ctx_plugin.ai_stream_log_buffer = buffer.new()
end

-- now handle each chunk/frame
local chunk = ngx.arg[1]
local finished = ngx.arg[2]

if type(chunk) == "string" and chunk ~= "" then
-- transform each one into flat format, skipping transformer errors
Expand Down Expand Up @@ -261,144 +257,116 @@ local function handle_streaming_frame(conf)
end
end

function _M:header_filter(conf)
-- free up the buffered body used in the access phase
kong.ctx.shared.ai_request_body = nil

local kong_ctx_plugin = kong.ctx.plugin
local function transform_body(conf)
local route_type = conf.route_type
local ai_driver = require("kong.llm.drivers." .. conf.model.provider)

if llm_state.is_response_transformer_skipped() then
return
end
-- Note: below even if we are told not to do response transform, we still need to do
-- get the body for analytics

-- clear shared restricted headers
for _, v in ipairs(ai_shared.clear_response_headers.shared) do
kong.response.clear_header(v)
end
-- try parsed response from other plugin first
local response_body = llm_state.get_parsed_response()
-- read from upstream if it's not been parsed/transformed by other plugins
if not response_body then
response_body = kong.service.response.get_raw_body()

-- only act on 200 in first release - pass the unmodifed response all the way through if any failure
if kong.response.get_status() ~= 200 then
return
if response_body and kong.service.response.get_header("Content-Encoding") == "gzip" then
response_body = kong_utils.inflate_gzip(response_body)
end
end

-- we use openai's streaming mode (SSE)
if llm_state.is_streaming_mode() then
-- we are going to send plaintext event-stream frames for ALL models
kong.response.set_header("Content-Type", "text/event-stream")
return
end
local err

local response_body = kong.service.response.get_raw_body()
if not response_body then
return
end
err = "no response body found when transforming response"

local ai_driver = require("kong.llm.drivers." .. conf.model.provider)
local route_type = conf.route_type
elseif route_type ~= "preserve" then
response_body, err = ai_driver.from_format(response_body, conf.model, route_type)

-- if this is a 'streaming' request, we can't know the final
-- result of the response body, so we just proceed to body_filter
-- to translate each SSE event frame
if not llm_state.is_streaming_mode() then
local is_gzip = kong.response.get_header("Content-Encoding") == "gzip"
if is_gzip then
response_body = kong_utils.inflate_gzip(response_body)
if err then
kong.log.err("issue when transforming the response body for analytics: ", err)
end
end

if route_type == "preserve" then
kong_ctx_plugin.parsed_response = response_body
else
local new_response_string, err = ai_driver.from_format(response_body, conf.model, route_type)
if err then
kong_ctx_plugin.ai_parser_error = true
if err then
ngx.status = 500
response_body = cjson.encode({ error = { message = err }})

ngx.status = 500
kong_ctx_plugin.parsed_response = cjson.encode({ error = { message = err } })
else
ai_shared.post_request(conf, response_body)
end

elseif new_response_string then
-- preserve the same response content type; assume the from_format function
-- has returned the body in the appropriate response output format
kong_ctx_plugin.parsed_response = new_response_string
end
end
if accept_gzip() then
response_body = kong_utils.deflate_gzip(response_body)
end

ai_driver.post_request(conf)
kong.ctx.plugin.buffered_response_body = response_body
end

function _M:header_filter(conf)
-- free up the buffered body used in the access phase
llm_state.set_request_body_table(nil)

function _M:body_filter(conf)
local kong_ctx_plugin = kong.ctx.plugin
local ai_driver = require("kong.llm.drivers." .. conf.model.provider)
ai_driver.post_request(conf)

-- if body_filter is called twice, then return
if kong_ctx_plugin.body_called and not llm_state.is_streaming_mode() then
if llm_state.should_disable_ai_proxy_response_transform() then
return
end

local route_type = conf.route_type
-- only act on 200 in first release - pass the unmodifed response all the way through if any failure
if kong.response.get_status() ~= 200 then
return
end

if llm_state.is_response_transformer_skipped() and (route_type ~= "preserve") then
local response_body = llm_state.get_parsed_response()

if not response_body and kong.response.get_status() == 200 then
response_body = kong.service.response.get_raw_body()
if not response_body then
kong.log.warn("issue when retrieve the response body for analytics in the body filter phase.",
" Please check AI request transformer plugin response.")
else
local is_gzip = kong.response.get_header("Content-Encoding") == "gzip"
if is_gzip then
response_body = kong_utils.inflate_gzip(response_body)
end
end
else
kong.response.exit(500, "no response body found")
end
-- if not streaming, prepare the response body buffer
-- this must be called before sending any response headers so that
-- we can modify status code if needed
if not llm_state.is_streaming_mode() then
transform_body(conf)
end

local ai_driver = require("kong.llm.drivers." .. conf.model.provider)
local new_response_string, err = ai_driver.from_format(response_body, conf.model, route_type)
-- clear shared restricted headers
for _, v in ipairs(ai_shared.clear_response_headers.shared) do
kong.response.clear_header(v)
end

if err then
kong.log.warn("issue when transforming the response body for analytics in the body filter phase, ", err)
-- we use openai's streaming mode (SSE)
if llm_state.is_streaming_mode() then
-- we are going to send plaintext event-stream frames for ALL models
kong.response.set_header("Content-Type", "text/event-stream")
end

elseif new_response_string then
ai_shared.post_request(conf, new_response_string)
end
if accept_gzip() then
kong.response.set_header("Content-Encoding", "gzip")
else
kong.response.clear_header("Content-Encoding")
end
end

if not llm_state.is_response_transformer_skipped() then
if (kong.response.get_status() ~= 200) and (not kong_ctx_plugin.ai_parser_error) then
return
end

if route_type ~= "preserve" then
if llm_state.is_streaming_mode() then
handle_streaming_frame(conf)
else
-- all errors MUST be checked and returned in header_filter
-- we should receive a replacement response body from the same thread
local original_request = kong_ctx_plugin.parsed_response
local deflated_request = original_request

if deflated_request then
local is_gzip = kong.response.get_header("Content-Encoding") == "gzip"
if is_gzip then
deflated_request = kong_utils.deflate_gzip(deflated_request)
end
-- body filter is only used for streaming mode; for non-streaming mode, everything
-- is already done in header_filter. This is because it would be too late to
-- send the status code if we are modifying non-streaming body in body_filter
function _M:body_filter(conf)
if kong.service.response.get_status() ~= 200 then
return
end

kong.response.set_raw_body(deflated_request)
end
-- emit the full body if not streaming
if not llm_state.is_streaming_mode() then
ngx.arg[1] = kong.ctx.plugin.buffered_response_body
ngx.arg[2] = true

-- call with replacement body, or original body if nothing changed
local _, err = ai_shared.post_request(conf, original_request)
if err then
kong.log.warn("analytics phase failed for request, ", err)
end
end
end
kong.ctx.plugin.buffered_response_body = nil
return
end

kong_ctx_plugin.body_called = true
if not llm_state.should_disable_ai_proxy_response_transform() and
conf.route_type ~= "preserve" then

handle_streaming_frame(conf, ngx.arg[1], ngx.arg[2])
end
end


Expand Down Expand Up @@ -474,12 +442,10 @@ function _M:access(conf)
kong_ctx_plugin.llm_model_requested = conf_m.model.name

-- check the incoming format is the same as the configured LLM format
if not multipart then
local compatible, err = llm.is_compatible(request_table, route_type)
if not compatible then
llm_state.set_response_transformer_skipped()
return bail(400, err)
end
local compatible, err = llm.is_compatible(request_table, route_type)
if not multipart and not compatible then
llm_state.disable_ai_proxy_response_transform()
return bail(400, err)
end

-- check if the user has asked for a stream, and/or if
Expand Down Expand Up @@ -527,7 +493,7 @@ function _M:access(conf)
-- transform the body to Kong-format for this provider/model
parsed_request_body, content_type, err = ai_driver.to_format(request_table, conf_m.model, route_type)
if err then
llm_state.set_response_transformer_skipped()
llm_state.disable_ai_proxy_response_transform()
return bail(400, err)
end
end
Expand All @@ -554,7 +520,7 @@ function _M:access(conf)
local ok, err = ai_driver.configure_request(conf_m,
identity_interface and identity_interface.interface)
if not ok then
llm_state.set_response_transformer_skipped()
llm_state.disable_ai_proxy_response_transform()
kong.log.err("failed to configure request for AI service: ", err)
return bail(500)
end
Expand Down
2 changes: 2 additions & 0 deletions kong/llm/state.lua
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
local _M = {}

-- Set disable_ai_proxy_response_transform if response is just a error message or has been generated
-- by plugin and should skip further transformation
function _M.disable_ai_proxy_response_transform()
kong.ctx.shared.llm_disable_ai_proxy_response_transform = true
end
Expand Down
2 changes: 1 addition & 1 deletion kong/plugins/ai-request-transformer/handler.lua
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ end

function _M:access(conf)
kong.service.request.enable_buffering()
llm_state.set_response_transformer_skipped()
llm_state.should_disable_ai_proxy_response_transform()

-- first find the configured LLM interface and driver
local http_opts = create_http_opts(conf)
Expand Down
2 changes: 1 addition & 1 deletion kong/plugins/ai-response-transformer/handler.lua
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ end

function _M:access(conf)
kong.service.request.enable_buffering()
llm_state.set_response_transformer_skipped()
llm_state.disable_ai_proxy_response_transform()

-- first find the configured LLM interface and driver
local http_opts = create_http_opts(conf)
Expand Down

0 comments on commit 94d3ddc

Please sign in to comment.