From ec9b3566b0caa714dbc6edaa90a4a5701763fdc8 Mon Sep 17 00:00:00 2001 From: Jack Tysoe <91137069+tysoekong@users.noreply.github.com> Date: Mon, 6 May 2024 12:14:01 +0100 Subject: [PATCH] fix(ai proxy): 3.7 regression fixes rollup (#12974) * fix(ai-proxy): remove unsupported top_k parameter from openai format(s) * fix(ai-proxy): broken ollama-type streaming events * fix(ai-proxy): streaming setting moved ai-proxy only * fix(ai-proxy): anthropic token counts * fix(ai-proxy): wrong analytics format; missing azure extra analytics * fix(ai-proxy): correct plugin name in log serializer * fix(ai-proxy): store body string in case of regression * fix(ai-proxy): fix tests (cherry picked from commit 84700562632045223a05a077be74771bbd030f06) --- kong/clustering/compat/checkers.lua | 4 +- kong/llm/drivers/anthropic.lua | 7 +-- kong/llm/drivers/azure.lua | 8 ++- kong/llm/drivers/openai.lua | 2 + kong/llm/drivers/shared.lua | 58 +++++++++++++------ kong/llm/init.lua | 6 -- kong/plugins/ai-proxy/handler.lua | 23 +++++--- kong/plugins/ai-proxy/schema.lua | 20 ++++++- .../09-hybrid_mode/09-config-compat_spec.lua | 4 +- .../02-openai_integration_spec.lua | 5 +- .../02-integration_spec.lua | 1 - .../02-integration_spec.lua | 1 - 12 files changed, 92 insertions(+), 47 deletions(-) diff --git a/kong/clustering/compat/checkers.lua b/kong/clustering/compat/checkers.lua index 6c361dc853e4..b33fcd5726d9 100644 --- a/kong/clustering/compat/checkers.lua +++ b/kong/clustering/compat/checkers.lua @@ -31,8 +31,8 @@ local compatible_checkers = { if plugin.name == 'ai-proxy' then local config = plugin.config if config.model and config.model.options then - if config.model.options.response_streaming then - config.model.options.response_streaming = nil + if config.response_streaming then + config.response_streaming = nil log_warn_message('configures ' .. plugin.name .. ' plugin with' .. ' response_streaming == nil, because it is not supported' .. ' in this release', diff --git a/kong/llm/drivers/anthropic.lua b/kong/llm/drivers/anthropic.lua index 8eb206b8c1f6..a18774b331d0 100644 --- a/kong/llm/drivers/anthropic.lua +++ b/kong/llm/drivers/anthropic.lua @@ -209,8 +209,7 @@ local function handle_stream_event(event_t, model_info, route_type) and event_data.usage then return nil, nil, { prompt_tokens = nil, - completion_tokens = event_data.meta.usage - and event_data.meta.usage.output_tokens + completion_tokens = event_data.usage.output_tokens or nil, stop_reason = event_data.delta and event_data.delta.stop_reason @@ -336,7 +335,7 @@ function _M.from_format(response_string, model_info, route_type) return nil, fmt("no transformer available from format %s://%s", model_info.provider, route_type) end - local ok, response_string, err = pcall(transform, response_string, model_info, route_type) + local ok, response_string, err, metadata = pcall(transform, response_string, model_info, route_type) if not ok or err then return nil, fmt("transformation failed from type %s://%s: %s", model_info.provider, @@ -345,7 +344,7 @@ function _M.from_format(response_string, model_info, route_type) ) end - return response_string, nil + return response_string, nil, metadata end function _M.to_format(request_table, model_info, route_type) diff --git a/kong/llm/drivers/azure.lua b/kong/llm/drivers/azure.lua index 390a96256cb2..a0ba1741a861 100644 --- a/kong/llm/drivers/azure.lua +++ b/kong/llm/drivers/azure.lua @@ -22,9 +22,11 @@ function _M.pre_request(conf) -- for azure provider, all of these must/will be set by now if conf.logging and conf.logging.log_statistics then - kong.log.set_serialize_value("ai.meta.azure_instance_id", conf.model.options.azure_instance) - kong.log.set_serialize_value("ai.meta.azure_deployment_id", conf.model.options.azure_deployment_id) - kong.log.set_serialize_value("ai.meta.azure_api_version", conf.model.options.azure_api_version) + kong.ctx.plugin.ai_extra_meta = { + ["azure_instance_id"] = conf.model.options.azure_instance, + ["azure_deployment_id"] = conf.model.options.azure_deployment_id, + ["azure_api_version"] = conf.model.options.azure_api_version, + } end return true diff --git a/kong/llm/drivers/openai.lua b/kong/llm/drivers/openai.lua index 9f4965ece0d9..b08f29bc3255 100644 --- a/kong/llm/drivers/openai.lua +++ b/kong/llm/drivers/openai.lua @@ -20,6 +20,7 @@ local transformers_to = { ["llm/v1/chat"] = function(request_table, model_info, route_type) request_table.model = request_table.model or model_info.name request_table.stream = request_table.stream or false -- explicitly set this + request_table.top_k = nil -- explicitly remove unsupported default return request_table, "application/json", nil end, @@ -27,6 +28,7 @@ local transformers_to = { ["llm/v1/completions"] = function(request_table, model_info, route_type) request_table.model = model_info.name request_table.stream = request_table.stream or false -- explicitly set this + request_table.top_k = nil -- explicitly remove unsupported default return request_table, "application/json", nil end, diff --git a/kong/llm/drivers/shared.lua b/kong/llm/drivers/shared.lua index f4696f116c10..0b9cdcf3ab35 100644 --- a/kong/llm/drivers/shared.lua +++ b/kong/llm/drivers/shared.lua @@ -24,9 +24,9 @@ local log_entry_keys = { TOKENS_CONTAINER = "usage", META_CONTAINER = "meta", PAYLOAD_CONTAINER = "payload", - REQUEST_BODY = "ai.payload.request", -- payload keys + REQUEST_BODY = "request", RESPONSE_BODY = "response", -- meta keys @@ -264,20 +264,30 @@ function _M.to_ollama(request_table, model) end function _M.from_ollama(response_string, model_info, route_type) - local output, _, analytics - - local response_table, err = cjson.decode(response_string) - if err then - return nil, "failed to decode ollama response" - end + local output, err, _, analytics if route_type == "stream/llm/v1/chat" then + local response_table, err = cjson.decode(response_string.data) + if err then + return nil, "failed to decode ollama response" + end + output, _, analytics = handle_stream_event(response_table, model_info, route_type) elseif route_type == "stream/llm/v1/completions" then + local response_table, err = cjson.decode(response_string.data) + if err then + return nil, "failed to decode ollama response" + end + output, _, analytics = handle_stream_event(response_table, model_info, route_type) else + local response_table, err = cjson.decode(response_string) + if err then + return nil, "failed to decode ollama response" + end + -- there is no direct field indicating STOP reason, so calculate it manually local stop_length = (model_info.options and model_info.options.max_tokens) or -1 local stop_reason = "stop" @@ -405,14 +415,14 @@ function _M.pre_request(conf, request_table) request_table[auth_param_name] = auth_param_value end - if conf.logging and conf.logging.log_statistics then - kong.log.set_serialize_value(log_entry_keys.REQUEST_MODEL, conf.model.name) - kong.log.set_serialize_value(log_entry_keys.PROVIDER_NAME, conf.model.provider) - end - -- if enabled AND request type is compatible, capture the input for analytics if conf.logging and conf.logging.log_payloads then - kong.log.set_serialize_value(log_entry_keys.REQUEST_BODY, kong.request.get_raw_body()) + local plugin_name = conf.__key__:match('plugins:(.-):') + if not plugin_name or plugin_name == "" then + return nil, "no plugin name is being passed by the plugin" + end + + kong.log.set_serialize_value(fmt("ai.%s.%s.%s", plugin_name, log_entry_keys.PAYLOAD_CONTAINER, log_entry_keys.REQUEST_BODY), kong.request.get_raw_body()) end -- log tokens prompt for reports and billing @@ -468,7 +478,6 @@ function _M.post_request(conf, response_object) if not request_analytics_plugin then request_analytics_plugin = { [log_entry_keys.META_CONTAINER] = {}, - [log_entry_keys.PAYLOAD_CONTAINER] = {}, [log_entry_keys.TOKENS_CONTAINER] = { [log_entry_keys.PROMPT_TOKEN] = 0, [log_entry_keys.COMPLETION_TOKEN] = 0, @@ -478,11 +487,18 @@ function _M.post_request(conf, response_object) end -- Set the model, response, and provider names in the current try context - request_analytics_plugin[log_entry_keys.META_CONTAINER][log_entry_keys.REQUEST_MODEL] = conf.model.name + request_analytics_plugin[log_entry_keys.META_CONTAINER][log_entry_keys.REQUEST_MODEL] = kong.ctx.plugin.llm_model_requested or conf.model.name request_analytics_plugin[log_entry_keys.META_CONTAINER][log_entry_keys.RESPONSE_MODEL] = response_object.model or conf.model.name request_analytics_plugin[log_entry_keys.META_CONTAINER][log_entry_keys.PROVIDER_NAME] = provider_name request_analytics_plugin[log_entry_keys.META_CONTAINER][log_entry_keys.PLUGIN_ID] = conf.__plugin_id + -- set extra per-provider meta + if kong.ctx.plugin.ai_extra_meta and type(kong.ctx.plugin.ai_extra_meta) == "table" then + for k, v in pairs(kong.ctx.plugin.ai_extra_meta) do + request_analytics_plugin[log_entry_keys.META_CONTAINER][k] = v + end + end + -- Capture openai-format usage stats from the transformed response body if response_object.usage then if response_object.usage.prompt_tokens then @@ -498,16 +514,24 @@ function _M.post_request(conf, response_object) -- Log response body if logging payloads is enabled if conf.logging and conf.logging.log_payloads then - request_analytics_plugin[log_entry_keys.PAYLOAD_CONTAINER][log_entry_keys.RESPONSE_BODY] = body_string + kong.log.set_serialize_value(fmt("ai.%s.%s.%s", plugin_name, log_entry_keys.PAYLOAD_CONTAINER, log_entry_keys.RESPONSE_BODY), body_string) end -- Update context with changed values + request_analytics_plugin[log_entry_keys.PAYLOAD_CONTAINER] = { + [log_entry_keys.RESPONSE_BODY] = body_string, + } request_analytics[plugin_name] = request_analytics_plugin kong.ctx.shared.analytics = request_analytics if conf.logging and conf.logging.log_statistics then -- Log analytics data - kong.log.set_serialize_value(fmt("%s.%s", "ai", plugin_name), request_analytics_plugin) + kong.log.set_serialize_value(fmt("ai.%s.%s", plugin_name, log_entry_keys.TOKENS_CONTAINER), + request_analytics_plugin[log_entry_keys.TOKENS_CONTAINER]) + + -- Log meta + kong.log.set_serialize_value(fmt("ai.%s.%s", plugin_name, log_entry_keys.META_CONTAINER), + request_analytics_plugin[log_entry_keys.META_CONTAINER]) end -- log tokens response for reports and billing diff --git a/kong/llm/init.lua b/kong/llm/init.lua index 6ac1a1ff0b96..af3833ff44f1 100644 --- a/kong/llm/init.lua +++ b/kong/llm/init.lua @@ -47,12 +47,6 @@ local model_options_schema = { type = "record", required = false, fields = { - { response_streaming = { - type = "string", - description = "Whether to 'optionally allow', 'deny', or 'always' (force) the streaming of answers via server sent events.", - required = false, - default = "allow", - one_of = { "allow", "deny", "always" } }}, { max_tokens = { type = "integer", description = "Defines the max_tokens, if using chat or completion models.", diff --git a/kong/plugins/ai-proxy/handler.lua b/kong/plugins/ai-proxy/handler.lua index cb2baabc4d70..739c33f0667d 100644 --- a/kong/plugins/ai-proxy/handler.lua +++ b/kong/plugins/ai-proxy/handler.lua @@ -112,7 +112,7 @@ local function handle_streaming_frame(conf) if not event_t then event_t, err = cjson.decode(formatted) end - + if not err then if not token_t then token_t = get_token_text(event_t) @@ -126,10 +126,6 @@ local function handle_streaming_frame(conf) (kong.ctx.plugin.ai_stream_completion_tokens or 0) + math.ceil(#strip(token_t) / 4) end end - - elseif metadata then - kong.ctx.plugin.ai_stream_completion_tokens = metadata.completion_tokens or kong.ctx.plugin.ai_stream_completion_tokens - kong.ctx.plugin.ai_stream_prompt_tokens = metadata.prompt_tokens or kong.ctx.plugin.ai_stream_prompt_tokens end end @@ -137,6 +133,17 @@ local function handle_streaming_frame(conf) framebuffer:put(formatted or "") framebuffer:put((formatted ~= "[DONE]") and "\n\n" or "") end + + if conf.logging and conf.logging.log_statistics and metadata then + kong.ctx.plugin.ai_stream_completion_tokens = + (kong.ctx.plugin.ai_stream_completion_tokens or 0) + + (metadata.completion_tokens or 0) + or kong.ctx.plugin.ai_stream_completion_tokens + kong.ctx.plugin.ai_stream_prompt_tokens = + (kong.ctx.plugin.ai_stream_prompt_tokens or 0) + + (metadata.prompt_tokens or 0) + or kong.ctx.plugin.ai_stream_prompt_tokens + end end end @@ -367,10 +374,12 @@ function _M:access(conf) -- check if the user has asked for a stream, and/or if -- we are forcing all requests to be of streaming type if request_table and request_table.stream or - (conf_m.model.options and conf_m.model.options.response_streaming) == "always" then + (conf_m.response_streaming and conf_m.response_streaming == "always") then + request_table.stream = true + -- this condition will only check if user has tried -- to activate streaming mode within their request - if conf_m.model.options and conf_m.model.options.response_streaming == "deny" then + if conf_m.response_streaming and conf_m.response_streaming == "deny" then return bad_request("response streaming is not enabled for this LLM") end diff --git a/kong/plugins/ai-proxy/schema.lua b/kong/plugins/ai-proxy/schema.lua index 9259582c9ac2..52bafe129c3e 100644 --- a/kong/plugins/ai-proxy/schema.lua +++ b/kong/plugins/ai-proxy/schema.lua @@ -1,5 +1,23 @@ local typedefs = require("kong.db.schema.typedefs") local llm = require("kong.llm") +local deep_copy = require("kong.tools.utils").deep_copy + +local this_schema = deep_copy(llm.config_schema) + +local ai_proxy_only_config = { + { + response_streaming = { + type = "string", + description = "Whether to 'optionally allow', 'deny', or 'always' (force) the streaming of answers via server sent events.", + required = false, + default = "allow", + one_of = { "allow", "deny", "always" }}, + }, +} + +for i, v in pairs(ai_proxy_only_config) do + this_schema.fields[#this_schema.fields+1] = v +end return { name = "ai-proxy", @@ -7,6 +25,6 @@ return { { protocols = typedefs.protocols_http }, { consumer = typedefs.no_consumer }, { service = typedefs.no_service }, - { config = llm.config_schema }, + { config = this_schema }, }, } diff --git a/spec/02-integration/09-hybrid_mode/09-config-compat_spec.lua b/spec/02-integration/09-hybrid_mode/09-config-compat_spec.lua index 96f41bfd03d1..4cfa96efea6c 100644 --- a/spec/02-integration/09-hybrid_mode/09-config-compat_spec.lua +++ b/spec/02-integration/09-hybrid_mode/09-config-compat_spec.lua @@ -480,6 +480,7 @@ describe("CP/DP config compat transformations #" .. strategy, function() name = "ai-proxy", enabled = true, config = { + response_streaming = "allow", -- becomes nil route_type = "preserve", -- becomes 'llm/v1/chat' auth = { header_name = "header", @@ -491,7 +492,6 @@ describe("CP/DP config compat transformations #" .. strategy, function() options = { max_tokens = 512, temperature = 0.5, - response_streaming = "allow", -- becomes nil upstream_path = "/anywhere", -- becomes nil }, }, @@ -500,7 +500,7 @@ describe("CP/DP config compat transformations #" .. strategy, function() -- ]] local expected_ai_proxy_prior_37 = utils.cycle_aware_deep_copy(ai_proxy) - expected_ai_proxy_prior_37.config.model.options.response_streaming = nil + expected_ai_proxy_prior_37.config.response_streaming = nil expected_ai_proxy_prior_37.config.model.options.upstream_path = nil expected_ai_proxy_prior_37.config.route_type = "llm/v1/chat" diff --git a/spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua b/spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua index b7a55183dca3..e9fb74c3114a 100644 --- a/spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua +++ b/spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua @@ -42,7 +42,6 @@ local _EXPECTED_CHAT_STATS = { request_model = 'gpt-3.5-turbo', response_model = 'gpt-3.5-turbo-0613', }, - payload = {}, usage = { completion_token = 12, prompt_token = 25, @@ -775,8 +774,8 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then assert.is_number(log_message.response.size) -- test request bodies - assert.matches('"content": "What is 1 + 1?"', log_message.ai.payload.request, nil, true) - assert.matches('"role": "user"', log_message.ai.payload.request, nil, true) + assert.matches('"content": "What is 1 + 1?"', log_message.ai['ai-proxy'].payload.request, nil, true) + assert.matches('"role": "user"', log_message.ai['ai-proxy'].payload.request, nil, true) -- test response bodies assert.matches('"content": "The sum of 1 + 1 is 2.",', log_message.ai["ai-proxy"].payload.response, nil, true) diff --git a/spec/03-plugins/39-ai-request-transformer/02-integration_spec.lua b/spec/03-plugins/39-ai-request-transformer/02-integration_spec.lua index 2711f4aa393f..00b0391d7499 100644 --- a/spec/03-plugins/39-ai-request-transformer/02-integration_spec.lua +++ b/spec/03-plugins/39-ai-request-transformer/02-integration_spec.lua @@ -122,7 +122,6 @@ local _EXPECTED_CHAT_STATS = { request_model = 'gpt-4', response_model = 'gpt-3.5-turbo-0613', }, - payload = {}, usage = { completion_token = 12, prompt_token = 25, diff --git a/spec/03-plugins/40-ai-response-transformer/02-integration_spec.lua b/spec/03-plugins/40-ai-response-transformer/02-integration_spec.lua index 13e4b558a3ef..800100c9a67c 100644 --- a/spec/03-plugins/40-ai-response-transformer/02-integration_spec.lua +++ b/spec/03-plugins/40-ai-response-transformer/02-integration_spec.lua @@ -179,7 +179,6 @@ local _EXPECTED_CHAT_STATS = { request_model = 'gpt-4', response_model = 'gpt-3.5-turbo-0613', }, - payload = {}, usage = { completion_token = 12, prompt_token = 25,