From 0bc0c99f2be55b358f2fccc7b8a80b3b63d9af1d Mon Sep 17 00:00:00 2001 From: Wangchong Zhou Date: Wed, 17 Jul 2024 22:11:03 +0800 Subject: [PATCH] fix(ai-proxy): only send compressed response when client requested --- kong/llm/drivers/openai.lua | 10 +++++----- kong/llm/proxy/handler.lua | 16 +++++++++++----- .../38-ai-proxy/08-encoding_integration_spec.lua | 5 +++++ 3 files changed, 21 insertions(+), 10 deletions(-) diff --git a/kong/llm/drivers/openai.lua b/kong/llm/drivers/openai.lua index 1c592e5ef60b..ca1e37d91831 100644 --- a/kong/llm/drivers/openai.lua +++ b/kong/llm/drivers/openai.lua @@ -35,12 +35,12 @@ local transformers_to = { } local transformers_from = { - ["llm/v1/chat"] = function(response_string, model_info) + ["llm/v1/chat"] = function(response_string, _) local response_object, err = cjson.decode(response_string) if err then - return nil, "'choices' not in llm/v1/chat response" + return nil, "failed to decode llm/v1/chat response" end - + if response_object.choices then return response_string, nil else @@ -48,10 +48,10 @@ local transformers_from = { end end, - ["llm/v1/completions"] = function(response_string, model_info) + ["llm/v1/completions"] = function(response_string, _) local response_object, err = cjson.decode(response_string) if err then - return nil, "'choices' not in llm/v1/completions response" + return nil, "failed to decode llm/v1/completions response" end if response_object.choices then diff --git a/kong/llm/proxy/handler.lua b/kong/llm/proxy/handler.lua index 82769b625b05..d6c7fd1ec6fc 100644 --- a/kong/llm/proxy/handler.lua +++ b/kong/llm/proxy/handler.lua @@ -109,6 +109,11 @@ local _KEYBASTION = setmetatable({}, { }) +local function accept_gzip() + return not not kong.ctx.plugin.accept_gzip +end + + -- get the token text from an event frame local function get_token_text(event_t) -- get: event_t.choices[1] @@ -124,7 +129,6 @@ end local function handle_streaming_frame(conf, chunk, finished) -- make a re-usable framebuffer local framebuffer = buffer.new() - local is_gzip = kong.response.get_header("Content-Encoding") == "gzip" local ai_driver = require("kong.llm.drivers." .. conf.model.provider) @@ -140,8 +144,8 @@ local function handle_streaming_frame(conf, chunk, finished) -- transform each one into flat format, skipping transformer errors -- because we have already 200 OK'd the client by now - if (not finished) and (is_gzip) then - chunk = kong_utils.inflate_gzip(ngx.arg[1]) + if not finished and kong.service.response.get_header("Content-Encoding") == "gzip" then + chunk = kong_utils.inflate_gzip(chunk) end local events = ai_shared.frame_to_events(chunk, conf.model.provider) @@ -152,7 +156,7 @@ local function handle_streaming_frame(conf, chunk, finished) -- and then send the client a readable error in a single chunk local response = ERROR__NOT_SET - if is_gzip then + if accept_gzip() then response = kong_utils.deflate_gzip(response) end @@ -234,7 +238,7 @@ local function handle_streaming_frame(conf, chunk, finished) end local response_frame = framebuffer:get() - if (not finished) and (is_gzip) then + if not finished and accept_gzip() then response_frame = kong_utils.deflate_gzip(response_frame) end @@ -372,6 +376,8 @@ end function _M:access(conf) local kong_ctx_plugin = kong.ctx.plugin + -- record the request header very early, otherwise kong.serivce.request.set_header will polute it + kong_ctx_plugin.accept_gzip = (kong.request.get_header("Accept-Encoding") or ""):match("%f[%a]gzip%f[%A]") -- store the route_type in ctx for use in response parsing local route_type = conf.route_type diff --git a/spec/03-plugins/38-ai-proxy/08-encoding_integration_spec.lua b/spec/03-plugins/38-ai-proxy/08-encoding_integration_spec.lua index 049920e460bb..0cc63ba41ba4 100644 --- a/spec/03-plugins/38-ai-proxy/08-encoding_integration_spec.lua +++ b/spec/03-plugins/38-ai-proxy/08-encoding_integration_spec.lua @@ -257,6 +257,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then ["content-type"] = "application/json", ["accept"] = "application/json", ["x-test-type"] = "200", + ["accept-encoding"] = "gzip, identity" }, body = format_stencils.llm_v1_chat.good.user_request, }) @@ -287,6 +288,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then ["content-type"] = "application/json", ["accept"] = "application/json", ["x-test-type"] = "200_FAULTY", + ["accept-encoding"] = "gzip, identity" }, body = format_stencils.llm_v1_chat.good.user_request, }) @@ -307,6 +309,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then ["content-type"] = "application/json", ["accept"] = "application/json", ["x-test-type"] = "401", + ["accept-encoding"] = "gzip, identity" }, body = format_stencils.llm_v1_chat.good.user_request, }) @@ -327,6 +330,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then ["content-type"] = "application/json", ["accept"] = "application/json", ["x-test-type"] = "500", + ["accept-encoding"] = "gzip, identity" }, body = format_stencils.llm_v1_chat.good.user_request, }) @@ -347,6 +351,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then ["content-type"] = "application/json", ["accept"] = "application/json", ["x-test-type"] = "500_FAULTY", + ["accept-encoding"] = "gzip, identity" }, body = format_stencils.llm_v1_chat.good.user_request, })