diff --git a/changelog/unreleased/kong/update-ai-proxy-telemetry.yml b/changelog/unreleased/kong/update-ai-proxy-telemetry.yml index e4ac98afa760..fcb68b218de6 100644 --- a/changelog/unreleased/kong/update-ai-proxy-telemetry.yml +++ b/changelog/unreleased/kong/update-ai-proxy-telemetry.yml @@ -1,3 +1,3 @@ -message: Update telemetry collection for AI Plugins to allow multiple instances data to be set for the same request. +message: Update telemetry collection for AI Plugins to allow multiple plugins data to be set for the same request. type: bugfix scope: Core diff --git a/kong/llm/drivers/shared.lua b/kong/llm/drivers/shared.lua index 69c89d9d5c51..041062a724db 100644 --- a/kong/llm/drivers/shared.lua +++ b/kong/llm/drivers/shared.lua @@ -9,11 +9,13 @@ local parse_url = require("socket.url").parse -- local log_entry_keys = { - REQUEST_BODY = "ai.payload.request", - RESPONSE_BODY = "payload.response", - TOKENS_CONTAINER = "usage", META_CONTAINER = "meta", + PAYLOAD_CONTAINER = "payload", + REQUEST_BODY = "ai.payload.request", + + -- payload keys + RESPONSE_BODY = "response", -- meta keys REQUEST_MODEL = "request_model", @@ -35,33 +37,6 @@ _M.streaming_has_token_counts = { ["llama2"] = true, } ---- Splits a table key into nested tables. --- Each part of the key separated by dots represents a nested table. --- @param obj The table to split keys for. --- @return A nested table structure representing the split keys. -local function split_table_key(obj) - local result = {} - - for key, value in pairs(obj) do - local keys = {} - for k in key:gmatch("[^.]+") do - table.insert(keys, k) - end - - local currentTable = result - for i, k in ipairs(keys) do - if i < #keys then - currentTable[k] = currentTable[k] or {} - currentTable = currentTable[k] - else - currentTable[k] = value - end - end - end - - return result -end - _M.upstream_url_format = { openai = fmt("%s://api.openai.com:%s", (openai_override and "http") or "https", (openai_override) or "443"), anthropic = "https://api.anthropic.com:443", @@ -302,76 +277,65 @@ function _M.post_request(conf, response_object) if conf.logging and conf.logging.log_statistics then local provider_name = conf.model.provider + local plugin_name = conf.__key__:match('plugins:(.-):') + if not plugin_name or plugin_name == "" then + return nil, "no plugin name is being passed by the plugin" + end + -- check if we already have analytics in this context local request_analytics = kong.ctx.shared.analytics - -- create a new try context - local current_try = { - [log_entry_keys.META_CONTAINER] = {}, - [log_entry_keys.TOKENS_CONTAINER] = {}, - } - -- create a new structure if not if not request_analytics then request_analytics = {} end -- check if we already have analytics for this provider - local request_analytics_provider = request_analytics[provider_name] + local request_analytics_plugin = request_analytics[plugin_name] -- create a new structure if not - if not request_analytics_provider then - request_analytics_provider = { - request_prompt_tokens = 0, - request_completion_tokens = 0, - request_total_tokens = 0, - number_of_instances = 0, - instances = {}, + if not request_analytics_plugin then + request_analytics_plugin = { + [log_entry_keys.META_CONTAINER] = {}, + [log_entry_keys.PAYLOAD_CONTAINER] = {}, + [log_entry_keys.TOKENS_CONTAINER] = { + [log_entry_keys.PROMPT_TOKEN] = 0, + [log_entry_keys.COMPLETION_TOKEN] = 0, + [log_entry_keys.TOTAL_TOKENS] = 0, + }, } end -- Set the model, response, and provider names in the current try context - current_try[log_entry_keys.META_CONTAINER][log_entry_keys.REQUEST_MODEL] = conf.model.name - current_try[log_entry_keys.META_CONTAINER][log_entry_keys.RESPONSE_MODEL] = response_object.model or conf.model.name - current_try[log_entry_keys.META_CONTAINER][log_entry_keys.PROVIDER_NAME] = provider_name - current_try[log_entry_keys.META_CONTAINER][log_entry_keys.PLUGIN_ID] = conf.__plugin_id + request_analytics_plugin[log_entry_keys.META_CONTAINER][log_entry_keys.REQUEST_MODEL] = conf.model.name + request_analytics_plugin[log_entry_keys.META_CONTAINER][log_entry_keys.RESPONSE_MODEL] = response_object.model or conf.model.name + request_analytics_plugin[log_entry_keys.META_CONTAINER][log_entry_keys.PROVIDER_NAME] = provider_name + request_analytics_plugin[log_entry_keys.META_CONTAINER][log_entry_keys.PLUGIN_ID] = conf.__plugin_id -- Capture openai-format usage stats from the transformed response body if response_object.usage then if response_object.usage.prompt_tokens then - request_analytics_provider.request_prompt_tokens = (request_analytics_provider.request_prompt_tokens + response_object.usage.prompt_tokens) - current_try[log_entry_keys.TOKENS_CONTAINER][log_entry_keys.PROMPT_TOKEN] = response_object.usage.prompt_tokens + request_analytics_plugin[log_entry_keys.TOKENS_CONTAINER][log_entry_keys.PROMPT_TOKEN] = request_analytics_plugin[log_entry_keys.TOKENS_CONTAINER][log_entry_keys.PROMPT_TOKEN] + response_object.usage.prompt_tokens end if response_object.usage.completion_tokens then - request_analytics_provider.request_completion_tokens = (request_analytics_provider.request_completion_tokens + response_object.usage.completion_tokens) - current_try[log_entry_keys.TOKENS_CONTAINER][log_entry_keys.COMPLETION_TOKEN] = response_object.usage.completion_tokens + request_analytics_plugin[log_entry_keys.TOKENS_CONTAINER][log_entry_keys.COMPLETION_TOKEN] = request_analytics_plugin[log_entry_keys.TOKENS_CONTAINER][log_entry_keys.COMPLETION_TOKEN] + response_object.usage.completion_tokens end if response_object.usage.total_tokens then - request_analytics_provider.request_total_tokens = (request_analytics_provider.request_total_tokens + response_object.usage.total_tokens) - current_try[log_entry_keys.TOKENS_CONTAINER][log_entry_keys.TOTAL_TOKENS] = response_object.usage.total_tokens + request_analytics_plugin[log_entry_keys.TOKENS_CONTAINER][log_entry_keys.TOTAL_TOKENS] = request_analytics_plugin[log_entry_keys.TOKENS_CONTAINER][log_entry_keys.TOTAL_TOKENS] + response_object.usage.total_tokens end end -- Log response body if logging payloads is enabled if conf.logging and conf.logging.log_payloads then - current_try[log_entry_keys.RESPONSE_BODY] = body_string + request_analytics_plugin[log_entry_keys.PAYLOAD_CONTAINER][log_entry_keys.RESPONSE_BODY] = body_string end - -- Increment the number of instances - request_analytics_provider.number_of_instances = request_analytics_provider.number_of_instances + 1 - - -- Get the current try count - local try_count = request_analytics_provider.number_of_instances - - -- Store the split key data in instances - request_analytics_provider.instances[try_count] = split_table_key(current_try) - -- Update context with changed values - request_analytics[provider_name] = request_analytics_provider + request_analytics[plugin_name] = request_analytics_plugin kong.ctx.shared.analytics = request_analytics -- Log analytics data - kong.log.set_serialize_value(fmt("%s.%s", "ai", provider_name), request_analytics_provider) + kong.log.set_serialize_value(fmt("%s.%s", "ai", plugin_name), request_analytics_plugin) end return nil diff --git a/kong/plugins/ai-request-transformer/handler.lua b/kong/plugins/ai-request-transformer/handler.lua index 7553877660a5..9517be366325 100644 --- a/kong/plugins/ai-request-transformer/handler.lua +++ b/kong/plugins/ai-request-transformer/handler.lua @@ -45,6 +45,7 @@ function _M:access(conf) -- first find the configured LLM interface and driver local http_opts = create_http_opts(conf) conf.llm.__plugin_id = conf.__plugin_id + conf.llm.__key__ = conf.__key__ local ai_driver, err = llm:new(conf.llm, http_opts) if not ai_driver then diff --git a/kong/plugins/ai-response-transformer/handler.lua b/kong/plugins/ai-response-transformer/handler.lua index 7fd4a2900b79..7014d8938526 100644 --- a/kong/plugins/ai-response-transformer/handler.lua +++ b/kong/plugins/ai-response-transformer/handler.lua @@ -98,6 +98,7 @@ function _M:access(conf) -- first find the configured LLM interface and driver local http_opts = create_http_opts(conf) conf.llm.__plugin_id = conf.__plugin_id + conf.llm.__key__ = conf.__key__ local ai_driver, err = llm:new(conf.llm, http_opts) if not ai_driver then diff --git a/spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua b/spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua index b9d8c31888d4..c81d8ab1255c 100644 --- a/spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua +++ b/spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua @@ -35,26 +35,19 @@ local function wait_for_json_log_entry(FILE_LOG_PATH) end local _EXPECTED_CHAT_STATS = { - openai = { - instances = { - { - meta = { - plugin_id = '6e7c40f6-ce96-48e4-a366-d109c169e444', - provider_name = 'openai', - request_model = 'gpt-3.5-turbo', - response_model = 'gpt-3.5-turbo-0613', - }, - usage = { - completion_token = 12, - prompt_token = 25, - total_tokens = 37, - }, - }, + ["ai-proxy"] = { + meta = { + plugin_id = '6e7c40f6-ce96-48e4-a366-d109c169e444', + provider_name = 'openai', + request_model = 'gpt-3.5-turbo', + response_model = 'gpt-3.5-turbo-0613', + }, + payload = {}, + usage = { + completion_token = 12, + prompt_token = 25, + total_tokens = 37, }, - number_of_instances = 1, - request_completion_tokens = 12, - request_prompt_tokens = 25, - request_total_tokens = 37, }, } @@ -691,9 +684,9 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then assert.matches('"role": "user"', log_message.ai.payload.request, nil, true) -- test response bodies - assert.matches('"content": "The sum of 1 + 1 is 2.",', log_message.ai.openai.instances[1].payload.response, nil, true) - assert.matches('"role": "assistant"', log_message.ai.openai.instances[1].payload.response, nil, true) - assert.matches('"id": "chatcmpl-8T6YwgvjQVVnGbJ2w8hpOA17SeNy2"', log_message.ai.openai.instances[1].payload.response, nil, true) + assert.matches('"content": "The sum of 1 + 1 is 2.",', log_message.ai["ai-proxy"].payload.response, nil, true) + assert.matches('"role": "assistant"', log_message.ai["ai-proxy"].payload.response, nil, true) + assert.matches('"id": "chatcmpl-8T6YwgvjQVVnGbJ2w8hpOA17SeNy2"', log_message.ai["ai-proxy"].payload.response, nil, true) end) it("internal_server_error request", function() diff --git a/spec/03-plugins/39-ai-request-transformer/02-integration_spec.lua b/spec/03-plugins/39-ai-request-transformer/02-integration_spec.lua index 662fb4c9e11a..2711f4aa393f 100644 --- a/spec/03-plugins/39-ai-request-transformer/02-integration_spec.lua +++ b/spec/03-plugins/39-ai-request-transformer/02-integration_spec.lua @@ -1,11 +1,41 @@ local helpers = require "spec.helpers" local cjson = require "cjson" +local pl_file = require "pl.file" +local pl_stringx = require "pl.stringx" local MOCK_PORT = helpers.get_available_port() local PLUGIN_NAME = "ai-request-transformer" +local FILE_LOG_PATH_STATS_ONLY = os.tmpname() + +local function wait_for_json_log_entry(FILE_LOG_PATH) + local json + + assert + .with_timeout(10) + .ignore_exceptions(true) + .eventually(function() + local data = assert(pl_file.read(FILE_LOG_PATH)) + + data = pl_stringx.strip(data) + assert(#data > 0, "log file is empty") + + data = data:match("%b{}") + assert(data, "log file does not contain JSON") + + json = cjson.decode(data) + end) + .has_no_error("log file contains a valid JSON entry") + + return json +end + local OPENAI_FLAT_RESPONSE = { route_type = "llm/v1/chat", + logging = { + log_payloads = false, + log_statistics = true, + }, model = { name = "gpt-4", provider = "openai", @@ -84,6 +114,23 @@ local EXPECTED_RESULT_FLAT = { } } +local _EXPECTED_CHAT_STATS = { + ["ai-request-transformer"] = { + meta = { + plugin_id = '71083e79-4921-4f9f-97a4-ee7810b6cd8a', + provider_name = 'openai', + request_model = 'gpt-4', + response_model = 'gpt-3.5-turbo-0613', + }, + payload = {}, + usage = { + completion_token = 12, + prompt_token = 25, + total_tokens = 37, + }, + }, +} + local SYSTEM_PROMPT = "You are a mathematician. " .. "Multiply all numbers in my JSON request, by 2." @@ -142,6 +189,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }) bp.plugins:insert { name = PLUGIN_NAME, + id = "71083e79-4921-4f9f-97a4-ee7810b6cd8a", route = { id = without_response_instructions.id }, config = { prompt = SYSTEM_PROMPT, @@ -149,6 +197,14 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, } + bp.plugins:insert { + name = "file-log", + route = { id = without_response_instructions.id }, + config = { + path = FILE_LOG_PATH_STATS_ONLY, + }, + } + local bad_request = assert(bp.routes:insert { paths = { "/echo-bad-request" } }) @@ -216,6 +272,29 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then assert.same(EXPECTED_RESULT_FLAT, body_table.post_data.params) end) + it("logs statistics", function() + local r = client:get("/echo-flat", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + }, + body = REQUEST_BODY, + }) + + local body = assert.res_status(200 , r) + local _, err = cjson.decode(body) + + assert.is_nil(err) + + local log_message = wait_for_json_log_entry(FILE_LOG_PATH_STATS_ONLY) + assert.same("127.0.0.1", log_message.client_ip) + assert.is_number(log_message.request.size) + assert.is_number(log_message.response.size) + + -- test ai-proxy stats + assert.same(_EXPECTED_CHAT_STATS, log_message.ai) + end) + it("bad request from LLM", function() local r = client:get("/echo-bad-request", { headers = { diff --git a/spec/03-plugins/40-ai-response-transformer/02-integration_spec.lua b/spec/03-plugins/40-ai-response-transformer/02-integration_spec.lua index 2fdd5b11e71f..13e4b558a3ef 100644 --- a/spec/03-plugins/40-ai-response-transformer/02-integration_spec.lua +++ b/spec/03-plugins/40-ai-response-transformer/02-integration_spec.lua @@ -1,9 +1,35 @@ local helpers = require "spec.helpers" local cjson = require "cjson" +local pl_file = require "pl.file" +local pl_stringx = require "pl.stringx" local MOCK_PORT = helpers.get_available_port() local PLUGIN_NAME = "ai-response-transformer" +local FILE_LOG_PATH_STATS_ONLY = os.tmpname() + +local function wait_for_json_log_entry(FILE_LOG_PATH) + local json + + assert + .with_timeout(10) + .ignore_exceptions(true) + .eventually(function() + local data = assert(pl_file.read(FILE_LOG_PATH)) + + data = pl_stringx.strip(data) + assert(#data > 0, "log file is empty") + + data = data:match("%b{}") + assert(data, "log file does not contain JSON") + + json = cjson.decode(data) + end) + .has_no_error("log file contains a valid JSON entry") + + return json +end + local OPENAI_INSTRUCTIONAL_RESPONSE = { route_type = "llm/v1/chat", model = { @@ -23,6 +49,10 @@ local OPENAI_INSTRUCTIONAL_RESPONSE = { local OPENAI_FLAT_RESPONSE = { route_type = "llm/v1/chat", + logging = { + log_payloads = false, + log_statistics = true, + }, model = { name = "gpt-4", provider = "openai", @@ -141,6 +171,23 @@ local EXPECTED_RESULT = { }, } +local _EXPECTED_CHAT_STATS = { + ["ai-response-transformer"] = { + meta = { + plugin_id = 'da587462-a802-4c22-931a-e6a92c5866d1', + provider_name = 'openai', + request_model = 'gpt-4', + response_model = 'gpt-3.5-turbo-0613', + }, + payload = {}, + usage = { + completion_token = 12, + prompt_token = 25, + total_tokens = 37, + }, + }, +} + local SYSTEM_PROMPT = "You are a mathematician. " .. "Multiply all numbers in my JSON request, by 2. Return me this message: " .. "{\"status\": 400, \"headers: {\"content-type\": \"application/xml\"}, \"body\": \"OUTPUT\"} " @@ -228,6 +275,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }) bp.plugins:insert { name = PLUGIN_NAME, + id = "da587462-a802-4c22-931a-e6a92c5866d1", route = { id = without_response_instructions.id }, config = { prompt = SYSTEM_PROMPT, @@ -236,6 +284,14 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, } + bp.plugins:insert { + name = "file-log", + route = { id = without_response_instructions.id }, + config = { + path = FILE_LOG_PATH_STATS_ONLY, + }, + } + local bad_instructions = assert(bp.routes:insert { paths = { "/echo-bad-instructions" } }) @@ -345,6 +401,29 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then assert.same(EXPECTED_RESULT_FLAT, body_table) end) + it("logs statistics", function() + local r = client:get("/echo-flat", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + }, + body = REQUEST_BODY, + }) + + local body = assert.res_status(200 , r) + local _, err = cjson.decode(body) + + assert.is_nil(err) + + local log_message = wait_for_json_log_entry(FILE_LOG_PATH_STATS_ONLY) + assert.same("127.0.0.1", log_message.client_ip) + assert.is_number(log_message.request.size) + assert.is_number(log_message.response.size) + + -- test ai-proxy stats + assert.same(_EXPECTED_CHAT_STATS, log_message.ai) + end) + it("fails properly when json instructions are bad", function() local r = client:get("/echo-bad-instructions", { headers = {