Skip to content

Commit

Permalink
finished gemini (text-only) support
Browse files Browse the repository at this point in the history
  • Loading branch information
tysoekong committed Jul 4, 2024
1 parent 9d0756d commit e10ee63
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 38 deletions.
75 changes: 68 additions & 7 deletions kong/llm/drivers/gemini.lua
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,69 @@ local function to_gemini_generation_config(request_table)
}
end

local function handle_stream_event(event_t, model_info, route_type)
local metadata


-- discard empty frames, it should either be a random new line, or comment
if (not event_t.data) or (#event_t.data < 1) then
return
end

local event, err = cjson.decode(event_t.data)
if err then
ngx.log(ngx.WARN, "failed to decode stream event frame from gemini: " .. err)
return nil, "failed to decode stream event frame from gemini", nil
end

local new_event
local metadata

if event.candidates and
#event.candidates > 0 then

if event.candidates[1].content and
event.candidates[1].content.parts and
#event.candidates[1].content.parts > 0 and
event.candidates[1].content.parts[1].text then

new_event = {
choices = {
[1] = {
delta = {
content = event.candidates[1].content.parts[1].text or "",
role = "assistant",
},
index = 0,
},
},
}
end

if event.candidates[1].finishReason then
metadata = metadata or {}
metadata.finished_reason = event.candidates[1].finishReason
new_event = "[DONE]"
end
end

if event.usageMetadata then
metadata = metadata or {}
metadata.completion_tokens = event.usageMetadata.candidatesTokenCount or 0
metadata.prompt_tokens = event.usageMetadata.promptTokenCount or 0
end

if new_event then
if new_event ~= "[DONE]" then
new_event = cjson.encode(new_event)
end

return new_event, nil, metadata
else
return nil, nil, metadata -- caller code will handle "unrecognised" event types
end
end

local function to_gemini_chat_openai(request_table, model_info, route_type)
if request_table then -- try-catch type mechanism
local new_r = {}
Expand Down Expand Up @@ -180,12 +243,11 @@ end

local transformers_to = {
["llm/v1/chat"] = to_gemini_chat_openai,
["gemini/v1/chat"] = to_gemini_chat_gemini,
}

local transformers_from = {
["llm/v1/chat"] = from_gemini_chat_openai,
["gemini/v1/chat"] = from_gemini_chat_gemini,
["stream/llm/v1/chat"] = handle_stream_event,
}

function _M.from_format(response_string, model_info, route_type)
Expand All @@ -196,7 +258,7 @@ function _M.from_format(response_string, model_info, route_type)
return nil, fmt("no transformer available from format %s://%s", model_info.provider, route_type)
end

local ok, response_string, err = pcall(transformers_from[route_type], response_string, model_info, route_type)
local ok, response_string, err, metadata = pcall(transformers_from[route_type], response_string, model_info, route_type)
if not ok or err then
return nil, fmt("transformation failed from type %s://%s: %s",
model_info.provider,
Expand All @@ -205,7 +267,7 @@ function _M.from_format(response_string, model_info, route_type)
)
end

return response_string, nil
return response_string, nil, metadata
end

function _M.to_format(request_table, model_info, route_type)
Expand Down Expand Up @@ -302,7 +364,8 @@ function _M.post_request(conf)
end

function _M.pre_request(conf, body)
kong.service.request.set_header("Accept-Encoding", "gzip, identity") -- tell server not to send brotli
-- disable gzip for gemini because it breaks streaming
kong.service.request.set_header("Accept-Encoding", "identity")

return true, nil
end
Expand Down Expand Up @@ -341,8 +404,6 @@ function _M.configure_request(conf, identity_interface)

parsed_url = socket_url.parse(f_url)

kong.log.inspect(parsed_url)

if conf.model.options and conf.model.options.upstream_path then
-- upstream path override is set (or templated from request params)
parsed_url.path = conf.model.options.upstream_path
Expand Down
63 changes: 40 additions & 23 deletions kong/llm/drivers/shared.lua
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ _M.streaming_has_token_counts = {
["cohere"] = true,
["llama2"] = true,
["anthropic"] = true,
["gemini"] = true,
}

_M.upstream_url_format = {
Expand Down Expand Up @@ -216,12 +217,12 @@ end
-- as if it were an SSE message.
--
-- @param {string} frame input string to format into SSE events
-- @param {string} delimiter delimeter (can be complex string) to split by
-- @param {boolean} raw_json sets application/json byte-parser mode
-- @return {table} n number of split SSE messages, or empty table
function _M.frame_to_events(frame)
function _M.frame_to_events(frame, raw_json_mode)
local events = {}

if (not frame) or #frame < 1 then
if (not frame) or (#frame < 1) or (type(frame)) ~= "string" then
return
end

Expand All @@ -234,36 +235,52 @@ function _M.frame_to_events(frame)
data = event,
}
end

-- some new LLMs return the JSON object-by-object,
-- because that totally makes sense to parse?!
elseif raw_json_mode then
-- if this is the first frame, it will begin with array opener '['
frame = (string.sub(str_ltrim(frame), 1, 1) == "[" and string.sub(str_ltrim(frame), 2)) or frame

-- it may start with ',' which is the start of the new frame
frame = (string.sub(str_ltrim(frame), 1, 1) == "," and string.sub(str_ltrim(frame), 2)) or frame

-- finally, it may end with the array terminator ']' indicating the finished stream
frame = (string.sub(str_ltrim(frame), -1) == "]" and string.sub(str_ltrim(frame), 1, -2)) or frame

-- for multiple events that arrive in the same frame, split by top-level comma
for _, v in ipairs(split(frame, "\n,")) do
events[#events+1] = { data = v }
end

else
-- standard SSE parser
local event_lines = split(frame, "\n")
local struct = { event = nil, id = nil, data = nil }

for i, dat in ipairs(event_lines) do
if #dat < 1 then
events[#events + 1] = struct
struct = { event = nil, id = nil, data = nil }
end
-- test for truncated chunk on the last line (no trailing \r\n\r\n)
if #dat > 0 and #event_lines == i then
ngx.log(ngx.DEBUG, "[ai-proxy] truncated sse frame head")
kong.ctx.plugin.truncated_frame = dat
break -- stop parsing immediately, server has done something wrong
end

-- test for truncated chunk on the last line (no trailing \r\n\r\n)
if #dat > 0 and #event_lines == i then
ngx.log(ngx.DEBUG, "[ai-proxy] truncated sse frame head")
kong.ctx.plugin.truncated_frame = dat
break -- stop parsing immediately, server has done something wrong
end
-- test for abnormal start-of-frame (truncation tail)
if kong and kong.ctx.plugin.truncated_frame then
-- this is the tail of a previous incomplete chunk
ngx.log(ngx.DEBUG, "[ai-proxy] truncated sse frame tail")
dat = fmt("%s%s", kong.ctx.plugin.truncated_frame, dat)
kong.ctx.plugin.truncated_frame = nil
end

-- test for abnormal start-of-frame (truncation tail)
if kong and kong.ctx.plugin.truncated_frame then
-- this is the tail of a previous incomplete chunk
ngx.log(ngx.DEBUG, "[ai-proxy] truncated sse frame tail")
dat = fmt("%s%s", kong.ctx.plugin.truncated_frame, dat)
kong.ctx.plugin.truncated_frame = nil
end
local s1, _ = str_find(dat, ":") -- find where the cut point is

local s1, _ = str_find(dat, ":") -- find where the cut point is
if s1 and s1 ~= 1 then
local field = str_sub(dat, 1, s1-1) -- returns "data" from data: hello world
local value = str_ltrim(str_sub(dat, s1+1)) -- returns "hello world" from data: hello world

if s1 and s1 ~= 1 then
local field = str_sub(dat, 1, s1-1) -- returns "data" from data: hello world
local field = str_sub(dat, 1, s1-1) -- returns "data " from data: hello world
local value = str_ltrim(str_sub(dat, s1+1)) -- returns "hello world" from data: hello world

-- for now not checking if the value is already been set
Expand Down
26 changes: 18 additions & 8 deletions kong/plugins/ai-proxy/handler.lua
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ local _M = {
}


-- static messages
local ERROR_MSG = { error = { message = "" } }
local ERROR__NOT_SET = 'data: {"error": true, "message": "empty or unsupported transformer response"}'


local _KEYBASTION = setmetatable({}, {
__mode = "k",
__index = function(this_cache, plugin_config)
Expand Down Expand Up @@ -92,14 +97,17 @@ local function handle_streaming_frame(conf)
-- because we have already 200 OK'd the client by now

if (not finished) and (is_gzip) then
chunk = kong_utils.inflate_gzip(chunk)
chunk = kong_utils.inflate_gzip(ngx.arg[1])
end

local events = ai_shared.frame_to_events(chunk)
local events = ai_shared.frame_to_events(chunk, conf.model.provider == "gemini")

if not events then
local response = 'data: {"error": true, "message": "empty transformer response"}'

-- usually a not-supported-transformer or empty frames.
-- header_filter has already run, so all we can do is log it,
-- and then send the client a readable error in a single chunk
local response = ERROR__NOT_SET

if is_gzip then
response = kong_utils.deflate_gzip(response)
end
Expand Down Expand Up @@ -419,8 +427,9 @@ function _M:access(conf)
return bad_request("response streaming is not enabled for this LLM")
end

-- store token cost estimate, on first pass
if not kong_ctx_plugin.ai_stream_prompt_tokens then
-- store token cost estimate, on first pass, if the
-- provider doesn't reply with a prompt token count
if (not kong.ctx.plugin.ai_stream_prompt_tokens) and (not ai_shared.streaming_has_token_counts[conf_m.model.provider]) then
local prompt_tokens, err = ai_shared.calculate_cost(request_table or {}, {}, 1.8)
if err then
kong.log.err("unable to estimate request token cost: ", err)
Expand Down Expand Up @@ -468,15 +477,16 @@ function _M:access(conf)

-- get the provider's cached identity interface - nil may come back, which is fine
local identity_interface = _KEYBASTION[conf]
if identity_interface.error then
if identity_interface and identity_interface.error then
kong.ctx.shared.skip_response_transformer = true
kong.log.err("error authenticating with cloud-provider, ", identity_interface.error)

return internal_server_error("LLM request failed before proxying")
end

-- now re-configure the request for this operation type
local ok, err = ai_driver.configure_request(conf_m, identity_interface.interface)
local ok, err = ai_driver.configure_request(conf_m,
identity_interface and identity_interface.interface)
if not ok then
kong_ctx_shared.skip_response_transformer = true
kong.log.err("failed to configure request for AI service: ", err)
Expand Down

0 comments on commit e10ee63

Please sign in to comment.