Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(ai-proxy) clean up code and fix gzip response #13155

Merged
merged 5 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions changelog/unreleased/kong/fix-ai-gzip-content.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
message: |
**AI-Proxy**: Fixed issue when response is gzipped even if client doesn't accept.
type: bugfix
scope: Plugin
2 changes: 2 additions & 0 deletions kong-3.8.0-0.rockspec
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,8 @@ build = {
["kong.llm.drivers.anthropic"] = "kong/llm/drivers/anthropic.lua",
["kong.llm.drivers.mistral"] = "kong/llm/drivers/mistral.lua",
["kong.llm.drivers.llama2"] = "kong/llm/drivers/llama2.lua",
["kong.llm.state"] = "kong/llm/state.lua",

["kong.llm.drivers.gemini"] = "kong/llm/drivers/gemini.lua",
["kong.llm.drivers.bedrock"] = "kong/llm/drivers/bedrock.lua",

Expand Down
4 changes: 2 additions & 2 deletions kong/llm/drivers/anthropic.lua
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ local function to_claude_prompt(req)
return kong_messages_to_claude_prompt(req.messages)

end

return nil, "request is missing .prompt and .messages commands"
end

Expand Down Expand Up @@ -328,7 +328,7 @@ function _M.from_format(response_string, model_info, route_type)
if not transform then
return nil, fmt("no transformer available from format %s://%s", model_info.provider, route_type)
end

local ok, response_string, err, metadata = pcall(transform, response_string, model_info, route_type)
if not ok or err then
return nil, fmt("transformation failed from type %s://%s: %s",
Expand Down
2 changes: 1 addition & 1 deletion kong/llm/drivers/azure.lua
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ function _M.configure_request(conf)
-- technically min supported version
query_table["api-version"] = kong.request.get_query_arg("api-version")
or (conf.model.options and conf.model.options.azure_api_version)

if auth_param_name and auth_param_value and auth_param_location == "query" then
query_table[auth_param_name] = auth_param_value
end
Expand Down
2 changes: 1 addition & 1 deletion kong/llm/drivers/bedrock.lua
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ function _M.from_format(response_string, model_info, route_type)
if not transformers_from[route_type] then
return nil, fmt("no transformer available from format %s://%s", model_info.provider, route_type)
end

local ok, response_string, err, metadata = pcall(transformers_from[route_type], response_string, model_info, route_type)
if not ok or err then
return nil, fmt("transformation failed from type %s://%s: %s",
Expand Down
52 changes: 26 additions & 26 deletions kong/llm/drivers/cohere.lua
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ local _CHAT_ROLES = {

local function handle_stream_event(event_t, model_info, route_type)
local metadata

-- discard empty frames, it should either be a random new line, or comment
if (not event_t.data) or (#event_t.data < 1) then
return
Expand All @@ -31,12 +31,12 @@ local function handle_stream_event(event_t, model_info, route_type)
if err then
return nil, "failed to decode event frame from cohere: " .. err, nil
end

local new_event

if event.event_type == "stream-start" then
kong.ctx.plugin.ai_proxy_cohere_stream_id = event.generation_id

-- ignore the rest of this one
new_event = {
choices = {
Expand All @@ -52,7 +52,7 @@ local function handle_stream_event(event_t, model_info, route_type)
model = model_info.name,
object = "chat.completion.chunk",
}

elseif event.event_type == "text-generation" then
-- this is a token
if route_type == "stream/llm/v1/chat" then
Expand Down Expand Up @@ -137,19 +137,19 @@ end
local function handle_json_inference_event(request_table, model)
request_table.temperature = request_table.temperature
request_table.max_tokens = request_table.max_tokens

request_table.p = request_table.top_p
request_table.k = request_table.top_k

request_table.top_p = nil
request_table.top_k = nil

request_table.model = model.name or request_table.model
request_table.stream = request_table.stream or false -- explicitly set this

if request_table.prompt and request_table.messages then
return kong.response.exit(400, "cannot run a 'prompt' and a history of 'messages' at the same time - refer to schema")

elseif request_table.messages then
-- we have to move all BUT THE LAST message into "chat_history" array
-- and move the LAST message (from 'user') into "message" string
Expand All @@ -164,26 +164,26 @@ local function handle_json_inference_event(request_table, model)
else
role = _CHAT_ROLES.user
end

chat_history[i] = {
role = role,
message = v.content,
}
end
end

request_table.chat_history = chat_history
end

request_table.message = request_table.messages[#request_table.messages].content
request_table.messages = nil

elseif request_table.prompt then
request_table.prompt = request_table.prompt
request_table.messages = nil
request_table.message = nil
end

return request_table, "application/json", nil
end

Expand All @@ -202,7 +202,7 @@ local transformers_from = {
-- messages/choices table is only 1 size, so don't need to static allocate
local messages = {}
messages.choices = {}

if response_table.prompt and response_table.generations then
-- this is a "co.generate"
for i, v in ipairs(response_table.generations) do
Expand All @@ -215,7 +215,7 @@ local transformers_from = {
messages.object = "text_completion"
messages.model = model_info.name
messages.id = response_table.id

local stats = {
completion_tokens = response_table.meta
and response_table.meta.billed_units
Expand All @@ -230,10 +230,10 @@ local transformers_from = {
and (response_table.meta.billed_units.output_tokens + response_table.meta.billed_units.input_tokens),
}
messages.usage = stats

elseif response_table.text then
-- this is a "co.chat"

messages.choices[1] = {
index = 0,
message = {
Expand All @@ -245,7 +245,7 @@ local transformers_from = {
messages.object = "chat.completion"
messages.model = model_info.name
messages.id = response_table.generation_id

local stats = {
completion_tokens = response_table.meta
and response_table.meta.billed_units
Expand All @@ -260,10 +260,10 @@ local transformers_from = {
and (response_table.meta.billed_units.output_tokens + response_table.meta.billed_units.input_tokens),
}
messages.usage = stats

else -- probably a fault
return nil, "'text' or 'generations' missing from cohere response body"

end

return cjson.encode(messages)
Expand Down Expand Up @@ -314,17 +314,17 @@ local transformers_from = {
prompt.object = "chat.completion"
prompt.model = model_info.name
prompt.id = response_table.generation_id

local stats = {
completion_tokens = response_table.token_count and response_table.token_count.response_tokens,
prompt_tokens = response_table.token_count and response_table.token_count.prompt_tokens,
total_tokens = response_table.token_count and response_table.token_count.total_tokens,
}
prompt.usage = stats

else -- probably a fault
return nil, "'text' or 'generations' missing from cohere response body"

end

return cjson.encode(prompt)
Expand Down Expand Up @@ -465,7 +465,7 @@ function _M.configure_request(conf)
and ai_shared.operation_map[DRIVER_NAME][conf.route_type].path
or "/"
end

-- if the path is read from a URL capture, ensure that it is valid
parsed_url.path = string_gsub(parsed_url.path, "^/*", "/")

Expand Down
4 changes: 2 additions & 2 deletions kong/llm/drivers/gemini.lua
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ local function to_gemini_chat_openai(request_table, model_info, route_type)
}
end
end

new_r.generationConfig = to_gemini_generation_config(request_table)

return new_r, "application/json", nil
Expand Down Expand Up @@ -222,7 +222,7 @@ function _M.from_format(response_string, model_info, route_type)
if not transformers_from[route_type] then
return nil, fmt("no transformer available from format %s://%s", model_info.provider, route_type)
end

local ok, response_string, err, metadata = pcall(transformers_from[route_type], response_string, model_info, route_type)
if not ok or err then
return nil, fmt("transformation failed from type %s://%s: %s",
Expand Down
2 changes: 1 addition & 1 deletion kong/llm/drivers/llama2.lua
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ local function to_raw(request_table, model)
messages.parameters.top_k = request_table.top_k
messages.parameters.temperature = request_table.temperature
messages.parameters.stream = request_table.stream or false -- explicitly set this

if request_table.prompt and request_table.messages then
return kong.response.exit(400, "cannot run raw 'prompt' and chat history 'messages' requests at the same time - refer to schema")

Expand Down
14 changes: 7 additions & 7 deletions kong/llm/drivers/openai.lua
Original file line number Diff line number Diff line change
Expand Up @@ -35,23 +35,23 @@ local transformers_to = {
}

local transformers_from = {
["llm/v1/chat"] = function(response_string, model_info)
["llm/v1/chat"] = function(response_string, _)
local response_object, err = cjson.decode(response_string)
if err then
return nil, "'choices' not in llm/v1/chat response"
return nil, "failed to decode llm/v1/chat response"
end

if response_object.choices then
return response_string, nil
else
return nil, "'choices' not in llm/v1/chat response"
end
end,

["llm/v1/completions"] = function(response_string, model_info)
["llm/v1/completions"] = function(response_string, _)
local response_object, err = cjson.decode(response_string)
if err then
return nil, "'choices' not in llm/v1/completions response"
return nil, "failed to decode llm/v1/completions response"
end

if response_object.choices then
Expand All @@ -72,7 +72,7 @@ function _M.from_format(response_string, model_info, route_type)
if not transformers_from[route_type] then
return nil, fmt("no transformer available from format %s://%s", model_info.provider, route_type)
end

local ok, response_string, err = pcall(transformers_from[route_type], response_string, model_info)
if not ok or err then
return nil, fmt("transformation failed from type %s://%s: %s",
Expand Down Expand Up @@ -203,7 +203,7 @@ function _M.configure_request(conf)
parsed_url = socket_url.parse(ai_shared.upstream_url_format[DRIVER_NAME])
parsed_url.path = path
end

-- if the path is read from a URL capture, ensure that it is valid
parsed_url.path = string_gsub(parsed_url.path, "^/*", "/")

Expand Down
29 changes: 15 additions & 14 deletions kong/llm/drivers/shared.lua
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
local _M = {}

-- imports
local cjson = require("cjson.safe")
local http = require("resty.http")
local fmt = string.format
local os = os
local parse_url = require("socket.url").parse
local cjson = require("cjson.safe")
local http = require("resty.http")
local fmt = string.format
local os = os
local parse_url = require("socket.url").parse
local llm_state = require("kong.llm.state")
local aws_stream = require("kong.tools.aws_stream")
--

Expand Down Expand Up @@ -258,7 +259,7 @@ function _M.frame_to_events(frame, provider)

-- it may start with ',' which is the start of the new frame
frame = (string.sub(str_ltrim(frame), 1, 1) == "," and string.sub(str_ltrim(frame), 2)) or frame

-- it may end with the array terminator ']' indicating the finished stream
if string.sub(str_rtrim(frame), -1) == "]" then
frame = string.sub(str_rtrim(frame), 1, -2)
Expand Down Expand Up @@ -341,7 +342,7 @@ function _M.frame_to_events(frame, provider)
end -- if
end
end

return events
end

Expand Down Expand Up @@ -445,7 +446,7 @@ function _M.from_ollama(response_string, model_info, route_type)

end
end

if output and output ~= _M._CONST.SSE_TERMINATOR then
output, err = cjson.encode(output)
end
Expand Down Expand Up @@ -500,7 +501,7 @@ function _M.resolve_plugin_conf(kong_request, conf)
if #splitted ~= 2 then
return nil, "cannot parse expression for field '" .. v .. "'"
end

-- find the request parameter, with the configured name
prop_m, err = _M.conf_from_request(kong_request, splitted[1], splitted[2])
if err then
Expand All @@ -524,7 +525,7 @@ function _M.pre_request(conf, request_table)
local auth_param_name = conf.auth and conf.auth.param_name
local auth_param_value = conf.auth and conf.auth.param_value
local auth_param_location = conf.auth and conf.auth.param_location

if auth_param_name and auth_param_value and auth_param_location == "body" and request_table then
request_table[auth_param_name] = auth_param_value
end
Expand All @@ -547,7 +548,7 @@ function _M.pre_request(conf, request_table)
kong.log.warn("failed calculating cost for prompt tokens: ", err)
prompt_tokens = 0
end
kong.ctx.shared.ai_prompt_tokens = (kong.ctx.shared.ai_prompt_tokens or 0) + prompt_tokens
llm_state.increase_prompt_tokens_count(prompt_tokens)
end

local start_time_key = "ai_request_start_time_" .. plugin_name
Expand Down Expand Up @@ -586,7 +587,7 @@ function _M.post_request(conf, response_object)
end

-- check if we already have analytics in this context
local request_analytics = kong.ctx.shared.analytics
local request_analytics = llm_state.get_request_analytics()

-- create a new structure if not
if not request_analytics then
Expand Down Expand Up @@ -657,7 +658,7 @@ function _M.post_request(conf, response_object)
[log_entry_keys.RESPONSE_BODY] = body_string,
}
request_analytics[plugin_name] = request_analytics_plugin
kong.ctx.shared.analytics = request_analytics
llm_state.set_request_analytics(request_analytics)

if conf.logging and conf.logging.log_statistics then
-- Log meta data
Expand All @@ -679,7 +680,7 @@ function _M.post_request(conf, response_object)
kong.log.warn("failed calculating cost for response tokens: ", err)
response_tokens = 0
end
kong.ctx.shared.ai_response_tokens = (kong.ctx.shared.ai_response_tokens or 0) + response_tokens
llm_state.increase_response_tokens_count(response_tokens)

return nil
end
Expand Down
Loading
Loading