From 360b4f0e7d8a7d2a7e1ca373fd4f7d43a76b37e1 Mon Sep 17 00:00:00 2001 From: Wangchong Zhou Date: Mon, 8 Jul 2024 15:32:51 +0800 Subject: [PATCH] feat(llm/proxy): allow balancer retry --- kong/llm/proxy/handler.lua | 57 +++++++++++++++++++++++++------------- 1 file changed, 37 insertions(+), 20 deletions(-) diff --git a/kong/llm/proxy/handler.lua b/kong/llm/proxy/handler.lua index fe7497e7198a..43e096c44b2d 100644 --- a/kong/llm/proxy/handler.lua +++ b/kong/llm/proxy/handler.lua @@ -30,13 +30,15 @@ local EMPTY = {} local _M = {} ---- Return a 400 response with a JSON body. This function is used to --- return errors to the client while also logging the error. -local function bad_request(msg) - kong.log.info(msg) - return kong.response.exit(400, { error = { message = msg } }) -end +local function bail(code, msg) + if code == 400 and msg then + kong.log.info(msg) + end + if ngx.get_phase() ~= "balancer" then + return kong.response.exit(code, msg and { error = { message = msg } } or nil) + end +end -- static messages @@ -275,6 +277,9 @@ local function handle_streaming_frame(conf) end function _M:header_filter(conf) + -- free up the buffered body used in the access phase + kong.ctx.shared.ai_request_body = nil + local kong_ctx_plugin = kong.ctx.plugin local kong_ctx_shared = kong.ctx.shared @@ -427,6 +432,10 @@ function _M:access(conf) local request_table local multipart = false + -- TODO: the access phase may be called mulitple times also in the balancer phase + -- Refactor this function a bit so that we don't mess them in the same function + local balancer_phase = ngx.get_phase() == "balancer" + -- we may have received a replacement / decorated request body from another AI plugin if kong_ctx_shared.replacement_request then kong.log.debug("replacement request body received from another AI plugin") @@ -436,11 +445,19 @@ function _M:access(conf) -- first, calculate the coordinates of the request local content_type = kong.request.get_header("Content-Type") or "application/json" - request_table = kong.request.get_body(content_type, nil, conf.max_request_body_size) + request_table = kong_ctx_shared.ai_request_body + if not request_table then + if balancer_phase then + error("Too late to read body", 2) + end + + request_table = kong.request.get_body(content_type, nil, conf.max_request_body_size) + kong_ctx_shared.ai_request_body = request_table + end if not request_table then if not string.find(content_type, "multipart/form-data", nil, true) then - return bad_request("content-type header does not match request body, or bad JSON formatting") + return bail(400, "content-type header does not match request body, or bad JSON formatting") end multipart = true -- this may be a large file upload, so we have to proxy it directly @@ -450,7 +467,7 @@ function _M:access(conf) -- resolve the real plugin config values local conf_m, err = ai_shared.resolve_plugin_conf(kong.request, conf) if err then - return bad_request(err) + return bail(400, err) end -- copy from the user request if present @@ -465,13 +482,13 @@ function _M:access(conf) -- check that the user isn't trying to override the plugin conf model in the request body if request_table and request_table.model and type(request_table.model) == "string" and request_table.model ~= "" then if request_table.model ~= conf_m.model.name then - return bad_request("cannot use own model - must be: " .. conf_m.model.name) + return bail(400, "cannot use own model - must be: " .. conf_m.model.name) end end -- model is stashed in the copied plugin conf, for consistency in transformation functions if not conf_m.model.name then - return bad_request("model parameter not found in request, nor in gateway configuration") + return bail(400, "model parameter not found in request, nor in gateway configuration") end kong_ctx_plugin.llm_model_requested = conf_m.model.name @@ -481,7 +498,7 @@ function _M:access(conf) local compatible, err = llm.is_compatible(request_table, route_type) if not compatible then kong_ctx_shared.skip_response_transformer = true - return bad_request(err) + return bail(400, err) end end @@ -494,7 +511,7 @@ function _M:access(conf) -- this condition will only check if user has tried -- to activate streaming mode within their request if conf_m.response_streaming and conf_m.response_streaming == "deny" then - return bad_request("response streaming is not enabled for this LLM") + return bail(400, "response streaming is not enabled for this LLM") end -- store token cost estimate, on first pass, if the @@ -503,7 +520,7 @@ function _M:access(conf) local prompt_tokens, err = ai_shared.calculate_cost(request_table or {}, {}, 1.8) if err then kong.log.err("unable to estimate request token cost: ", err) - return kong.response.exit(500) + return bail(500) end kong_ctx_plugin.ai_stream_prompt_tokens = prompt_tokens @@ -521,7 +538,7 @@ function _M:access(conf) -- execute pre-request hooks for this driver local ok, err = ai_driver.pre_request(conf_m, request_table) if not ok then - return bad_request(err) + return bail(400, err) end -- transform the body to Kong-format for this provider/model @@ -531,17 +548,17 @@ function _M:access(conf) parsed_request_body, content_type, err = ai_driver.to_format(request_table, conf_m.model, route_type) if err then kong_ctx_shared.skip_response_transformer = true - return bad_request(err) + return bail(400, err) end end -- execute pre-request hooks for "all" drivers before set new body local ok, err = ai_shared.pre_request(conf_m, parsed_request_body) if not ok then - return bad_request(err) + return bail(400, err) end - if route_type ~= "preserve" then + if route_type ~= "preserve" and not balancer_phase then kong.service.request.set_body(parsed_request_body, content_type) end @@ -550,7 +567,7 @@ function _M:access(conf) if identity_interface and identity_interface.error then kong.ctx.shared.skip_response_transformer = true kong.log.err("error authenticating with cloud-provider, ", identity_interface.error) - return kong.response.exit(500, "LLM request failed before proxying") + return bail(500, "LLM request failed before proxying") end -- now re-configure the request for this operation type @@ -559,7 +576,7 @@ function _M:access(conf) if not ok then kong_ctx_shared.skip_response_transformer = true kong.log.err("failed to configure request for AI service: ", err) - return kong.response.exit(500) + return bail(500) end -- lights out, and away we go