Skip to content

Commit

Permalink
feat(ai-proxy): add streaming support and transformers (#12792)
Browse files Browse the repository at this point in the history
* feat(ai-proxy): add streaming support and transformers

* feat(ai-proxy): streaming unit tests; hop-by-hop headers

* fix cohere empty comments

* fix(syntax): shared text extractor for ai token

* fix(ai-proxy): integration tests for streaming

* fix(ai-proxy): integration tests for streaming

* Update 09-streaming_integration_spec.lua

* Update kong/llm/init.lua

Co-authored-by: Michael Martin <[email protected]>

* discussion_r1560031734

* discussion_r1560047662

* discussion_r1560109626

* discussion_r1560117584

* discussion_r1560120287

* discussion_r1560121506

* discussion_r1560123437

* discussion_r1561272376

* discussion_r1561272376

---------

Co-authored-by: Michael Martin <[email protected]>
  • Loading branch information
tysoekong and flrgh authored Apr 12, 2024
1 parent 4bcc53c commit cb1b163
Show file tree
Hide file tree
Showing 26 changed files with 1,292 additions and 120 deletions.
4 changes: 4 additions & 0 deletions changelog/unreleased/kong/feat-ai-proxy-add-streaming.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
message: |
**AI-Proxy**: add support for streaming event-by-event responses back to client on supported providers
scope: Plugin
type: feature
4 changes: 2 additions & 2 deletions kong/llm/drivers/anthropic.lua
Original file line number Diff line number Diff line change
Expand Up @@ -272,13 +272,13 @@ function _M.subrequest(body, conf, http_opts, return_res_table)
headers[conf.auth.header_name] = conf.auth.header_value
end

local res, err = ai_shared.http_request(url, body_string, method, headers, http_opts)
local res, err, httpc = ai_shared.http_request(url, body_string, method, headers, http_opts, return_res_table)
if err then
return nil, nil, "request to ai service failed: " .. err
end

if return_res_table then
return res, res.status, nil
return res, res.status, nil, httpc
else
-- At this point, the entire request / response is complete and the connection
-- will be closed or back on the connection pool.
Expand Down
5 changes: 2 additions & 3 deletions kong/llm/drivers/azure.lua
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,13 @@ function _M.subrequest(body, conf, http_opts, return_res_table)
headers[conf.auth.header_name] = conf.auth.header_value
end

local res, err = ai_shared.http_request(url, body_string, method, headers, http_opts)
local res, err, httpc = ai_shared.http_request(url, body_string, method, headers, http_opts, return_res_table)
if err then
return nil, nil, "request to ai service failed: " .. err
end

if return_res_table then
return res, res.status, nil
return res, res.status, nil, httpc
else
-- At this point, the entire request / response is complete and the connection
-- will be closed or back on the connection pool.
Expand All @@ -82,7 +82,6 @@ end

-- returns err or nil
function _M.configure_request(conf)

local parsed_url

if conf.model.options.upstream_url then
Expand Down
126 changes: 121 additions & 5 deletions kong/llm/drivers/cohere.lua
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,119 @@ local table_new = require("table.new")
local DRIVER_NAME = "cohere"
--

local function handle_stream_event(event_string, model_info, route_type)
local metadata

-- discard empty frames, it should either be a random new line, or comment
if #event_string < 1 then
return
end

local event, err = cjson.decode(event_string)
if err then
return nil, "failed to decode event frame from cohere: " .. err, nil
end

local new_event

if event.event_type == "stream-start" then
kong.ctx.plugin.ai_proxy_cohere_stream_id = event.generation_id

-- ignore the rest of this one
new_event = {
choices = {
[1] = {
delta = {
content = "",
role = "assistant",
},
index = 0,
},
},
id = event.generation_id,
model = model_info.name,
object = "chat.completion.chunk",
}

elseif event.event_type == "text-generation" then
-- this is a token
if route_type == "stream/llm/v1/chat" then
new_event = {
choices = {
[1] = {
delta = {
content = event.text or "",
},
index = 0,
finish_reason = cjson.null,
logprobs = cjson.null,
},
},
id = kong
and kong.ctx
and kong.ctx.plugin
and kong.ctx.plugin.ai_proxy_cohere_stream_id,
model = model_info.name,
object = "chat.completion.chunk",
}

elseif route_type == "stream/llm/v1/completions" then
new_event = {
choices = {
[1] = {
text = event.text or "",
index = 0,
finish_reason = cjson.null,
logprobs = cjson.null,
},
},
id = kong
and kong.ctx
and kong.ctx.plugin
and kong.ctx.plugin.ai_proxy_cohere_stream_id,
model = model_info.name,
object = "text_completion",
}

end

elseif event.event_type == "stream-end" then
-- return a metadata object, with a null event
metadata = {
-- prompt_tokens = event.response.token_count.prompt_tokens,
-- completion_tokens = event.response.token_count.response_tokens,

completion_tokens = event.response
and event.response.meta
and event.response.meta.billed_units
and event.response.meta.billed_units.output_tokens
or
event.response
and event.response.token_count
and event.response.token_count.response_tokens
or 0,

prompt_tokens = event.response
and event.response.meta
and event.response.meta.billed_units
and event.response.meta.billed_units.input_tokens
or
event.response
and event.response.token_count
and event.token_count.prompt_tokens
or 0,
}

end

if new_event then
new_event = cjson.encode(new_event)
return new_event, nil, metadata
else
return nil, nil, metadata -- caller code will handle "unrecognised" event types
end
end

local transformers_to = {
["llm/v1/chat"] = function(request_table, model)
request_table.model = model.name
Expand Down Expand Up @@ -193,7 +306,7 @@ local transformers_from = {

if response_table.prompt and response_table.generations then
-- this is a "co.generate"

for i, v in ipairs(response_table.generations) do
prompt.choices[i] = {
index = (i-1),
Expand Down Expand Up @@ -243,6 +356,9 @@ local transformers_from = {

return cjson.encode(prompt)
end,

["stream/llm/v1/chat"] = handle_stream_event,
["stream/llm/v1/completions"] = handle_stream_event,
}

function _M.from_format(response_string, model_info, route_type)
Expand All @@ -253,7 +369,7 @@ function _M.from_format(response_string, model_info, route_type)
return nil, fmt("no transformer available from format %s://%s", model_info.provider, route_type)
end

local ok, response_string, err = pcall(transformers_from[route_type], response_string, model_info)
local ok, response_string, err, metadata = pcall(transformers_from[route_type], response_string, model_info, route_type)
if not ok or err then
return nil, fmt("transformation failed from type %s://%s: %s",
model_info.provider,
Expand All @@ -262,7 +378,7 @@ function _M.from_format(response_string, model_info, route_type)
)
end

return response_string, nil
return response_string, nil, metadata
end

function _M.to_format(request_table, model_info, route_type)
Expand Down Expand Up @@ -344,13 +460,13 @@ function _M.subrequest(body, conf, http_opts, return_res_table)
headers[conf.auth.header_name] = conf.auth.header_value
end

local res, err = ai_shared.http_request(url, body_string, method, headers, http_opts)
local res, err, httpc = ai_shared.http_request(url, body_string, method, headers, http_opts, return_res_table)
if err then
return nil, nil, "request to ai service failed: " .. err
end

if return_res_table then
return res, res.status, nil
return res, res.status, nil, httpc
else
-- At this point, the entire request / response is complete and the connection
-- will be closed or back on the connection pool.
Expand Down
14 changes: 8 additions & 6 deletions kong/llm/drivers/llama2.lua
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ local transformers_from = {
["llm/v1/completions/raw"] = from_raw,
["llm/v1/chat/ollama"] = ai_shared.from_ollama,
["llm/v1/completions/ollama"] = ai_shared.from_ollama,
["stream/llm/v1/chat/ollama"] = ai_shared.from_ollama,
["stream/llm/v1/completions/ollama"] = ai_shared.from_ollama,
}

local transformers_to = {
Expand All @@ -155,8 +157,8 @@ function _M.from_format(response_string, model_info, route_type)
if not transformers_from[transformer_type] then
return nil, fmt("no transformer available from format %s://%s", model_info.provider, transformer_type)
end
local ok, response_string, err = pcall(

local ok, response_string, err, metadata = pcall(
transformers_from[transformer_type],
response_string,
model_info,
Expand All @@ -166,7 +168,7 @@ function _M.from_format(response_string, model_info, route_type)
return nil, fmt("transformation failed from type %s://%s: %s", model_info.provider, route_type, err or "unexpected_error")
end

return response_string, nil
return response_string, nil, metadata
end

function _M.to_format(request_table, model_info, route_type)
Expand Down Expand Up @@ -217,13 +219,13 @@ function _M.subrequest(body, conf, http_opts, return_res_table)
headers[conf.auth.header_name] = conf.auth.header_value
end

local res, err = ai_shared.http_request(url, body_string, method, headers, http_opts)
local res, err, httpc = ai_shared.http_request(url, body_string, method, headers, http_opts, return_res_table)
if err then
return nil, nil, "request to ai service failed: " .. err
end

if return_res_table then
return res, res.status, nil
return res, res.status, nil, httpc
else
-- At this point, the entire request / response is complete and the connection
-- will be closed or back on the connection pool.
Expand Down Expand Up @@ -265,7 +267,7 @@ function _M.configure_request(conf)

kong.service.request.set_path(parsed_url.path)
kong.service.request.set_scheme(parsed_url.scheme)
kong.service.set_target(parsed_url.host, tonumber(parsed_url.port))
kong.service.set_target(parsed_url.host, (tonumber(parsed_url.port) or 443))

local auth_header_name = conf.auth and conf.auth.header_name
local auth_header_value = conf.auth and conf.auth.header_value
Expand Down
6 changes: 4 additions & 2 deletions kong/llm/drivers/mistral.lua
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ local DRIVER_NAME = "mistral"
local transformers_from = {
["llm/v1/chat/ollama"] = ai_shared.from_ollama,
["llm/v1/completions/ollama"] = ai_shared.from_ollama,
["stream/llm/v1/chat/ollama"] = ai_shared.from_ollama,
["stream/llm/v1/completions/ollama"] = ai_shared.from_ollama,
}

local transformers_to = {
Expand Down Expand Up @@ -104,13 +106,13 @@ function _M.subrequest(body, conf, http_opts, return_res_table)
headers[conf.auth.header_name] = conf.auth.header_value
end

local res, err = ai_shared.http_request(url, body_string, method, headers, http_opts)
local res, err, httpc = ai_shared.http_request(url, body_string, method, headers, http_opts, return_res_table)
if err then
return nil, nil, "request to ai service failed: " .. err
end

if return_res_table then
return res, res.status, nil
return res, res.status, nil, httpc
else
-- At this point, the entire request / response is complete and the connection
-- will be closed or back on the connection pool.
Expand Down
25 changes: 21 additions & 4 deletions kong/llm/drivers/openai.lua
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,18 @@ local socket_url = require "socket.url"
local DRIVER_NAME = "openai"
--

local function handle_stream_event(event_string)
if #event_string > 0 then
local lbl, val = event_string:match("(%w*): (.*)")

if lbl == "data" then
return val
end
end

return nil
end

local transformers_to = {
["llm/v1/chat"] = function(request_table, model, max_tokens, temperature, top_p)
-- if user passed a prompt as a chat, transform it to a chat message
Expand All @@ -29,8 +41,9 @@ local transformers_to = {
max_tokens = max_tokens,
temperature = temperature,
top_p = top_p,
stream = request_table.stream or false,
}

return this, "application/json", nil
end,

Expand All @@ -40,6 +53,7 @@ local transformers_to = {
model = model,
max_tokens = max_tokens,
temperature = temperature,
stream = request_table.stream or false,
}

return this, "application/json", nil
Expand All @@ -52,7 +66,7 @@ local transformers_from = {
if err then
return nil, "'choices' not in llm/v1/chat response"
end

if response_object.choices then
return response_string, nil
else
Expand All @@ -72,6 +86,9 @@ local transformers_from = {
return nil, "'choices' not in llm/v1/completions response"
end
end,

["stream/llm/v1/chat"] = handle_stream_event,
["stream/llm/v1/completions"] = handle_stream_event,
}

function _M.from_format(response_string, model_info, route_type)
Expand Down Expand Up @@ -155,13 +172,13 @@ function _M.subrequest(body, conf, http_opts, return_res_table)
headers[conf.auth.header_name] = conf.auth.header_value
end

local res, err = ai_shared.http_request(url, body_string, method, headers, http_opts)
local res, err, httpc = ai_shared.http_request(url, body_string, method, headers, http_opts, return_res_table)
if err then
return nil, nil, "request to ai service failed: " .. err
end

if return_res_table then
return res, res.status, nil
return res, res.status, nil, httpc
else
-- At this point, the entire request / response is complete and the connection
-- will be closed or back on the connection pool.
Expand Down
Loading

1 comment on commit cb1b163

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bazel Build

Docker image available kong/kong:cb1b16313205d437bd9b2de09a36aabc9a10bd71
Artifacts available https://github.com/Kong/kong/actions/runs/8655900661

Please sign in to comment.