diff --git a/kong/clustering/compat/checkers.lua b/kong/clustering/compat/checkers.lua index a14f41f09a22..bc77bdab0ea7 100644 --- a/kong/clustering/compat/checkers.lua +++ b/kong/clustering/compat/checkers.lua @@ -127,8 +127,8 @@ local compatible_checkers = { if plugin.name == 'ai-proxy' then local config = plugin.config if config.model and config.model.options then - if config.model.options.response_streaming then - config.model.options.response_streaming = nil + if config.response_streaming then + config.response_streaming = nil log_warn_message('configures ' .. plugin.name .. ' plugin with' .. ' response_streaming == nil, because it is not supported' .. ' in this release', diff --git a/kong/llm/drivers/anthropic.lua b/kong/llm/drivers/anthropic.lua index 8f633e185662..0b2c536160e0 100644 --- a/kong/llm/drivers/anthropic.lua +++ b/kong/llm/drivers/anthropic.lua @@ -216,8 +216,7 @@ local function handle_stream_event(event_t, model_info, route_type) and event_data.usage then return nil, nil, { prompt_tokens = nil, - completion_tokens = event_data.meta.usage - and event_data.meta.usage.output_tokens + completion_tokens = event_data.usage.output_tokens or nil, stop_reason = event_data.delta and event_data.delta.stop_reason @@ -343,7 +342,7 @@ function _M.from_format(response_string, model_info, route_type) return nil, fmt("no transformer available from format %s://%s", model_info.provider, route_type) end - local ok, response_string, err = pcall(transform, response_string, model_info, route_type) + local ok, response_string, err, metadata = pcall(transform, response_string, model_info, route_type) if not ok or err then return nil, fmt("transformation failed from type %s://%s: %s", model_info.provider, @@ -352,7 +351,7 @@ function _M.from_format(response_string, model_info, route_type) ) end - return response_string, nil + return response_string, nil, metadata end function _M.to_format(request_table, model_info, route_type) diff --git a/kong/llm/drivers/azure.lua b/kong/llm/drivers/azure.lua index e49af619196a..9a06f9266486 100644 --- a/kong/llm/drivers/azure.lua +++ b/kong/llm/drivers/azure.lua @@ -29,9 +29,11 @@ function _M.pre_request(conf) -- for azure provider, all of these must/will be set by now if conf.logging and conf.logging.log_statistics then - kong.log.set_serialize_value("ai.meta.azure_instance_id", conf.model.options.azure_instance) - kong.log.set_serialize_value("ai.meta.azure_deployment_id", conf.model.options.azure_deployment_id) - kong.log.set_serialize_value("ai.meta.azure_api_version", conf.model.options.azure_api_version) + kong.ctx.plugin.ai_extra_meta = { + ["azure_instance_id"] = conf.model.options.azure_instance, + ["azure_deployment_id"] = conf.model.options.azure_deployment_id, + ["azure_api_version"] = conf.model.options.azure_api_version, + } end return true diff --git a/kong/llm/drivers/openai.lua b/kong/llm/drivers/openai.lua index 6288c111ce18..78713116a969 100644 --- a/kong/llm/drivers/openai.lua +++ b/kong/llm/drivers/openai.lua @@ -27,6 +27,7 @@ local transformers_to = { ["llm/v1/chat"] = function(request_table, model_info, route_type) request_table.model = request_table.model or model_info.name request_table.stream = request_table.stream or false -- explicitly set this + request_table.top_k = nil -- explicitly remove unsupported default return request_table, "application/json", nil end, @@ -34,6 +35,7 @@ local transformers_to = { ["llm/v1/completions"] = function(request_table, model_info, route_type) request_table.model = model_info.name request_table.stream = request_table.stream or false -- explicitly set this + request_table.top_k = nil -- explicitly remove unsupported default return request_table, "application/json", nil end, diff --git a/kong/llm/drivers/shared.lua b/kong/llm/drivers/shared.lua index 5db89aafb2b3..78cc4e233032 100644 --- a/kong/llm/drivers/shared.lua +++ b/kong/llm/drivers/shared.lua @@ -31,9 +31,9 @@ local log_entry_keys = { TOKENS_CONTAINER = "usage", META_CONTAINER = "meta", PAYLOAD_CONTAINER = "payload", - REQUEST_BODY = "ai.payload.request", -- payload keys + REQUEST_BODY = "request", RESPONSE_BODY = "response", -- meta keys @@ -271,20 +271,30 @@ function _M.to_ollama(request_table, model) end function _M.from_ollama(response_string, model_info, route_type) - local output, _, analytics - - local response_table, err = cjson.decode(response_string) - if err then - return nil, "failed to decode ollama response" - end + local output, err, _, analytics if route_type == "stream/llm/v1/chat" then + local response_table, err = cjson.decode(response_string.data) + if err then + return nil, "failed to decode ollama response" + end + output, _, analytics = handle_stream_event(response_table, model_info, route_type) elseif route_type == "stream/llm/v1/completions" then + local response_table, err = cjson.decode(response_string.data) + if err then + return nil, "failed to decode ollama response" + end + output, _, analytics = handle_stream_event(response_table, model_info, route_type) else + local response_table, err = cjson.decode(response_string) + if err then + return nil, "failed to decode ollama response" + end + -- there is no direct field indicating STOP reason, so calculate it manually local stop_length = (model_info.options and model_info.options.max_tokens) or -1 local stop_reason = "stop" @@ -412,14 +422,14 @@ function _M.pre_request(conf, request_table) request_table[auth_param_name] = auth_param_value end - if conf.logging and conf.logging.log_statistics then - kong.log.set_serialize_value(log_entry_keys.REQUEST_MODEL, conf.model.name) - kong.log.set_serialize_value(log_entry_keys.PROVIDER_NAME, conf.model.provider) - end - -- if enabled AND request type is compatible, capture the input for analytics if conf.logging and conf.logging.log_payloads then - kong.log.set_serialize_value(log_entry_keys.REQUEST_BODY, kong.request.get_raw_body()) + local plugin_name = conf.__key__:match('plugins:(.-):') + if not plugin_name or plugin_name == "" then + return nil, "no plugin name is being passed by the plugin" + end + + kong.log.set_serialize_value(fmt("ai.%s.%s.%s", plugin_name, log_entry_keys.PAYLOAD_CONTAINER, log_entry_keys.REQUEST_BODY), kong.request.get_raw_body()) end -- log tokens prompt for reports and billing @@ -475,7 +485,6 @@ function _M.post_request(conf, response_object) if not request_analytics_plugin then request_analytics_plugin = { [log_entry_keys.META_CONTAINER] = {}, - [log_entry_keys.PAYLOAD_CONTAINER] = {}, [log_entry_keys.TOKENS_CONTAINER] = { [log_entry_keys.PROMPT_TOKEN] = 0, [log_entry_keys.COMPLETION_TOKEN] = 0, @@ -485,11 +494,18 @@ function _M.post_request(conf, response_object) end -- Set the model, response, and provider names in the current try context - request_analytics_plugin[log_entry_keys.META_CONTAINER][log_entry_keys.REQUEST_MODEL] = conf.model.name + request_analytics_plugin[log_entry_keys.META_CONTAINER][log_entry_keys.REQUEST_MODEL] = kong.ctx.plugin.llm_model_requested or conf.model.name request_analytics_plugin[log_entry_keys.META_CONTAINER][log_entry_keys.RESPONSE_MODEL] = response_object.model or conf.model.name request_analytics_plugin[log_entry_keys.META_CONTAINER][log_entry_keys.PROVIDER_NAME] = provider_name request_analytics_plugin[log_entry_keys.META_CONTAINER][log_entry_keys.PLUGIN_ID] = conf.__plugin_id + -- set extra per-provider meta + if kong.ctx.plugin.ai_extra_meta and type(kong.ctx.plugin.ai_extra_meta) == "table" then + for k, v in pairs(kong.ctx.plugin.ai_extra_meta) do + request_analytics_plugin[log_entry_keys.META_CONTAINER][k] = v + end + end + -- Capture openai-format usage stats from the transformed response body if response_object.usage then if response_object.usage.prompt_tokens then @@ -505,16 +521,24 @@ function _M.post_request(conf, response_object) -- Log response body if logging payloads is enabled if conf.logging and conf.logging.log_payloads then - request_analytics_plugin[log_entry_keys.PAYLOAD_CONTAINER][log_entry_keys.RESPONSE_BODY] = body_string + kong.log.set_serialize_value(fmt("ai.%s.%s.%s", plugin_name, log_entry_keys.PAYLOAD_CONTAINER, log_entry_keys.RESPONSE_BODY), body_string) end -- Update context with changed values + request_analytics_plugin[log_entry_keys.PAYLOAD_CONTAINER] = { + [log_entry_keys.RESPONSE_BODY] = body_string, + } request_analytics[plugin_name] = request_analytics_plugin kong.ctx.shared.analytics = request_analytics if conf.logging and conf.logging.log_statistics then -- Log analytics data - kong.log.set_serialize_value(fmt("%s.%s", "ai", plugin_name), request_analytics_plugin) + kong.log.set_serialize_value(fmt("ai.%s.%s", plugin_name, log_entry_keys.TOKENS_CONTAINER), + request_analytics_plugin[log_entry_keys.TOKENS_CONTAINER]) + + -- Log meta + kong.log.set_serialize_value(fmt("ai.%s.%s", plugin_name, log_entry_keys.META_CONTAINER), + request_analytics_plugin[log_entry_keys.META_CONTAINER]) end -- log tokens response for reports and billing diff --git a/kong/llm/init.lua b/kong/llm/init.lua index 88a70e6fc5ad..e807b6bb5dce 100644 --- a/kong/llm/init.lua +++ b/kong/llm/init.lua @@ -93,12 +93,6 @@ local model_options_schema = { type = "record", required = false, fields = { - { response_streaming = { - type = "string", - description = "Whether to 'optionally allow', 'deny', or 'always' (force) the streaming of answers via server sent events.", - required = false, - default = "allow", - one_of = { "allow", "deny", "always" } }}, { max_tokens = { type = "integer", description = "Defines the max_tokens, if using chat or completion models.", diff --git a/kong/plugins/ai-proxy/handler.lua b/kong/plugins/ai-proxy/handler.lua index 140f8edf12d7..2d1de53de020 100644 --- a/kong/plugins/ai-proxy/handler.lua +++ b/kong/plugins/ai-proxy/handler.lua @@ -152,7 +152,7 @@ local function handle_streaming_frame(conf) if not event_t then event_t, err = cjson.decode(formatted) end - + if not err then if not token_t then token_t = get_token_text(event_t) @@ -166,10 +166,6 @@ local function handle_streaming_frame(conf) (kong.ctx.plugin.ai_stream_completion_tokens or 0) + math.ceil(#strip(token_t) / 4) end end - - elseif metadata then - kong.ctx.plugin.ai_stream_completion_tokens = metadata.completion_tokens or kong.ctx.plugin.ai_stream_completion_tokens - kong.ctx.plugin.ai_stream_prompt_tokens = metadata.prompt_tokens or kong.ctx.plugin.ai_stream_prompt_tokens end end @@ -177,6 +173,17 @@ local function handle_streaming_frame(conf) framebuffer:put(formatted or "") framebuffer:put((formatted ~= "[DONE]") and "\n\n" or "") end + + if conf.logging and conf.logging.log_statistics and metadata then + kong.ctx.plugin.ai_stream_completion_tokens = + (kong.ctx.plugin.ai_stream_completion_tokens or 0) + + (metadata.completion_tokens or 0) + or kong.ctx.plugin.ai_stream_completion_tokens + kong.ctx.plugin.ai_stream_prompt_tokens = + (kong.ctx.plugin.ai_stream_prompt_tokens or 0) + + (metadata.prompt_tokens or 0) + or kong.ctx.plugin.ai_stream_prompt_tokens + end end end @@ -427,10 +434,12 @@ function _M:access(conf) -- check if the user has asked for a stream, and/or if -- we are forcing all requests to be of streaming type if request_table and request_table.stream or - (conf_m.model.options and conf_m.model.options.response_streaming) == "always" then + (conf_m.response_streaming and conf_m.response_streaming == "always") then + request_table.stream = true + -- this condition will only check if user has tried -- to activate streaming mode within their request - if conf_m.model.options and conf_m.model.options.response_streaming == "deny" then + if conf_m.response_streaming and conf_m.response_streaming == "deny" then return bad_request("response streaming is not enabled for this LLM") end diff --git a/kong/plugins/ai-proxy/schema.lua b/kong/plugins/ai-proxy/schema.lua index 0f8353360388..92e40d8008b8 100644 --- a/kong/plugins/ai-proxy/schema.lua +++ b/kong/plugins/ai-proxy/schema.lua @@ -7,6 +7,24 @@ local typedefs = require("kong.db.schema.typedefs") local llm = require("kong.llm") +local deep_copy = require("kong.tools.utils").deep_copy + +local this_schema = deep_copy(llm.config_schema) + +local ai_proxy_only_config = { + { + response_streaming = { + type = "string", + description = "Whether to 'optionally allow', 'deny', or 'always' (force) the streaming of answers via server sent events.", + required = false, + default = "allow", + one_of = { "allow", "deny", "always" }}, + }, +} + +for i, v in pairs(ai_proxy_only_config) do + this_schema.fields[#this_schema.fields+1] = v +end return { name = "ai-proxy", @@ -14,6 +32,6 @@ return { { protocols = typedefs.protocols_http }, { consumer = typedefs.no_consumer }, { service = typedefs.no_service }, - { config = llm.config_schema }, + { config = this_schema }, }, } diff --git a/spec/02-integration/09-hybrid_mode/09-config-compat_spec.lua b/spec/02-integration/09-hybrid_mode/09-config-compat_spec.lua index 1d877eeeef32..3569b3f66bb0 100644 --- a/spec/02-integration/09-hybrid_mode/09-config-compat_spec.lua +++ b/spec/02-integration/09-hybrid_mode/09-config-compat_spec.lua @@ -550,6 +550,7 @@ describe("CP/DP config compat transformations #" .. strategy, function() name = "ai-proxy", enabled = true, config = { + response_streaming = "allow", -- becomes nil route_type = "preserve", -- becomes 'llm/v1/chat' auth = { header_name = "header", @@ -561,7 +562,6 @@ describe("CP/DP config compat transformations #" .. strategy, function() options = { max_tokens = 512, temperature = 0.5, - response_streaming = "allow", -- becomes nil upstream_path = "/anywhere", -- becomes nil }, }, @@ -570,7 +570,7 @@ describe("CP/DP config compat transformations #" .. strategy, function() -- ]] local expected_ai_proxy_prior_37 = utils.cycle_aware_deep_copy(ai_proxy) - expected_ai_proxy_prior_37.config.model.options.response_streaming = nil + expected_ai_proxy_prior_37.config.response_streaming = nil expected_ai_proxy_prior_37.config.model.options.upstream_path = nil expected_ai_proxy_prior_37.config.route_type = "llm/v1/chat" diff --git a/spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua b/spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua index 0a6148ca87fd..a59b5997edfc 100644 --- a/spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua +++ b/spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua @@ -49,7 +49,6 @@ local _EXPECTED_CHAT_STATS = { request_model = 'gpt-3.5-turbo', response_model = 'gpt-3.5-turbo-0613', }, - payload = {}, usage = { completion_token = 12, prompt_token = 25, @@ -782,8 +781,8 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then assert.is_number(log_message.response.size) -- test request bodies - assert.matches('"content": "What is 1 + 1?"', log_message.ai.payload.request, nil, true) - assert.matches('"role": "user"', log_message.ai.payload.request, nil, true) + assert.matches('"content": "What is 1 + 1?"', log_message.ai['ai-proxy'].payload.request, nil, true) + assert.matches('"role": "user"', log_message.ai['ai-proxy'].payload.request, nil, true) -- test response bodies assert.matches('"content": "The sum of 1 + 1 is 2.",', log_message.ai["ai-proxy"].payload.response, nil, true) diff --git a/spec/03-plugins/39-ai-request-transformer/02-integration_spec.lua b/spec/03-plugins/39-ai-request-transformer/02-integration_spec.lua index a5a471339a3a..fb16bf005474 100644 --- a/spec/03-plugins/39-ai-request-transformer/02-integration_spec.lua +++ b/spec/03-plugins/39-ai-request-transformer/02-integration_spec.lua @@ -129,7 +129,6 @@ local _EXPECTED_CHAT_STATS = { request_model = 'gpt-4', response_model = 'gpt-3.5-turbo-0613', }, - payload = {}, usage = { completion_token = 12, prompt_token = 25, diff --git a/spec/03-plugins/40-ai-response-transformer/02-integration_spec.lua b/spec/03-plugins/40-ai-response-transformer/02-integration_spec.lua index 5a897b7983ce..d35b54557fb1 100644 --- a/spec/03-plugins/40-ai-response-transformer/02-integration_spec.lua +++ b/spec/03-plugins/40-ai-response-transformer/02-integration_spec.lua @@ -186,7 +186,6 @@ local _EXPECTED_CHAT_STATS = { request_model = 'gpt-4', response_model = 'gpt-3.5-turbo-0613', }, - payload = {}, usage = { completion_token = 12, prompt_token = 25,