Skip to content

Commit

Permalink
fix(ai-proxy): only send compressed response when client requested
Browse files Browse the repository at this point in the history
  • Loading branch information
fffonion committed Aug 7, 2024
1 parent 94d3ddc commit 0bc0c99
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 10 deletions.
10 changes: 5 additions & 5 deletions kong/llm/drivers/openai.lua
Original file line number Diff line number Diff line change
Expand Up @@ -35,23 +35,23 @@ local transformers_to = {
}

local transformers_from = {
["llm/v1/chat"] = function(response_string, model_info)
["llm/v1/chat"] = function(response_string, _)
local response_object, err = cjson.decode(response_string)
if err then
return nil, "'choices' not in llm/v1/chat response"
return nil, "failed to decode llm/v1/chat response"
end

if response_object.choices then
return response_string, nil
else
return nil, "'choices' not in llm/v1/chat response"
end
end,

["llm/v1/completions"] = function(response_string, model_info)
["llm/v1/completions"] = function(response_string, _)
local response_object, err = cjson.decode(response_string)
if err then
return nil, "'choices' not in llm/v1/completions response"
return nil, "failed to decode llm/v1/completions response"
end

if response_object.choices then
Expand Down
16 changes: 11 additions & 5 deletions kong/llm/proxy/handler.lua
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,11 @@ local _KEYBASTION = setmetatable({}, {
})


local function accept_gzip()
return not not kong.ctx.plugin.accept_gzip
end


-- get the token text from an event frame
local function get_token_text(event_t)
-- get: event_t.choices[1]
Expand All @@ -124,7 +129,6 @@ end
local function handle_streaming_frame(conf, chunk, finished)
-- make a re-usable framebuffer
local framebuffer = buffer.new()
local is_gzip = kong.response.get_header("Content-Encoding") == "gzip"

local ai_driver = require("kong.llm.drivers." .. conf.model.provider)

Expand All @@ -140,8 +144,8 @@ local function handle_streaming_frame(conf, chunk, finished)
-- transform each one into flat format, skipping transformer errors
-- because we have already 200 OK'd the client by now

if (not finished) and (is_gzip) then
chunk = kong_utils.inflate_gzip(ngx.arg[1])
if not finished and kong.service.response.get_header("Content-Encoding") == "gzip" then
chunk = kong_utils.inflate_gzip(chunk)
end

local events = ai_shared.frame_to_events(chunk, conf.model.provider)
Expand All @@ -152,7 +156,7 @@ local function handle_streaming_frame(conf, chunk, finished)
-- and then send the client a readable error in a single chunk
local response = ERROR__NOT_SET

if is_gzip then
if accept_gzip() then
response = kong_utils.deflate_gzip(response)
end

Expand Down Expand Up @@ -234,7 +238,7 @@ local function handle_streaming_frame(conf, chunk, finished)
end

local response_frame = framebuffer:get()
if (not finished) and (is_gzip) then
if not finished and accept_gzip() then
response_frame = kong_utils.deflate_gzip(response_frame)
end

Expand Down Expand Up @@ -372,6 +376,8 @@ end

function _M:access(conf)
local kong_ctx_plugin = kong.ctx.plugin
-- record the request header very early, otherwise kong.serivce.request.set_header will polute it
kong_ctx_plugin.accept_gzip = (kong.request.get_header("Accept-Encoding") or ""):match("%f[%a]gzip%f[%A]")

-- store the route_type in ctx for use in response parsing
local route_type = conf.route_type
Expand Down
5 changes: 5 additions & 0 deletions spec/03-plugins/38-ai-proxy/08-encoding_integration_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
["content-type"] = "application/json",
["accept"] = "application/json",
["x-test-type"] = "200",
["accept-encoding"] = "gzip, identity"
},
body = format_stencils.llm_v1_chat.good.user_request,
})
Expand Down Expand Up @@ -287,6 +288,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
["content-type"] = "application/json",
["accept"] = "application/json",
["x-test-type"] = "200_FAULTY",
["accept-encoding"] = "gzip, identity"
},
body = format_stencils.llm_v1_chat.good.user_request,
})
Expand All @@ -307,6 +309,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
["content-type"] = "application/json",
["accept"] = "application/json",
["x-test-type"] = "401",
["accept-encoding"] = "gzip, identity"
},
body = format_stencils.llm_v1_chat.good.user_request,
})
Expand All @@ -327,6 +330,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
["content-type"] = "application/json",
["accept"] = "application/json",
["x-test-type"] = "500",
["accept-encoding"] = "gzip, identity"
},
body = format_stencils.llm_v1_chat.good.user_request,
})
Expand All @@ -347,6 +351,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
["content-type"] = "application/json",
["accept"] = "application/json",
["x-test-type"] = "500_FAULTY",
["accept-encoding"] = "gzip, identity"
},
body = format_stencils.llm_v1_chat.good.user_request,
})
Expand Down

0 comments on commit 0bc0c99

Please sign in to comment.