Skip to content

Commit

Permalink
feat(llm/proxy): allow balancer retry
Browse files Browse the repository at this point in the history
  • Loading branch information
fffonion committed Jul 25, 2024
1 parent 30f8a04 commit 4929d46
Showing 1 changed file with 36 additions and 19 deletions.
55 changes: 36 additions & 19 deletions kong/llm/proxy/handler.lua
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,15 @@ local EMPTY = {}

local _M = {}

--- Return a 400 response with a JSON body. This function is used to
-- return errors to the client while also logging the error.
local function bad_request(msg)
kong.log.info(msg)
return kong.response.exit(400, { error = { message = msg } })
end
local function bail(code, msg)
if code == 400 and msg then
kong.log.info(msg)
end

if ngx.get_phase() ~= "balancer" then
return kong.response.exit(code, msg and { error = { message = msg } } or nil)
end
end


-- static messages
Expand Down Expand Up @@ -228,6 +230,9 @@ local function handle_streaming_frame(conf)
end

function _M:header_filter(conf)
-- free up the buffered body used in the access phase
kong.ctx.shared.ai_request_body = nil

local kong_ctx_plugin = kong.ctx.plugin
local kong_ctx_shared = kong.ctx.shared

Expand Down Expand Up @@ -377,6 +382,10 @@ function _M:access(conf)
local request_table
local multipart = false

-- TODO: the access phase may be called mulitple times also in the balancer phase
-- Refactor this function a bit so that we don't mess them in the same function
local balancer_phase = ngx.get_phase() == "balancer"

-- we may have received a replacement / decorated request body from another AI plugin
if kong_ctx_shared.replacement_request then
kong.log.debug("replacement request body received from another AI plugin")
Expand All @@ -386,11 +395,19 @@ function _M:access(conf)
-- first, calculate the coordinates of the request
local content_type = kong.request.get_header("Content-Type") or "application/json"

request_table = kong.request.get_body(content_type, nil, conf.max_request_body_size)
request_table = kong_ctx_shared.ai_request_body
if not request_table then
if balancer_phase then
error("Too late to read body", 2)
end

request_table = kong.request.get_body(content_type, nil, conf.max_request_body_size)
kong_ctx_shared.ai_request_body = request_table
end

if not request_table then
if not string.find(content_type, "multipart/form-data", nil, true) then
return bad_request("content-type header does not match request body, or bad JSON formatting")
return bail(400, "content-type header does not match request body, or bad JSON formatting")
end

multipart = true -- this may be a large file upload, so we have to proxy it directly
Expand All @@ -400,7 +417,7 @@ function _M:access(conf)
-- resolve the real plugin config values
local conf_m, err = ai_shared.resolve_plugin_conf(kong.request, conf)
if err then
return bad_request(err)
return bail(400, err)
end

-- copy from the user request if present
Expand All @@ -415,13 +432,13 @@ function _M:access(conf)
-- check that the user isn't trying to override the plugin conf model in the request body
if request_table and request_table.model and type(request_table.model) == "string" and request_table.model ~= "" then
if request_table.model ~= conf_m.model.name then
return bad_request("cannot use own model - must be: " .. conf_m.model.name)
return bail(400, "cannot use own model - must be: " .. conf_m.model.name)
end
end

-- model is stashed in the copied plugin conf, for consistency in transformation functions
if not conf_m.model.name then
return bad_request("model parameter not found in request, nor in gateway configuration")
return bail(400, "model parameter not found in request, nor in gateway configuration")
end

kong_ctx_plugin.llm_model_requested = conf_m.model.name
Expand All @@ -431,7 +448,7 @@ function _M:access(conf)
local compatible, err = llm.is_compatible(request_table, route_type)
if not compatible then
kong_ctx_shared.skip_response_transformer = true
return bad_request(err)
return bail(400, err)
end
end

Expand All @@ -444,7 +461,7 @@ function _M:access(conf)
-- this condition will only check if user has tried
-- to activate streaming mode within their request
if conf_m.response_streaming and conf_m.response_streaming == "deny" then
return bad_request("response streaming is not enabled for this LLM")
return bail(400, "response streaming is not enabled for this LLM")
end

-- store token cost estimate, on first pass, if the
Expand All @@ -453,7 +470,7 @@ function _M:access(conf)
local prompt_tokens, err = ai_shared.calculate_cost(request_table or {}, {}, 1.8)
if err then
kong.log.err("unable to estimate request token cost: ", err)
return kong.response.exit(500)
return bail(500)
end

kong_ctx_plugin.ai_stream_prompt_tokens = prompt_tokens
Expand All @@ -471,7 +488,7 @@ function _M:access(conf)
-- execute pre-request hooks for this driver
local ok, err = ai_driver.pre_request(conf_m, request_table)
if not ok then
return bad_request(err)
return bail(400, err)
end

-- transform the body to Kong-format for this provider/model
Expand All @@ -481,17 +498,17 @@ function _M:access(conf)
parsed_request_body, content_type, err = ai_driver.to_format(request_table, conf_m.model, route_type)
if err then
kong_ctx_shared.skip_response_transformer = true
return bad_request(err)
return bail(400, err)
end
end

-- execute pre-request hooks for "all" drivers before set new body
local ok, err = ai_shared.pre_request(conf_m, parsed_request_body)
if not ok then
return bad_request(err)
return bail(400, err)
end

if route_type ~= "preserve" then
if route_type ~= "preserve" and not balancer_phase then
kong.service.request.set_body(parsed_request_body, content_type)
end

Expand All @@ -509,7 +526,7 @@ function _M:access(conf)
if not ok then
kong_ctx_shared.skip_response_transformer = true
kong.log.err("failed to configure request for AI service: ", err)
return kong.response.exit(500)
return bail(500)
end

-- lights out, and away we go
Expand Down

0 comments on commit 4929d46

Please sign in to comment.