From f7096f23cb85d383939d0fb4db10bf8a76f4694e Mon Sep 17 00:00:00 2001 From: Jack Tysoe <91137069+tysoekong@users.noreply.github.com> Date: Mon, 22 Jul 2024 08:43:12 +0100 Subject: [PATCH] feat(ai-proxy): google-gemini support (#9678) AG-27 --- .../kong/ai-proxy-google-gemini.yml | 5 + kong-3.8.0-0.rockspec | 2 + kong/llm/drivers/azure.lua | 34 +- kong/llm/drivers/gemini.lua | 440 ++++++++++++++++++ kong/llm/drivers/shared.lua | 92 +++- kong/llm/init.lua | 2 - kong/llm/schemas/init.lua | 38 +- kong/plugins/ai-proxy/handler.lua | 158 ++++--- spec/03-plugins/38-ai-proxy/01-unit_spec.lua | 57 +++ .../expected-requests/gemini/llm-v1-chat.json | 57 +++ .../gemini/llm-v1-chat.json | 14 + .../real-responses/gemini/llm-v1-chat.json | 34 ++ .../complete-json/expected-output.json | 8 + .../complete-json/input.bin | 2 + .../expected-output.json | 14 + .../partial-json-beginning/input.bin | 141 ++++++ .../partial-json-end/expected-output.json | 8 + .../partial-json-end/input.bin | 80 ++++ .../text-event-stream/expected-output.json | 11 + .../text-event-stream/input.bin | 7 + 20 files changed, 1111 insertions(+), 93 deletions(-) create mode 100644 changelog/unreleased/kong/ai-proxy-google-gemini.yml create mode 100644 kong/llm/drivers/gemini.lua create mode 100644 spec/fixtures/ai-proxy/unit/expected-requests/gemini/llm-v1-chat.json create mode 100644 spec/fixtures/ai-proxy/unit/expected-responses/gemini/llm-v1-chat.json create mode 100644 spec/fixtures/ai-proxy/unit/real-responses/gemini/llm-v1-chat.json create mode 100644 spec/fixtures/ai-proxy/unit/streaming-chunk-formats/complete-json/expected-output.json create mode 100644 spec/fixtures/ai-proxy/unit/streaming-chunk-formats/complete-json/input.bin create mode 100644 spec/fixtures/ai-proxy/unit/streaming-chunk-formats/partial-json-beginning/expected-output.json create mode 100644 spec/fixtures/ai-proxy/unit/streaming-chunk-formats/partial-json-beginning/input.bin create mode 100644 spec/fixtures/ai-proxy/unit/streaming-chunk-formats/partial-json-end/expected-output.json create mode 100644 spec/fixtures/ai-proxy/unit/streaming-chunk-formats/partial-json-end/input.bin create mode 100644 spec/fixtures/ai-proxy/unit/streaming-chunk-formats/text-event-stream/expected-output.json create mode 100644 spec/fixtures/ai-proxy/unit/streaming-chunk-formats/text-event-stream/input.bin diff --git a/changelog/unreleased/kong/ai-proxy-google-gemini.yml b/changelog/unreleased/kong/ai-proxy-google-gemini.yml new file mode 100644 index 000000000000..bc4fb06b21c4 --- /dev/null +++ b/changelog/unreleased/kong/ai-proxy-google-gemini.yml @@ -0,0 +1,5 @@ +message: | + Kong AI Gateway (AI Proxy and associated plugin family) now supports + the Google Gemini "chat" (generateContent) interface. +type: feature +scope: Plugin diff --git a/kong-3.8.0-0.rockspec b/kong-3.8.0-0.rockspec index ca6c48f8da7a..16c7232334ef 100644 --- a/kong-3.8.0-0.rockspec +++ b/kong-3.8.0-0.rockspec @@ -909,6 +909,8 @@ build = { ["kong.llm.drivers.mistral"] = "kong/llm/drivers/mistral.lua", ["kong.llm.drivers.llama2"] = "kong/llm/drivers/llama2.lua", + ["kong.llm.drivers.gemini"] = "kong/llm/drivers/gemini.lua", + ["kong.plugins.ai-prompt-template.handler"] = "kong/plugins/ai-prompt-template/handler.lua", ["kong.plugins.ai-prompt-template.schema"] = "kong/plugins/ai-prompt-template/schema.lua", ["kong.plugins.ai-prompt-template.templater"] = "kong/plugins/ai-prompt-template/templater.lua", diff --git a/kong/llm/drivers/azure.lua b/kong/llm/drivers/azure.lua index 9a06f9266486..957601c62be6 100644 --- a/kong/llm/drivers/azure.lua +++ b/kong/llm/drivers/azure.lua @@ -105,7 +105,7 @@ function _M.subrequest(body, conf, http_opts, return_res_table) end -- returns err or nil -function _M.configure_request(conf) +function _M.configure_request(conf, identity_interface) local parsed_url if conf.model.options.upstream_url then @@ -136,26 +136,38 @@ function _M.configure_request(conf) local auth_param_name = conf.auth and conf.auth.param_name local auth_param_value = conf.auth and conf.auth.param_value local auth_param_location = conf.auth and conf.auth.param_location + local query_table = kong.request.get_query() - if auth_header_name and auth_header_value then - kong.service.request.set_header(auth_header_name, auth_header_value) - end + -- [[ EE + if identity_interface then -- managed identity mode, passed from ai-proxy handler.lua + local _, token, _, err = identity_interface.credentials:get() - local query_table = kong.request.get_query() + if err then + kong.log.err("failed to authenticate with Azure OpenAI, ", err) + return kong.response.exit(500, { error = { message = "failed to authenticate with Azure OpenAI" }}) + end + + kong.service.request.set_header("Authorization", "Bearer " .. token) + + else + if auth_header_name and auth_header_value then + kong.service.request.set_header(auth_header_name, auth_header_value) + end + + if auth_param_name and auth_param_value and auth_param_location == "query" then + query_table[auth_param_name] = auth_param_value + end + -- if auth_param_location is "form", it will have already been set in a pre-request hook + end + -- ]] -- technically min supported version query_table["api-version"] = kong.request.get_query_arg("api-version") or (conf.model.options and conf.model.options.azure_api_version) - - if auth_param_name and auth_param_value and auth_param_location == "query" then - query_table[auth_param_name] = auth_param_value - end kong.service.request.set_query(query_table) - -- if auth_param_location is "form", it will have already been set in a pre-request hook return true, nil end - return _M diff --git a/kong/llm/drivers/gemini.lua b/kong/llm/drivers/gemini.lua new file mode 100644 index 000000000000..e36baf2a49f9 --- /dev/null +++ b/kong/llm/drivers/gemini.lua @@ -0,0 +1,440 @@ +-- This software is copyright Kong Inc. and its licensors. +-- Use of the software is subject to the agreement between your organization +-- and Kong Inc. If there is no such agreement, use is governed by and +-- subject to the terms of the Kong Master Software License Agreement found +-- at https://konghq.com/enterprisesoftwarelicense/. +-- [ END OF LICENSE 0867164ffc95e54f04670b5169c09574bdbd9bba ] + +local _M = {} + +-- imports +local cjson = require("cjson.safe") +local fmt = string.format +local ai_shared = require("kong.llm.drivers.shared") +local socket_url = require("socket.url") +local string_gsub = string.gsub +local buffer = require("string.buffer") +local table_insert = table.insert +local string_lower = string.lower +-- + +-- globals +local DRIVER_NAME = "gemini" +-- + +local _OPENAI_ROLE_MAPPING = { + ["system"] = "system", + ["user"] = "user", + ["assistant"] = "model", +} + +local function to_gemini_generation_config(request_table) + return { + ["maxOutputTokens"] = request_table.max_tokens, + ["stopSequences"] = request_table.stop, + ["temperature"] = request_table.temperature, + ["topK"] = request_table.top_k, + ["topP"] = request_table.top_p, + } +end + +local function is_response_content(content) + return content + and content.candidates + and #content.candidates > 0 + and content.candidates[1].content + and content.candidates[1].content.parts + and #content.candidates[1].content.parts > 0 + and content.candidates[1].content.parts[1].text +end + +local function is_response_finished(content) + return content + and content.candidates + and #content.candidates > 0 + and content.candidates[1].finishReason +end + +local function handle_stream_event(event_t, model_info, route_type) + -- discard empty frames, it should either be a random new line, or comment + if (not event_t.data) or (#event_t.data < 1) then + return + end + + local event, err = cjson.decode(event_t.data) + if err then + ngx.log(ngx.WARN, "failed to decode stream event frame from gemini: " .. err) + return nil, "failed to decode stream event frame from gemini", nil + end + + local new_event + local metadata = nil + + if is_response_content(event) then + new_event = { + choices = { + [1] = { + delta = { + content = event.candidates[1].content.parts[1].text or "", + role = "assistant", + }, + index = 0, + }, + }, + } + end + + if is_response_finished(event) then + metadata = metadata or {} + metadata.finished_reason = event.candidates[1].finishReason + new_event = "[DONE]" + end + + if event.usageMetadata then + metadata = metadata or {} + metadata.completion_tokens = event.usageMetadata.candidatesTokenCount or 0 + metadata.prompt_tokens = event.usageMetadata.promptTokenCount or 0 + end + + if new_event then + if new_event ~= "[DONE]" then + new_event = cjson.encode(new_event) + end + + return new_event, nil, metadata + else + return nil, nil, metadata -- caller code will handle "unrecognised" event types + end +end + +local function to_gemini_chat_openai(request_table, model_info, route_type) + if request_table then -- try-catch type mechanism + local new_r = {} + + if request_table.messages and #request_table.messages > 0 then + local system_prompt + + for i, v in ipairs(request_table.messages) do + + -- for 'system', we just concat them all into one Gemini instruction + if v.role and v.role == "system" then + system_prompt = system_prompt or buffer.new() + system_prompt:put(v.content or "") + else + -- for any other role, just construct the chat history as 'parts.text' type + new_r.contents = new_r.contents or {} + table_insert(new_r.contents, { + role = _OPENAI_ROLE_MAPPING[v.role or "user"], -- default to 'user' + parts = { + { + text = v.content or "" + }, + }, + }) + end + end + + -- This was only added in Gemini 1.5 + if system_prompt and model_info.name:sub(1, 10) == "gemini-1.0" then + return nil, nil, "system prompts aren't supported on gemini-1.0 models" + + elseif system_prompt then + new_r.systemInstruction = { + parts = { + { + text = system_prompt:get(), + }, + }, + } + end + end + + new_r.generationConfig = to_gemini_generation_config(request_table) + + return new_r, "application/json", nil + end + + local new_r = {} + + if request_table.messages and #request_table.messages > 0 then + local system_prompt + + for i, v in ipairs(request_table.messages) do + + -- for 'system', we just concat them all into one Gemini instruction + if v.role and v.role == "system" then + system_prompt = system_prompt or buffer.new() + system_prompt:put(v.content or "") + else + -- for any other role, just construct the chat history as 'parts.text' type + new_r.contents = new_r.contents or {} + table_insert(new_r.contents, { + role = _OPENAI_ROLE_MAPPING[v.role or "user"], -- default to 'user' + parts = { + { + text = v.content or "" + }, + }, + }) + end + end + end + + new_r.generationConfig = to_gemini_generation_config(request_table) + + return new_r, "application/json", nil +end + +local function from_gemini_chat_openai(response, model_info, route_type) + local response, err = cjson.decode(response) + + if err then + local err_client = "failed to decode response from Gemini" + ngx.log(ngx.ERR, fmt("%s: %s", err_client, err)) + return nil, err_client + end + + -- messages/choices table is only 1 size, so don't need to static allocate + local messages = {} + messages.choices = {} + + if response.candidates + and #response.candidates > 0 + and is_response_content(response) then + + messages.choices[1] = { + index = 0, + message = { + role = "assistant", + content = response.candidates[1].content.parts[1].text, + }, + finish_reason = string_lower(response.candidates[1].finishReason), + } + messages.object = "chat.completion" + messages.model = model_info.name + + else -- probably a server fault or other unexpected response + local err = "no generation candidates received from Gemini, or max_tokens too short" + ngx.log(ngx.ERR, err) + return nil, err + end + + return cjson.encode(messages) +end + +local transformers_to = { + ["llm/v1/chat"] = to_gemini_chat_openai, +} + +local transformers_from = { + ["llm/v1/chat"] = from_gemini_chat_openai, + ["stream/llm/v1/chat"] = handle_stream_event, +} + +function _M.from_format(response_string, model_info, route_type) + ngx.log(ngx.DEBUG, "converting from ", model_info.provider, "://", route_type, " type to kong") + + -- MUST return a string, to set as the response body + if not transformers_from[route_type] then + return nil, fmt("no transformer available from format %s://%s", model_info.provider, route_type) + end + + local ok, response_string, err, metadata = pcall(transformers_from[route_type], response_string, model_info, route_type) + if not ok or err then + return nil, fmt("transformation failed from type %s://%s: %s", + model_info.provider, + route_type, + err or "unexpected_error" + ) + end + + return response_string, nil, metadata +end + +function _M.to_format(request_table, model_info, route_type) + ngx.log(ngx.DEBUG, "converting from kong type to ", model_info.provider, "/", route_type) + + if route_type == "preserve" then + -- do nothing + return request_table, nil, nil + end + + if not transformers_to[route_type] then + return nil, nil, fmt("no transformer for %s://%s", model_info.provider, route_type) + end + + request_table = ai_shared.merge_config_defaults(request_table, model_info.options, model_info.route_type) + + local ok, response_object, content_type, err = pcall( + transformers_to[route_type], + request_table, + model_info + ) + if err or (not ok) then + return nil, nil, fmt("error transforming to %s://%s: %s", model_info.provider, route_type, err) + end + + return response_object, content_type, nil +end + +function _M.subrequest(body, conf, http_opts, return_res_table) + -- use shared/standard subrequest routine + local body_string, err + + if type(body) == "table" then + body_string, err = cjson.encode(body) + if err then + return nil, nil, "failed to parse body to json: " .. err + end + elseif type(body) == "string" then + body_string = body + else + return nil, nil, "body must be table or string" + end + + -- may be overridden + local url = (conf.model.options and conf.model.options.upstream_url) + or fmt( + "%s%s", + ai_shared.upstream_url_format[DRIVER_NAME], + ai_shared.operation_map[DRIVER_NAME][conf.route_type].path + ) + + local method = ai_shared.operation_map[DRIVER_NAME][conf.route_type].method + + local headers = { + ["Accept"] = "application/json", + ["Content-Type"] = "application/json", + } + + if conf.auth and conf.auth.header_name then + headers[conf.auth.header_name] = conf.auth.header_value + end + + local res, err, httpc = ai_shared.http_request(url, body_string, method, headers, http_opts, return_res_table) + if err then + return nil, nil, "request to ai service failed: " .. err + end + + if return_res_table then + return res, res.status, nil, httpc + else + -- At this point, the entire request / response is complete and the connection + -- will be closed or back on the connection pool. + local status = res.status + local body = res.body + + if status > 299 then + return body, res.status, "status code " .. status + end + + return body, res.status, nil + end +end + +function _M.header_filter_hooks(body) + -- nothing to parse in header_filter phase +end + +function _M.post_request(conf) + if ai_shared.clear_response_headers[DRIVER_NAME] then + for i, v in ipairs(ai_shared.clear_response_headers[DRIVER_NAME]) do + kong.response.clear_header(v) + end + end +end + +function _M.pre_request(conf, body) + -- disable gzip for gemini because it breaks streaming + kong.service.request.set_header("Accept-Encoding", "identity") + + return true, nil +end + +-- returns err or nil +function _M.configure_request(conf, identity_interface) + local parsed_url + local operation = kong.ctx.shared.ai_proxy_streaming_mode and "streamGenerateContent" + or "generateContent" + local f_url = conf.model.options and conf.model.options.upstream_url + + if not f_url then -- upstream_url override is not set + -- check if this is "public" or "vertex" gemini deployment + if conf.model.options + and conf.model.options.gemini + and conf.model.options.gemini.api_endpoint + and conf.model.options.gemini.project_id + and conf.model.options.gemini.location_id + then + -- vertex mode + f_url = fmt(ai_shared.upstream_url_format["gemini_vertex"], + conf.model.options.gemini.api_endpoint) .. + fmt(ai_shared.operation_map["gemini_vertex"][conf.route_type].path, + conf.model.options.gemini.project_id, + conf.model.options.gemini.location_id, + conf.model.name, + operation) + else + -- public mode + f_url = ai_shared.upstream_url_format["gemini"] .. + fmt(ai_shared.operation_map["gemini"][conf.route_type].path, + conf.model.name, + operation) + end + end + + parsed_url = socket_url.parse(f_url) + + if conf.model.options and conf.model.options.upstream_path then + -- upstream path override is set (or templated from request params) + parsed_url.path = conf.model.options.upstream_path + end + + -- if the path is read from a URL capture, ensure that it is valid + parsed_url.path = string_gsub(parsed_url.path, "^/*", "/") + + kong.service.request.set_path(parsed_url.path) + kong.service.request.set_scheme(parsed_url.scheme) + kong.service.set_target(parsed_url.host, (tonumber(parsed_url.port) or 443)) + + local auth_header_name = conf.auth and conf.auth.header_name + local auth_header_value = conf.auth and conf.auth.header_value + local auth_param_name = conf.auth and conf.auth.param_name + local auth_param_value = conf.auth and conf.auth.param_value + local auth_param_location = conf.auth and conf.auth.param_location + + -- DBO restrictions makes sure that only one of these auth blocks runs in one plugin config + if auth_header_name and auth_header_value then + kong.service.request.set_header(auth_header_name, auth_header_value) + end + + if auth_param_name and auth_param_value and auth_param_location == "query" then + local query_table = kong.request.get_query() + query_table[auth_param_name] = auth_param_value + kong.service.request.set_query(query_table) + end + -- if auth_param_location is "form", it will have already been set in a global pre-request hook + + -- if we're passed a GCP SDK, for cloud identity / SSO, use it appropriately + if identity_interface then + if identity_interface:needsRefresh() then + -- HACK: A bug in lua-resty-gcp tries to re-load the environment + -- variable every time, which fails in nginx + -- Create a whole new interface instead. + -- Memory leaks are mega unlikely because this should only + -- happen about once an hour, and the old one will be + -- cleaned up anyway. + local service_account_json = identity_interface.service_account_json + local identity_interface_new = identity_interface:new(service_account_json) + identity_interface.token = identity_interface_new.token + + kong.log.notice("gcp identity token for ", kong.plugin.get_id(), " has been refreshed") + end + + kong.service.request.set_header("Authorization", "Bearer " .. identity_interface.token) + end + + return true +end + +return _M diff --git a/kong/llm/drivers/shared.lua b/kong/llm/drivers/shared.lua index 713f199e1dc0..c82164258d2b 100644 --- a/kong/llm/drivers/shared.lua +++ b/kong/llm/drivers/shared.lua @@ -23,7 +23,7 @@ local split = require("kong.tools.string").split local cycle_aware_deep_copy = require("kong.tools.table").cycle_aware_deep_copy local function str_ltrim(s) -- remove leading whitespace from string. - return (s:gsub("^%s*", "")) + return type(s) == "string" and s:gsub("^%s*", "") end -- @@ -62,6 +62,7 @@ _M.streaming_has_token_counts = { ["cohere"] = true, ["llama2"] = true, ["anthropic"] = true, + ["gemini"] = true, } _M.upstream_url_format = { @@ -69,6 +70,8 @@ _M.upstream_url_format = { anthropic = "https://api.anthropic.com:443", cohere = "https://api.cohere.com:443", azure = "https://%s.openai.azure.com:443/openai/deployments/%s", + gemini = "https://generativelanguage.googleapis.com", + gemini_vertex = "https://%s", } _M.operation_map = { @@ -112,6 +115,18 @@ _M.operation_map = { method = "POST", }, }, + gemini = { + ["llm/v1/chat"] = { + path = "/v1beta/models/%s:%s", + method = "POST", + }, + }, + gemini_vertex = { + ["llm/v1/chat"] = { + path = "/v1/projects/%s/locations/%s/publishers/google/models/%s:%s", + method = "POST", + }, + }, } _M.clear_response_headers = { @@ -127,6 +142,9 @@ _M.clear_response_headers = { mistral = { "Set-Cookie", }, + gemini = { + "Set-Cookie", + }, } --- @@ -206,21 +224,44 @@ end -- as if it were an SSE message. -- -- @param {string} frame input string to format into SSE events --- @param {string} delimiter delimeter (can be complex string) to split by +-- @param {boolean} raw_json sets application/json byte-parser mode -- @return {table} n number of split SSE messages, or empty table -function _M.frame_to_events(frame) +function _M.frame_to_events(frame, raw_json_mode) local events = {} + if (not frame) or (#frame < 1) or (type(frame)) ~= "string" then + return + end + + -- some new LLMs return the JSON object-by-object, + -- because that totally makes sense to parse?! + if raw_json_mode then + -- if this is the first frame, it will begin with array opener '[' + frame = (string.sub(str_ltrim(frame), 1, 1) == "[" and string.sub(str_ltrim(frame), 2)) or frame + + -- it may start with ',' which is the start of the new frame + frame = (string.sub(str_ltrim(frame), 1, 1) == "," and string.sub(str_ltrim(frame), 2)) or frame + + -- finally, it may end with the array terminator ']' indicating the finished stream + frame = (string.sub(str_ltrim(frame), -1) == "]" and string.sub(str_ltrim(frame), 1, -2)) or frame + + -- for multiple events that arrive in the same frame, split by top-level comma + for _, v in ipairs(split(frame, "\n,")) do + events[#events+1] = { data = v } + end + + -- check if it's raw json and just return the split up data frame -- Cohere / Other flat-JSON format parser -- just return the split up data frame - if (not kong or not kong.ctx.plugin.truncated_frame) and string.sub(str_ltrim(frame), 1, 1) == "{" then + elseif (not kong or not kong.ctx.plugin.truncated_frame) and string.sub(str_ltrim(frame), 1, 1) == "{" then for event in frame:gmatch("[^\r\n]+") do events[#events + 1] = { data = event, } end + + -- standard SSE parser else - -- standard SSE parser local event_lines = split(frame, "\n") local struct = { event = nil, id = nil, data = nil } @@ -233,7 +274,10 @@ function _M.frame_to_events(frame) -- test for truncated chunk on the last line (no trailing \r\n\r\n) if #dat > 0 and #event_lines == i then ngx.log(ngx.DEBUG, "[ai-proxy] truncated sse frame head") - kong.ctx.plugin.truncated_frame = dat + if kong then + kong.ctx.plugin.truncated_frame = dat + end + break -- stop parsing immediately, server has done something wrong end @@ -411,24 +455,26 @@ function _M.resolve_plugin_conf(kong_request, conf) -- handle all other options for k, v in pairs(conf.model.options or {}) do - local prop_m = string_match(v or "", '%$%((.-)%)') - if prop_m then - local splitted = split(prop_m, '.') - if #splitted ~= 2 then - return nil, "cannot parse expression for field '" .. v .. "'" - end - - -- find the request parameter, with the configured name - prop_m, err = _M.conf_from_request(kong_request, splitted[1], splitted[2]) - if err then - return nil, err - end - if not prop_m then - return nil, splitted[1] .. " key " .. splitted[2] .. " was not provided" - end + if type(v) == "string" then + local prop_m = string_match(v or "", '%$%((.-)%)') + if prop_m then + local splitted = split(prop_m, '.') + if #splitted ~= 2 then + return nil, "cannot parse expression for field '" .. v .. "'" + end + + -- find the request parameter, with the configured name + prop_m, err = _M.conf_from_request(kong_request, splitted[1], splitted[2]) + if err then + return nil, err + end + if not prop_m then + return nil, splitted[1] .. " key " .. splitted[2] .. " was not provided" + end - -- replace the value - conf_m.model.options[k] = prop_m + -- replace the value + conf_m.model.options[k] = prop_m + end end end diff --git a/kong/llm/init.lua b/kong/llm/init.lua index 88820cb31136..3d7ff7eebe52 100644 --- a/kong/llm/init.lua +++ b/kong/llm/init.lua @@ -19,8 +19,6 @@ local _M = { config_schema = require "kong.llm.schemas", } - - do -- formats_compatible is a map of formats that are compatible with each other. local formats_compatible = { diff --git a/kong/llm/schemas/init.lua b/kong/llm/schemas/init.lua index f8a0f3ab6d08..29d2e7ab9619 100644 --- a/kong/llm/schemas/init.lua +++ b/kong/llm/schemas/init.lua @@ -10,6 +10,28 @@ local typedefs = require("kong.db.schema.typedefs") local fmt = string.format +local gemini_options_schema = { + type = "record", + required = false, + fields = { + { api_endpoint = { + type = "string", + description = "If running Gemini on Vertex, specify the regional API endpoint (hostname only).", + required = false }}, + { project_id = { + type = "string", + description = "If running Gemini on Vertex, specify the project ID.", + required = false }}, + { location_id = { + type = "string", + description = "If running Gemini on Vertex, specify the location ID.", + required = false }}, + }, + entity_checks = { + { mutually_required = { "api_endpoint", "project_id", "location_id" }, }, + }, +} + local auth_schema = { type = "record", @@ -81,11 +103,22 @@ local auth_schema = { } }, -- EE ]] + { gcp_use_service_account = { + type = "boolean", + description = "Use service account auth for GCP-based providers and models.", + required = false, + default = false }}, + { gcp_service_account_json = { + type = "string", + description = "Set this field to the full JSON of the GCP service account to authenticate, if required. " .. + "If null (and gcp_use_service_account is true), Kong will attempt to read from " .. + "environment variable `GCP_SERVICE_ACCOUNT`.", + required = false, + referenceable = true }}, } } - local model_options_schema = { description = "Key/value settings for the model", type = "record", @@ -157,6 +190,7 @@ local model_options_schema = { .. "used when e.g. using the 'preserve' route_type.", type = "string", required = false }}, + { gemini = gemini_options_schema }, } } @@ -170,7 +204,7 @@ local model_schema = { type = "string", description = "AI provider request format - Kong translates " .. "requests to and from the specified backend compatible formats.", required = true, - one_of = { "openai", "azure", "anthropic", "cohere", "mistral", "llama2" }}}, + one_of = { "openai", "azure", "anthropic", "cohere", "mistral", "llama2", "gemini" }}}, { name = { type = "string", description = "Model name to execute.", diff --git a/kong/plugins/ai-proxy/handler.lua b/kong/plugins/ai-proxy/handler.lua index 4e698233adca..f402d91b1b56 100644 --- a/kong/plugins/ai-proxy/handler.lua +++ b/kong/plugins/ai-proxy/handler.lua @@ -13,6 +13,14 @@ local kong_meta = require("kong.meta") local buffer = require "string.buffer" local strip = require("kong.tools.utils").strip +-- cloud auth/sdk providers +local GCP_SERVICE_ACCOUNT do + GCP_SERVICE_ACCOUNT = os.getenv("GCP_SERVICE_ACCOUNT") +end + +local GCP = require("resty.gcp.request.credentials.accesstoken") +-- + local EMPTY = {} @@ -23,43 +31,65 @@ local _M = { } +-- static messages +local ERROR__NOT_SET = 'data: {"error": true, "message": "empty or unsupported transformer response"}' ---- Return a 400 response with a JSON body. This function is used to --- return errors to the client while also logging the error. -local function bad_request(msg) - kong.log.info(msg) - return kong.response.exit(400, { error = { message = msg } }) -end - - --- [[ EE local _KEYBASTION = setmetatable({}, { __mode = "k", __index = function(this_cache, plugin_config) + if plugin_config.model.provider == "gemini" and + plugin_config.auth and + plugin_config.auth.gcp_use_service_account then - ngx.log(ngx.DEBUG, "loading azure sdk for ", plugin_config.model.options.azure_deployment_id, " in ", plugin_config.model.options.azure_instance) + ngx.log(ngx.NOTICE, "loading gcp sdk for plugin ", kong.plugin.get_id()) - local azure_client = require("resty.azure"):new({ - client_id = plugin_config.auth.azure_client_id, - client_secret = plugin_config.auth.azure_client_secret, - tenant_id = plugin_config.auth.azure_tenant_id, - token_scope = "https://cognitiveservices.azure.com/.default", - token_version = "v2.0", - }) + local service_account_json = (plugin_config.auth and plugin_config.auth.gcp_service_account_json) or GCP_SERVICE_ACCOUNT - local _, err = azure_client.authenticate() - if err then + local ok, gcp_auth = pcall(GCP.new, nil, service_account_json) + if ok and gcp_auth then + -- store our item for the next time we need it + gcp_auth.service_account_json = service_account_json + this_cache[plugin_config] = { interface = gcp_auth, error = nil } + return this_cache[plugin_config] + end + + return { interface = nil, error = "cloud-authentication with GCP failed" } + + -- [[ EE + elseif plugin_config.model.provider == "azure" + and plugin_config.auth.azure_use_managed_identity then + ngx.log(ngx.NOTICE, "loading azure sdk for plugin ", kong.plugin.get_id()) + + local azure_client = require("resty.azure"):new({ + client_id = plugin_config.auth.azure_client_id, + client_secret = plugin_config.auth.azure_client_secret, + tenant_id = plugin_config.auth.azure_tenant_id, + token_scope = "https://cognitiveservices.azure.com/.default", + token_version = "v2.0", + }) + + local _, err = azure_client.authenticate() + if not err then + -- store our item for the next time we need it + this_cache[plugin_config] = { interface = azure_client, error = nil } + return this_cache[plugin_config] + end + kong.log.err("failed to authenticate with Azure OpenAI: ", err) - return kong.response.exit(500, { error = { message = "failed to authenticate with Azure OpenAI" }}) + return { interface = nil, error = "managed identity auth with Azure OpenAI failed" } end + -- ]] - -- store our item for the next time we need it - this_cache[plugin_config] = azure_client - return azure_client end, }) --- ]] + + +local function bad_request(msg) + kong.log.info(msg) + return kong.response.exit(400, { error = { message = msg } }) +end + -- get the token text from an event frame local function get_token_text(event_t) @@ -74,7 +104,6 @@ local function get_token_text(event_t) end - local function handle_streaming_frame(conf) -- make a re-usable framebuffer local framebuffer = buffer.new() @@ -98,10 +127,43 @@ local function handle_streaming_frame(conf) -- because we have already 200 OK'd the client by now if (not finished) and (is_gzip) then - chunk = kong_utils.inflate_gzip(chunk) + chunk = kong_utils.inflate_gzip(ngx.arg[1]) end - local events = ai_shared.frame_to_events(chunk) + local is_raw_json = conf.model.provider == "gemini" + local events = ai_shared.frame_to_events(chunk, is_raw_json ) + + if not events then + -- usually a not-supported-transformer or empty frames. + -- header_filter has already run, so all we can do is log it, + -- and then send the client a readable error in a single chunk + local response = ERROR__NOT_SET + + if is_gzip then + response = kong_utils.deflate_gzip(response) + end + + ngx.arg[1] = response + ngx.arg[2] = true + + return + end + + if not events then + -- usually a not-supported-transformer or empty frames. + -- header_filter has already run, so all we can do is log it, + -- and then send the client a readable error in a single chunk + local response = ERROR__NOT_SET + + if is_gzip then + response = kong_utils.deflate_gzip(response) + end + + ngx.arg[1] = response + ngx.arg[2] = true + + return + end for _, event in ipairs(events) do local formatted, _, metadata = ai_driver.from_format(event, conf.model, "stream/" .. conf.route_type) @@ -355,7 +417,7 @@ function _M:access(conf) if not request_table then if not string.find(content_type, "multipart/form-data", nil, true) then - return bad_request("content-type header does not match request body") + return bad_request("content-type header does not match request body, or bad JSON formatting") end multipart = true -- this may be a large file upload, so we have to proxy it directly @@ -400,30 +462,6 @@ function _M:access(conf) end end - -- check the incoming format is the same as the configured LLM format - local compatible, err = llm.is_compatible(request_table, route_type) - if not compatible then - kong_ctx_shared.skip_response_transformer = true - return bad_request(err) - end - - -- [[ EE - if conf.model.provider == "azure" - and conf.auth.azure_use_managed_identity then - local identity_interface = _KEYBASTION[conf] - if identity_interface then - local _, token, _, err = identity_interface.credentials:get() - - if err then - kong.log.err("failed to authenticate with Azure Content Services, ", err) - return kong.response.exit(500, { error = { message = "failed to authenticate with Azure Content Services" }}) - end - - kong.service.request.set_header("Authorization", "Bearer " .. token) - end - end - -- ]] - -- check if the user has asked for a stream, and/or if -- we are forcing all requests to be of streaming type if request_table and request_table.stream or @@ -436,8 +474,9 @@ function _M:access(conf) return bad_request("response streaming is not enabled for this LLM") end - -- store token cost estimate, on first pass - if not kong_ctx_plugin.ai_stream_prompt_tokens then + -- store token cost estimate, on first pass, if the + -- provider doesn't reply with a prompt token count + if (not kong.ctx.plugin.ai_stream_prompt_tokens) and (not ai_shared.streaming_has_token_counts[conf_m.model.provider]) then local prompt_tokens, err = ai_shared.calculate_cost(request_table or {}, {}, 1.8) if err then kong.log.err("unable to estimate request token cost: ", err) @@ -483,8 +522,17 @@ function _M:access(conf) kong.service.request.set_body(parsed_request_body, content_type) end + -- get the provider's cached identity interface - nil may come back, which is fine + local identity_interface = _KEYBASTION[conf] + if identity_interface and identity_interface.error then + kong.ctx.shared.skip_response_transformer = true + kong.log.err("error authenticating with ", conf.model.provider, " using native provider auth, ", identity_interface.error) + return kong.response.exit(500, "LLM request failed before proxying") + end + -- now re-configure the request for this operation type - local ok, err = ai_driver.configure_request(conf_m) + local ok, err = ai_driver.configure_request(conf_m, + identity_interface and identity_interface.interface) if not ok then kong_ctx_shared.skip_response_transformer = true kong.log.err("failed to configure request for AI service: ", err) diff --git a/spec/03-plugins/38-ai-proxy/01-unit_spec.lua b/spec/03-plugins/38-ai-proxy/01-unit_spec.lua index 30b19aa2fa0f..aa82e11486c1 100644 --- a/spec/03-plugins/38-ai-proxy/01-unit_spec.lua +++ b/spec/03-plugins/38-ai-proxy/01-unit_spec.lua @@ -230,6 +230,20 @@ local FORMATS = { }, }, }, + gemini = { + ["llm/v1/chat"] = { + config = { + name = "gemini-pro", + provider = "gemini", + options = { + max_tokens = 8192, + temperature = 0.8, + top_k = 1, + top_p = 0.6, + }, + }, + }, + }, } local STREAMS = { @@ -653,5 +667,48 @@ describe(PLUGIN_NAME .. ": (unit)", function() }, formatted) end) + describe("streaming transformer tests", function() + + it("transforms truncated-json type (beginning of stream)", function() + local input = pl_file.read(fmt("spec/fixtures/ai-proxy/unit/streaming-chunk-formats/partial-json-beginning/input.bin")) + local events = ai_shared.frame_to_events(input, true) + + local expected = pl_file.read(fmt("spec/fixtures/ai-proxy/unit/streaming-chunk-formats/partial-json-beginning/expected-output.json")) + local expected_events = cjson.decode(expected) + + assert.same(events, expected_events, true) + end) + + it("transforms truncated-json type (end of stream)", function() + local input = pl_file.read(fmt("spec/fixtures/ai-proxy/unit/streaming-chunk-formats/partial-json-end/input.bin")) + local events = ai_shared.frame_to_events(input, true) + + local expected = pl_file.read(fmt("spec/fixtures/ai-proxy/unit/streaming-chunk-formats/partial-json-end/expected-output.json")) + local expected_events = cjson.decode(expected) + + assert.same(events, expected_events, true) + end) + + it("transforms complete-json type", function() + local input = pl_file.read(fmt("spec/fixtures/ai-proxy/unit/streaming-chunk-formats/complete-json/input.bin")) + local events = ai_shared.frame_to_events(input, false) -- not "truncated json mode" like Gemini + + local expected = pl_file.read(fmt("spec/fixtures/ai-proxy/unit/streaming-chunk-formats/complete-json/expected-output.json")) + local expected_events = cjson.decode(expected) + + assert.same(events, expected_events) + end) + + it("transforms text/event-stream type", function() + local input = pl_file.read(fmt("spec/fixtures/ai-proxy/unit/streaming-chunk-formats/text-event-stream/input.bin")) + local events = ai_shared.frame_to_events(input, false) -- not "truncated json mode" like Gemini + + local expected = pl_file.read(fmt("spec/fixtures/ai-proxy/unit/streaming-chunk-formats/text-event-stream/expected-output.json")) + local expected_events = cjson.decode(expected) + + assert.same(events, expected_events) + end) + + end) end) diff --git a/spec/fixtures/ai-proxy/unit/expected-requests/gemini/llm-v1-chat.json b/spec/fixtures/ai-proxy/unit/expected-requests/gemini/llm-v1-chat.json new file mode 100644 index 000000000000..f236df678a4d --- /dev/null +++ b/spec/fixtures/ai-proxy/unit/expected-requests/gemini/llm-v1-chat.json @@ -0,0 +1,57 @@ +{ + "contents": [ + { + "role": "user", + "parts": [ + { + "text": "What is 1 + 2?" + } + ] + }, + { + "role": "model", + "parts": [ + { + "text": "The sum of 1 + 2 is 3. If you have any more math questions or if there's anything else I can help you with, feel free to ask!" + } + ] + }, + { + "role": "user", + "parts": [ + { + "text": "Multiply that by 2" + } + ] + }, + { + "role": "model", + "parts": [ + { + "text": "Certainly! If you multiply 3 by 2, the result is 6. If you have any more questions or if there's anything else I can help you with, feel free to ask!" + } + ] + }, + { + "role": "user", + "parts": [ + { + "text": "Why can't you divide by zero?" + } + ] + } + ], + "generationConfig": { + "temperature": 0.8, + "topK": 1, + "topP": 0.6, + "maxOutputTokens": 8192 + }, + "systemInstruction": { + "parts": [ + { + "text": "You are a mathematician." + } + ] + } +} \ No newline at end of file diff --git a/spec/fixtures/ai-proxy/unit/expected-responses/gemini/llm-v1-chat.json b/spec/fixtures/ai-proxy/unit/expected-responses/gemini/llm-v1-chat.json new file mode 100644 index 000000000000..90a1656d2a37 --- /dev/null +++ b/spec/fixtures/ai-proxy/unit/expected-responses/gemini/llm-v1-chat.json @@ -0,0 +1,14 @@ +{ + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "message": { + "content": "Ah, vous voulez savoir le double de ce résultat ? Eh bien, le double de 2 est **4**. \n", + "role": "assistant" + } + } + ], + "model": "gemini-pro", + "object": "chat.completion" +} \ No newline at end of file diff --git a/spec/fixtures/ai-proxy/unit/real-responses/gemini/llm-v1-chat.json b/spec/fixtures/ai-proxy/unit/real-responses/gemini/llm-v1-chat.json new file mode 100644 index 000000000000..80781b6eb72a --- /dev/null +++ b/spec/fixtures/ai-proxy/unit/real-responses/gemini/llm-v1-chat.json @@ -0,0 +1,34 @@ +{ + "candidates": [ + { + "content": { + "parts": [ + { + "text": "Ah, vous voulez savoir le double de ce résultat ? Eh bien, le double de 2 est **4**. \n" + } + ], + "role": "model" + }, + "finishReason": "STOP", + "index": 0, + "safetyRatings": [ + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HARASSMENT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "probability": "NEGLIGIBLE" + } + ] + } + ] + } \ No newline at end of file diff --git a/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/complete-json/expected-output.json b/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/complete-json/expected-output.json new file mode 100644 index 000000000000..b08549afbf4e --- /dev/null +++ b/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/complete-json/expected-output.json @@ -0,0 +1,8 @@ +[ + { + "data": "{\"is_finished\":false,\"event_type\":\"stream-start\",\"generation_id\":\"10f31c2f-1a4c-48cf-b500-dc8141a25ae5\"}" + }, + { + "data": "{\"is_finished\":false,\"event_type\":\"text-generation\",\"text\":\"2\"}" + } +] \ No newline at end of file diff --git a/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/complete-json/input.bin b/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/complete-json/input.bin new file mode 100644 index 000000000000..af13220a423a --- /dev/null +++ b/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/complete-json/input.bin @@ -0,0 +1,2 @@ +{"is_finished":false,"event_type":"stream-start","generation_id":"10f31c2f-1a4c-48cf-b500-dc8141a25ae5"} +{"is_finished":false,"event_type":"text-generation","text":"2"} \ No newline at end of file diff --git a/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/partial-json-beginning/expected-output.json b/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/partial-json-beginning/expected-output.json new file mode 100644 index 000000000000..5f3b0afa51d4 --- /dev/null +++ b/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/partial-json-beginning/expected-output.json @@ -0,0 +1,14 @@ +[ + { + "data": "{\n \"candidates\": [\n {\n \"content\": {\n \"parts\": [\n {\n \"text\": \"The\"\n }\n ],\n \"role\": \"model\"\n },\n \"finishReason\": \"STOP\",\n \"index\": 0\n }\n ],\n \"usageMetadata\": {\n \"promptTokenCount\": 6,\n \"candidatesTokenCount\": 1,\n \"totalTokenCount\": 7\n }\n}" + }, + { + "data": "\n{\n \"candidates\": [\n {\n \"content\": {\n \"parts\": [\n {\n \"text\": \" theory of relativity is actually two theories by Albert Einstein: **special relativity** and\"\n }\n ],\n \"role\": \"model\"\n },\n \"finishReason\": \"STOP\",\n \"index\": 0,\n \"safetyRatings\": [\n {\n \"category\": \"HARM_CATEGORY_SEXUALLY_EXPLICIT\",\n \"probability\": \"NEGLIGIBLE\"\n },\n {\n \"category\": \"HARM_CATEGORY_HATE_SPEECH\",\n \"probability\": \"NEGLIGIBLE\"\n },\n {\n \"category\": \"HARM_CATEGORY_HARASSMENT\",\n \"probability\": \"NEGLIGIBLE\"\n },\n {\n \"category\": \"HARM_CATEGORY_DANGEROUS_CONTENT\",\n \"probability\": \"NEGLIGIBLE\"\n }\n ]\n }\n ],\n \"usageMetadata\": {\n \"promptTokenCount\": 6,\n \"candidatesTokenCount\": 17,\n \"totalTokenCount\": 23\n }\n}" + }, + { + "data": "\n{\n \"candidates\": [\n {\n \"content\": {\n \"parts\": [\n {\n \"text\": \" **general relativity**. Here's a simplified breakdown:\\n\\n**Special Relativity (\"\n }\n ],\n \"role\": \"model\"\n },\n \"finishReason\": \"STOP\",\n \"index\": 0,\n \"safetyRatings\": [\n {\n \"category\": \"HARM_CATEGORY_SEXUALLY_EXPLICIT\",\n \"probability\": \"NEGLIGIBLE\"\n },\n {\n \"category\": \"HARM_CATEGORY_HATE_SPEECH\",\n \"probability\": \"NEGLIGIBLE\"\n },\n {\n \"category\": \"HARM_CATEGORY_HARASSMENT\",\n \"probability\": \"NEGLIGIBLE\"\n },\n {\n \"category\": \"HARM_CATEGORY_DANGEROUS_CONTENT\",\n \"probability\": \"NEGLIGIBLE\"\n }\n ]\n }\n ],\n \"usageMetadata\": {\n \"promptTokenCount\": 6,\n \"candidatesTokenCount\": 33,\n \"totalTokenCount\": 39\n }\n}" + }, + { + "data": "\n{\n \"candidates\": [\n {\n \"content\": {\n \"parts\": [\n {\n \"text\": \"1905):**\\n\\n* **Focus:** The relationship between space and time.\\n* **Key ideas:**\\n * **Speed of light\"\n }\n ],\n \"role\": \"model\"\n },\n \"finishReason\": \"STOP\",\n \"index\": 0,\n \"safetyRatings\": [\n {\n \"category\": \"HARM_CATEGORY_SEXUALLY_EXPLICIT\",\n \"probability\": \"NEGLIGIBLE\"\n },\n {\n \"category\": \"HARM_CATEGORY_HATE_SPEECH\",\n \"probability\": \"NEGLIGIBLE\"\n },\n {\n \"category\": \"HARM_CATEGORY_HARASSMENT\",\n \"probability\": \"NEGLIGIBLE\"\n },\n {\n \"category\": \"HARM_CATEGORY_DANGEROUS_CONTENT\",\n \"probability\": \"NEGLIGIBLE\"\n }\n ]\n }\n ],\n \"usageMetadata\": {\n \"promptTokenCount\": 6,\n \"candidatesTokenCount\": 65,\n \"totalTokenCount\": 71\n }\n}\n" + } +] \ No newline at end of file diff --git a/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/partial-json-beginning/input.bin b/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/partial-json-beginning/input.bin new file mode 100644 index 000000000000..8cef2a01fa8d --- /dev/null +++ b/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/partial-json-beginning/input.bin @@ -0,0 +1,141 @@ +[{ + "candidates": [ + { + "content": { + "parts": [ + { + "text": "The" + } + ], + "role": "model" + }, + "finishReason": "STOP", + "index": 0 + } + ], + "usageMetadata": { + "promptTokenCount": 6, + "candidatesTokenCount": 1, + "totalTokenCount": 7 + } +} +, +{ + "candidates": [ + { + "content": { + "parts": [ + { + "text": " theory of relativity is actually two theories by Albert Einstein: **special relativity** and" + } + ], + "role": "model" + }, + "finishReason": "STOP", + "index": 0, + "safetyRatings": [ + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HARASSMENT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "probability": "NEGLIGIBLE" + } + ] + } + ], + "usageMetadata": { + "promptTokenCount": 6, + "candidatesTokenCount": 17, + "totalTokenCount": 23 + } +} +, +{ + "candidates": [ + { + "content": { + "parts": [ + { + "text": " **general relativity**. Here's a simplified breakdown:\n\n**Special Relativity (" + } + ], + "role": "model" + }, + "finishReason": "STOP", + "index": 0, + "safetyRatings": [ + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HARASSMENT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "probability": "NEGLIGIBLE" + } + ] + } + ], + "usageMetadata": { + "promptTokenCount": 6, + "candidatesTokenCount": 33, + "totalTokenCount": 39 + } +} +, +{ + "candidates": [ + { + "content": { + "parts": [ + { + "text": "1905):**\n\n* **Focus:** The relationship between space and time.\n* **Key ideas:**\n * **Speed of light" + } + ], + "role": "model" + }, + "finishReason": "STOP", + "index": 0, + "safetyRatings": [ + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HARASSMENT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "probability": "NEGLIGIBLE" + } + ] + } + ], + "usageMetadata": { + "promptTokenCount": 6, + "candidatesTokenCount": 65, + "totalTokenCount": 71 + } +} diff --git a/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/partial-json-end/expected-output.json b/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/partial-json-end/expected-output.json new file mode 100644 index 000000000000..ba6a64384d95 --- /dev/null +++ b/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/partial-json-end/expected-output.json @@ -0,0 +1,8 @@ +[ + { + "data": "{\n \"candidates\": [\n {\n \"content\": {\n \"parts\": [\n {\n \"text\": \" is constant:** No matter how fast you are moving, light always travels at the same speed (approximately 299,792,458\"\n }\n ],\n \"role\": \"model\"\n },\n \"finishReason\": \"STOP\",\n \"index\": 0,\n \"safetyRatings\": [\n {\n \"category\": \"HARM_CATEGORY_SEXUALLY_EXPLICIT\",\n \"probability\": \"NEGLIGIBLE\"\n },\n {\n \"category\": \"HARM_CATEGORY_HATE_SPEECH\",\n \"probability\": \"NEGLIGIBLE\"\n },\n {\n \"category\": \"HARM_CATEGORY_HARASSMENT\",\n \"probability\": \"NEGLIGIBLE\"\n },\n {\n \"category\": \"HARM_CATEGORY_DANGEROUS_CONTENT\",\n \"probability\": \"NEGLIGIBLE\"\n }\n ]\n }\n ],\n \"usageMetadata\": {\n \"promptTokenCount\": 6,\n \"candidatesTokenCount\": 97,\n \"totalTokenCount\": 103\n }\n}" + }, + { + "data": "\n{\n \"candidates\": [\n {\n \"content\": {\n \"parts\": [\n {\n \"text\": \" not a limit.\\n\\nIf you're interested in learning more about relativity, I encourage you to explore further resources online or in books. There are many excellent introductory materials available. \\n\"\n }\n ],\n \"role\": \"model\"\n },\n \"finishReason\": \"STOP\",\n \"index\": 0,\n \"safetyRatings\": [\n {\n \"category\": \"HARM_CATEGORY_SEXUALLY_EXPLICIT\",\n \"probability\": \"NEGLIGIBLE\"\n },\n {\n \"category\": \"HARM_CATEGORY_HATE_SPEECH\",\n \"probability\": \"NEGLIGIBLE\"\n },\n {\n \"category\": \"HARM_CATEGORY_HARASSMENT\",\n \"probability\": \"NEGLIGIBLE\"\n },\n {\n \"category\": \"HARM_CATEGORY_DANGEROUS_CONTENT\",\n \"probability\": \"NEGLIGIBLE\"\n }\n ]\n }\n ],\n \"usageMetadata\": {\n \"promptTokenCount\": 6,\n \"candidatesTokenCount\": 547,\n \"totalTokenCount\": 553\n }\n}\n" + } +] \ No newline at end of file diff --git a/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/partial-json-end/input.bin b/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/partial-json-end/input.bin new file mode 100644 index 000000000000..d6489e74d19d --- /dev/null +++ b/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/partial-json-end/input.bin @@ -0,0 +1,80 @@ +,{ + "candidates": [ + { + "content": { + "parts": [ + { + "text": " is constant:** No matter how fast you are moving, light always travels at the same speed (approximately 299,792,458" + } + ], + "role": "model" + }, + "finishReason": "STOP", + "index": 0, + "safetyRatings": [ + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HARASSMENT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "probability": "NEGLIGIBLE" + } + ] + } + ], + "usageMetadata": { + "promptTokenCount": 6, + "candidatesTokenCount": 97, + "totalTokenCount": 103 + } +} +, +{ + "candidates": [ + { + "content": { + "parts": [ + { + "text": " not a limit.\n\nIf you're interested in learning more about relativity, I encourage you to explore further resources online or in books. There are many excellent introductory materials available. \n" + } + ], + "role": "model" + }, + "finishReason": "STOP", + "index": 0, + "safetyRatings": [ + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HARASSMENT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "probability": "NEGLIGIBLE" + } + ] + } + ], + "usageMetadata": { + "promptTokenCount": 6, + "candidatesTokenCount": 547, + "totalTokenCount": 553 + } +} +] \ No newline at end of file diff --git a/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/text-event-stream/expected-output.json b/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/text-event-stream/expected-output.json new file mode 100644 index 000000000000..f515516c7ec8 --- /dev/null +++ b/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/text-event-stream/expected-output.json @@ -0,0 +1,11 @@ +[ + { + "data": "{ \"choices\": [ { \"delta\": { \"content\": \"\", \"role\": \"assistant\" }, \"finish_reason\": null, \"index\": 0, \"logprobs\": null } ], \"created\": 1720136012, \"id\": \"chatcmpl-9hQFArK1oMZcRwaIa86RGwrjVNmeY\", \"model\": \"gpt-4-0613\", \"object\": \"chat.completion.chunk\", \"system_fingerprint\": null}" + }, + { + "data": "{ \"choices\": [ { \"delta\": { \"content\": \"2\" }, \"finish_reason\": null, \"index\": 0, \"logprobs\": null } ], \"created\": 1720136012, \"id\": \"chatcmpl-9hQFArK1oMZcRwaIa86RGwrjVNmeY\", \"model\": \"gpt-4-0613\", \"object\": \"chat.completion.chunk\", \"system_fingerprint\": null}" + }, + { + "data": "{ \"choices\": [ { \"delta\": {}, \"finish_reason\": \"stop\", \"index\": 0, \"logprobs\": null } ], \"created\": 1720136012, \"id\": \"chatcmpl-9hQFArK1oMZcRwaIa86RGwrjVNmeY\", \"model\": \"gpt-4-0613\", \"object\": \"chat.completion.chunk\", \"system_fingerprint\": null}" + } +] \ No newline at end of file diff --git a/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/text-event-stream/input.bin b/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/text-event-stream/input.bin new file mode 100644 index 000000000000..efe2ad50c657 --- /dev/null +++ b/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/text-event-stream/input.bin @@ -0,0 +1,7 @@ +data: { "choices": [ { "delta": { "content": "", "role": "assistant" }, "finish_reason": null, "index": 0, "logprobs": null } ], "created": 1720136012, "id": "chatcmpl-9hQFArK1oMZcRwaIa86RGwrjVNmeY", "model": "gpt-4-0613", "object": "chat.completion.chunk", "system_fingerprint": null} + +data: { "choices": [ { "delta": { "content": "2" }, "finish_reason": null, "index": 0, "logprobs": null } ], "created": 1720136012, "id": "chatcmpl-9hQFArK1oMZcRwaIa86RGwrjVNmeY", "model": "gpt-4-0613", "object": "chat.completion.chunk", "system_fingerprint": null} + +data: { "choices": [ { "delta": {}, "finish_reason": "stop", "index": 0, "logprobs": null } ], "created": 1720136012, "id": "chatcmpl-9hQFArK1oMZcRwaIa86RGwrjVNmeY", "model": "gpt-4-0613", "object": "chat.completion.chunk", "system_fingerprint": null} + +data: [DONE] \ No newline at end of file