diff --git a/kong/llm/proxy/handler.lua b/kong/llm/proxy/handler.lua index 70d5789f2cd8..d0d957f803e4 100644 --- a/kong/llm/proxy/handler.lua +++ b/kong/llm/proxy/handler.lua @@ -9,13 +9,15 @@ local strip = require("kong.tools.utils").strip local EMPTY = {} ---- Return a 400 response with a JSON body. This function is used to --- return errors to the client while also logging the error. -local function bad_request(msg) - kong.log.info(msg) - return kong.response.exit(400, { error = { message = msg } }) -end +local function bail(code, msg) + if code == 400 and msg then + kong.log.info(msg) + end + if ngx.get_phase() ~= "balancer" then + return kong.response.exit(code, msg and { error = { message = msg } } or nil) + end +end -- get the token text from an event frame @@ -150,6 +152,9 @@ local function handle_streaming_frame(conf) end function _M:header_filter(conf) + -- free up the buffered body used in the access phase + kong.ctx.shared.ai_request_body = nil + local kong_ctx_plugin = kong.ctx.plugin local kong_ctx_shared = kong.ctx.shared @@ -299,6 +304,10 @@ function _M:access(conf) local request_table local multipart = false + -- TODO: the access phase may be called mulitple times also in the balancer phase + -- Refactor this function a bit so that we don't mess them in the same function + local balancer_phase = ngx.get_phase() == "balancer" + -- we may have received a replacement / decorated request body from another AI plugin if kong_ctx_shared.replacement_request then kong.log.debug("replacement request body received from another AI plugin") @@ -308,11 +317,19 @@ function _M:access(conf) -- first, calculate the coordinates of the request local content_type = kong.request.get_header("Content-Type") or "application/json" - request_table = kong.request.get_body(content_type) + request_table = kong_ctx_shared.ai_request_body + if not request_table then + if balancer_phase then + error("Too late to read body", 2) + end + + request_table = kong.request.get_body(content_type) + kong_ctx_shared.ai_request_body = request_table + end if not request_table then if not string.find(content_type, "multipart/form-data", nil, true) then - return bad_request("content-type header does not match request body") + return bail(400, "content-type header does not match request body") end multipart = true -- this may be a large file upload, so we have to proxy it directly @@ -322,7 +339,7 @@ function _M:access(conf) -- resolve the real plugin config values local conf_m, err = ai_shared.resolve_plugin_conf(kong.request, conf) if err then - return bad_request(err) + return bail(400, err) end -- copy from the user request if present @@ -337,13 +354,13 @@ function _M:access(conf) -- check that the user isn't trying to override the plugin conf model in the request body if request_table and request_table.model and type(request_table.model) == "string" and request_table.model ~= "" then if request_table.model ~= conf_m.model.name then - return bad_request("cannot use own model - must be: " .. conf_m.model.name) + return bail(400, "cannot use own model - must be: " .. conf_m.model.name) end end -- model is stashed in the copied plugin conf, for consistency in transformation functions if not conf_m.model.name then - return bad_request("model parameter not found in request, nor in gateway configuration") + return bail(400, "model parameter not found in request, nor in gateway configuration") end kong_ctx_plugin.llm_model_requested = conf_m.model.name @@ -353,7 +370,7 @@ function _M:access(conf) local compatible, err = llm.is_compatible(request_table, route_type) if not compatible then kong_ctx_shared.skip_response_transformer = true - return bad_request(err) + return bail(400, err) end end @@ -361,7 +378,7 @@ function _M:access(conf) local compatible, err = llm.is_compatible(request_table, route_type) if not compatible then kong_ctx_shared.skip_response_transformer = true - return bad_request(err) + return bail(400, err) end -- check if the user has asked for a stream, and/or if @@ -373,7 +390,7 @@ function _M:access(conf) -- this condition will only check if user has tried -- to activate streaming mode within their request if conf_m.response_streaming and conf_m.response_streaming == "deny" then - return bad_request("response streaming is not enabled for this LLM") + return bail(400, "response streaming is not enabled for this LLM") end -- store token cost estimate, on first pass @@ -381,7 +398,7 @@ function _M:access(conf) local prompt_tokens, err = ai_shared.calculate_cost(request_table or {}, {}, 1.8) if err then kong.log.err("unable to estimate request token cost: ", err) - return kong.response.exit(500) + return bail(500) end kong_ctx_plugin.ai_stream_prompt_tokens = prompt_tokens @@ -399,7 +416,7 @@ function _M:access(conf) -- execute pre-request hooks for this driver local ok, err = ai_driver.pre_request(conf_m, request_table) if not ok then - return bad_request(err) + return bail(400, err) end -- transform the body to Kong-format for this provider/model @@ -409,14 +426,14 @@ function _M:access(conf) parsed_request_body, content_type, err = ai_driver.to_format(request_table, conf_m.model, route_type) if err then kong_ctx_shared.skip_response_transformer = true - return bad_request(err) + return bail(400, err) end end -- execute pre-request hooks for "all" drivers before set new body local ok, err = ai_shared.pre_request(conf_m, parsed_request_body) if not ok then - return bad_request(err) + return bail(400, err) end if route_type ~= "preserve" then @@ -428,7 +445,7 @@ function _M:access(conf) if not ok then kong_ctx_shared.skip_response_transformer = true kong.log.err("failed to configure request for AI service: ", err) - return kong.response.exit(500) + return bail(500) end -- lights out, and away we go