diff --git a/kong/llm/drivers/shared.lua b/kong/llm/drivers/shared.lua index 92bc72ea6416..8004b9384df1 100644 --- a/kong/llm/drivers/shared.lua +++ b/kong/llm/drivers/shared.lua @@ -1,21 +1,13 @@ local _M = {} -- imports -<<<<<<< HEAD -local cjson = require("cjson.safe") -local http = require("resty.http") -local fmt = string.format -local os = os -local parse_url = require("socket.url").parse -local aws_stream = require("kong.tools.aws_stream") -======= local cjson = require("cjson.safe") local http = require("resty.http") local fmt = string.format local os = os local parse_url = require("socket.url").parse local llm_state = require("kong.llm.state") ->>>>>>> d9053432f9 (refactor(plugins): move shared ctx usage of ai plugins to use a proper API) +local aws_stream = require("kong.tools.aws_stream") -- -- static diff --git a/kong/llm/proxy/handler.lua b/kong/llm/proxy/handler.lua index 21013518f87f..82769b625b05 100644 --- a/kong/llm/proxy/handler.lua +++ b/kong/llm/proxy/handler.lua @@ -121,8 +121,7 @@ local function get_token_text(event_t) return (type(token_text) == "string" and token_text) or "" end - -local function handle_streaming_frame(conf) +local function handle_streaming_frame(conf, chunk, finished) -- make a re-usable framebuffer local framebuffer = buffer.new() local is_gzip = kong.response.get_header("Content-Encoding") == "gzip" @@ -136,9 +135,6 @@ local function handle_streaming_frame(conf) kong_ctx_plugin.ai_stream_log_buffer = buffer.new() end - -- now handle each chunk/frame - local chunk = ngx.arg[1] - local finished = ngx.arg[2] if type(chunk) == "string" and chunk ~= "" then -- transform each one into flat format, skipping transformer errors @@ -261,144 +257,116 @@ local function handle_streaming_frame(conf) end end -function _M:header_filter(conf) - -- free up the buffered body used in the access phase - kong.ctx.shared.ai_request_body = nil - - local kong_ctx_plugin = kong.ctx.plugin +local function transform_body(conf) + local route_type = conf.route_type + local ai_driver = require("kong.llm.drivers." .. conf.model.provider) - if llm_state.is_response_transformer_skipped() then - return - end + -- Note: below even if we are told not to do response transform, we still need to do + -- get the body for analytics - -- clear shared restricted headers - for _, v in ipairs(ai_shared.clear_response_headers.shared) do - kong.response.clear_header(v) - end + -- try parsed response from other plugin first + local response_body = llm_state.get_parsed_response() + -- read from upstream if it's not been parsed/transformed by other plugins + if not response_body then + response_body = kong.service.response.get_raw_body() - -- only act on 200 in first release - pass the unmodifed response all the way through if any failure - if kong.response.get_status() ~= 200 then - return + if response_body and kong.service.response.get_header("Content-Encoding") == "gzip" then + response_body = kong_utils.inflate_gzip(response_body) + end end - -- we use openai's streaming mode (SSE) - if llm_state.is_streaming_mode() then - -- we are going to send plaintext event-stream frames for ALL models - kong.response.set_header("Content-Type", "text/event-stream") - return - end + local err - local response_body = kong.service.response.get_raw_body() if not response_body then - return - end + err = "no response body found when transforming response" - local ai_driver = require("kong.llm.drivers." .. conf.model.provider) - local route_type = conf.route_type + elseif route_type ~= "preserve" then + response_body, err = ai_driver.from_format(response_body, conf.model, route_type) - -- if this is a 'streaming' request, we can't know the final - -- result of the response body, so we just proceed to body_filter - -- to translate each SSE event frame - if not llm_state.is_streaming_mode() then - local is_gzip = kong.response.get_header("Content-Encoding") == "gzip" - if is_gzip then - response_body = kong_utils.inflate_gzip(response_body) + if err then + kong.log.err("issue when transforming the response body for analytics: ", err) end + end - if route_type == "preserve" then - kong_ctx_plugin.parsed_response = response_body - else - local new_response_string, err = ai_driver.from_format(response_body, conf.model, route_type) - if err then - kong_ctx_plugin.ai_parser_error = true + if err then + ngx.status = 500 + response_body = cjson.encode({ error = { message = err }}) - ngx.status = 500 - kong_ctx_plugin.parsed_response = cjson.encode({ error = { message = err } }) + else + ai_shared.post_request(conf, response_body) + end - elseif new_response_string then - -- preserve the same response content type; assume the from_format function - -- has returned the body in the appropriate response output format - kong_ctx_plugin.parsed_response = new_response_string - end - end + if accept_gzip() then + response_body = kong_utils.deflate_gzip(response_body) end - ai_driver.post_request(conf) + kong.ctx.plugin.buffered_response_body = response_body end +function _M:header_filter(conf) + -- free up the buffered body used in the access phase + llm_state.set_request_body_table(nil) -function _M:body_filter(conf) - local kong_ctx_plugin = kong.ctx.plugin + local ai_driver = require("kong.llm.drivers." .. conf.model.provider) + ai_driver.post_request(conf) - -- if body_filter is called twice, then return - if kong_ctx_plugin.body_called and not llm_state.is_streaming_mode() then + if llm_state.should_disable_ai_proxy_response_transform() then return end - local route_type = conf.route_type + -- only act on 200 in first release - pass the unmodifed response all the way through if any failure + if kong.response.get_status() ~= 200 then + return + end - if llm_state.is_response_transformer_skipped() and (route_type ~= "preserve") then - local response_body = llm_state.get_parsed_response() - - if not response_body and kong.response.get_status() == 200 then - response_body = kong.service.response.get_raw_body() - if not response_body then - kong.log.warn("issue when retrieve the response body for analytics in the body filter phase.", - " Please check AI request transformer plugin response.") - else - local is_gzip = kong.response.get_header("Content-Encoding") == "gzip" - if is_gzip then - response_body = kong_utils.inflate_gzip(response_body) - end - end - else - kong.response.exit(500, "no response body found") - end + -- if not streaming, prepare the response body buffer + -- this must be called before sending any response headers so that + -- we can modify status code if needed + if not llm_state.is_streaming_mode() then + transform_body(conf) + end - local ai_driver = require("kong.llm.drivers." .. conf.model.provider) - local new_response_string, err = ai_driver.from_format(response_body, conf.model, route_type) + -- clear shared restricted headers + for _, v in ipairs(ai_shared.clear_response_headers.shared) do + kong.response.clear_header(v) + end - if err then - kong.log.warn("issue when transforming the response body for analytics in the body filter phase, ", err) + -- we use openai's streaming mode (SSE) + if llm_state.is_streaming_mode() then + -- we are going to send plaintext event-stream frames for ALL models + kong.response.set_header("Content-Type", "text/event-stream") + end - elseif new_response_string then - ai_shared.post_request(conf, new_response_string) - end + if accept_gzip() then + kong.response.set_header("Content-Encoding", "gzip") + else + kong.response.clear_header("Content-Encoding") end +end - if not llm_state.is_response_transformer_skipped() then - if (kong.response.get_status() ~= 200) and (not kong_ctx_plugin.ai_parser_error) then - return - end - if route_type ~= "preserve" then - if llm_state.is_streaming_mode() then - handle_streaming_frame(conf) - else - -- all errors MUST be checked and returned in header_filter - -- we should receive a replacement response body from the same thread - local original_request = kong_ctx_plugin.parsed_response - local deflated_request = original_request - - if deflated_request then - local is_gzip = kong.response.get_header("Content-Encoding") == "gzip" - if is_gzip then - deflated_request = kong_utils.deflate_gzip(deflated_request) - end +-- body filter is only used for streaming mode; for non-streaming mode, everything +-- is already done in header_filter. This is because it would be too late to +-- send the status code if we are modifying non-streaming body in body_filter +function _M:body_filter(conf) + if kong.service.response.get_status() ~= 200 then + return + end - kong.response.set_raw_body(deflated_request) - end + -- emit the full body if not streaming + if not llm_state.is_streaming_mode() then + ngx.arg[1] = kong.ctx.plugin.buffered_response_body + ngx.arg[2] = true - -- call with replacement body, or original body if nothing changed - local _, err = ai_shared.post_request(conf, original_request) - if err then - kong.log.warn("analytics phase failed for request, ", err) - end - end - end + kong.ctx.plugin.buffered_response_body = nil + return end - kong_ctx_plugin.body_called = true + if not llm_state.should_disable_ai_proxy_response_transform() and + conf.route_type ~= "preserve" then + + handle_streaming_frame(conf, ngx.arg[1], ngx.arg[2]) + end end @@ -474,12 +442,10 @@ function _M:access(conf) kong_ctx_plugin.llm_model_requested = conf_m.model.name -- check the incoming format is the same as the configured LLM format - if not multipart then - local compatible, err = llm.is_compatible(request_table, route_type) - if not compatible then - llm_state.set_response_transformer_skipped() - return bail(400, err) - end + local compatible, err = llm.is_compatible(request_table, route_type) + if not multipart and not compatible then + llm_state.disable_ai_proxy_response_transform() + return bail(400, err) end -- check if the user has asked for a stream, and/or if @@ -527,7 +493,7 @@ function _M:access(conf) -- transform the body to Kong-format for this provider/model parsed_request_body, content_type, err = ai_driver.to_format(request_table, conf_m.model, route_type) if err then - llm_state.set_response_transformer_skipped() + llm_state.disable_ai_proxy_response_transform() return bail(400, err) end end @@ -554,7 +520,7 @@ function _M:access(conf) local ok, err = ai_driver.configure_request(conf_m, identity_interface and identity_interface.interface) if not ok then - llm_state.set_response_transformer_skipped() + llm_state.disable_ai_proxy_response_transform() kong.log.err("failed to configure request for AI service: ", err) return bail(500) end diff --git a/kong/llm/state.lua b/kong/llm/state.lua index e45a25f16ef3..fa2c29edf8c1 100644 --- a/kong/llm/state.lua +++ b/kong/llm/state.lua @@ -1,5 +1,7 @@ local _M = {} +-- Set disable_ai_proxy_response_transform if response is just a error message or has been generated +-- by plugin and should skip further transformation function _M.disable_ai_proxy_response_transform() kong.ctx.shared.llm_disable_ai_proxy_response_transform = true end diff --git a/kong/plugins/ai-request-transformer/handler.lua b/kong/plugins/ai-request-transformer/handler.lua index c7316cf35edc..1bad3a92db3d 100644 --- a/kong/plugins/ai-request-transformer/handler.lua +++ b/kong/plugins/ai-request-transformer/handler.lua @@ -41,7 +41,7 @@ end function _M:access(conf) kong.service.request.enable_buffering() - llm_state.set_response_transformer_skipped() + llm_state.should_disable_ai_proxy_response_transform() -- first find the configured LLM interface and driver local http_opts = create_http_opts(conf) diff --git a/kong/plugins/ai-response-transformer/handler.lua b/kong/plugins/ai-response-transformer/handler.lua index a350ee74e6fd..815b64f351fa 100644 --- a/kong/plugins/ai-response-transformer/handler.lua +++ b/kong/plugins/ai-response-transformer/handler.lua @@ -100,7 +100,7 @@ end function _M:access(conf) kong.service.request.enable_buffering() - llm_state.set_response_transformer_skipped() + llm_state.disable_ai_proxy_response_transform() -- first find the configured LLM interface and driver local http_opts = create_http_opts(conf)