Skip to content

Commit

Permalink
fix(ai proxy): 3.7 regression fixes rollup (#12974)
Browse files Browse the repository at this point in the history
* fix(ai-proxy): remove unsupported top_k parameter from openai format(s)

* fix(ai-proxy): broken ollama-type streaming events

* fix(ai-proxy): streaming setting moved ai-proxy only

* fix(ai-proxy): anthropic token counts

* fix(ai-proxy): wrong analytics format; missing azure extra analytics

* fix(ai-proxy): correct plugin name in log serializer

* fix(ai-proxy): store body string in case of regression

* fix(ai-proxy): fix tests

(cherry picked from commit 8470056)
  • Loading branch information
tysoekong authored and locao committed May 6, 2024
1 parent a5530a7 commit 46c1dd1
Show file tree
Hide file tree
Showing 12 changed files with 92 additions and 47 deletions.
4 changes: 2 additions & 2 deletions kong/clustering/compat/checkers.lua
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@ local compatible_checkers = {
if plugin.name == 'ai-proxy' then
local config = plugin.config
if config.model and config.model.options then
if config.model.options.response_streaming then
config.model.options.response_streaming = nil
if config.response_streaming then
config.response_streaming = nil
log_warn_message('configures ' .. plugin.name .. ' plugin with' ..
' response_streaming == nil, because it is not supported' ..
' in this release',
Expand Down
7 changes: 3 additions & 4 deletions kong/llm/drivers/anthropic.lua
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,7 @@ local function handle_stream_event(event_t, model_info, route_type)
and event_data.usage then
return nil, nil, {
prompt_tokens = nil,
completion_tokens = event_data.meta.usage
and event_data.meta.usage.output_tokens
completion_tokens = event_data.usage.output_tokens
or nil,
stop_reason = event_data.delta
and event_data.delta.stop_reason
Expand Down Expand Up @@ -343,7 +342,7 @@ function _M.from_format(response_string, model_info, route_type)
return nil, fmt("no transformer available from format %s://%s", model_info.provider, route_type)
end

local ok, response_string, err = pcall(transform, response_string, model_info, route_type)
local ok, response_string, err, metadata = pcall(transform, response_string, model_info, route_type)
if not ok or err then
return nil, fmt("transformation failed from type %s://%s: %s",
model_info.provider,
Expand All @@ -352,7 +351,7 @@ function _M.from_format(response_string, model_info, route_type)
)
end

return response_string, nil
return response_string, nil, metadata
end

function _M.to_format(request_table, model_info, route_type)
Expand Down
8 changes: 5 additions & 3 deletions kong/llm/drivers/azure.lua
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,11 @@ function _M.pre_request(conf)

-- for azure provider, all of these must/will be set by now
if conf.logging and conf.logging.log_statistics then
kong.log.set_serialize_value("ai.meta.azure_instance_id", conf.model.options.azure_instance)
kong.log.set_serialize_value("ai.meta.azure_deployment_id", conf.model.options.azure_deployment_id)
kong.log.set_serialize_value("ai.meta.azure_api_version", conf.model.options.azure_api_version)
kong.ctx.plugin.ai_extra_meta = {
["azure_instance_id"] = conf.model.options.azure_instance,
["azure_deployment_id"] = conf.model.options.azure_deployment_id,
["azure_api_version"] = conf.model.options.azure_api_version,
}
end

return true
Expand Down
2 changes: 2 additions & 0 deletions kong/llm/drivers/openai.lua
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,15 @@ local transformers_to = {
["llm/v1/chat"] = function(request_table, model_info, route_type)
request_table.model = request_table.model or model_info.name
request_table.stream = request_table.stream or false -- explicitly set this
request_table.top_k = nil -- explicitly remove unsupported default

return request_table, "application/json", nil
end,

["llm/v1/completions"] = function(request_table, model_info, route_type)
request_table.model = model_info.name
request_table.stream = request_table.stream or false -- explicitly set this
request_table.top_k = nil -- explicitly remove unsupported default

return request_table, "application/json", nil
end,
Expand Down
58 changes: 41 additions & 17 deletions kong/llm/drivers/shared.lua
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ local log_entry_keys = {
TOKENS_CONTAINER = "usage",
META_CONTAINER = "meta",
PAYLOAD_CONTAINER = "payload",
REQUEST_BODY = "ai.payload.request",

-- payload keys
REQUEST_BODY = "request",
RESPONSE_BODY = "response",

-- meta keys
Expand Down Expand Up @@ -271,20 +271,30 @@ function _M.to_ollama(request_table, model)
end

function _M.from_ollama(response_string, model_info, route_type)
local output, _, analytics

local response_table, err = cjson.decode(response_string)
if err then
return nil, "failed to decode ollama response"
end
local output, err, _, analytics

if route_type == "stream/llm/v1/chat" then
local response_table, err = cjson.decode(response_string.data)
if err then
return nil, "failed to decode ollama response"
end

output, _, analytics = handle_stream_event(response_table, model_info, route_type)

elseif route_type == "stream/llm/v1/completions" then
local response_table, err = cjson.decode(response_string.data)
if err then
return nil, "failed to decode ollama response"
end

output, _, analytics = handle_stream_event(response_table, model_info, route_type)

else
local response_table, err = cjson.decode(response_string)
if err then
return nil, "failed to decode ollama response"
end

-- there is no direct field indicating STOP reason, so calculate it manually
local stop_length = (model_info.options and model_info.options.max_tokens) or -1
local stop_reason = "stop"
Expand Down Expand Up @@ -412,14 +422,14 @@ function _M.pre_request(conf, request_table)
request_table[auth_param_name] = auth_param_value
end

if conf.logging and conf.logging.log_statistics then
kong.log.set_serialize_value(log_entry_keys.REQUEST_MODEL, conf.model.name)
kong.log.set_serialize_value(log_entry_keys.PROVIDER_NAME, conf.model.provider)
end

-- if enabled AND request type is compatible, capture the input for analytics
if conf.logging and conf.logging.log_payloads then
kong.log.set_serialize_value(log_entry_keys.REQUEST_BODY, kong.request.get_raw_body())
local plugin_name = conf.__key__:match('plugins:(.-):')
if not plugin_name or plugin_name == "" then
return nil, "no plugin name is being passed by the plugin"
end

kong.log.set_serialize_value(fmt("ai.%s.%s.%s", plugin_name, log_entry_keys.PAYLOAD_CONTAINER, log_entry_keys.REQUEST_BODY), kong.request.get_raw_body())
end

-- log tokens prompt for reports and billing
Expand Down Expand Up @@ -475,7 +485,6 @@ function _M.post_request(conf, response_object)
if not request_analytics_plugin then
request_analytics_plugin = {
[log_entry_keys.META_CONTAINER] = {},
[log_entry_keys.PAYLOAD_CONTAINER] = {},
[log_entry_keys.TOKENS_CONTAINER] = {
[log_entry_keys.PROMPT_TOKEN] = 0,
[log_entry_keys.COMPLETION_TOKEN] = 0,
Expand All @@ -485,11 +494,18 @@ function _M.post_request(conf, response_object)
end

-- Set the model, response, and provider names in the current try context
request_analytics_plugin[log_entry_keys.META_CONTAINER][log_entry_keys.REQUEST_MODEL] = conf.model.name
request_analytics_plugin[log_entry_keys.META_CONTAINER][log_entry_keys.REQUEST_MODEL] = kong.ctx.plugin.llm_model_requested or conf.model.name
request_analytics_plugin[log_entry_keys.META_CONTAINER][log_entry_keys.RESPONSE_MODEL] = response_object.model or conf.model.name
request_analytics_plugin[log_entry_keys.META_CONTAINER][log_entry_keys.PROVIDER_NAME] = provider_name
request_analytics_plugin[log_entry_keys.META_CONTAINER][log_entry_keys.PLUGIN_ID] = conf.__plugin_id

-- set extra per-provider meta
if kong.ctx.plugin.ai_extra_meta and type(kong.ctx.plugin.ai_extra_meta) == "table" then
for k, v in pairs(kong.ctx.plugin.ai_extra_meta) do
request_analytics_plugin[log_entry_keys.META_CONTAINER][k] = v
end
end

-- Capture openai-format usage stats from the transformed response body
if response_object.usage then
if response_object.usage.prompt_tokens then
Expand All @@ -505,16 +521,24 @@ function _M.post_request(conf, response_object)

-- Log response body if logging payloads is enabled
if conf.logging and conf.logging.log_payloads then
request_analytics_plugin[log_entry_keys.PAYLOAD_CONTAINER][log_entry_keys.RESPONSE_BODY] = body_string
kong.log.set_serialize_value(fmt("ai.%s.%s.%s", plugin_name, log_entry_keys.PAYLOAD_CONTAINER, log_entry_keys.RESPONSE_BODY), body_string)
end

-- Update context with changed values
request_analytics_plugin[log_entry_keys.PAYLOAD_CONTAINER] = {
[log_entry_keys.RESPONSE_BODY] = body_string,
}
request_analytics[plugin_name] = request_analytics_plugin
kong.ctx.shared.analytics = request_analytics

if conf.logging and conf.logging.log_statistics then
-- Log analytics data
kong.log.set_serialize_value(fmt("%s.%s", "ai", plugin_name), request_analytics_plugin)
kong.log.set_serialize_value(fmt("ai.%s.%s", plugin_name, log_entry_keys.TOKENS_CONTAINER),
request_analytics_plugin[log_entry_keys.TOKENS_CONTAINER])

-- Log meta
kong.log.set_serialize_value(fmt("ai.%s.%s", plugin_name, log_entry_keys.META_CONTAINER),
request_analytics_plugin[log_entry_keys.META_CONTAINER])
end

-- log tokens response for reports and billing
Expand Down
6 changes: 0 additions & 6 deletions kong/llm/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,6 @@ local model_options_schema = {
type = "record",
required = false,
fields = {
{ response_streaming = {
type = "string",
description = "Whether to 'optionally allow', 'deny', or 'always' (force) the streaming of answers via server sent events.",
required = false,
default = "allow",
one_of = { "allow", "deny", "always" } }},
{ max_tokens = {
type = "integer",
description = "Defines the max_tokens, if using chat or completion models.",
Expand Down
23 changes: 16 additions & 7 deletions kong/plugins/ai-proxy/handler.lua
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ local function handle_streaming_frame(conf)
if not event_t then
event_t, err = cjson.decode(formatted)
end

if not err then
if not token_t then
token_t = get_token_text(event_t)
Expand All @@ -166,17 +166,24 @@ local function handle_streaming_frame(conf)
(kong.ctx.plugin.ai_stream_completion_tokens or 0) + math.ceil(#strip(token_t) / 4)
end
end

elseif metadata then
kong.ctx.plugin.ai_stream_completion_tokens = metadata.completion_tokens or kong.ctx.plugin.ai_stream_completion_tokens
kong.ctx.plugin.ai_stream_prompt_tokens = metadata.prompt_tokens or kong.ctx.plugin.ai_stream_prompt_tokens
end
end

framebuffer:put("data: ")
framebuffer:put(formatted or "")
framebuffer:put((formatted ~= "[DONE]") and "\n\n" or "")
end

if conf.logging and conf.logging.log_statistics and metadata then
kong.ctx.plugin.ai_stream_completion_tokens =
(kong.ctx.plugin.ai_stream_completion_tokens or 0) +
(metadata.completion_tokens or 0)
or kong.ctx.plugin.ai_stream_completion_tokens
kong.ctx.plugin.ai_stream_prompt_tokens =
(kong.ctx.plugin.ai_stream_prompt_tokens or 0) +
(metadata.prompt_tokens or 0)
or kong.ctx.plugin.ai_stream_prompt_tokens
end
end
end

Expand Down Expand Up @@ -427,10 +434,12 @@ function _M:access(conf)
-- check if the user has asked for a stream, and/or if
-- we are forcing all requests to be of streaming type
if request_table and request_table.stream or
(conf_m.model.options and conf_m.model.options.response_streaming) == "always" then
(conf_m.response_streaming and conf_m.response_streaming == "always") then
request_table.stream = true

-- this condition will only check if user has tried
-- to activate streaming mode within their request
if conf_m.model.options and conf_m.model.options.response_streaming == "deny" then
if conf_m.response_streaming and conf_m.response_streaming == "deny" then
return bad_request("response streaming is not enabled for this LLM")
end

Expand Down
20 changes: 19 additions & 1 deletion kong/plugins/ai-proxy/schema.lua
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,31 @@

local typedefs = require("kong.db.schema.typedefs")
local llm = require("kong.llm")
local deep_copy = require("kong.tools.utils").deep_copy

local this_schema = deep_copy(llm.config_schema)

local ai_proxy_only_config = {
{
response_streaming = {
type = "string",
description = "Whether to 'optionally allow', 'deny', or 'always' (force) the streaming of answers via server sent events.",
required = false,
default = "allow",
one_of = { "allow", "deny", "always" }},
},
}

for i, v in pairs(ai_proxy_only_config) do
this_schema.fields[#this_schema.fields+1] = v
end

return {
name = "ai-proxy",
fields = {
{ protocols = typedefs.protocols_http },
{ consumer = typedefs.no_consumer },
{ service = typedefs.no_service },
{ config = llm.config_schema },
{ config = this_schema },
},
}
4 changes: 2 additions & 2 deletions spec/02-integration/09-hybrid_mode/09-config-compat_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,7 @@ describe("CP/DP config compat transformations #" .. strategy, function()
name = "ai-proxy",
enabled = true,
config = {
response_streaming = "allow", -- becomes nil
route_type = "preserve", -- becomes 'llm/v1/chat'
auth = {
header_name = "header",
Expand All @@ -561,7 +562,6 @@ describe("CP/DP config compat transformations #" .. strategy, function()
options = {
max_tokens = 512,
temperature = 0.5,
response_streaming = "allow", -- becomes nil
upstream_path = "/anywhere", -- becomes nil
},
},
Expand All @@ -570,7 +570,7 @@ describe("CP/DP config compat transformations #" .. strategy, function()
-- ]]

local expected_ai_proxy_prior_37 = utils.cycle_aware_deep_copy(ai_proxy)
expected_ai_proxy_prior_37.config.model.options.response_streaming = nil
expected_ai_proxy_prior_37.config.response_streaming = nil
expected_ai_proxy_prior_37.config.model.options.upstream_path = nil
expected_ai_proxy_prior_37.config.route_type = "llm/v1/chat"

Expand Down
5 changes: 2 additions & 3 deletions spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ local _EXPECTED_CHAT_STATS = {
request_model = 'gpt-3.5-turbo',
response_model = 'gpt-3.5-turbo-0613',
},
payload = {},
usage = {
completion_token = 12,
prompt_token = 25,
Expand Down Expand Up @@ -782,8 +781,8 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
assert.is_number(log_message.response.size)

-- test request bodies
assert.matches('"content": "What is 1 + 1?"', log_message.ai.payload.request, nil, true)
assert.matches('"role": "user"', log_message.ai.payload.request, nil, true)
assert.matches('"content": "What is 1 + 1?"', log_message.ai['ai-proxy'].payload.request, nil, true)
assert.matches('"role": "user"', log_message.ai['ai-proxy'].payload.request, nil, true)

-- test response bodies
assert.matches('"content": "The sum of 1 + 1 is 2.",', log_message.ai["ai-proxy"].payload.response, nil, true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,6 @@ local _EXPECTED_CHAT_STATS = {
request_model = 'gpt-4',
response_model = 'gpt-3.5-turbo-0613',
},
payload = {},
usage = {
completion_token = 12,
prompt_token = 25,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,6 @@ local _EXPECTED_CHAT_STATS = {
request_model = 'gpt-4',
response_model = 'gpt-3.5-turbo-0613',
},
payload = {},
usage = {
completion_token = 12,
prompt_token = 25,
Expand Down

0 comments on commit 46c1dd1

Please sign in to comment.