From cb1b16313205d437bd9b2de09a36aabc9a10bd71 Mon Sep 17 00:00:00 2001 From: Jack Tysoe <91137069+tysoekong@users.noreply.github.com> Date: Fri, 12 Apr 2024 03:04:27 +0100 Subject: [PATCH] feat(ai-proxy): add streaming support and transformers (#12792) * feat(ai-proxy): add streaming support and transformers * feat(ai-proxy): streaming unit tests; hop-by-hop headers * fix cohere empty comments * fix(syntax): shared text extractor for ai token * fix(ai-proxy): integration tests for streaming * fix(ai-proxy): integration tests for streaming * Update 09-streaming_integration_spec.lua * Update kong/llm/init.lua Co-authored-by: Michael Martin * discussion_r1560031734 * discussion_r1560047662 * discussion_r1560109626 * discussion_r1560117584 * discussion_r1560120287 * discussion_r1560121506 * discussion_r1560123437 * discussion_r1561272376 * discussion_r1561272376 --------- Co-authored-by: Michael Martin --- .../kong/feat-ai-proxy-add-streaming.yml | 4 + kong/llm/drivers/anthropic.lua | 4 +- kong/llm/drivers/azure.lua | 5 +- kong/llm/drivers/cohere.lua | 126 ++++- kong/llm/drivers/llama2.lua | 14 +- kong/llm/drivers/mistral.lua | 6 +- kong/llm/drivers/openai.lua | 25 +- kong/llm/drivers/shared.lua | 219 +++++--- kong/llm/init.lua | 298 +++++++++- kong/plugins/ai-proxy/handler.lua | 66 ++- spec/03-plugins/38-ai-proxy/01-unit_spec.lua | 73 +++ .../09-streaming_integration_spec.lua | 514 ++++++++++++++++++ .../llm-v1-chat/requests/good-stream.json | 13 + .../ai-proxy/llama2/ollama/chat-stream.json | 13 + .../llm-v1-chat/requests/good-stream.json | 13 + .../openai/llm-v1-chat/requests/good.json | 3 +- .../llm-v1-chat/requests/good_own_model.json | 3 +- .../expected-requests/azure/llm-v1-chat.json | 1 + .../azure/llm-v1-completions.json | 3 +- .../mistral/openai/llm-v1-chat.json | 1 + .../expected-requests/openai/llm-v1-chat.json | 1 + .../openai/llm-v1-completions.json | 3 +- .../real-stream-frames/cohere/llm-v1-chat.txt | 1 + .../cohere/llm-v1-completions.txt | 1 + .../real-stream-frames/openai/llm-v1-chat.txt | 1 + .../openai/llm-v1-completions.txt | 1 + 26 files changed, 1292 insertions(+), 120 deletions(-) create mode 100644 changelog/unreleased/kong/feat-ai-proxy-add-streaming.yml create mode 100644 spec/03-plugins/38-ai-proxy/09-streaming_integration_spec.lua create mode 100644 spec/fixtures/ai-proxy/cohere/llm-v1-chat/requests/good-stream.json create mode 100644 spec/fixtures/ai-proxy/llama2/ollama/chat-stream.json create mode 100644 spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good-stream.json create mode 100644 spec/fixtures/ai-proxy/unit/real-stream-frames/cohere/llm-v1-chat.txt create mode 100644 spec/fixtures/ai-proxy/unit/real-stream-frames/cohere/llm-v1-completions.txt create mode 100644 spec/fixtures/ai-proxy/unit/real-stream-frames/openai/llm-v1-chat.txt create mode 100644 spec/fixtures/ai-proxy/unit/real-stream-frames/openai/llm-v1-completions.txt diff --git a/changelog/unreleased/kong/feat-ai-proxy-add-streaming.yml b/changelog/unreleased/kong/feat-ai-proxy-add-streaming.yml new file mode 100644 index 00000000000..4f4f348fefc --- /dev/null +++ b/changelog/unreleased/kong/feat-ai-proxy-add-streaming.yml @@ -0,0 +1,4 @@ +message: | + **AI-Proxy**: add support for streaming event-by-event responses back to client on supported providers +scope: Plugin +type: feature diff --git a/kong/llm/drivers/anthropic.lua b/kong/llm/drivers/anthropic.lua index 18c3f2bce5b..69b6ddf6838 100644 --- a/kong/llm/drivers/anthropic.lua +++ b/kong/llm/drivers/anthropic.lua @@ -272,13 +272,13 @@ function _M.subrequest(body, conf, http_opts, return_res_table) headers[conf.auth.header_name] = conf.auth.header_value end - local res, err = ai_shared.http_request(url, body_string, method, headers, http_opts) + local res, err, httpc = ai_shared.http_request(url, body_string, method, headers, http_opts, return_res_table) if err then return nil, nil, "request to ai service failed: " .. err end if return_res_table then - return res, res.status, nil + return res, res.status, nil, httpc else -- At this point, the entire request / response is complete and the connection -- will be closed or back on the connection pool. diff --git a/kong/llm/drivers/azure.lua b/kong/llm/drivers/azure.lua index 9207b20a54d..7918cf166bc 100644 --- a/kong/llm/drivers/azure.lua +++ b/kong/llm/drivers/azure.lua @@ -59,13 +59,13 @@ function _M.subrequest(body, conf, http_opts, return_res_table) headers[conf.auth.header_name] = conf.auth.header_value end - local res, err = ai_shared.http_request(url, body_string, method, headers, http_opts) + local res, err, httpc = ai_shared.http_request(url, body_string, method, headers, http_opts, return_res_table) if err then return nil, nil, "request to ai service failed: " .. err end if return_res_table then - return res, res.status, nil + return res, res.status, nil, httpc else -- At this point, the entire request / response is complete and the connection -- will be closed or back on the connection pool. @@ -82,7 +82,6 @@ end -- returns err or nil function _M.configure_request(conf) - local parsed_url if conf.model.options.upstream_url then diff --git a/kong/llm/drivers/cohere.lua b/kong/llm/drivers/cohere.lua index 46bde9bc3e1..2788c749b46 100644 --- a/kong/llm/drivers/cohere.lua +++ b/kong/llm/drivers/cohere.lua @@ -12,6 +12,119 @@ local table_new = require("table.new") local DRIVER_NAME = "cohere" -- +local function handle_stream_event(event_string, model_info, route_type) + local metadata + + -- discard empty frames, it should either be a random new line, or comment + if #event_string < 1 then + return + end + + local event, err = cjson.decode(event_string) + if err then + return nil, "failed to decode event frame from cohere: " .. err, nil + end + + local new_event + + if event.event_type == "stream-start" then + kong.ctx.plugin.ai_proxy_cohere_stream_id = event.generation_id + + -- ignore the rest of this one + new_event = { + choices = { + [1] = { + delta = { + content = "", + role = "assistant", + }, + index = 0, + }, + }, + id = event.generation_id, + model = model_info.name, + object = "chat.completion.chunk", + } + + elseif event.event_type == "text-generation" then + -- this is a token + if route_type == "stream/llm/v1/chat" then + new_event = { + choices = { + [1] = { + delta = { + content = event.text or "", + }, + index = 0, + finish_reason = cjson.null, + logprobs = cjson.null, + }, + }, + id = kong + and kong.ctx + and kong.ctx.plugin + and kong.ctx.plugin.ai_proxy_cohere_stream_id, + model = model_info.name, + object = "chat.completion.chunk", + } + + elseif route_type == "stream/llm/v1/completions" then + new_event = { + choices = { + [1] = { + text = event.text or "", + index = 0, + finish_reason = cjson.null, + logprobs = cjson.null, + }, + }, + id = kong + and kong.ctx + and kong.ctx.plugin + and kong.ctx.plugin.ai_proxy_cohere_stream_id, + model = model_info.name, + object = "text_completion", + } + + end + + elseif event.event_type == "stream-end" then + -- return a metadata object, with a null event + metadata = { + -- prompt_tokens = event.response.token_count.prompt_tokens, + -- completion_tokens = event.response.token_count.response_tokens, + + completion_tokens = event.response + and event.response.meta + and event.response.meta.billed_units + and event.response.meta.billed_units.output_tokens + or + event.response + and event.response.token_count + and event.response.token_count.response_tokens + or 0, + + prompt_tokens = event.response + and event.response.meta + and event.response.meta.billed_units + and event.response.meta.billed_units.input_tokens + or + event.response + and event.response.token_count + and event.token_count.prompt_tokens + or 0, + } + + end + + if new_event then + new_event = cjson.encode(new_event) + return new_event, nil, metadata + else + return nil, nil, metadata -- caller code will handle "unrecognised" event types + end +end + local transformers_to = { ["llm/v1/chat"] = function(request_table, model) request_table.model = model.name @@ -193,7 +306,7 @@ local transformers_from = { if response_table.prompt and response_table.generations then -- this is a "co.generate" - + for i, v in ipairs(response_table.generations) do prompt.choices[i] = { index = (i-1), @@ -243,6 +356,9 @@ local transformers_from = { return cjson.encode(prompt) end, + + ["stream/llm/v1/chat"] = handle_stream_event, + ["stream/llm/v1/completions"] = handle_stream_event, } function _M.from_format(response_string, model_info, route_type) @@ -253,7 +369,7 @@ function _M.from_format(response_string, model_info, route_type) return nil, fmt("no transformer available from format %s://%s", model_info.provider, route_type) end - local ok, response_string, err = pcall(transformers_from[route_type], response_string, model_info) + local ok, response_string, err, metadata = pcall(transformers_from[route_type], response_string, model_info, route_type) if not ok or err then return nil, fmt("transformation failed from type %s://%s: %s", model_info.provider, @@ -262,7 +378,7 @@ function _M.from_format(response_string, model_info, route_type) ) end - return response_string, nil + return response_string, nil, metadata end function _M.to_format(request_table, model_info, route_type) @@ -344,13 +460,13 @@ function _M.subrequest(body, conf, http_opts, return_res_table) headers[conf.auth.header_name] = conf.auth.header_value end - local res, err = ai_shared.http_request(url, body_string, method, headers, http_opts) + local res, err, httpc = ai_shared.http_request(url, body_string, method, headers, http_opts, return_res_table) if err then return nil, nil, "request to ai service failed: " .. err end if return_res_table then - return res, res.status, nil + return res, res.status, nil, httpc else -- At this point, the entire request / response is complete and the connection -- will be closed or back on the connection pool. diff --git a/kong/llm/drivers/llama2.lua b/kong/llm/drivers/llama2.lua index 7e965e2c553..bf3ee42ee74 100644 --- a/kong/llm/drivers/llama2.lua +++ b/kong/llm/drivers/llama2.lua @@ -133,6 +133,8 @@ local transformers_from = { ["llm/v1/completions/raw"] = from_raw, ["llm/v1/chat/ollama"] = ai_shared.from_ollama, ["llm/v1/completions/ollama"] = ai_shared.from_ollama, + ["stream/llm/v1/chat/ollama"] = ai_shared.from_ollama, + ["stream/llm/v1/completions/ollama"] = ai_shared.from_ollama, } local transformers_to = { @@ -155,8 +157,8 @@ function _M.from_format(response_string, model_info, route_type) if not transformers_from[transformer_type] then return nil, fmt("no transformer available from format %s://%s", model_info.provider, transformer_type) end - - local ok, response_string, err = pcall( + + local ok, response_string, err, metadata = pcall( transformers_from[transformer_type], response_string, model_info, @@ -166,7 +168,7 @@ function _M.from_format(response_string, model_info, route_type) return nil, fmt("transformation failed from type %s://%s: %s", model_info.provider, route_type, err or "unexpected_error") end - return response_string, nil + return response_string, nil, metadata end function _M.to_format(request_table, model_info, route_type) @@ -217,13 +219,13 @@ function _M.subrequest(body, conf, http_opts, return_res_table) headers[conf.auth.header_name] = conf.auth.header_value end - local res, err = ai_shared.http_request(url, body_string, method, headers, http_opts) + local res, err, httpc = ai_shared.http_request(url, body_string, method, headers, http_opts, return_res_table) if err then return nil, nil, "request to ai service failed: " .. err end if return_res_table then - return res, res.status, nil + return res, res.status, nil, httpc else -- At this point, the entire request / response is complete and the connection -- will be closed or back on the connection pool. @@ -265,7 +267,7 @@ function _M.configure_request(conf) kong.service.request.set_path(parsed_url.path) kong.service.request.set_scheme(parsed_url.scheme) - kong.service.set_target(parsed_url.host, tonumber(parsed_url.port)) + kong.service.set_target(parsed_url.host, (tonumber(parsed_url.port) or 443)) local auth_header_name = conf.auth and conf.auth.header_name local auth_header_value = conf.auth and conf.auth.header_value diff --git a/kong/llm/drivers/mistral.lua b/kong/llm/drivers/mistral.lua index 84f96178295..d091939eeb2 100644 --- a/kong/llm/drivers/mistral.lua +++ b/kong/llm/drivers/mistral.lua @@ -17,6 +17,8 @@ local DRIVER_NAME = "mistral" local transformers_from = { ["llm/v1/chat/ollama"] = ai_shared.from_ollama, ["llm/v1/completions/ollama"] = ai_shared.from_ollama, + ["stream/llm/v1/chat/ollama"] = ai_shared.from_ollama, + ["stream/llm/v1/completions/ollama"] = ai_shared.from_ollama, } local transformers_to = { @@ -104,13 +106,13 @@ function _M.subrequest(body, conf, http_opts, return_res_table) headers[conf.auth.header_name] = conf.auth.header_value end - local res, err = ai_shared.http_request(url, body_string, method, headers, http_opts) + local res, err, httpc = ai_shared.http_request(url, body_string, method, headers, http_opts, return_res_table) if err then return nil, nil, "request to ai service failed: " .. err end if return_res_table then - return res, res.status, nil + return res, res.status, nil, httpc else -- At this point, the entire request / response is complete and the connection -- will be closed or back on the connection pool. diff --git a/kong/llm/drivers/openai.lua b/kong/llm/drivers/openai.lua index 5d612055236..27472be5c9a 100644 --- a/kong/llm/drivers/openai.lua +++ b/kong/llm/drivers/openai.lua @@ -11,6 +11,18 @@ local socket_url = require "socket.url" local DRIVER_NAME = "openai" -- +local function handle_stream_event(event_string) + if #event_string > 0 then + local lbl, val = event_string:match("(%w*): (.*)") + + if lbl == "data" then + return val + end + end + + return nil +end + local transformers_to = { ["llm/v1/chat"] = function(request_table, model, max_tokens, temperature, top_p) -- if user passed a prompt as a chat, transform it to a chat message @@ -29,8 +41,9 @@ local transformers_to = { max_tokens = max_tokens, temperature = temperature, top_p = top_p, + stream = request_table.stream or false, } - + return this, "application/json", nil end, @@ -40,6 +53,7 @@ local transformers_to = { model = model, max_tokens = max_tokens, temperature = temperature, + stream = request_table.stream or false, } return this, "application/json", nil @@ -52,7 +66,7 @@ local transformers_from = { if err then return nil, "'choices' not in llm/v1/chat response" end - + if response_object.choices then return response_string, nil else @@ -72,6 +86,9 @@ local transformers_from = { return nil, "'choices' not in llm/v1/completions response" end end, + + ["stream/llm/v1/chat"] = handle_stream_event, + ["stream/llm/v1/completions"] = handle_stream_event, } function _M.from_format(response_string, model_info, route_type) @@ -155,13 +172,13 @@ function _M.subrequest(body, conf, http_opts, return_res_table) headers[conf.auth.header_name] = conf.auth.header_value end - local res, err = ai_shared.http_request(url, body_string, method, headers, http_opts) + local res, err, httpc = ai_shared.http_request(url, body_string, method, headers, http_opts, return_res_table) if err then return nil, nil, "request to ai service failed: " .. err end if return_res_table then - return res, res.status, nil + return res, res.status, nil, httpc else -- At this point, the entire request / response is complete and the connection -- will be closed or back on the connection pool. diff --git a/kong/llm/drivers/shared.lua b/kong/llm/drivers/shared.lua index a254cad92cf..b38fd8d6c84 100644 --- a/kong/llm/drivers/shared.lua +++ b/kong/llm/drivers/shared.lua @@ -1,10 +1,11 @@ local _M = {} -- imports -local cjson = require("cjson.safe") -local http = require("resty.http") -local fmt = string.format -local os = os +local cjson = require("cjson.safe") +local http = require("resty.http") +local fmt = string.format +local os = os +local parse_url = require("socket.url").parse -- local log_entry_keys = { @@ -21,6 +22,11 @@ local log_entry_keys = { local openai_override = os.getenv("OPENAI_TEST_PORT") +_M.streaming_has_token_counts = { + ["cohere"] = true, + ["llama2"] = true, +} + _M.upstream_url_format = { openai = fmt("%s://api.openai.com:%s", (openai_override and "http") or "https", (openai_override) or "443"), anthropic = "https://api.anthropic.com:443", @@ -86,6 +92,48 @@ _M.clear_response_headers = { }, } + +local function handle_stream_event(event_table, model_info, route_type) + if event_table.done then + -- return analytics table + return nil, nil, { + prompt_tokens = event_table.prompt_eval_count or 0, + completion_tokens = event_table.eval_count or 0, + } + + else + -- parse standard response frame + if route_type == "stream/llm/v1/chat" then + return { + choices = { + [1] = { + delta = { + content = event_table.message and event_table.message.content or "", + }, + index = 0, + }, + }, + model = event_table.model, + object = "chat.completion.chunk", + } + + elseif route_type == "stream/llm/v1/completions" then + return { + choices = { + [1] = { + text = event_table.response or "", + index = 0, + }, + }, + model = event_table.model, + object = "text_completion", + } + + end + end +end + + function _M.to_ollama(request_table, model) local input = {} @@ -117,57 +165,67 @@ function _M.to_ollama(request_table, model) end function _M.from_ollama(response_string, model_info, route_type) + local output, _, analytics + local response_table, err = cjson.decode(response_string) if err then return nil, "failed to decode ollama response" end - -- there is no direct field indicating STOP reason, so calculate it manually - local stop_length = (model_info.options and model_info.options.max_tokens) or -1 - local stop_reason = "stop" - if response_table.eval_count and response_table.eval_count == stop_length then - stop_reason = "length" - end + if route_type == "stream/llm/v1/chat" then + output, _, analytics = handle_stream_event(response_table, model_info, route_type) - local output = {} - - -- common fields - output.model = response_table.model - output.created = response_table.created_at - - -- analytics - output.usage = { - completion_tokens = response_table.eval_count or 0, - prompt_tokens = response_table.prompt_eval_count or 0, - total_tokens = (response_table.eval_count or 0) + - (response_table.prompt_eval_count or 0), - } - - if route_type == "llm/v1/chat" then - output.object = "chat.completion" - output.choices = { - [1] = { - finish_reason = stop_reason, - index = 0, - message = response_table.message, - } + elseif route_type == "stream/llm/v1/completions" then + output, _, analytics = handle_stream_event(response_table, model_info, route_type) + + else + -- there is no direct field indicating STOP reason, so calculate it manually + local stop_length = (model_info.options and model_info.options.max_tokens) or -1 + local stop_reason = "stop" + if response_table.eval_count and response_table.eval_count == stop_length then + stop_reason = "length" + end + + output = {} + + -- common fields + output.model = response_table.model + output.created = response_table.created_at + + -- analytics + output.usage = { + completion_tokens = response_table.eval_count or 0, + prompt_tokens = response_table.prompt_eval_count or 0, + total_tokens = (response_table.eval_count or 0) + + (response_table.prompt_eval_count or 0), } - elseif route_type == "llm/v1/completions" then - output.object = "text_completion" - output.choices = { - [1] = { - index = 0, - text = response_table.response, + if route_type == "llm/v1/chat" then + output.object = "chat.completion" + output.choices = { + { + finish_reason = stop_reason, + index = 0, + message = response_table.message, + } } - } - else - return nil, "no ollama-format transformer for response type " .. route_type + elseif route_type == "llm/v1/completions" then + output.object = "text_completion" + output.choices = { + { + index = 0, + text = response_table.response, + } + } + + else + return nil, "no ollama-format transformer for response type " .. route_type + end end - return cjson.encode(output) + return output and cjson.encode(output) or nil, nil, analytics end function _M.pre_request(conf, request_table) @@ -188,9 +246,24 @@ function _M.pre_request(conf, request_table) return true, nil end -function _M.post_request(conf, response_string) - if conf.logging and conf.logging.log_payloads then - kong.log.set_serialize_value(log_entry_keys.RESPONSE_BODY, response_string) +function _M.post_request(conf, response_object) + local err + + if type(response_object) == "string" then + -- set raw string body first, then decode + if conf.logging and conf.logging.log_payloads then + kong.log.set_serialize_value(log_entry_keys.RESPONSE_BODY, response_object) + end + + response_object, err = cjson.decode(response_object) + if err then + return nil, "failed to decode response from JSON" + end + else + -- this has come from another AI subsystem, and contains "response" field + if conf.logging and conf.logging.log_payloads then + kong.log.set_serialize_value(log_entry_keys.RESPONSE_BODY, response_object.response or "ERROR__NOT_SET") + end end -- analytics and logging @@ -207,11 +280,6 @@ function _M.post_request(conf, response_string) } end - local response_object, err = cjson.decode(response_string) - if err then - return nil, "failed to decode response from JSON" - end - -- this captures the openai-format usage stats from the transformed response body if response_object.usage then if response_object.usage.prompt_tokens then @@ -239,7 +307,7 @@ function _M.post_request(conf, response_string) return nil end -function _M.http_request(url, body, method, headers, http_opts) +function _M.http_request(url, body, method, headers, http_opts, buffered) local httpc = http.new() if http_opts.http_timeout then @@ -250,19 +318,48 @@ function _M.http_request(url, body, method, headers, http_opts) httpc:set_proxy_options(http_opts.proxy_opts) end - local res, err = httpc:request_uri( - url, - { - method = method, - body = body, - headers = headers, + local parsed = parse_url(url) + + if buffered then + local ok, err, _ = httpc:connect({ + scheme = parsed.scheme, + host = parsed.host, + port = parsed.port or 443, -- this always fails. experience. + ssl_server_name = parsed.host, ssl_verify = http_opts.https_verify, }) - if not res then - return nil, "request failed: " .. err - end + if not ok then + return nil, err + end - return res, nil + local res, err = httpc:request({ + path = parsed.path or "/", + query = parsed.query, + method = method, + headers = headers, + body = body, + }) + if not res then + return nil, "connection failed: " .. err + end + + return res, nil, httpc + else + -- 'single-shot' + local res, err = httpc:request_uri( + url, + { + method = method, + body = body, + headers = headers, + ssl_verify = http_opts.https_verify, + }) + if not res then + return nil, "request failed: " .. err + end + + return res, nil, nil + end end return _M diff --git a/kong/llm/init.lua b/kong/llm/init.lua index 6b973ae262e..f18e1a5ad8b 100644 --- a/kong/llm/init.lua +++ b/kong/llm/init.lua @@ -1,10 +1,13 @@ -- imports -local typedefs = require("kong.db.schema.typedefs") -local fmt = string.format -local cjson = require("cjson.safe") -local re_match = ngx.re.match - +local typedefs = require("kong.db.schema.typedefs") +local fmt = string.format +local cjson = require("cjson.safe") +local re_match = ngx.re.match +local buf = require("string.buffer") +local lower = string.lower +local meta = require "kong.meta" local ai_shared = require("kong.llm.drivers.shared") +local strip = require("kong.tools.utils").strip -- local _M = {} @@ -48,6 +51,12 @@ local model_options_schema = { type = "record", required = false, fields = { + { response_streaming = { + type = "string", + description = "Whether to 'optionally allow', 'deny', or 'always' (force) the streaming of answers via WebSocket.", + required = true, + default = "allow", + one_of = { "allow", "deny", "always" } }}, { max_tokens = { type = "integer", description = "Defines the max_tokens, if using chat or completion models.", @@ -225,6 +234,15 @@ _M.config_schema = { }, } +local streaming_skip_headers = { + ["connection"] = true, + ["content-type"] = true, + ["keep-alive"] = true, + ["set-cookie"] = true, + ["transfer-encoding"] = true, + ["via"] = true, +} + local formats_compatible = { ["llm/v1/chat"] = { ["llm/v1/chat"] = true, @@ -234,6 +252,20 @@ local formats_compatible = { }, } +local function bad_request(msg) + ngx.log(ngx.WARN, msg) + ngx.status = 400 + ngx.header["Content-Type"] = "application/json" + ngx.say(cjson.encode({ error = { message = msg } })) +end + +local function internal_server_error(msg) + ngx.log(ngx.ERR, msg) + ngx.status = 500 + ngx.header["Content-Type"] = "application/json" + ngx.say(cjson.encode({ error = { message = msg } })) +end + local function identify_request(request) -- primitive request format determination local formats = {} @@ -260,6 +292,103 @@ local function identify_request(request) end end +local function get_token_text(event_t) + -- chat + return + event_t and + event_t.choices and + #event_t.choices > 0 and + event_t.choices[1].delta and + event_t.choices[1].delta.content + + or + + -- completions + event_t and + event_t.choices and + #event_t.choices > 0 and + event_t.choices[1].text + + or "" +end + +-- Function to count the number of words in a string +local function count_words(str) + local count = 0 + for word in str:gmatch("%S+") do + count = count + 1 + end + return count +end + +-- Function to count the number of words or tokens based on the content type +local function count_prompt(content, tokens_factor) + local count = 0 + + if type(content) == "string" then + count = count_words(content) * tokens_factor + elseif type(content) == "table" then + for _, item in ipairs(content) do + if type(item) == "string" then + count = count + (count_words(item) * tokens_factor) + elseif type(item) == "number" then + count = count + 1 + elseif type(item) == "table" then + for _2, item2 in ipairs(item) do + if type(item2) == "number" then + count = count + 1 + else + return nil, "Invalid request format" + end + end + else + return nil, "Invalid request format" + end + end + else + return nil, "Invalid request format" + end + return count +end + +function _M:calculate_cost(query_body, tokens_models, tokens_factor) + local query_cost = 0 + local err + + -- Check if max_tokens is provided in the request body + local max_tokens = query_body.max_tokens + + if not max_tokens then + if query_body.model and tokens_models then + max_tokens = tonumber(tokens_models[query_body.model]) + end + end + + if not max_tokens then + return nil, "No max_tokens in query and no key found in the plugin config for model: " .. query_body.model + end + + if query_body.messages then + -- Calculate the cost based on the content type + for _, message in ipairs(query_body.messages) do + query_cost = query_cost + (count_words(message.content) * tokens_factor) + end + elseif query_body.prompt then + -- Calculate the cost based on the content type + query_cost, err = count_prompt(query_body.prompt, tokens_factor) + if err then + return nil, err + end + else + return nil, "No messages or prompt in query" + end + + -- Round the total cost quantified + query_cost = math.floor(query_cost + 0.5) + + return query_cost +end + function _M.is_compatible(request, route_type) local format, err = identify_request(request) if err then @@ -273,6 +402,160 @@ function _M.is_compatible(request, route_type) return false, fmt("[%s] message format is not compatible with [%s] route type", format, route_type) end +function _M:handle_streaming_request(body) + -- convert it to the specified driver format + local request, _, err = self.driver.to_format(body, self.conf.model, self.conf.route_type) + if err then + return internal_server_error(err) + end + + -- run the shared logging/analytics/auth function + ai_shared.pre_request(self.conf, request) + + local prompt_tokens = 0 + local err + if not ai_shared.streaming_has_token_counts[self.conf.model.provider] then + -- Estimate the cost using KONG CX's counter implementation + prompt_tokens, err = self:calculate_cost(request, {}, 1.8) + if err then + return internal_server_error("unable to estimate request token cost: " .. err) + end + end + + -- send it to the ai service + local res, _, err, httpc = self.driver.subrequest(request, self.conf, self.http_opts, true) + if err then + return internal_server_error("failed to connect to " .. self.conf.model.provider .. " for streaming: " .. err) + end + if res.status ~= 200 then + err = "bad status code whilst opening streaming to " .. self.conf.model.provider .. ": " .. res.status + ngx.log(ngx.WARN, err) + return bad_request(err) + end + + -- get a big enough buffered ready to make sure we rip the entire chunk(s) each time + local reader = res.body_reader + local buffer_size = 35536 + local events + + -- we create a fake "kong response" table to pass to the telemetry handler later + local telemetry_fake_table = { + response = buf:new(), + usage = { + prompt_tokens = prompt_tokens, + completion_tokens = 0, + total_tokens = 0, + }, + } + + ngx.status = 200 + ngx.header["Content-Type"] = "text/event-stream" + ngx.header["Via"] = meta._SERVER_TOKENS + + for k, v in pairs(res.headers) do + if not streaming_skip_headers[lower(k)] then + ngx.header[k] = v + end + end + + -- server-sent events should ALWAYS be chunk encoded. + -- if they aren't then... we just won't support them. + repeat + -- receive next chunk + local buffer, err = reader(buffer_size) + if err then + ngx.log(ngx.ERR, "failed to read chunk of streaming buffer, ", err) + break + elseif not buffer then + break + end + + -- we need to rip each message from this chunk + events = {} + for s in buffer:gmatch("[^\r\n]+") do + table.insert(events, s) + end + + local metadata + local route_type = "stream/" .. self.conf.route_type + + -- then parse each into the standard inference format + for i, event in ipairs(events) do + local event_t + local token_t + + -- some LLMs do a final reply with token counts, and such + -- so we will stash them if supported + local formatted, err, this_metadata = self.driver.from_format(event, self.conf.model, route_type) + if err then + return internal_server_error(err) + end + + metadata = this_metadata or metadata + + -- handle event telemetry + if self.conf.logging.log_statistics then + + if not ai_shared.streaming_has_token_counts[self.conf.model.provider] then + event_t = cjson.decode(formatted) + token_t = get_token_text(event_t) + + -- incredibly loose estimate based on https://help.openai.com/en/articles/4936856-what-are-tokens-and-how-to-count-them + -- but this is all we can do until OpenAI fixes this... + -- + -- essentially, every 4 characters is a token, with minimum of 1 per event + telemetry_fake_table.usage.completion_tokens = + telemetry_fake_table.usage.completion_tokens + math.ceil(#strip(token_t) / 4) + + elseif metadata then + telemetry_fake_table.usage.completion_tokens = metadata.completion_tokens + telemetry_fake_table.usage.prompt_tokens = metadata.prompt_tokens + end + + end + + -- then stream to the client + if formatted then -- only stream relevant frames back to the user + if self.conf.logging.log_payloads then + -- append the "choice" to the buffer, for logging later. this actually works! + if not event_t then + event_t, err = cjson.decode(formatted) + end + + if err then + return internal_server_error("something wrong with decoding a specific token") + end + + if not token_t then + token_t = get_token_text(event_t) + end + + telemetry_fake_table.response:put(token_t) + end + + -- construct, transmit, and flush the frame + ngx.print("data: ", formatted, "\n\n") + ngx.flush(true) + end + end + + until not buffer + + local ok, err = httpc:set_keepalive() + if not ok then + -- continue even if keepalive gets killed + ngx.log(ngx.WARN, "setting keepalive failed: ", err) + end + + -- process telemetry + telemetry_fake_table.response = telemetry_fake_table.response:tostring() + + telemetry_fake_table.usage.total_tokens = telemetry_fake_table.usage.completion_tokens + + telemetry_fake_table.usage.prompt_tokens + + ai_shared.post_request(self.conf, telemetry_fake_table) +end + function _M:ai_introspect_body(request, system_prompt, http_opts, response_regex_match) local err, _ @@ -287,7 +570,8 @@ function _M:ai_introspect_body(request, system_prompt, http_opts, response_regex role = "user", content = request, } - } + }, + stream = false, } -- convert it to the specified driver format @@ -325,7 +609,7 @@ function _M:ai_introspect_body(request, system_prompt, http_opts, response_regex and ai_response.choices[1].message and ai_response.choices[1].message.content if not new_request_body then - return nil, "no response choices received from upstream AI service" + return nil, "no 'choices' in upstream AI service response" end -- if specified, extract the first regex match from the AI response diff --git a/kong/plugins/ai-proxy/handler.lua b/kong/plugins/ai-proxy/handler.lua index 8b7564b480c..64ea7c17dd2 100644 --- a/kong/plugins/ai-proxy/handler.lua +++ b/kong/plugins/ai-proxy/handler.lua @@ -2,6 +2,7 @@ local _M = {} -- imports local ai_shared = require("kong.llm.drivers.shared") +local ai_module = require("kong.llm") local llm = require("kong.llm") local cjson = require("cjson.safe") local kong_utils = require("kong.tools.gzip") @@ -112,14 +113,10 @@ end function _M:access(conf) - kong.service.request.enable_buffering() - -- store the route_type in ctx for use in response parsing local route_type = conf.route_type kong.ctx.plugin.operation = route_type - local ai_driver = require("kong.llm.drivers." .. conf.model.provider) - local request_table -- we may have received a replacement / decorated request body from another AI plugin @@ -145,33 +142,50 @@ function _M:access(conf) return bad_request(err) end - -- execute pre-request hooks for this driver - local ok, err = ai_driver.pre_request(conf, request_table) - if not ok then - return bad_request(err) - end + if request_table.stream or conf.model.options.response_streaming == "always" then + kong.ctx.shared.skip_response_transformer = true - -- transform the body to Kong-format for this provider/model - local parsed_request_body, content_type, err = ai_driver.to_format(request_table, conf.model, route_type) - if err then - return bad_request(err) - end + -- into sub-request streaming handler + -- everything happens in the access phase here + if conf.model.options.response_streaming == "deny" then + return bad_request("response streaming is not enabled for this LLM") + end - -- execute pre-request hooks for "all" drivers before set new body - local ok, err = ai_shared.pre_request(conf, parsed_request_body) - if not ok then - return bad_request(err) - end + local llm_handler = ai_module:new(conf, {}) + llm_handler:handle_streaming_request(request_table) + else + kong.service.request.enable_buffering() - kong.service.request.set_body(parsed_request_body, content_type) + local ai_driver = require("kong.llm.drivers." .. conf.model.provider) - -- now re-configure the request for this operation type - local ok, err = ai_driver.configure_request(conf) - if not ok then - return internal_server_error(err) - end + -- execute pre-request hooks for this driver + local ok, err = ai_driver.pre_request(conf, request_table) + if not ok then + return bad_request(err) + end - -- lights out, and away we go + -- transform the body to Kong-format for this provider/model + local parsed_request_body, content_type, err = ai_driver.to_format(request_table, conf.model, route_type) + if err then + return bad_request(err) + end + + -- execute pre-request hooks for "all" drivers before set new body + local ok, err = ai_shared.pre_request(conf, parsed_request_body) + if not ok then + return bad_request(err) + end + + kong.service.request.set_body(parsed_request_body, content_type) + + -- now re-configure the request for this operation type + local ok, err = ai_driver.configure_request(conf) + if not ok then + return internal_server_error(err) + end + + -- lights out, and away we go + end end diff --git a/spec/03-plugins/38-ai-proxy/01-unit_spec.lua b/spec/03-plugins/38-ai-proxy/01-unit_spec.lua index 4bc5eb76d76..7773ee6c71d 100644 --- a/spec/03-plugins/38-ai-proxy/01-unit_spec.lua +++ b/spec/03-plugins/38-ai-proxy/01-unit_spec.lua @@ -175,6 +175,49 @@ local FORMATS = { }, } +local STREAMS = { + openai = { + ["llm/v1/chat"] = { + name = "gpt-4", + provider = "openai", + }, + ["llm/v1/completions"] = { + name = "gpt-3.5-turbo-instruct", + provider = "openai", + }, + }, + cohere = { + ["llm/v1/chat"] = { + name = "command", + provider = "cohere", + }, + ["llm/v1/completions"] = { + name = "command-light", + provider = "cohere", + }, + }, +} + +local expected_stream_choices = { + ["llm/v1/chat"] = { + [1] = { + delta = { + content = "the answer", + }, + finish_reason = ngx.null, + index = 0, + logprobs = ngx.null, + }, + }, + ["llm/v1/completions"] = { + [1] = { + text = "the answer", + finish_reason = ngx.null, + index = 0, + logprobs = ngx.null, + }, + }, +} describe(PLUGIN_NAME .. ": (unit)", function() @@ -310,6 +353,35 @@ describe(PLUGIN_NAME .. ": (unit)", function() end) end + -- streaming tests + for provider_name, provider_format in pairs(STREAMS) do + + describe(provider_name .. " stream format tests", function() + + for format_name, config in pairs(provider_format) do + + ---- actual testing code begins here + describe(format_name .. " format test", function() + local driver = require("kong.llm.drivers." .. config.provider) + + -- what we do is first put the SAME request message from the user, through the converter, for this provider/format + it("converts to provider request format correctly", function() + local real_stream_frame = pl_file.read(fmt("spec/fixtures/ai-proxy/unit/real-stream-frames/%s/%s.txt", config.provider, pl_replace(format_name, "/", "-"))) + local real_transformed_frame, err = driver.from_format(real_stream_frame, config, "stream/" .. format_name) + + assert.is_nil(err) + + real_transformed_frame = cjson.decode(real_transformed_frame) + assert.same(expected_stream_choices[format_name], real_transformed_frame.choices) + end) + + end) + end + end) + + end + + -- generic tests it("throws correct error when format is not supported", function() local driver = require("kong.llm.drivers.mistral") -- one-shot, random example of provider with only prompt support @@ -334,4 +406,5 @@ describe(PLUGIN_NAME .. ": (unit)", function() assert.is_nil(content_type) assert.equal(err, "no transformer available to format mistral://llm/v1/chatnopenotsupported/ollama") end) + end) diff --git a/spec/03-plugins/38-ai-proxy/09-streaming_integration_spec.lua b/spec/03-plugins/38-ai-proxy/09-streaming_integration_spec.lua new file mode 100644 index 00000000000..089dbbb671b --- /dev/null +++ b/spec/03-plugins/38-ai-proxy/09-streaming_integration_spec.lua @@ -0,0 +1,514 @@ +local helpers = require "spec.helpers" +local cjson = require "cjson.safe" +local pl_file = require "pl.file" + +local http = require("resty.http") + +local PLUGIN_NAME = "ai-proxy" +local MOCK_PORT = helpers.get_available_port() + +for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then + describe(PLUGIN_NAME .. ": (access) [#" .. strategy .. "]", function() + local client + + lazy_setup(function() + local bp = helpers.get_db_utils(strategy == "off" and "postgres" or strategy, nil, { PLUGIN_NAME }) + + -- set up openai mock fixtures + local fixtures = { + http_mock = {}, + dns_mock = helpers.dns_mock.new({ + mocks_only = true, -- don't fallback to "real" DNS + }), + } + + fixtures.dns_mock:A { + name = "api.openai.com", + address = "127.0.0.1", + } + + fixtures.dns_mock:A { + name = "api.cohere.com", + address = "127.0.0.1", + } + + fixtures.http_mock.streams = [[ + server { + server_name openai; + listen ]]..MOCK_PORT..[[; + + default_type 'application/json'; + chunked_transfer_encoding on; + proxy_buffering on; + proxy_buffer_size 600; + proxy_buffers 10 600; + + location = "/openai/llm/v1/chat/good" { + content_by_lua_block { + local _EVENT_CHUNKS = { + [1] = 'data: { "choices": [ { "delta": { "content": "", "role": "assistant" }, "finish_reason": null, "index": 0, "logprobs": null } ], "created": 1712538905, "id": "chatcmpl-9BXtBvU8Tsw1U7CarzV71vQEjvYwq", "model": "gpt-4-0613", "object": "chat.completion.chunk", "system_fingerprint": null}', + [2] = 'data: { "choices": [ { "delta": { "content": "The " }, "finish_reason": null, "index": 0, "logprobs": null } ], "created": 1712538905, "id": "chatcmpl-9BXtBvU8Tsw1U7CarzV71vQEjvYwq", "model": "gpt-4-0613", "object": "chat.completion.chunk", "system_fingerprint": null}\n\ndata: { "choices": [ { "delta": { "content": "answer " }, "finish_reason": null, "index": 0, "logprobs": null } ], "created": 1712538905, "id": "chatcmpl-9BXtBvU8Tsw1U7CarzV71vQEjvYwq", "model": "gpt-4-0613", "object": "chat.completion.chunk", "system_fingerprint": null}', + [3] = 'data: { "choices": [ { "delta": { "content": "to 1 + " }, "finish_reason": null, "index": 0, "logprobs": null } ], "created": 1712538905, "id": "chatcmpl-9BXtBvU8Tsw1U7CarzV71vQEjvYwq", "model": "gpt-4-0613", "object": "chat.completion.chunk", "system_fingerprint": null}', + [4] = 'data: { "choices": [ { "delta": { "content": "1 is " }, "finish_reason": null, "index": 0, "logprobs": null } ], "created": 1712538905, "id": "chatcmpl-9BXtBvU8Tsw1U7CarzV71vQEjvYwq", "model": "gpt-4-0613", "object": "chat.completion.chunk", "system_fingerprint": null}\n\ndata: { "choices": [ { "delta": { "content": "2." }, "finish_reason": null, "index": 0, "logprobs": null } ], "created": 1712538905, "id": "chatcmpl-9BXtBvU8Tsw1U7CarzV71vQEjvYwq", "model": "gpt-4-0613", "object": "chat.completion.chunk", "system_fingerprint": null}', + [5] = 'data: { "choices": [ { "delta": {}, "finish_reason": "stop", "index": 0, "logprobs": null } ], "created": 1712538905, "id": "chatcmpl-9BXtBvU8Tsw1U7CarzV71vQEjvYwq", "model": "gpt-4-0613", "object": "chat.completion.chunk", "system_fingerprint": null}', + [6] = 'data: [DONE]', + } + + local fmt = string.format + local pl_file = require "pl.file" + local json = require("cjson.safe") + + ngx.req.read_body() + local body, err = ngx.req.get_body_data() + body, err = json.decode(body) + + local token = ngx.req.get_headers()["authorization"] + local token_query = ngx.req.get_uri_args()["apikey"] + + if token == "Bearer openai-key" or token_query == "openai-key" or body.apikey == "openai-key" then + ngx.req.read_body() + local body, err = ngx.req.get_body_data() + body, err = json.decode(body) + + if err or (body.messages == ngx.null) then + ngx.status = 400 + ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/responses/bad_request.json")) + else + -- GOOD RESPONSE + + ngx.status = 200 + ngx.header["Content-Type"] = "text/event-stream" + ngx.header["Transfer-Encoding"] = "chunked" + + for i, EVENT in ipairs(_EVENT_CHUNKS) do + ngx.print(fmt("%s\n\n", EVENT)) + end + end + else + ngx.status = 401 + ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/responses/unauthorized.json")) + end + } + } + + location = "/cohere/llm/v1/chat/good" { + content_by_lua_block { + local _EVENT_CHUNKS = { + [1] = '{"is_finished":false,"event_type":"stream-start","generation_id":"3f41d0ea-0d9c-4ecd-990a-88ba46ede663"}', + [2] = '{"is_finished":false,"event_type":"text-generation","text":"1"}', + [3] = '{"is_finished":false,"event_type":"text-generation","text":" +"}', + [4] = '{"is_finished":false,"event_type":"text-generation","text":" 1"}', + [5] = '{"is_finished":false,"event_type":"text-generation","text":" ="}', + [6] = '{"is_finished":false,"event_type":"text-generation","text":" 2"}', + [7] = '{"is_finished":false,"event_type":"text-generation","text":"."}\n\n{"is_finished":false,"event_type":"text-generation","text":" This"}', + [8] = '{"is_finished":false,"event_type":"text-generation","text":" is"}', + [9] = '{"is_finished":false,"event_type":"text-generation","text":" the"}', + [10] = '{"is_finished":false,"event_type":"text-generation","text":" most"}\n\n{"is_finished":false,"event_type":"text-generation","text":" basic"}', + [11] = '{"is_finished":false,"event_type":"text-generation","text":" example"}\n\n{"is_finished":false,"event_type":"text-generation","text":" of"}\n\n{"is_finished":false,"event_type":"text-generation","text":" addition"}', + [12] = '{"is_finished":false,"event_type":"text-generation","text":"."}', + [13] = '{"is_finished":true,"event_type":"stream-end","response":{"response_id":"4658c450-4755-4454-8f9e-a98dd376b9ad","text":"1 + 1 = 2. This is the most basic example of addition.","generation_id":"3f41d0ea-0d9c-4ecd-990a-88ba46ede663","chat_history":[{"role":"USER","message":"What is 1 + 1?"},{"role":"CHATBOT","message":"1 + 1 = 2. This is the most basic example of addition, an arithmetic operation that involves combining two or more numbers together to find their sum. In this case, the numbers being added are both 1, and the answer is 2, meaning 1 + 1 = 2 is an algebraic equation that shows the relationship between these two numbers when added together. This equation is often used as an example of the importance of paying close attention to details when doing math problems, because it is surprising to some people that something so trivial as adding 1 + 1 could ever equal anything other than 2."}],"meta":{"api_version":{"version":"1"},"billed_units":{"input_tokens":57,"output_tokens":123},"tokens":{"input_tokens":68,"output_tokens":123}}},"finish_reason":"COMPLETE"}', + } + + local fmt = string.format + local pl_file = require "pl.file" + local json = require("cjson.safe") + + ngx.req.read_body() + local body, err = ngx.req.get_body_data() + body, err = json.decode(body) + + local token = ngx.req.get_headers()["authorization"] + local token_query = ngx.req.get_uri_args()["apikey"] + + if token == "Bearer cohere-key" or token_query == "cohere-key" or body.apikey == "cohere-key" then + ngx.req.read_body() + local body, err = ngx.req.get_body_data() + body, err = json.decode(body) + + if err or (body.messages == ngx.null) then + ngx.status = 400 + ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/responses/bad_request.json")) + else + -- GOOD RESPONSE + + ngx.status = 200 + ngx.header["Content-Type"] = "text/event-stream" + ngx.header["Transfer-Encoding"] = "chunked" + + for i, EVENT in ipairs(_EVENT_CHUNKS) do + ngx.print(fmt("%s\n\n", EVENT)) + end + end + else + ngx.status = 401 + ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/responses/unauthorized.json")) + end + } + } + + location = "/openai/llm/v1/chat/bad" { + content_by_lua_block { + local fmt = string.format + local pl_file = require "pl.file" + local json = require("cjson.safe") + + ngx.req.read_body() + local body, err = ngx.req.get_body_data() + body, err = json.decode(body) + + local token = ngx.req.get_headers()["authorization"] + local token_query = ngx.req.get_uri_args()["apikey"] + + if token == "Bearer openai-key" or token_query == "openai-key" or body.apikey == "openai-key" then + ngx.req.read_body() + local body, err = ngx.req.get_body_data() + body, err = json.decode(body) + + if err or (body.messages == ngx.null) then + ngx.status = 400 + ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/responses/bad_request.json")) + else + -- BAD RESPONSE + + ngx.status = 400 + + ngx.say('{"error": { "message": "failure of some kind" }}') + end + else + ngx.status = 401 + ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/responses/unauthorized.json")) + end + } + } + } + ]] + + local empty_service = assert(bp.services:insert { + name = "empty_service", + host = "localhost", + port = 8080, + path = "/", + }) + + -- 200 chat openai + local openai_chat_good = assert(bp.routes:insert { + service = empty_service, + protocols = { "http" }, + strip_path = true, + paths = { "/openai/llm/v1/chat/good" } + }) + bp.plugins:insert { + name = PLUGIN_NAME, + route = { id = openai_chat_good.id }, + config = { + route_type = "llm/v1/chat", + auth = { + header_name = "Authorization", + header_value = "Bearer openai-key", + }, + model = { + name = "gpt-3.5-turbo", + provider = "openai", + options = { + max_tokens = 256, + temperature = 1.0, + upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/openai/llm/v1/chat/good" + }, + }, + }, + } + bp.plugins:insert { + name = "file-log", + route = { id = openai_chat_good.id }, + config = { + path = "/dev/stdout", + }, + } + -- + + -- 200 chat cohere + local cohere_chat_good = assert(bp.routes:insert { + service = empty_service, + protocols = { "http" }, + strip_path = true, + paths = { "/cohere/llm/v1/chat/good" } + }) + bp.plugins:insert { + name = PLUGIN_NAME, + route = { id = cohere_chat_good.id }, + config = { + route_type = "llm/v1/chat", + auth = { + header_name = "Authorization", + header_value = "Bearer cohere-key", + }, + model = { + name = "command", + provider = "cohere", + options = { + max_tokens = 256, + temperature = 1.0, + upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/cohere/llm/v1/chat/good" + }, + }, + }, + } + bp.plugins:insert { + name = "file-log", + route = { id = cohere_chat_good.id }, + config = { + path = "/dev/stdout", + }, + } + -- + + -- 400 chat openai + local openai_chat_bad = assert(bp.routes:insert { + service = empty_service, + protocols = { "http" }, + strip_path = true, + paths = { "/openai/llm/v1/chat/bad" } + }) + bp.plugins:insert { + name = PLUGIN_NAME, + route = { id = openai_chat_bad.id }, + config = { + route_type = "llm/v1/chat", + auth = { + header_name = "Authorization", + header_value = "Bearer openai-key", + }, + model = { + name = "gpt-3.5-turbo", + provider = "openai", + options = { + max_tokens = 256, + temperature = 1.0, + upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/openai/llm/v1/chat/bad" + }, + }, + }, + } + bp.plugins:insert { + name = "file-log", + route = { id = openai_chat_bad.id }, + config = { + path = "/dev/stdout", + }, + } + -- + + -- start kong + assert(helpers.start_kong({ + -- set the strategy + database = strategy, + -- use the custom test template to create a local mock server + nginx_conf = "spec/fixtures/custom_nginx.template", + -- make sure our plugin gets loaded + plugins = "bundled," .. PLUGIN_NAME, + -- write & load declarative config, only if 'strategy=off' + declarative_config = strategy == "off" and helpers.make_yaml_file() or nil, + }, nil, nil, fixtures)) + end) + + lazy_teardown(function() + helpers.stop_kong() + end) + + before_each(function() + client = helpers.proxy_client() + end) + + after_each(function() + if client then client:close() end + end) + + describe("stream llm/v1/chat", function() + it("good stream request openai", function() + local httpc = http.new() + + local ok, err, _ = httpc:connect({ + scheme = "http", + host = helpers.mock_upstream_host, + port = helpers.get_proxy_port(), + }) + if not ok then + ngx.log(ngx.ERR, "connection failed: ", err) + return + end + + -- Then send using `request`, supplying a path and `Host` header instead of a + -- full URI. + local res, err = httpc:request({ + path = "/openai/llm/v1/chat/good", + body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good-stream.json"), + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + }, + }) + if not res then + ngx.log(ngx.ERR, "request failed: ", err) + return + end + + local reader = res.body_reader + local buffer_size = 35536 + local events = {} + local buf = require("string.buffer").new() + + -- extract event + repeat + -- receive next chunk + local buffer, err = reader(buffer_size) + if err then + ngx.log(ngx.ERR, err) + break + end + + if buffer then + -- we need to rip each message from this chunk + for s in buffer:gmatch("[^\r\n]+") do + local s_copy = s + s_copy = string.sub(s_copy,7) + s_copy = cjson.decode(s_copy) + + buf:put(s_copy + and s_copy.choices + and s_copy.choices + and s_copy.choices[1] + and s_copy.choices[1].delta + and s_copy.choices[1].delta.content + or "") + + table.insert(events, s) + end + end + until not buffer + + assert.equal(#events, 8) + assert.equal(buf:tostring(), "The answer to 1 + 1 is 2.") + end) + + it("good stream request cohere", function() + local httpc = http.new() + + local ok, err, _ = httpc:connect({ + scheme = "http", + host = helpers.mock_upstream_host, + port = helpers.get_proxy_port(), + }) + if not ok then + ngx.log(ngx.ERR, "connection failed: ", err) + return + end + + -- Then send using `request`, supplying a path and `Host` header instead of a + -- full URI. + local res, err = httpc:request({ + path = "/cohere/llm/v1/chat/good", + body = pl_file.read("spec/fixtures/ai-proxy/cohere/llm-v1-chat/requests/good-stream.json"), + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + }, + }) + if not res then + ngx.log(ngx.ERR, "request failed: ", err) + return + end + + local reader = res.body_reader + local buffer_size = 35536 + local events = {} + local buf = require("string.buffer").new() + + -- extract event + repeat + -- receive next chunk + local buffer, err = reader(buffer_size) + if err then + ngx.log(ngx.ERR, err) + break + end + + if buffer then + -- we need to rip each message from this chunk + for s in buffer:gmatch("[^\r\n]+") do + local s_copy = s + s_copy = string.sub(s_copy,7) + s_copy = cjson.decode(s_copy) + + buf:put(s_copy + and s_copy.choices + and s_copy.choices + and s_copy.choices[1] + and s_copy.choices[1].delta + and s_copy.choices[1].delta.content + or "") + table.insert(events, s) + end + end + until not buffer + + assert.equal(#events, 16) + assert.equal(buf:tostring(), "1 + 1 = 2. This is the most basic example of addition.") + end) + + it("bad request is returned to the client not-streamed", function() + local httpc = http.new() + + local ok, err, _ = httpc:connect({ + scheme = "http", + host = helpers.mock_upstream_host, + port = helpers.get_proxy_port(), + }) + if not ok then + ngx.log(ngx.ERR, "connection failed: ", err) + return + end + + -- Then send using `request`, supplying a path and `Host` header instead of a + -- full URI. + local res, err = httpc:request({ + path = "/openai/llm/v1/chat/bad", + body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good-stream.json"), + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + }, + }) + if not res then + ngx.log(ngx.ERR, "request failed: ", err) + return + end + + local reader = res.body_reader + local buffer_size = 35536 + local events = {} + + -- extract event + repeat + -- receive next chunk + local buffer, err = reader(buffer_size) + if err then + ngx.log(ngx.ERR, err) + break + end + + if buffer then + -- we need to rip each message from this chunk + for s in buffer:gmatch("[^\r\n]+") do + table.insert(events, s) + end + end + until not buffer + + assert.equal(#events, 1) + assert.equal(res.status, 400) + end) + + end) + end) + +end end diff --git a/spec/fixtures/ai-proxy/cohere/llm-v1-chat/requests/good-stream.json b/spec/fixtures/ai-proxy/cohere/llm-v1-chat/requests/good-stream.json new file mode 100644 index 00000000000..c05edd15b8a --- /dev/null +++ b/spec/fixtures/ai-proxy/cohere/llm-v1-chat/requests/good-stream.json @@ -0,0 +1,13 @@ +{ + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": "user", + "content": "What is 1 + 1?" + } + ], + "stream": true +} \ No newline at end of file diff --git a/spec/fixtures/ai-proxy/llama2/ollama/chat-stream.json b/spec/fixtures/ai-proxy/llama2/ollama/chat-stream.json new file mode 100644 index 00000000000..790bb70726a --- /dev/null +++ b/spec/fixtures/ai-proxy/llama2/ollama/chat-stream.json @@ -0,0 +1,13 @@ +{ + "messages":[ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": "user", + "content": "What is 1 + 1?" + } + ], + "stream": true +} \ No newline at end of file diff --git a/spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good-stream.json b/spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good-stream.json new file mode 100644 index 00000000000..790bb70726a --- /dev/null +++ b/spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good-stream.json @@ -0,0 +1,13 @@ +{ + "messages":[ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": "user", + "content": "What is 1 + 1?" + } + ], + "stream": true +} \ No newline at end of file diff --git a/spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good.json b/spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good.json index 4cc10281331..5d32fa0af9e 100644 --- a/spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good.json +++ b/spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good.json @@ -8,5 +8,6 @@ "role": "user", "content": "What is 1 + 1?" } - ] + ], + "stream": false } \ No newline at end of file diff --git a/spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good_own_model.json b/spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good_own_model.json index 9c27eaa1186..e72de3c7a70 100644 --- a/spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good_own_model.json +++ b/spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good_own_model.json @@ -9,5 +9,6 @@ "content": "What is 1 + 1?" } ], - "model": "try-to-override-the-model" + "model": "try-to-override-the-model", + "stream": false } \ No newline at end of file diff --git a/spec/fixtures/ai-proxy/unit/expected-requests/azure/llm-v1-chat.json b/spec/fixtures/ai-proxy/unit/expected-requests/azure/llm-v1-chat.json index c02be6e513f..5a9c2d9e70c 100644 --- a/spec/fixtures/ai-proxy/unit/expected-requests/azure/llm-v1-chat.json +++ b/spec/fixtures/ai-proxy/unit/expected-requests/azure/llm-v1-chat.json @@ -28,5 +28,6 @@ "model": "gpt-4", "max_tokens": 512, "temperature": 0.5, + "stream": false, "top_p": 1.0 } \ No newline at end of file diff --git a/spec/fixtures/ai-proxy/unit/expected-requests/azure/llm-v1-completions.json b/spec/fixtures/ai-proxy/unit/expected-requests/azure/llm-v1-completions.json index 0a9efde384f..bc7368bb7d4 100644 --- a/spec/fixtures/ai-proxy/unit/expected-requests/azure/llm-v1-completions.json +++ b/spec/fixtures/ai-proxy/unit/expected-requests/azure/llm-v1-completions.json @@ -2,5 +2,6 @@ "prompt": "Explain why you can't divide by zero?", "model": "gpt-3.5-turbo-instruct", "max_tokens": 512, - "temperature": 0.5 + "temperature": 0.5, + "stream": false } diff --git a/spec/fixtures/ai-proxy/unit/expected-requests/mistral/openai/llm-v1-chat.json b/spec/fixtures/ai-proxy/unit/expected-requests/mistral/openai/llm-v1-chat.json index 4e5191c0963..adb87db4085 100644 --- a/spec/fixtures/ai-proxy/unit/expected-requests/mistral/openai/llm-v1-chat.json +++ b/spec/fixtures/ai-proxy/unit/expected-requests/mistral/openai/llm-v1-chat.json @@ -27,5 +27,6 @@ ], "model": "mistral-tiny", "max_tokens": 512, + "stream": false, "temperature": 0.5 } \ No newline at end of file diff --git a/spec/fixtures/ai-proxy/unit/expected-requests/openai/llm-v1-chat.json b/spec/fixtures/ai-proxy/unit/expected-requests/openai/llm-v1-chat.json index 23e165166a2..d7aa2028f45 100644 --- a/spec/fixtures/ai-proxy/unit/expected-requests/openai/llm-v1-chat.json +++ b/spec/fixtures/ai-proxy/unit/expected-requests/openai/llm-v1-chat.json @@ -27,5 +27,6 @@ ], "model": "gpt-4", "max_tokens": 512, + "stream": false, "temperature": 0.5 } \ No newline at end of file diff --git a/spec/fixtures/ai-proxy/unit/expected-requests/openai/llm-v1-completions.json b/spec/fixtures/ai-proxy/unit/expected-requests/openai/llm-v1-completions.json index 0a9efde384f..bc7368bb7d4 100644 --- a/spec/fixtures/ai-proxy/unit/expected-requests/openai/llm-v1-completions.json +++ b/spec/fixtures/ai-proxy/unit/expected-requests/openai/llm-v1-completions.json @@ -2,5 +2,6 @@ "prompt": "Explain why you can't divide by zero?", "model": "gpt-3.5-turbo-instruct", "max_tokens": 512, - "temperature": 0.5 + "temperature": 0.5, + "stream": false } diff --git a/spec/fixtures/ai-proxy/unit/real-stream-frames/cohere/llm-v1-chat.txt b/spec/fixtures/ai-proxy/unit/real-stream-frames/cohere/llm-v1-chat.txt new file mode 100644 index 00000000000..f61584882d6 --- /dev/null +++ b/spec/fixtures/ai-proxy/unit/real-stream-frames/cohere/llm-v1-chat.txt @@ -0,0 +1 @@ +{"is_finished":false,"event_type":"text-generation","text":"the answer"} diff --git a/spec/fixtures/ai-proxy/unit/real-stream-frames/cohere/llm-v1-completions.txt b/spec/fixtures/ai-proxy/unit/real-stream-frames/cohere/llm-v1-completions.txt new file mode 100644 index 00000000000..4add796f4f2 --- /dev/null +++ b/spec/fixtures/ai-proxy/unit/real-stream-frames/cohere/llm-v1-completions.txt @@ -0,0 +1 @@ +{"text":"the answer","is_finished":false,"event_type":"text-generation"} diff --git a/spec/fixtures/ai-proxy/unit/real-stream-frames/openai/llm-v1-chat.txt b/spec/fixtures/ai-proxy/unit/real-stream-frames/openai/llm-v1-chat.txt new file mode 100644 index 00000000000..2f7c45fe0a5 --- /dev/null +++ b/spec/fixtures/ai-proxy/unit/real-stream-frames/openai/llm-v1-chat.txt @@ -0,0 +1 @@ +data: {"choices": [{"delta": {"content": "the answer"},"finish_reason": null,"index": 0,"logprobs": null}],"created": 1711938086,"id": "chatcmpl-991aYb1iD8OSD54gcxZxv8uazlTZy","model": "gpt-4-0613","object": "chat.completion.chunk","system_fingerprint": null} diff --git a/spec/fixtures/ai-proxy/unit/real-stream-frames/openai/llm-v1-completions.txt b/spec/fixtures/ai-proxy/unit/real-stream-frames/openai/llm-v1-completions.txt new file mode 100644 index 00000000000..fac4fed43ff --- /dev/null +++ b/spec/fixtures/ai-proxy/unit/real-stream-frames/openai/llm-v1-completions.txt @@ -0,0 +1 @@ +data: {"choices": [{"finish_reason": null,"index": 0,"logprobs": null,"text": "the answer"}],"created": 1711938803,"id": "cmpl-991m7YSJWEnzrBqk41In8Xer9RIEB","model": "gpt-3.5-turbo-instruct","object": "text_completion"} \ No newline at end of file