From 488b3d8f779345052140771edcf6cee0ae634563 Mon Sep 17 00:00:00 2001 From: Wangchong Zhou Date: Tue, 4 Jun 2024 15:39:42 +0800 Subject: [PATCH 1/5] refactor(plugins): move shared ctx usage of ai plugins to use a proper API To make typo more obvious to be catched --- kong-3.8.0-0.rockspec | 2 + kong/llm/drivers/shared.lua | 23 +++-- kong/llm/proxy/handler.lua | 46 ++++----- kong/llm/state.lua | 95 +++++++++++++++++++ kong/plugins/ai-prompt-decorator/handler.lua | 3 +- kong/plugins/ai-prompt-guard/handler.lua | 3 +- kong/plugins/ai-prompt-template/handler.lua | 3 +- kong/plugins/ai-proxy/handler.lua | 1 - .../ai-request-transformer/handler.lua | 3 +- .../ai-response-transformer/handler.lua | 5 +- kong/reports.lua | 12 ++- 11 files changed, 153 insertions(+), 43 deletions(-) create mode 100644 kong/llm/state.lua diff --git a/kong-3.8.0-0.rockspec b/kong-3.8.0-0.rockspec index bab08660e7d..3dc070e8684 100644 --- a/kong-3.8.0-0.rockspec +++ b/kong-3.8.0-0.rockspec @@ -614,6 +614,8 @@ build = { ["kong.llm.drivers.anthropic"] = "kong/llm/drivers/anthropic.lua", ["kong.llm.drivers.mistral"] = "kong/llm/drivers/mistral.lua", ["kong.llm.drivers.llama2"] = "kong/llm/drivers/llama2.lua", + ["kong.llm.state"] = "kong/llm/state.lua", + ["kong.llm.drivers.gemini"] = "kong/llm/drivers/gemini.lua", ["kong.llm.drivers.bedrock"] = "kong/llm/drivers/bedrock.lua", diff --git a/kong/llm/drivers/shared.lua b/kong/llm/drivers/shared.lua index 8c0c88e6573..92bc72ea641 100644 --- a/kong/llm/drivers/shared.lua +++ b/kong/llm/drivers/shared.lua @@ -1,12 +1,21 @@ local _M = {} -- imports +<<<<<<< HEAD local cjson = require("cjson.safe") local http = require("resty.http") local fmt = string.format local os = os local parse_url = require("socket.url").parse local aws_stream = require("kong.tools.aws_stream") +======= +local cjson = require("cjson.safe") +local http = require("resty.http") +local fmt = string.format +local os = os +local parse_url = require("socket.url").parse +local llm_state = require("kong.llm.state") +>>>>>>> d9053432f9 (refactor(plugins): move shared ctx usage of ai plugins to use a proper API) -- -- static @@ -341,7 +350,7 @@ function _M.frame_to_events(frame, provider) end -- if end end - + return events end @@ -500,7 +509,7 @@ function _M.resolve_plugin_conf(kong_request, conf) if #splitted ~= 2 then return nil, "cannot parse expression for field '" .. v .. "'" end - + -- find the request parameter, with the configured name prop_m, err = _M.conf_from_request(kong_request, splitted[1], splitted[2]) if err then @@ -524,7 +533,7 @@ function _M.pre_request(conf, request_table) local auth_param_name = conf.auth and conf.auth.param_name local auth_param_value = conf.auth and conf.auth.param_value local auth_param_location = conf.auth and conf.auth.param_location - + if auth_param_name and auth_param_value and auth_param_location == "body" and request_table then request_table[auth_param_name] = auth_param_value end @@ -547,7 +556,7 @@ function _M.pre_request(conf, request_table) kong.log.warn("failed calculating cost for prompt tokens: ", err) prompt_tokens = 0 end - kong.ctx.shared.ai_prompt_tokens = (kong.ctx.shared.ai_prompt_tokens or 0) + prompt_tokens + llm_state.increase_prompt_tokens_count(prompt_tokens) end local start_time_key = "ai_request_start_time_" .. plugin_name @@ -586,7 +595,7 @@ function _M.post_request(conf, response_object) end -- check if we already have analytics in this context - local request_analytics = kong.ctx.shared.analytics + local request_analytics = llm_state.get_request_analytics() -- create a new structure if not if not request_analytics then @@ -657,7 +666,7 @@ function _M.post_request(conf, response_object) [log_entry_keys.RESPONSE_BODY] = body_string, } request_analytics[plugin_name] = request_analytics_plugin - kong.ctx.shared.analytics = request_analytics + llm_state.set_request_analytics(request_analytics) if conf.logging and conf.logging.log_statistics then -- Log meta data @@ -679,7 +688,7 @@ function _M.post_request(conf, response_object) kong.log.warn("failed calculating cost for response tokens: ", err) response_tokens = 0 end - kong.ctx.shared.ai_response_tokens = (kong.ctx.shared.ai_response_tokens or 0) + response_tokens + llm_state.increase_response_tokens_count(response_tokens) return nil end diff --git a/kong/llm/proxy/handler.lua b/kong/llm/proxy/handler.lua index de028bd7ee4..21013518f87 100644 --- a/kong/llm/proxy/handler.lua +++ b/kong/llm/proxy/handler.lua @@ -7,6 +7,7 @@ local ai_shared = require("kong.llm.drivers.shared") local llm = require("kong.llm") +local llm_state = require("kong.llm.state") local cjson = require("cjson.safe") local kong_utils = require("kong.tools.gzip") local buffer = require "string.buffer" @@ -265,9 +266,8 @@ function _M:header_filter(conf) kong.ctx.shared.ai_request_body = nil local kong_ctx_plugin = kong.ctx.plugin - local kong_ctx_shared = kong.ctx.shared - if kong_ctx_shared.skip_response_transformer then + if llm_state.is_response_transformer_skipped() then return end @@ -282,7 +282,7 @@ function _M:header_filter(conf) end -- we use openai's streaming mode (SSE) - if kong_ctx_shared.ai_proxy_streaming_mode then + if llm_state.is_streaming_mode() then -- we are going to send plaintext event-stream frames for ALL models kong.response.set_header("Content-Type", "text/event-stream") return @@ -299,7 +299,7 @@ function _M:header_filter(conf) -- if this is a 'streaming' request, we can't know the final -- result of the response body, so we just proceed to body_filter -- to translate each SSE event frame - if not kong_ctx_shared.ai_proxy_streaming_mode then + if not llm_state.is_streaming_mode() then local is_gzip = kong.response.get_header("Content-Encoding") == "gzip" if is_gzip then response_body = kong_utils.inflate_gzip(response_body) @@ -329,22 +329,18 @@ end function _M:body_filter(conf) local kong_ctx_plugin = kong.ctx.plugin - local kong_ctx_shared = kong.ctx.shared -- if body_filter is called twice, then return - if kong_ctx_plugin.body_called and not kong_ctx_shared.ai_proxy_streaming_mode then + if kong_ctx_plugin.body_called and not llm_state.is_streaming_mode() then return end local route_type = conf.route_type - if kong_ctx_shared.skip_response_transformer and (route_type ~= "preserve") then - local response_body + if llm_state.is_response_transformer_skipped() and (route_type ~= "preserve") then + local response_body = llm_state.get_parsed_response() - if kong_ctx_shared.parsed_response then - response_body = kong_ctx_shared.parsed_response - - elseif kong.response.get_status() == 200 then + if not response_body and kong.response.get_status() == 200 then response_body = kong.service.response.get_raw_body() if not response_body then kong.log.warn("issue when retrieve the response body for analytics in the body filter phase.", @@ -355,6 +351,8 @@ function _M:body_filter(conf) response_body = kong_utils.inflate_gzip(response_body) end end + else + kong.response.exit(500, "no response body found") end local ai_driver = require("kong.llm.drivers." .. conf.model.provider) @@ -368,13 +366,13 @@ function _M:body_filter(conf) end end - if not kong_ctx_shared.skip_response_transformer then + if not llm_state.is_response_transformer_skipped() then if (kong.response.get_status() ~= 200) and (not kong_ctx_plugin.ai_parser_error) then return end if route_type ~= "preserve" then - if kong_ctx_shared.ai_proxy_streaming_mode then + if llm_state.is_streaming_mode() then handle_streaming_frame(conf) else -- all errors MUST be checked and returned in header_filter @@ -406,14 +404,12 @@ end function _M:access(conf) local kong_ctx_plugin = kong.ctx.plugin - local kong_ctx_shared = kong.ctx.shared -- store the route_type in ctx for use in response parsing local route_type = conf.route_type kong_ctx_plugin.operation = route_type - local request_table local multipart = false -- TODO: the access phase may be called mulitple times also in the balancer phase @@ -421,22 +417,22 @@ function _M:access(conf) local balancer_phase = ngx.get_phase() == "balancer" -- we may have received a replacement / decorated request body from another AI plugin - if kong_ctx_shared.replacement_request then + local request_table = llm_state.get_replacement_response() -- not used + if request_table then kong.log.debug("replacement request body received from another AI plugin") - request_table = kong_ctx_shared.replacement_request else -- first, calculate the coordinates of the request local content_type = kong.request.get_header("Content-Type") or "application/json" - request_table = kong_ctx_shared.ai_request_body + request_table = llm_state.get_request_body_table() if not request_table then if balancer_phase then error("Too late to read body", 2) end request_table = kong.request.get_body(content_type, nil, conf.max_request_body_size) - kong_ctx_shared.ai_request_body = request_table + llm_state.set_request_body_table(request_table) end if not request_table then @@ -481,7 +477,7 @@ function _M:access(conf) if not multipart then local compatible, err = llm.is_compatible(request_table, route_type) if not compatible then - kong_ctx_shared.skip_response_transformer = true + llm_state.set_response_transformer_skipped() return bail(400, err) end end @@ -511,7 +507,7 @@ function _M:access(conf) end -- specific actions need to skip later for this to work - kong_ctx_shared.ai_proxy_streaming_mode = true + llm_state.set_streaming_mode() else kong.service.request.enable_buffering() @@ -531,7 +527,7 @@ function _M:access(conf) -- transform the body to Kong-format for this provider/model parsed_request_body, content_type, err = ai_driver.to_format(request_table, conf_m.model, route_type) if err then - kong_ctx_shared.skip_response_transformer = true + llm_state.set_response_transformer_skipped() return bail(400, err) end end @@ -549,7 +545,7 @@ function _M:access(conf) -- get the provider's cached identity interface - nil may come back, which is fine local identity_interface = _KEYBASTION[conf] if identity_interface and identity_interface.error then - kong.ctx.shared.skip_response_transformer = true + llm_state.set_response_transformer_skipped() kong.log.err("error authenticating with cloud-provider, ", identity_interface.error) return bail(500, "LLM request failed before proxying") end @@ -558,7 +554,7 @@ function _M:access(conf) local ok, err = ai_driver.configure_request(conf_m, identity_interface and identity_interface.interface) if not ok then - kong_ctx_shared.skip_response_transformer = true + llm_state.set_response_transformer_skipped() kong.log.err("failed to configure request for AI service: ", err) return bail(500) end diff --git a/kong/llm/state.lua b/kong/llm/state.lua new file mode 100644 index 00000000000..e45a25f16ef --- /dev/null +++ b/kong/llm/state.lua @@ -0,0 +1,95 @@ +local _M = {} + +function _M.disable_ai_proxy_response_transform() + kong.ctx.shared.llm_disable_ai_proxy_response_transform = true +end + +function _M.should_disable_ai_proxy_response_transform() + return kong.ctx.shared.llm_disable_ai_proxy_response_transform == true +end + +function _M.set_prompt_decorated() + kong.ctx.shared.llm_prompt_decorated = true +end + +function _M.is_prompt_decorated() + return kong.ctx.shared.llm_prompt_decorated == true +end + +function _M.set_prompt_guarded() + kong.ctx.shared.llm_prompt_guarded = true +end + +function _M.is_prompt_guarded() + return kong.ctx.shared.llm_prompt_guarded == true +end + +function _M.set_prompt_templated() + kong.ctx.shared.llm_prompt_templated = true +end + +function _M.is_prompt_templated() + return kong.ctx.shared.llm_prompt_templated == true +end + +function _M.set_streaming_mode() + kong.ctx.shared.llm_streaming_mode = true +end + +function _M.is_streaming_mode() + return kong.ctx.shared.llm_streaming_mode == true +end + +function _M.set_parsed_response(response) + kong.ctx.shared.llm_parsed_response = response +end + +function _M.get_parsed_response() + return kong.ctx.shared.llm_parsed_response +end + +function _M.set_request_body_table(body_t) + kong.ctx.shared.llm_request_body_t = body_t +end + +function _M.get_request_body_table() + return kong.ctx.shared.llm_request_body_t +end + +function _M.set_replacement_response(response) + kong.ctx.shared.llm_replacement_response = response +end + +function _M.get_replacement_response() + return kong.ctx.shared.llm_replacement_response +end + +function _M.set_request_analytics(tbl) + kong.ctx.shared.llm_request_analytics = tbl +end + +function _M.get_request_analytics() + return kong.ctx.shared.llm_request_analytics +end + +function _M.increase_prompt_tokens_count(by) + local count = (kong.ctx.shared.llm_prompt_tokens_count or 0) + by + kong.ctx.shared.llm_prompt_tokens_count = count + return count +end + +function _M.get_prompt_tokens_count() + return kong.ctx.shared.llm_prompt_tokens_count +end + +function _M.increase_response_tokens_count(by) + local count = (kong.ctx.shared.llm_response_tokens_count or 0) + by + kong.ctx.shared.llm_response_tokens_count = count + return count +end + +function _M.get_response_tokens_count() + return kong.ctx.shared.llm_response_tokens_count +end + +return _M \ No newline at end of file diff --git a/kong/plugins/ai-prompt-decorator/handler.lua b/kong/plugins/ai-prompt-decorator/handler.lua index 23a18ea7399..4600c1d35db 100644 --- a/kong/plugins/ai-prompt-decorator/handler.lua +++ b/kong/plugins/ai-prompt-decorator/handler.lua @@ -1,4 +1,5 @@ local new_tab = require("table.new") +local llm_state = require("kong.llm.state") local EMPTY = {} @@ -52,7 +53,7 @@ end function plugin:access(conf) kong.service.request.enable_buffering() - kong.ctx.shared.ai_prompt_decorated = true -- future use + llm_state.set_prompt_decorated() -- future use -- if plugin ordering was altered, receive the "decorated" request local request = kong.request.get_body("application/json", nil, conf.max_request_body_size) diff --git a/kong/plugins/ai-prompt-guard/handler.lua b/kong/plugins/ai-prompt-guard/handler.lua index b2aab78dbc7..5f4f8f4b369 100644 --- a/kong/plugins/ai-prompt-guard/handler.lua +++ b/kong/plugins/ai-prompt-guard/handler.lua @@ -1,4 +1,5 @@ local buffer = require("string.buffer") +local llm_state = require("kong.llm.state") local ngx_re_find = ngx.re.find local EMPTY = {} @@ -116,7 +117,7 @@ end function plugin:access(conf) kong.service.request.enable_buffering() - kong.ctx.shared.ai_prompt_guarded = true -- future use + llm_state.set_prompt_guarded() -- future use -- if plugin ordering was altered, receive the "decorated" request local request = kong.request.get_body("application/json", nil, conf.max_request_body_size) diff --git a/kong/plugins/ai-prompt-template/handler.lua b/kong/plugins/ai-prompt-template/handler.lua index 2be9137c9fe..11674717009 100644 --- a/kong/plugins/ai-prompt-template/handler.lua +++ b/kong/plugins/ai-prompt-template/handler.lua @@ -1,5 +1,6 @@ local templater = require("kong.plugins.ai-prompt-template.templater") +local llm_state = require("kong.llm.state") local ipairs = ipairs local type = type @@ -61,7 +62,7 @@ end function AIPromptTemplateHandler:access(conf) kong.service.request.enable_buffering() - kong.ctx.shared.ai_prompt_templated = true + llm_state.set_prompt_templated() if conf.log_original_request then kong.log.set_serialize_value(LOG_ENTRY_KEYS.REQUEST_BODY, kong.request.get_raw_body(conf.max_request_body_size)) diff --git a/kong/plugins/ai-proxy/handler.lua b/kong/plugins/ai-proxy/handler.lua index f2fc8df8985..558f4f24198 100644 --- a/kong/plugins/ai-proxy/handler.lua +++ b/kong/plugins/ai-proxy/handler.lua @@ -1,4 +1,3 @@ - local kong_meta = require("kong.meta") local deep_copy = require "kong.tools.table".deep_copy diff --git a/kong/plugins/ai-request-transformer/handler.lua b/kong/plugins/ai-request-transformer/handler.lua index 222ed079aa8..c7316cf35ed 100644 --- a/kong/plugins/ai-request-transformer/handler.lua +++ b/kong/plugins/ai-request-transformer/handler.lua @@ -4,6 +4,7 @@ local _M = {} local kong_meta = require "kong.meta" local fmt = string.format local llm = require("kong.llm") +local llm_state = require("kong.llm.state") -- _M.PRIORITY = 777 @@ -40,7 +41,7 @@ end function _M:access(conf) kong.service.request.enable_buffering() - kong.ctx.shared.skip_response_transformer = true + llm_state.set_response_transformer_skipped() -- first find the configured LLM interface and driver local http_opts = create_http_opts(conf) diff --git a/kong/plugins/ai-response-transformer/handler.lua b/kong/plugins/ai-response-transformer/handler.lua index c1e154dbd06..a350ee74e6f 100644 --- a/kong/plugins/ai-response-transformer/handler.lua +++ b/kong/plugins/ai-response-transformer/handler.lua @@ -6,6 +6,7 @@ local http = require("resty.http") local fmt = string.format local kong_utils = require("kong.tools.gzip") local llm = require("kong.llm") +local llm_state = require("kong.llm.state") -- _M.PRIORITY = 769 @@ -99,7 +100,7 @@ end function _M:access(conf) kong.service.request.enable_buffering() - kong.ctx.shared.skip_response_transformer = true + llm_state.set_response_transformer_skipped() -- first find the configured LLM interface and driver local http_opts = create_http_opts(conf) @@ -126,7 +127,7 @@ function _M:access(conf) res_body = kong_utils.inflate_gzip(res_body) end - kong.ctx.shared.parsed_response = res_body + llm_state.set_parsed_response(res_body) -- future use -- if asked, introspect the request before proxying kong.log.debug("introspecting response with LLM") diff --git a/kong/reports.lua b/kong/reports.lua index 2ce5777b29f..31f6c40ec42 100644 --- a/kong/reports.lua +++ b/kong/reports.lua @@ -6,6 +6,8 @@ local counter = require "resty.counter" local knode = (kong and kong.node) and kong.node or require "kong.pdk.node".new() +local llm_state = require "kong.llm.state" + local kong_dict = ngx.shared.kong local ngx = ngx @@ -525,13 +527,15 @@ return { incr_counter(WASM_REQUEST_COUNT_KEY) end - if kong.ctx.shared.ai_prompt_tokens then + local llm_prompt_tokens_count = llm_state.get_prompt_tokens_count() + if llm_prompt_tokens_count then incr_counter(AI_REQUEST_COUNT_KEY) - incr_counter(AI_PROMPT_TOKENS_COUNT_KEY, kong.ctx.shared.ai_prompt_tokens) + incr_counter(AI_PROMPT_TOKENS_COUNT_KEY, llm_prompt_tokens_count) end - if kong.ctx.shared.ai_response_tokens then - incr_counter(AI_RESPONSE_TOKENS_COUNT_KEY, kong.ctx.shared.ai_response_tokens) + local llm_response_tokens_count = llm_state.get_response_tokens_count() + if llm_response_tokens_count then + incr_counter(AI_RESPONSE_TOKENS_COUNT_KEY, llm_response_tokens_count) end local suffix = get_current_suffix(ctx) From e70c60bee335d1787d662c586dbee13c5da9e66d Mon Sep 17 00:00:00 2001 From: Wangchong Zhou Date: Tue, 4 Jun 2024 16:54:34 +0800 Subject: [PATCH 2/5] refactor(ai-proxy): cleanup body_filter and header_filter --- kong/llm/drivers/shared.lua | 10 +- kong/llm/proxy/handler.lua | 202 ++++++++---------- kong/llm/state.lua | 2 + .../ai-request-transformer/handler.lua | 2 +- .../ai-response-transformer/handler.lua | 2 +- 5 files changed, 89 insertions(+), 129 deletions(-) diff --git a/kong/llm/drivers/shared.lua b/kong/llm/drivers/shared.lua index 92bc72ea641..8004b9384df 100644 --- a/kong/llm/drivers/shared.lua +++ b/kong/llm/drivers/shared.lua @@ -1,21 +1,13 @@ local _M = {} -- imports -<<<<<<< HEAD -local cjson = require("cjson.safe") -local http = require("resty.http") -local fmt = string.format -local os = os -local parse_url = require("socket.url").parse -local aws_stream = require("kong.tools.aws_stream") -======= local cjson = require("cjson.safe") local http = require("resty.http") local fmt = string.format local os = os local parse_url = require("socket.url").parse local llm_state = require("kong.llm.state") ->>>>>>> d9053432f9 (refactor(plugins): move shared ctx usage of ai plugins to use a proper API) +local aws_stream = require("kong.tools.aws_stream") -- -- static diff --git a/kong/llm/proxy/handler.lua b/kong/llm/proxy/handler.lua index 21013518f87..82769b625b0 100644 --- a/kong/llm/proxy/handler.lua +++ b/kong/llm/proxy/handler.lua @@ -121,8 +121,7 @@ local function get_token_text(event_t) return (type(token_text) == "string" and token_text) or "" end - -local function handle_streaming_frame(conf) +local function handle_streaming_frame(conf, chunk, finished) -- make a re-usable framebuffer local framebuffer = buffer.new() local is_gzip = kong.response.get_header("Content-Encoding") == "gzip" @@ -136,9 +135,6 @@ local function handle_streaming_frame(conf) kong_ctx_plugin.ai_stream_log_buffer = buffer.new() end - -- now handle each chunk/frame - local chunk = ngx.arg[1] - local finished = ngx.arg[2] if type(chunk) == "string" and chunk ~= "" then -- transform each one into flat format, skipping transformer errors @@ -261,144 +257,116 @@ local function handle_streaming_frame(conf) end end -function _M:header_filter(conf) - -- free up the buffered body used in the access phase - kong.ctx.shared.ai_request_body = nil - - local kong_ctx_plugin = kong.ctx.plugin +local function transform_body(conf) + local route_type = conf.route_type + local ai_driver = require("kong.llm.drivers." .. conf.model.provider) - if llm_state.is_response_transformer_skipped() then - return - end + -- Note: below even if we are told not to do response transform, we still need to do + -- get the body for analytics - -- clear shared restricted headers - for _, v in ipairs(ai_shared.clear_response_headers.shared) do - kong.response.clear_header(v) - end + -- try parsed response from other plugin first + local response_body = llm_state.get_parsed_response() + -- read from upstream if it's not been parsed/transformed by other plugins + if not response_body then + response_body = kong.service.response.get_raw_body() - -- only act on 200 in first release - pass the unmodifed response all the way through if any failure - if kong.response.get_status() ~= 200 then - return + if response_body and kong.service.response.get_header("Content-Encoding") == "gzip" then + response_body = kong_utils.inflate_gzip(response_body) + end end - -- we use openai's streaming mode (SSE) - if llm_state.is_streaming_mode() then - -- we are going to send plaintext event-stream frames for ALL models - kong.response.set_header("Content-Type", "text/event-stream") - return - end + local err - local response_body = kong.service.response.get_raw_body() if not response_body then - return - end + err = "no response body found when transforming response" - local ai_driver = require("kong.llm.drivers." .. conf.model.provider) - local route_type = conf.route_type + elseif route_type ~= "preserve" then + response_body, err = ai_driver.from_format(response_body, conf.model, route_type) - -- if this is a 'streaming' request, we can't know the final - -- result of the response body, so we just proceed to body_filter - -- to translate each SSE event frame - if not llm_state.is_streaming_mode() then - local is_gzip = kong.response.get_header("Content-Encoding") == "gzip" - if is_gzip then - response_body = kong_utils.inflate_gzip(response_body) + if err then + kong.log.err("issue when transforming the response body for analytics: ", err) end + end - if route_type == "preserve" then - kong_ctx_plugin.parsed_response = response_body - else - local new_response_string, err = ai_driver.from_format(response_body, conf.model, route_type) - if err then - kong_ctx_plugin.ai_parser_error = true + if err then + ngx.status = 500 + response_body = cjson.encode({ error = { message = err }}) - ngx.status = 500 - kong_ctx_plugin.parsed_response = cjson.encode({ error = { message = err } }) + else + ai_shared.post_request(conf, response_body) + end - elseif new_response_string then - -- preserve the same response content type; assume the from_format function - -- has returned the body in the appropriate response output format - kong_ctx_plugin.parsed_response = new_response_string - end - end + if accept_gzip() then + response_body = kong_utils.deflate_gzip(response_body) end - ai_driver.post_request(conf) + kong.ctx.plugin.buffered_response_body = response_body end +function _M:header_filter(conf) + -- free up the buffered body used in the access phase + llm_state.set_request_body_table(nil) -function _M:body_filter(conf) - local kong_ctx_plugin = kong.ctx.plugin + local ai_driver = require("kong.llm.drivers." .. conf.model.provider) + ai_driver.post_request(conf) - -- if body_filter is called twice, then return - if kong_ctx_plugin.body_called and not llm_state.is_streaming_mode() then + if llm_state.should_disable_ai_proxy_response_transform() then return end - local route_type = conf.route_type + -- only act on 200 in first release - pass the unmodifed response all the way through if any failure + if kong.response.get_status() ~= 200 then + return + end - if llm_state.is_response_transformer_skipped() and (route_type ~= "preserve") then - local response_body = llm_state.get_parsed_response() - - if not response_body and kong.response.get_status() == 200 then - response_body = kong.service.response.get_raw_body() - if not response_body then - kong.log.warn("issue when retrieve the response body for analytics in the body filter phase.", - " Please check AI request transformer plugin response.") - else - local is_gzip = kong.response.get_header("Content-Encoding") == "gzip" - if is_gzip then - response_body = kong_utils.inflate_gzip(response_body) - end - end - else - kong.response.exit(500, "no response body found") - end + -- if not streaming, prepare the response body buffer + -- this must be called before sending any response headers so that + -- we can modify status code if needed + if not llm_state.is_streaming_mode() then + transform_body(conf) + end - local ai_driver = require("kong.llm.drivers." .. conf.model.provider) - local new_response_string, err = ai_driver.from_format(response_body, conf.model, route_type) + -- clear shared restricted headers + for _, v in ipairs(ai_shared.clear_response_headers.shared) do + kong.response.clear_header(v) + end - if err then - kong.log.warn("issue when transforming the response body for analytics in the body filter phase, ", err) + -- we use openai's streaming mode (SSE) + if llm_state.is_streaming_mode() then + -- we are going to send plaintext event-stream frames for ALL models + kong.response.set_header("Content-Type", "text/event-stream") + end - elseif new_response_string then - ai_shared.post_request(conf, new_response_string) - end + if accept_gzip() then + kong.response.set_header("Content-Encoding", "gzip") + else + kong.response.clear_header("Content-Encoding") end +end - if not llm_state.is_response_transformer_skipped() then - if (kong.response.get_status() ~= 200) and (not kong_ctx_plugin.ai_parser_error) then - return - end - if route_type ~= "preserve" then - if llm_state.is_streaming_mode() then - handle_streaming_frame(conf) - else - -- all errors MUST be checked and returned in header_filter - -- we should receive a replacement response body from the same thread - local original_request = kong_ctx_plugin.parsed_response - local deflated_request = original_request - - if deflated_request then - local is_gzip = kong.response.get_header("Content-Encoding") == "gzip" - if is_gzip then - deflated_request = kong_utils.deflate_gzip(deflated_request) - end +-- body filter is only used for streaming mode; for non-streaming mode, everything +-- is already done in header_filter. This is because it would be too late to +-- send the status code if we are modifying non-streaming body in body_filter +function _M:body_filter(conf) + if kong.service.response.get_status() ~= 200 then + return + end - kong.response.set_raw_body(deflated_request) - end + -- emit the full body if not streaming + if not llm_state.is_streaming_mode() then + ngx.arg[1] = kong.ctx.plugin.buffered_response_body + ngx.arg[2] = true - -- call with replacement body, or original body if nothing changed - local _, err = ai_shared.post_request(conf, original_request) - if err then - kong.log.warn("analytics phase failed for request, ", err) - end - end - end + kong.ctx.plugin.buffered_response_body = nil + return end - kong_ctx_plugin.body_called = true + if not llm_state.should_disable_ai_proxy_response_transform() and + conf.route_type ~= "preserve" then + + handle_streaming_frame(conf, ngx.arg[1], ngx.arg[2]) + end end @@ -474,12 +442,10 @@ function _M:access(conf) kong_ctx_plugin.llm_model_requested = conf_m.model.name -- check the incoming format is the same as the configured LLM format - if not multipart then - local compatible, err = llm.is_compatible(request_table, route_type) - if not compatible then - llm_state.set_response_transformer_skipped() - return bail(400, err) - end + local compatible, err = llm.is_compatible(request_table, route_type) + if not multipart and not compatible then + llm_state.disable_ai_proxy_response_transform() + return bail(400, err) end -- check if the user has asked for a stream, and/or if @@ -527,7 +493,7 @@ function _M:access(conf) -- transform the body to Kong-format for this provider/model parsed_request_body, content_type, err = ai_driver.to_format(request_table, conf_m.model, route_type) if err then - llm_state.set_response_transformer_skipped() + llm_state.disable_ai_proxy_response_transform() return bail(400, err) end end @@ -554,7 +520,7 @@ function _M:access(conf) local ok, err = ai_driver.configure_request(conf_m, identity_interface and identity_interface.interface) if not ok then - llm_state.set_response_transformer_skipped() + llm_state.disable_ai_proxy_response_transform() kong.log.err("failed to configure request for AI service: ", err) return bail(500) end diff --git a/kong/llm/state.lua b/kong/llm/state.lua index e45a25f16ef..fa2c29edf8c 100644 --- a/kong/llm/state.lua +++ b/kong/llm/state.lua @@ -1,5 +1,7 @@ local _M = {} +-- Set disable_ai_proxy_response_transform if response is just a error message or has been generated +-- by plugin and should skip further transformation function _M.disable_ai_proxy_response_transform() kong.ctx.shared.llm_disable_ai_proxy_response_transform = true end diff --git a/kong/plugins/ai-request-transformer/handler.lua b/kong/plugins/ai-request-transformer/handler.lua index c7316cf35ed..1bad3a92db3 100644 --- a/kong/plugins/ai-request-transformer/handler.lua +++ b/kong/plugins/ai-request-transformer/handler.lua @@ -41,7 +41,7 @@ end function _M:access(conf) kong.service.request.enable_buffering() - llm_state.set_response_transformer_skipped() + llm_state.should_disable_ai_proxy_response_transform() -- first find the configured LLM interface and driver local http_opts = create_http_opts(conf) diff --git a/kong/plugins/ai-response-transformer/handler.lua b/kong/plugins/ai-response-transformer/handler.lua index a350ee74e6f..815b64f351f 100644 --- a/kong/plugins/ai-response-transformer/handler.lua +++ b/kong/plugins/ai-response-transformer/handler.lua @@ -100,7 +100,7 @@ end function _M:access(conf) kong.service.request.enable_buffering() - llm_state.set_response_transformer_skipped() + llm_state.disable_ai_proxy_response_transform() -- first find the configured LLM interface and driver local http_opts = create_http_opts(conf) From 2f2b308590503e69cc3ddf1ab5d791370502603d Mon Sep 17 00:00:00 2001 From: Wangchong Zhou Date: Wed, 17 Jul 2024 22:11:03 +0800 Subject: [PATCH 3/5] fix(ai-proxy): only send compressed response when client requested --- kong/llm/drivers/openai.lua | 10 +++++----- kong/llm/proxy/handler.lua | 16 +++++++++++----- .../38-ai-proxy/08-encoding_integration_spec.lua | 5 +++++ 3 files changed, 21 insertions(+), 10 deletions(-) diff --git a/kong/llm/drivers/openai.lua b/kong/llm/drivers/openai.lua index 1c592e5ef60..ca1e37d9183 100644 --- a/kong/llm/drivers/openai.lua +++ b/kong/llm/drivers/openai.lua @@ -35,12 +35,12 @@ local transformers_to = { } local transformers_from = { - ["llm/v1/chat"] = function(response_string, model_info) + ["llm/v1/chat"] = function(response_string, _) local response_object, err = cjson.decode(response_string) if err then - return nil, "'choices' not in llm/v1/chat response" + return nil, "failed to decode llm/v1/chat response" end - + if response_object.choices then return response_string, nil else @@ -48,10 +48,10 @@ local transformers_from = { end end, - ["llm/v1/completions"] = function(response_string, model_info) + ["llm/v1/completions"] = function(response_string, _) local response_object, err = cjson.decode(response_string) if err then - return nil, "'choices' not in llm/v1/completions response" + return nil, "failed to decode llm/v1/completions response" end if response_object.choices then diff --git a/kong/llm/proxy/handler.lua b/kong/llm/proxy/handler.lua index 82769b625b0..d6c7fd1ec6f 100644 --- a/kong/llm/proxy/handler.lua +++ b/kong/llm/proxy/handler.lua @@ -109,6 +109,11 @@ local _KEYBASTION = setmetatable({}, { }) +local function accept_gzip() + return not not kong.ctx.plugin.accept_gzip +end + + -- get the token text from an event frame local function get_token_text(event_t) -- get: event_t.choices[1] @@ -124,7 +129,6 @@ end local function handle_streaming_frame(conf, chunk, finished) -- make a re-usable framebuffer local framebuffer = buffer.new() - local is_gzip = kong.response.get_header("Content-Encoding") == "gzip" local ai_driver = require("kong.llm.drivers." .. conf.model.provider) @@ -140,8 +144,8 @@ local function handle_streaming_frame(conf, chunk, finished) -- transform each one into flat format, skipping transformer errors -- because we have already 200 OK'd the client by now - if (not finished) and (is_gzip) then - chunk = kong_utils.inflate_gzip(ngx.arg[1]) + if not finished and kong.service.response.get_header("Content-Encoding") == "gzip" then + chunk = kong_utils.inflate_gzip(chunk) end local events = ai_shared.frame_to_events(chunk, conf.model.provider) @@ -152,7 +156,7 @@ local function handle_streaming_frame(conf, chunk, finished) -- and then send the client a readable error in a single chunk local response = ERROR__NOT_SET - if is_gzip then + if accept_gzip() then response = kong_utils.deflate_gzip(response) end @@ -234,7 +238,7 @@ local function handle_streaming_frame(conf, chunk, finished) end local response_frame = framebuffer:get() - if (not finished) and (is_gzip) then + if not finished and accept_gzip() then response_frame = kong_utils.deflate_gzip(response_frame) end @@ -372,6 +376,8 @@ end function _M:access(conf) local kong_ctx_plugin = kong.ctx.plugin + -- record the request header very early, otherwise kong.serivce.request.set_header will polute it + kong_ctx_plugin.accept_gzip = (kong.request.get_header("Accept-Encoding") or ""):match("%f[%a]gzip%f[%A]") -- store the route_type in ctx for use in response parsing local route_type = conf.route_type diff --git a/spec/03-plugins/38-ai-proxy/08-encoding_integration_spec.lua b/spec/03-plugins/38-ai-proxy/08-encoding_integration_spec.lua index 049920e460b..0cc63ba41ba 100644 --- a/spec/03-plugins/38-ai-proxy/08-encoding_integration_spec.lua +++ b/spec/03-plugins/38-ai-proxy/08-encoding_integration_spec.lua @@ -257,6 +257,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then ["content-type"] = "application/json", ["accept"] = "application/json", ["x-test-type"] = "200", + ["accept-encoding"] = "gzip, identity" }, body = format_stencils.llm_v1_chat.good.user_request, }) @@ -287,6 +288,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then ["content-type"] = "application/json", ["accept"] = "application/json", ["x-test-type"] = "200_FAULTY", + ["accept-encoding"] = "gzip, identity" }, body = format_stencils.llm_v1_chat.good.user_request, }) @@ -307,6 +309,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then ["content-type"] = "application/json", ["accept"] = "application/json", ["x-test-type"] = "401", + ["accept-encoding"] = "gzip, identity" }, body = format_stencils.llm_v1_chat.good.user_request, }) @@ -327,6 +330,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then ["content-type"] = "application/json", ["accept"] = "application/json", ["x-test-type"] = "500", + ["accept-encoding"] = "gzip, identity" }, body = format_stencils.llm_v1_chat.good.user_request, }) @@ -347,6 +351,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then ["content-type"] = "application/json", ["accept"] = "application/json", ["x-test-type"] = "500_FAULTY", + ["accept-encoding"] = "gzip, identity" }, body = format_stencils.llm_v1_chat.good.user_request, }) From eb4e3c57276dd65582bc2d36ce53273ab6632a75 Mon Sep 17 00:00:00 2001 From: Wangchong Zhou Date: Tue, 6 Aug 2024 13:08:06 +0800 Subject: [PATCH 4/5] chore(*): remove empty line with spaces --- kong/llm/drivers/anthropic.lua | 4 +- kong/llm/drivers/azure.lua | 2 +- kong/llm/drivers/bedrock.lua | 2 +- kong/llm/drivers/cohere.lua | 52 +++++++++---------- kong/llm/drivers/gemini.lua | 4 +- kong/llm/drivers/llama2.lua | 2 +- kong/llm/drivers/openai.lua | 4 +- kong/llm/drivers/shared.lua | 4 +- .../22-ai_plugins/01-reports_spec.lua | 16 +++--- .../03-plugins/38-ai-proxy/00-config_spec.lua | 12 ++--- spec/03-plugins/38-ai-proxy/01-unit_spec.lua | 14 ++--- .../02-openai_integration_spec.lua | 52 +++++++++---------- .../03-anthropic_integration_spec.lua | 30 +++++------ .../04-cohere_integration_spec.lua | 30 +++++------ .../38-ai-proxy/05-azure_integration_spec.lua | 34 ++++++------ .../06-mistral_integration_spec.lua | 14 ++--- .../07-llama2_integration_spec.lua | 18 +++---- .../08-encoding_integration_spec.lua | 4 +- .../09-streaming_integration_spec.lua | 6 +-- .../02-integration_spec.lua | 10 ++-- .../42-ai-prompt-guard/01-unit_spec.lua | 2 +- 21 files changed, 158 insertions(+), 158 deletions(-) diff --git a/kong/llm/drivers/anthropic.lua b/kong/llm/drivers/anthropic.lua index 77c9f363f9b..7161804cd93 100644 --- a/kong/llm/drivers/anthropic.lua +++ b/kong/llm/drivers/anthropic.lua @@ -71,7 +71,7 @@ local function to_claude_prompt(req) return kong_messages_to_claude_prompt(req.messages) end - + return nil, "request is missing .prompt and .messages commands" end @@ -328,7 +328,7 @@ function _M.from_format(response_string, model_info, route_type) if not transform then return nil, fmt("no transformer available from format %s://%s", model_info.provider, route_type) end - + local ok, response_string, err, metadata = pcall(transform, response_string, model_info, route_type) if not ok or err then return nil, fmt("transformation failed from type %s://%s: %s", diff --git a/kong/llm/drivers/azure.lua b/kong/llm/drivers/azure.lua index a0ba1741a86..343904ffad2 100644 --- a/kong/llm/drivers/azure.lua +++ b/kong/llm/drivers/azure.lua @@ -139,7 +139,7 @@ function _M.configure_request(conf) -- technically min supported version query_table["api-version"] = kong.request.get_query_arg("api-version") or (conf.model.options and conf.model.options.azure_api_version) - + if auth_param_name and auth_param_value and auth_param_location == "query" then query_table[auth_param_name] = auth_param_value end diff --git a/kong/llm/drivers/bedrock.lua b/kong/llm/drivers/bedrock.lua index 372a57fa827..a32ad5120e6 100644 --- a/kong/llm/drivers/bedrock.lua +++ b/kong/llm/drivers/bedrock.lua @@ -266,7 +266,7 @@ function _M.from_format(response_string, model_info, route_type) if not transformers_from[route_type] then return nil, fmt("no transformer available from format %s://%s", model_info.provider, route_type) end - + local ok, response_string, err, metadata = pcall(transformers_from[route_type], response_string, model_info, route_type) if not ok or err then return nil, fmt("transformation failed from type %s://%s: %s", diff --git a/kong/llm/drivers/cohere.lua b/kong/llm/drivers/cohere.lua index 1aafc9405b0..28c8c64eaac 100644 --- a/kong/llm/drivers/cohere.lua +++ b/kong/llm/drivers/cohere.lua @@ -21,7 +21,7 @@ local _CHAT_ROLES = { local function handle_stream_event(event_t, model_info, route_type) local metadata - + -- discard empty frames, it should either be a random new line, or comment if (not event_t.data) or (#event_t.data < 1) then return @@ -31,12 +31,12 @@ local function handle_stream_event(event_t, model_info, route_type) if err then return nil, "failed to decode event frame from cohere: " .. err, nil end - + local new_event - + if event.event_type == "stream-start" then kong.ctx.plugin.ai_proxy_cohere_stream_id = event.generation_id - + -- ignore the rest of this one new_event = { choices = { @@ -52,7 +52,7 @@ local function handle_stream_event(event_t, model_info, route_type) model = model_info.name, object = "chat.completion.chunk", } - + elseif event.event_type == "text-generation" then -- this is a token if route_type == "stream/llm/v1/chat" then @@ -137,19 +137,19 @@ end local function handle_json_inference_event(request_table, model) request_table.temperature = request_table.temperature request_table.max_tokens = request_table.max_tokens - + request_table.p = request_table.top_p request_table.k = request_table.top_k - + request_table.top_p = nil request_table.top_k = nil - + request_table.model = model.name or request_table.model request_table.stream = request_table.stream or false -- explicitly set this - + if request_table.prompt and request_table.messages then return kong.response.exit(400, "cannot run a 'prompt' and a history of 'messages' at the same time - refer to schema") - + elseif request_table.messages then -- we have to move all BUT THE LAST message into "chat_history" array -- and move the LAST message (from 'user') into "message" string @@ -164,26 +164,26 @@ local function handle_json_inference_event(request_table, model) else role = _CHAT_ROLES.user end - + chat_history[i] = { role = role, message = v.content, } end end - + request_table.chat_history = chat_history end - + request_table.message = request_table.messages[#request_table.messages].content request_table.messages = nil - + elseif request_table.prompt then request_table.prompt = request_table.prompt request_table.messages = nil request_table.message = nil end - + return request_table, "application/json", nil end @@ -202,7 +202,7 @@ local transformers_from = { -- messages/choices table is only 1 size, so don't need to static allocate local messages = {} messages.choices = {} - + if response_table.prompt and response_table.generations then -- this is a "co.generate" for i, v in ipairs(response_table.generations) do @@ -215,7 +215,7 @@ local transformers_from = { messages.object = "text_completion" messages.model = model_info.name messages.id = response_table.id - + local stats = { completion_tokens = response_table.meta and response_table.meta.billed_units @@ -230,10 +230,10 @@ local transformers_from = { and (response_table.meta.billed_units.output_tokens + response_table.meta.billed_units.input_tokens), } messages.usage = stats - + elseif response_table.text then -- this is a "co.chat" - + messages.choices[1] = { index = 0, message = { @@ -245,7 +245,7 @@ local transformers_from = { messages.object = "chat.completion" messages.model = model_info.name messages.id = response_table.generation_id - + local stats = { completion_tokens = response_table.meta and response_table.meta.billed_units @@ -260,10 +260,10 @@ local transformers_from = { and (response_table.meta.billed_units.output_tokens + response_table.meta.billed_units.input_tokens), } messages.usage = stats - + else -- probably a fault return nil, "'text' or 'generations' missing from cohere response body" - + end return cjson.encode(messages) @@ -314,17 +314,17 @@ local transformers_from = { prompt.object = "chat.completion" prompt.model = model_info.name prompt.id = response_table.generation_id - + local stats = { completion_tokens = response_table.token_count and response_table.token_count.response_tokens, prompt_tokens = response_table.token_count and response_table.token_count.prompt_tokens, total_tokens = response_table.token_count and response_table.token_count.total_tokens, } prompt.usage = stats - + else -- probably a fault return nil, "'text' or 'generations' missing from cohere response body" - + end return cjson.encode(prompt) @@ -465,7 +465,7 @@ function _M.configure_request(conf) and ai_shared.operation_map[DRIVER_NAME][conf.route_type].path or "/" end - + -- if the path is read from a URL capture, ensure that it is valid parsed_url.path = string_gsub(parsed_url.path, "^/*", "/") diff --git a/kong/llm/drivers/gemini.lua b/kong/llm/drivers/gemini.lua index 57ca7127ef2..0a68a0af8e1 100644 --- a/kong/llm/drivers/gemini.lua +++ b/kong/llm/drivers/gemini.lua @@ -123,7 +123,7 @@ local function to_gemini_chat_openai(request_table, model_info, route_type) } end end - + new_r.generationConfig = to_gemini_generation_config(request_table) return new_r, "application/json", nil @@ -222,7 +222,7 @@ function _M.from_format(response_string, model_info, route_type) if not transformers_from[route_type] then return nil, fmt("no transformer available from format %s://%s", model_info.provider, route_type) end - + local ok, response_string, err, metadata = pcall(transformers_from[route_type], response_string, model_info, route_type) if not ok or err then return nil, fmt("transformation failed from type %s://%s: %s", diff --git a/kong/llm/drivers/llama2.lua b/kong/llm/drivers/llama2.lua index 02b65818bd3..0526453f8a5 100644 --- a/kong/llm/drivers/llama2.lua +++ b/kong/llm/drivers/llama2.lua @@ -113,7 +113,7 @@ local function to_raw(request_table, model) messages.parameters.top_k = request_table.top_k messages.parameters.temperature = request_table.temperature messages.parameters.stream = request_table.stream or false -- explicitly set this - + if request_table.prompt and request_table.messages then return kong.response.exit(400, "cannot run raw 'prompt' and chat history 'messages' requests at the same time - refer to schema") diff --git a/kong/llm/drivers/openai.lua b/kong/llm/drivers/openai.lua index ca1e37d9183..331ad3b5e7e 100644 --- a/kong/llm/drivers/openai.lua +++ b/kong/llm/drivers/openai.lua @@ -72,7 +72,7 @@ function _M.from_format(response_string, model_info, route_type) if not transformers_from[route_type] then return nil, fmt("no transformer available from format %s://%s", model_info.provider, route_type) end - + local ok, response_string, err = pcall(transformers_from[route_type], response_string, model_info) if not ok or err then return nil, fmt("transformation failed from type %s://%s: %s", @@ -203,7 +203,7 @@ function _M.configure_request(conf) parsed_url = socket_url.parse(ai_shared.upstream_url_format[DRIVER_NAME]) parsed_url.path = path end - + -- if the path is read from a URL capture, ensure that it is valid parsed_url.path = string_gsub(parsed_url.path, "^/*", "/") diff --git a/kong/llm/drivers/shared.lua b/kong/llm/drivers/shared.lua index 8004b9384df..15d9ce7e62f 100644 --- a/kong/llm/drivers/shared.lua +++ b/kong/llm/drivers/shared.lua @@ -259,7 +259,7 @@ function _M.frame_to_events(frame, provider) -- it may start with ',' which is the start of the new frame frame = (string.sub(str_ltrim(frame), 1, 1) == "," and string.sub(str_ltrim(frame), 2)) or frame - + -- it may end with the array terminator ']' indicating the finished stream if string.sub(str_rtrim(frame), -1) == "]" then frame = string.sub(str_rtrim(frame), 1, -2) @@ -446,7 +446,7 @@ function _M.from_ollama(response_string, model_info, route_type) end end - + if output and output ~= _M._CONST.SSE_TERMINATOR then output, err = cjson.encode(output) end diff --git a/spec/02-integration/22-ai_plugins/01-reports_spec.lua b/spec/02-integration/22-ai_plugins/01-reports_spec.lua index 78c98c03153..9c4858e7127 100644 --- a/spec/02-integration/22-ai_plugins/01-reports_spec.lua +++ b/spec/02-integration/22-ai_plugins/01-reports_spec.lua @@ -38,32 +38,32 @@ for _, strategy in helpers.each_strategy() do local fixtures = { http_mock = {}, } - + fixtures.http_mock.openai = [[ server { server_name openai; listen ]]..MOCK_PORT..[[; - + default_type 'application/json'; - - + + location = "/llm/v1/chat/good" { content_by_lua_block { local pl_file = require "pl.file" local json = require("cjson.safe") - + ngx.req.read_body() local body, err = ngx.req.get_body_data() body, err = json.decode(body) - + local token = ngx.req.get_headers()["authorization"] local token_query = ngx.req.get_uri_args()["apikey"] - + if token == "Bearer openai-key" or token_query == "openai-key" or body.apikey == "openai-key" then ngx.req.read_body() local body, err = ngx.req.get_body_data() body, err = json.decode(body) - + if err or (body.messages == ngx.null) then ngx.status = 400 ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/responses/bad_request.json")) diff --git a/spec/03-plugins/38-ai-proxy/00-config_spec.lua b/spec/03-plugins/38-ai-proxy/00-config_spec.lua index bbd495918bb..0a15f131b46 100644 --- a/spec/03-plugins/38-ai-proxy/00-config_spec.lua +++ b/spec/03-plugins/38-ai-proxy/00-config_spec.lua @@ -84,7 +84,7 @@ describe(PLUGIN_NAME .. ": (schema)", function() end local ok, err = validate(config) - + assert.is_truthy(ok) assert.is_falsy(err) end) @@ -220,7 +220,7 @@ describe(PLUGIN_NAME .. ": (schema)", function() } local ok, err = validate(config) - + assert.equal(err["config"]["@entity"][1], "must set one of 'auth.header_name', 'auth.param_name', " .. "and its respective options, when provider is not self-hosted") assert.is_falsy(ok) @@ -244,7 +244,7 @@ describe(PLUGIN_NAME .. ": (schema)", function() } local ok, err = validate(config) - + assert.equals(err["config"]["@entity"][1], "all or none of these fields must be set: 'auth.header_name', 'auth.header_value'") assert.is_falsy(ok) end) @@ -268,7 +268,7 @@ describe(PLUGIN_NAME .. ": (schema)", function() } local ok, err = validate(config) - + assert.is_falsy(err) assert.is_truthy(ok) end) @@ -317,7 +317,7 @@ describe(PLUGIN_NAME .. ": (schema)", function() } local ok, err = validate(config) - + assert.is_falsy(err) assert.is_truthy(ok) end) @@ -344,7 +344,7 @@ describe(PLUGIN_NAME .. ": (schema)", function() } local ok, err = validate(config) - + assert.is_falsy(err) assert.is_truthy(ok) end) diff --git a/spec/03-plugins/38-ai-proxy/01-unit_spec.lua b/spec/03-plugins/38-ai-proxy/01-unit_spec.lua index 009f079195d..a73f12a409b 100644 --- a/spec/03-plugins/38-ai-proxy/01-unit_spec.lua +++ b/spec/03-plugins/38-ai-proxy/01-unit_spec.lua @@ -532,7 +532,7 @@ describe(PLUGIN_NAME .. ": (unit)", function() local expected_request_json = pl_file.read(filename) local expected_request_table, err = cjson.decode(expected_request_json) assert.is_nil(err) - + -- compare the tables assert.same(expected_request_table, actual_request_table) end) @@ -547,7 +547,7 @@ describe(PLUGIN_NAME .. ": (unit)", function() local filename if l.config.provider == "llama2" then filename = fmt("spec/fixtures/ai-proxy/unit/real-responses/%s/%s/%s.json", l.config.provider, l.config.options.llama2_format, pl_replace(k, "/", "-")) - + elseif l.config.provider == "mistral" then filename = fmt("spec/fixtures/ai-proxy/unit/real-responses/%s/%s/%s.json", l.config.provider, l.config.options.mistral_format, pl_replace(k, "/", "-")) @@ -604,7 +604,7 @@ describe(PLUGIN_NAME .. ": (unit)", function() it("converts to provider request format correctly", function() -- load the real provider frame from file local real_stream_frame = pl_file.read(fmt("spec/fixtures/ai-proxy/unit/real-stream-frames/%s/%s.txt", config.provider, pl_replace(format_name, "/", "-"))) - + -- use the shared function to produce an SSE format object local real_transformed_frame, err = ai_shared.frame_to_events(real_stream_frame) assert.is_nil(err) @@ -628,7 +628,7 @@ describe(PLUGIN_NAME .. ": (unit)", function() -- generic tests it("throws correct error when format is not supported", function() local driver = require("kong.llm.drivers.mistral") -- one-shot, random example of provider with only prompt support - + local model_config = { route_type = "llm/v1/chatnopenotsupported", name = "mistral-tiny", @@ -651,7 +651,7 @@ describe(PLUGIN_NAME .. ": (unit)", function() assert.equal(err, "no transformer available to format mistral://llm/v1/chatnopenotsupported/ollama") end) - + it("produces a correct default config merge", function() local formatted, err = ai_shared.merge_config_defaults( SAMPLE_LLM_V1_CHAT_WITH_SOME_OPTS, @@ -675,7 +675,7 @@ describe(PLUGIN_NAME .. ": (unit)", function() end) describe("streaming transformer tests", function() - + it("transforms truncated-json type (beginning of stream)", function() local input = pl_file.read(fmt("spec/fixtures/ai-proxy/unit/streaming-chunk-formats/partial-json-beginning/input.bin")) local events = ai_shared.frame_to_events(input, "gemini") @@ -695,7 +695,7 @@ describe(PLUGIN_NAME .. ": (unit)", function() assert.same(events, expected_events, true) end) - + it("transforms complete-json type", function() local input = pl_file.read(fmt("spec/fixtures/ai-proxy/unit/streaming-chunk-formats/complete-json/input.bin")) local events = ai_shared.frame_to_events(input, "cohere") -- not "truncated json mode" like Gemini diff --git a/spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua b/spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua index 0b5468e2e88..e963d908e32 100644 --- a/spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua +++ b/spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua @@ -68,14 +68,14 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then local fixtures = { http_mock = {}, } - + fixtures.http_mock.openai = [[ server { server_name openai; listen ]]..MOCK_PORT..[[; - + default_type 'application/json'; - + location = "/llm/v1/chat/good" { content_by_lua_block { @@ -93,7 +93,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then ngx.req.read_body() local body, err = ngx.req.get_body_data() body, err = json.decode(body) - + if err or (body.messages == ngx.null) then ngx.status = 400 ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/responses/bad_request.json")) @@ -118,7 +118,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then ngx.req.read_body() local body, err = ngx.req.get_body_data() body, err = json.decode(body) - + if err or (body.messages == ngx.null) then ngx.status = 400 ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/responses/bad_request.json")) @@ -136,7 +136,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then location = "/llm/v1/chat/bad_request" { content_by_lua_block { local pl_file = require "pl.file" - + ngx.status = 400 ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/responses/bad_request.json")) } @@ -145,7 +145,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then location = "/llm/v1/chat/internal_server_error" { content_by_lua_block { local pl_file = require "pl.file" - + ngx.status = 500 ngx.header["content-type"] = "text/html" ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/responses/internal_server_error.html")) @@ -166,7 +166,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then local token_query = ngx.req.get_uri_args()["apikey"] if token == "Bearer openai-key" or token_query == "openai-key" or body.apikey == "openai-key" then - + if err or (body.messages == ngx.null) then ngx.status = 400 ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-completions/responses/bad_request.json")) @@ -184,7 +184,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then location = "/llm/v1/completions/bad_request" { content_by_lua_block { local pl_file = require "pl.file" - + ngx.status = 400 ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-completions/responses/bad_request.json")) } @@ -664,7 +664,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then declarative_config = strategy == "off" and helpers.make_yaml_file() or nil, }, nil, nil, fixtures)) end) - + lazy_teardown(function() helpers.stop_kong() end) @@ -692,7 +692,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good.json"), }) - + -- validate that the request succeeded, response status 200 local body = assert.res_status(200 , r) local json = cjson.decode(body) @@ -739,7 +739,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good.json"), }) - + -- validate that the request succeeded, response status 200 local body = assert.res_status(200 , r) local json = cjson.decode(body) @@ -817,7 +817,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good.json"), }) - + local body = assert.res_status(500 , r) assert.is_not_nil(body) end) @@ -830,7 +830,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good.json"), }) - + local body = assert.res_status(401 , r) local json = cjson.decode(body) @@ -849,7 +849,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good.json"), }) - + -- validate that the request succeeded, response status 200 local body = assert.res_status(200 , r) local json = cjson.decode(body) @@ -937,7 +937,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/bad_request.json"), }) - + local body = assert.res_status(400 , r) local json = cjson.decode(body) @@ -956,7 +956,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-completions/requests/good.json"), }) - + -- validate that the request succeeded, response status 200 local body = assert.res_status(200 , r) local json = cjson.decode(body) @@ -979,7 +979,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-completions/requests/bad_request.json"), }) - + local body = assert.res_status(400 , r) local json = cjson.decode(body) @@ -1038,7 +1038,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then describe("one-shot request", function() it("success", function() local ai_driver = require("kong.llm.drivers.openai") - + local plugin_conf = { route_type = "llm/v1/chat", auth = { @@ -1054,7 +1054,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, }, } - + local request = { messages = { [1] = { @@ -1067,15 +1067,15 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then } } } - + -- convert it to the specified driver format local ai_request = ai_driver.to_format(request, plugin_conf.model, "llm/v1/chat") - + -- send it to the ai service local ai_response, status_code, err = ai_driver.subrequest(ai_request, plugin_conf, {}, false) assert.is_nil(err) assert.equal(200, status_code) - + -- parse and convert the response local ai_response, _, err = ai_driver.from_format(ai_response, plugin_conf.model, plugin_conf.route_type) assert.is_nil(err) @@ -1092,7 +1092,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then it("404", function() local ai_driver = require("kong.llm.drivers.openai") - + local plugin_conf = { route_type = "llm/v1/chat", auth = { @@ -1108,7 +1108,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, }, } - + local request = { messages = { [1] = { @@ -1121,7 +1121,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then } } } - + -- convert it to the specified driver format local ai_request = ai_driver.to_format(request, plugin_conf.model, "llm/v1/chat") diff --git a/spec/03-plugins/38-ai-proxy/03-anthropic_integration_spec.lua b/spec/03-plugins/38-ai-proxy/03-anthropic_integration_spec.lua index c3cdc525c61..78f990fe616 100644 --- a/spec/03-plugins/38-ai-proxy/03-anthropic_integration_spec.lua +++ b/spec/03-plugins/38-ai-proxy/03-anthropic_integration_spec.lua @@ -17,14 +17,14 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then local fixtures = { http_mock = {}, } - + fixtures.http_mock.anthropic = [[ server { server_name anthropic; listen ]]..MOCK_PORT..[[; - + default_type 'application/json'; - + location = "/llm/v1/chat/good" { content_by_lua_block { @@ -36,7 +36,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then ngx.req.read_body() local body, err = ngx.req.get_body_data() body, err = json.decode(body) - + if err or (not body.messages) then ngx.status = 400 ngx.print(pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-chat/responses/bad_request.json")) @@ -61,7 +61,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then ngx.req.read_body() local body, err = ngx.req.get_body_data() body, err = json.decode(body) - + if err or (not body.messages) then ngx.status = 400 ngx.print(pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-chat/responses/bad_request.json")) @@ -129,7 +129,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then location = "/llm/v1/chat/bad_request" { content_by_lua_block { local pl_file = require "pl.file" - + ngx.status = 400 ngx.print(pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-chat/responses/bad_request.json")) } @@ -138,7 +138,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then location = "/llm/v1/chat/internal_server_error" { content_by_lua_block { local pl_file = require "pl.file" - + ngx.status = 500 ngx.header["content-type"] = "text/html" ngx.print(pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-chat/responses/internal_server_error.html")) @@ -156,7 +156,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then ngx.req.read_body() local body, err = ngx.req.get_body_data() body, err = json.decode(body) - + if err or (not body.prompt) then ngx.status = 400 ngx.print(pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-completions/responses/bad_request.json")) @@ -174,7 +174,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then location = "/llm/v1/completions/bad_request" { content_by_lua_block { local pl_file = require "pl.file" - + ngx.status = 400 ngx.print(pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-completions/responses/bad_request.json")) } @@ -501,7 +501,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-chat/requests/good.json"), }) - + local body = assert.res_status(500 , r) assert.is_not_nil(body) end) @@ -514,7 +514,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-chat/requests/good.json"), }) - + local body = assert.res_status(401 , r) local json = cjson.decode(body) @@ -558,7 +558,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-chat/requests/good.json"), }) - + -- check we got internal server error local body = assert.res_status(500 , r) local json = cjson.decode(body) @@ -573,7 +573,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-chat/requests/bad_request.json"), }) - + local body = assert.res_status(400 , r) local json = cjson.decode(body) @@ -619,7 +619,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-completions/requests/good.json"), }) - + local body = assert.res_status(200 , r) local json = cjson.decode(body) @@ -640,7 +640,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-completions/requests/bad_request.json"), }) - + local body = assert.res_status(400 , r) local json = cjson.decode(body) diff --git a/spec/03-plugins/38-ai-proxy/04-cohere_integration_spec.lua b/spec/03-plugins/38-ai-proxy/04-cohere_integration_spec.lua index 721cf97566e..548db5e59be 100644 --- a/spec/03-plugins/38-ai-proxy/04-cohere_integration_spec.lua +++ b/spec/03-plugins/38-ai-proxy/04-cohere_integration_spec.lua @@ -16,14 +16,14 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then local fixtures = { http_mock = {}, } - + fixtures.http_mock.cohere = [[ server { server_name cohere; listen ]]..MOCK_PORT..[[; - + default_type 'application/json'; - + location = "/llm/v1/chat/good" { content_by_lua_block { @@ -78,7 +78,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then location = "/llm/v1/chat/bad_request" { content_by_lua_block { local pl_file = require "pl.file" - + ngx.status = 400 ngx.print(pl_file.read("spec/fixtures/ai-proxy/cohere/llm-v1-chat/responses/bad_request.json")) } @@ -87,7 +87,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then location = "/llm/v1/chat/internal_server_error" { content_by_lua_block { local pl_file = require "pl.file" - + ngx.status = 500 ngx.header["content-type"] = "text/html" ngx.print(pl_file.read("spec/fixtures/ai-proxy/cohere/llm-v1-chat/responses/internal_server_error.html")) @@ -105,7 +105,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then ngx.req.read_body() local body, err = ngx.req.get_body_data() body, err = json.decode(body) - + if err or (not body.prompt) then ngx.status = 400 ngx.print(pl_file.read("spec/fixtures/ai-proxy/cohere/llm-v1-completions/responses/bad_request.json")) @@ -123,7 +123,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then location = "/llm/v1/completions/bad_request" { content_by_lua_block { local pl_file = require "pl.file" - + ngx.status = 400 ngx.print(pl_file.read("spec/fixtures/ai-proxy/cohere/llm-v1-completions/responses/bad_request.json")) } @@ -356,7 +356,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then declarative_config = strategy == "off" and helpers.make_yaml_file() or nil, }, nil, nil, fixtures)) end) - + lazy_teardown(function() helpers.stop_kong() end) @@ -378,7 +378,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = pl_file.read("spec/fixtures/ai-proxy/cohere/llm-v1-chat/requests/good.json"), }) - + local body = assert.res_status(500 , r) assert.is_not_nil(body) end) @@ -391,7 +391,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = pl_file.read("spec/fixtures/ai-proxy/cohere/llm-v1-chat/requests/good.json"), }) - + local body = assert.res_status(401 , r) local json = cjson.decode(body) @@ -409,7 +409,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = pl_file.read("spec/fixtures/ai-proxy/cohere/llm-v1-chat/requests/good.json"), }) - + local body = assert.res_status(200 , r) local json = cjson.decode(body) @@ -433,7 +433,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = pl_file.read("spec/fixtures/ai-proxy/cohere/llm-v1-chat/requests/good.json"), }) - + -- check we got internal server error local body = assert.res_status(500 , r) local json = cjson.decode(body) @@ -448,7 +448,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = pl_file.read("spec/fixtures/ai-proxy/cohere/llm-v1-chat/requests/bad_request.json"), }) - + local body = assert.res_status(400 , r) local json = cjson.decode(body) @@ -466,7 +466,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = pl_file.read("spec/fixtures/ai-proxy/cohere/llm-v1-completions/requests/good.json"), }) - + local body = assert.res_status(200 , r) local json = cjson.decode(body) @@ -487,7 +487,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = pl_file.read("spec/fixtures/ai-proxy/cohere/llm-v1-completions/requests/bad_request.json"), }) - + local body = assert.res_status(400 , r) local json = cjson.decode(body) diff --git a/spec/03-plugins/38-ai-proxy/05-azure_integration_spec.lua b/spec/03-plugins/38-ai-proxy/05-azure_integration_spec.lua index a8efe9b21a1..d76d0c4ac50 100644 --- a/spec/03-plugins/38-ai-proxy/05-azure_integration_spec.lua +++ b/spec/03-plugins/38-ai-proxy/05-azure_integration_spec.lua @@ -16,14 +16,14 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then local fixtures = { http_mock = {}, } - + fixtures.http_mock.azure = [[ server { server_name azure; listen ]]..MOCK_PORT..[[; - + default_type 'application/json'; - + location = "/llm/v1/chat/good" { content_by_lua_block { @@ -35,7 +35,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then ngx.req.read_body() local body, err = ngx.req.get_body_data() body, err = json.decode(body) - + if err or (body.messages == ngx.null) then ngx.status = 400 ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/responses/bad_request.json")) @@ -60,7 +60,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then ngx.req.read_body() local body, err = ngx.req.get_body_data() body, err = json.decode(body) - + if err or (body.messages == ngx.null) then ngx.status = 400 ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/responses/bad_request.json")) @@ -78,7 +78,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then location = "/llm/v1/chat/bad_request" { content_by_lua_block { local pl_file = require "pl.file" - + ngx.status = 400 ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/responses/bad_request.json")) } @@ -87,7 +87,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then location = "/llm/v1/chat/internal_server_error" { content_by_lua_block { local pl_file = require "pl.file" - + ngx.status = 500 ngx.header["content-type"] = "text/html" ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/responses/internal_server_error.html")) @@ -105,7 +105,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then ngx.req.read_body() local body, err = ngx.req.get_body_data() body, err = json.decode(body) - + if err or (body.messages == ngx.null) then ngx.status = 400 ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-completions/responses/bad_request.json")) @@ -123,7 +123,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then location = "/llm/v1/completions/bad_request" { content_by_lua_block { local pl_file = require "pl.file" - + ngx.status = 400 ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-completions/responses/bad_request.json")) } @@ -370,7 +370,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then declarative_config = strategy == "off" and helpers.make_yaml_file() or nil, }, nil, nil, fixtures)) end) - + lazy_teardown(function() helpers.stop_kong() end) @@ -392,7 +392,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good.json"), }) - + local body = assert.res_status(500 , r) assert.is_not_nil(body) end) @@ -405,7 +405,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good.json"), }) - + local body = assert.res_status(401 , r) local json = cjson.decode(body) @@ -424,7 +424,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good.json"), }) - + -- validate that the request succeeded, response status 200 local body = assert.res_status(200 , r) local json = cjson.decode(body) @@ -450,7 +450,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good.json"), }) - + -- check we got internal server error local body = assert.res_status(500 , r) local json = cjson.decode(body) @@ -465,7 +465,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/bad_request.json"), }) - + local body = assert.res_status(400 , r) local json = cjson.decode(body) @@ -484,7 +484,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-completions/requests/good.json"), }) - + -- validate that the request succeeded, response status 200 local body = assert.res_status(200 , r) local json = cjson.decode(body) @@ -507,7 +507,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-completions/requests/bad_request.json"), }) - + local body = assert.res_status(400 , r) local json = cjson.decode(body) diff --git a/spec/03-plugins/38-ai-proxy/06-mistral_integration_spec.lua b/spec/03-plugins/38-ai-proxy/06-mistral_integration_spec.lua index 3c711cd83b4..94058750ff1 100644 --- a/spec/03-plugins/38-ai-proxy/06-mistral_integration_spec.lua +++ b/spec/03-plugins/38-ai-proxy/06-mistral_integration_spec.lua @@ -16,12 +16,12 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then local fixtures = { http_mock = {}, } - + fixtures.http_mock.mistral = [[ server { server_name mistral; listen ]]..MOCK_PORT..[[; - + default_type 'application/json'; location = "/v1/chat/completions" { @@ -34,7 +34,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then ngx.req.read_body() local body, err = ngx.req.get_body_data() body, err = json.decode(body) - + if err or (body.messages == ngx.null) then ngx.status = 400 ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/responses/bad_request.json")) @@ -59,7 +59,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then ngx.req.read_body() local body, err = ngx.req.get_body_data() body, err = json.decode(body) - + if err or (body.prompt == ngx.null) then ngx.status = 400 ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-completions/responses/bad_request.json")) @@ -307,7 +307,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then declarative_config = strategy == "off" and helpers.make_yaml_file() or nil, }, nil, nil, fixtures)) end) - + lazy_teardown(function() helpers.stop_kong() end) @@ -329,7 +329,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good.json"), }) - + -- validate that the request succeeded, response status 200 local body = assert.res_status(200 , r) local json = cjson.decode(body) @@ -357,7 +357,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-completions/requests/good.json"), }) - + -- validate that the request succeeded, response status 200 local body = assert.res_status(200 , r) local json = cjson.decode(body) diff --git a/spec/03-plugins/38-ai-proxy/07-llama2_integration_spec.lua b/spec/03-plugins/38-ai-proxy/07-llama2_integration_spec.lua index aa74ef9fd5b..778804d4af6 100644 --- a/spec/03-plugins/38-ai-proxy/07-llama2_integration_spec.lua +++ b/spec/03-plugins/38-ai-proxy/07-llama2_integration_spec.lua @@ -16,12 +16,12 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then local fixtures = { http_mock = {}, } - + fixtures.http_mock.llama2 = [[ server { server_name llama2; listen ]]..MOCK_PORT..[[; - + default_type 'application/json'; location = "/raw/llm/v1/chat" { @@ -155,7 +155,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then declarative_config = strategy == "off" and helpers.make_yaml_file() or nil, }, nil, nil, fixtures)) end) - + lazy_teardown(function() helpers.stop_kong() end) @@ -177,7 +177,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = pl_file.read("spec/fixtures/ai-proxy/llama2/raw/requests/good-chat.json"), }) - + local body = assert.res_status(200, r) local json = cjson.decode(body) @@ -192,7 +192,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = pl_file.read("spec/fixtures/ai-proxy/llama2/raw/requests/good-completions.json"), }) - + local body = assert.res_status(200, r) local json = cjson.decode(body) @@ -203,7 +203,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then describe("one-shot request", function() it("success", function() local ai_driver = require("kong.llm.drivers.llama2") - + local plugin_conf = { route_type = "llm/v1/chat", auth = { @@ -220,7 +220,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, }, } - + local request = { messages = { [1] = { @@ -260,7 +260,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then it("404", function() local ai_driver = require("kong.llm.drivers.llama2") - + local plugin_conf = { route_type = "llm/v1/chat", auth = { @@ -303,7 +303,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then it("401", function() local ai_driver = require("kong.llm.drivers.llama2") - + local plugin_conf = { route_type = "llm/v1/chat", auth = { diff --git a/spec/03-plugins/38-ai-proxy/08-encoding_integration_spec.lua b/spec/03-plugins/38-ai-proxy/08-encoding_integration_spec.lua index 0cc63ba41ba..0e9801a2923 100644 --- a/spec/03-plugins/38-ai-proxy/08-encoding_integration_spec.lua +++ b/spec/03-plugins/38-ai-proxy/08-encoding_integration_spec.lua @@ -152,7 +152,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then local token = ngx.req.get_headers()["authorization"] local token_query = ngx.req.get_uri_args()["apikey"] - + if token == "Bearer openai-key" or token_query == "openai-key" or body.apikey == "openai-key" then ngx.req.read_body() local body, err = ngx.req.get_body_data() @@ -235,7 +235,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then declarative_config = strategy == "off" and helpers.make_yaml_file() or nil, }, nil, nil, fixtures)) end) - + lazy_teardown(function() helpers.stop_kong() end) diff --git a/spec/03-plugins/38-ai-proxy/09-streaming_integration_spec.lua b/spec/03-plugins/38-ai-proxy/09-streaming_integration_spec.lua index 0d78e57b778..24707b8039e 100644 --- a/spec/03-plugins/38-ai-proxy/09-streaming_integration_spec.lua +++ b/spec/03-plugins/38-ai-proxy/09-streaming_integration_spec.lua @@ -494,7 +494,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then declarative_config = strategy == "off" and helpers.make_yaml_file() or nil, }, nil, nil, fixtures)) end) - + lazy_teardown(function() helpers.stop_kong() end) @@ -691,7 +691,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then end end until not buffer - + assert.equal(#events, 17) assert.equal(buf:tostring(), "1 + 1 = 2. This is the most basic example of addition.") end) @@ -753,7 +753,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then end end until not buffer - + assert.equal(#events, 8) assert.equal(buf:tostring(), "1 + 1 = 2") end) diff --git a/spec/03-plugins/39-ai-request-transformer/02-integration_spec.lua b/spec/03-plugins/39-ai-request-transformer/02-integration_spec.lua index 0b051da1479..9598bab7f56 100644 --- a/spec/03-plugins/39-ai-request-transformer/02-integration_spec.lua +++ b/spec/03-plugins/39-ai-request-transformer/02-integration_spec.lua @@ -158,7 +158,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then server { server_name llm; listen ]]..MOCK_PORT..[[; - + default_type 'application/json'; location ~/flat { @@ -171,7 +171,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then location = "/badrequest" { content_by_lua_block { local pl_file = require "pl.file" - + ngx.status = 400 ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/responses/bad_request.json")) } @@ -180,7 +180,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then location = "/internalservererror" { content_by_lua_block { local pl_file = require "pl.file" - + ngx.status = 500 ngx.header["content-type"] = "text/html" ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/responses/internal_server_error.html")) @@ -248,7 +248,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then declarative_config = strategy == "off" and helpers.make_yaml_file() or nil, }, nil, nil, fixtures)) end) - + lazy_teardown(function() helpers.stop_kong() end) @@ -270,7 +270,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = REQUEST_BODY, }) - + local body = assert.res_status(200 , r) local body_table, err = cjson.decode(body) diff --git a/spec/03-plugins/42-ai-prompt-guard/01-unit_spec.lua b/spec/03-plugins/42-ai-prompt-guard/01-unit_spec.lua index eab961081e6..b2c11519b05 100644 --- a/spec/03-plugins/42-ai-prompt-guard/01-unit_spec.lua +++ b/spec/03-plugins/42-ai-prompt-guard/01-unit_spec.lua @@ -18,7 +18,7 @@ local function create_request(typ) if typ ~= "chat" and typ ~= "completions" then error("type must be one of 'chat' or 'completions'", 2) end - + return setmetatable({ messages = messages, type = typ, From a69eebff9ee61046c5f5d41d11813f3c9f390757 Mon Sep 17 00:00:00 2001 From: Wangchong Zhou Date: Tue, 4 Jun 2024 17:04:01 +0800 Subject: [PATCH 5/5] doc(changelog): add changelog for #13155 --- changelog/unreleased/kong/fix-ai-gzip-content.yml | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 changelog/unreleased/kong/fix-ai-gzip-content.yml diff --git a/changelog/unreleased/kong/fix-ai-gzip-content.yml b/changelog/unreleased/kong/fix-ai-gzip-content.yml new file mode 100644 index 00000000000..ebbad1f1747 --- /dev/null +++ b/changelog/unreleased/kong/fix-ai-gzip-content.yml @@ -0,0 +1,4 @@ +message: | + **AI-Proxy**: Fixed issue when response is gzipped even if client doesn't accept. +type: bugfix +scope: Plugin