Skip to content

Commit

Permalink
fix(ai-proxy): streaming when chunks contains truncated SSE messages
Browse files Browse the repository at this point in the history
  • Loading branch information
tysoekong committed Jun 10, 2024
1 parent a673ce2 commit 4e761a9
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 2 deletions.
5 changes: 5 additions & 0 deletions changelog/unreleased/kong/ai-proxy-azure-streaming.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
message: |
**AI-proxy-plugin**: Fix a bug where certain Azure models would return partial tokens/words
when in response-streaming mode.
scope: Plugin
type: bugfix
20 changes: 18 additions & 2 deletions kong/llm/drivers/shared.lua
Original file line number Diff line number Diff line change
Expand Up @@ -206,19 +206,35 @@ function _M.frame_to_events(frame)
}
end
else
-- standard SSE parser
local event_lines = split(frame, "\n")
local struct = { event = nil, id = nil, data = nil }

for _, dat in ipairs(event_lines) do
for i, dat in ipairs(event_lines) do
if #dat < 1 then
events[#events + 1] = struct
struct = { event = nil, id = nil, data = nil }
end

-- test for truncated chunk on the last line (no trailing \r\n\r\n)
if #dat > 0 and #event_lines == i then
ngx.log(ngx.DEBUG, "[ai-proxy] truncated sse frame head")
kong.ctx.plugin.truncated_frame = dat
break -- stop parsing immediately, server has done something wrong
end

-- test for abnormal start-of-frame (truncation tail)
if kong and kong.ctx.plugin.truncated_frame then
-- this is the tail of a previous incomplete chunk
ngx.log(ngx.DEBUG, "[ai-proxy] truncated sse frame tail")
dat = fmt("%s%s", kong.ctx.plugin.truncated_frame, dat)
kong.ctx.plugin.truncated_frame = nil
end

local s1, _ = str_find(dat, ":") -- find where the cut point is

if s1 and s1 ~= 1 then
local field = str_sub(dat, 1, s1-1) -- returns "data " from data: hello world
local field = str_sub(dat, 1, s1-1) -- returns "data" from data: hello world
local value = str_ltrim(str_sub(dat, s1+1)) -- returns "hello world" from data: hello world

-- for now not checking if the value is already been set
Expand Down
145 changes: 145 additions & 0 deletions spec/03-plugins/38-ai-proxy/09-streaming_integration_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,59 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
}
}
location = "/openai/llm/v1/chat/partial" {
content_by_lua_block {
local _EVENT_CHUNKS = {
[1] = 'data: { "choices": [ { "delta": { "content": "", "role": "assistant" }, "finish_reason": null, "index": 0, "logprobs": null } ], "created": 1712538905, "id": "chatcmpl-9BXtBvU8Tsw1U7CarzV71vQEjvYwq", "model": "gpt-4-0613", "object": "chat.completion.chunk", "system_fingerprint": null}',
[2] = 'data: { "choices": [ { "delta": { "content": "The " }, "finish_reason": null, "index": 0, "logprobs": null } ], "created": 1712538905, "id": "chatcmpl-9BXtBvU8Tsw1U7CarzV71vQEjvYwq", "model": "gpt-4-0613", "object": "chat.completion.chunk", "system_fingerprint": null}\n\ndata: { "choices": [ { "delta": { "content": "answer " }, "finish_reason": null, "index": 0, "logprobs": null } ], "created": 1712538905, "id": "chatcmpl-9BXtBvU8Tsw1U7CarzV71vQEjvYwq", "model": "gpt-4-0613", "object": "chat.completion.chunk", "system_fingerprint": null}',
[3] = 'data: { "choices": [ { "delta": { "content": "to 1 + " }, "finish_reason": null, "index": 0, "logprobs": null } ], "created": 1712538905, "id": "chatcmpl-9BXtBvU8Ts',
[4] = 'w1U7CarzV71vQEjvYwq", "model": "gpt-4-0613", "object": "chat.completion.chunk", "system_fingerprint": null}',
[5] = 'data: { "choices": [ { "delta": { "content": "1 is " }, "finish_reason": null, "index": 0, "logprobs": null } ], "created": 1712538905, "id": "chatcmpl-9BXtBvU8Tsw1U7CarzV71vQEjvYwq", "model": "gpt-4-0613", "object": "chat.completion.chunk", "system_fingerprint": null}\n\ndata: { "choices": [ { "delta": { "content": "2." }, "finish_reason": null, "index": 0, "logprobs": null } ], "created": 1712538905, "id": "chatcmpl-9BXtBvU8Tsw1U7CarzV71vQEjvYwq", "model": "gpt-4-0613", "object": "chat.completion.chunk", "system_fingerprint": null}',
[6] = 'data: { "choices": [ { "delta": {}, "finish_reason": "stop", "index": 0, "logprobs": null } ], "created": 1712538905, "id": "chatcmpl-9BXtBvU8Tsw1U7CarzV71vQEjvYwq", "model": "gpt-4-0613", "object": "chat.completion.chunk", "system_fingerprint": null}',
[7] = 'data: [DONE]',
}
local fmt = string.format
local pl_file = require "pl.file"
local json = require("cjson.safe")
ngx.req.read_body()
local body, err = ngx.req.get_body_data()
body, err = json.decode(body)
local token = ngx.req.get_headers()["authorization"]
local token_query = ngx.req.get_uri_args()["apikey"]
if token == "Bearer openai-key" or token_query == "openai-key" or body.apikey == "openai-key" then
ngx.req.read_body()
local body, err = ngx.req.get_body_data()
body, err = json.decode(body)
if err or (body.messages == ngx.null) then
ngx.status = 400
ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/responses/bad_request.json"))
else
-- GOOD RESPONSE
ngx.status = 200
ngx.header["Content-Type"] = "text/event-stream"
for i, EVENT in ipairs(_EVENT_CHUNKS) do
-- pretend to truncate chunks
if _EVENT_CHUNKS[i+1] and _EVENT_CHUNKS[i+1]:sub(1, 5) ~= "data:" then
ngx.print(EVENT)
else
ngx.print(fmt("%s\n\n", EVENT))
end
end
end
else
ngx.status = 401
ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/responses/unauthorized.json"))
end
}
}
location = "/cohere/llm/v1/chat/good" {
content_by_lua_block {
local _EVENT_CHUNKS = {
Expand Down Expand Up @@ -291,6 +344,35 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
}
--

-- 200 chat openai - PARTIAL SPLIT CHUNKS
local openai_chat_partial = assert(bp.routes:insert {
service = empty_service,
protocols = { "http" },
strip_path = true,
paths = { "/openai/llm/v1/chat/partial" }
})
bp.plugins:insert {
name = PLUGIN_NAME,
route = { id = openai_chat_partial.id },
config = {
route_type = "llm/v1/chat",
auth = {
header_name = "Authorization",
header_value = "Bearer openai-key",
},
model = {
name = "gpt-3.5-turbo",
provider = "openai",
options = {
max_tokens = 256,
temperature = 1.0,
upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/openai/llm/v1/chat/partial"
},
},
},
}
--

-- 200 chat cohere
local cohere_chat_good = assert(bp.routes:insert {
service = empty_service,
Expand Down Expand Up @@ -489,6 +571,69 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
assert.equal(buf:tostring(), "The answer to 1 + 1 is 2.")
end)

it("good stream request openai with partial split chunks", function()
local httpc = http.new()

local ok, err, _ = httpc:connect({
scheme = "http",
host = helpers.mock_upstream_host,
port = helpers.get_proxy_port(),
})
if not ok then
assert.is_nil(err)
end

-- Then send using `request`, supplying a path and `Host` header instead of a
-- full URI.
local res, err = httpc:request({
path = "/openai/llm/v1/chat/partial",
body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good-stream.json"),
headers = {
["content-type"] = "application/json",
["accept"] = "application/json",
},
})
if not res then
assert.is_nil(err)
end

local reader = res.body_reader
local buffer_size = 35536
local events = {}
local buf = require("string.buffer").new()

-- extract event
repeat
-- receive next chunk
local buffer, err = reader(buffer_size)
if err then
assert.is_falsy(err and err ~= "closed")
end

if buffer then
-- we need to rip each message from this chunk
for s in buffer:gmatch("[^\r\n]+") do
local s_copy = s
s_copy = string.sub(s_copy,7)
s_copy = cjson.decode(s_copy)

buf:put(s_copy
and s_copy.choices
and s_copy.choices
and s_copy.choices[1]
and s_copy.choices[1].delta
and s_copy.choices[1].delta.content
or "")

table.insert(events, s)
end
end
until not buffer

assert.equal(#events, 8)
assert.equal(buf:tostring(), "The answer to 1 + 1 is 2.")
end)

it("good stream request cohere", function()
local httpc = http.new()

Expand Down

0 comments on commit 4e761a9

Please sign in to comment.