From 3a79e8348919ba3f414d69b086d2902d2f3085c4 Mon Sep 17 00:00:00 2001 From: Jack Tysoe Date: Sat, 27 Apr 2024 23:38:10 +0100 Subject: [PATCH 01/20] feat(ai-proxy): google-gemini support --- .../kong/ai-proxy-google-gemini.yml | 5 + kong/llm/drivers/gemini.lua | 312 ++++++++++++++++++ kong/llm/drivers/shared.lua | 10 + kong/llm/init.lua | 1 - kong/llm/schemas/init.lua | 2 +- spec/03-plugins/38-ai-proxy/01-unit_spec.lua | 14 + .../expected-requests/gemini/llm-v1-chat.json | 57 ++++ .../gemini/llm-v1-chat.json | 14 + .../real-responses/gemini/llm-v1-chat.json | 34 ++ 9 files changed, 447 insertions(+), 2 deletions(-) create mode 100644 changelog/unreleased/kong/ai-proxy-google-gemini.yml create mode 100644 kong/llm/drivers/gemini.lua create mode 100644 spec/fixtures/ai-proxy/unit/expected-requests/gemini/llm-v1-chat.json create mode 100644 spec/fixtures/ai-proxy/unit/expected-responses/gemini/llm-v1-chat.json create mode 100644 spec/fixtures/ai-proxy/unit/real-responses/gemini/llm-v1-chat.json diff --git a/changelog/unreleased/kong/ai-proxy-google-gemini.yml b/changelog/unreleased/kong/ai-proxy-google-gemini.yml new file mode 100644 index 000000000000..bc4fb06b21c4 --- /dev/null +++ b/changelog/unreleased/kong/ai-proxy-google-gemini.yml @@ -0,0 +1,5 @@ +message: | + Kong AI Gateway (AI Proxy and associated plugin family) now supports + the Google Gemini "chat" (generateContent) interface. +type: feature +scope: Plugin diff --git a/kong/llm/drivers/gemini.lua b/kong/llm/drivers/gemini.lua new file mode 100644 index 000000000000..5360690b6c63 --- /dev/null +++ b/kong/llm/drivers/gemini.lua @@ -0,0 +1,312 @@ +local _M = {} + +-- imports +local cjson = require("cjson.safe") +local fmt = string.format +local ai_shared = require("kong.llm.drivers.shared") +local socket_url = require("socket.url") +local string_gsub = string.gsub +local buffer = require("string.buffer") +local table_insert = table.insert +local string_lower = string.lower +-- + +-- globals +local DRIVER_NAME = "gemini" +-- + +local _OPENAI_ROLE_MAPPING = { + ["system"] = "system", + ["user"] = "user", + ["assistant"] = "model", +} + +local function to_bard_generation_config(request_table) + return { + ["maxOutputTokens"] = request_table.max_tokens, + ["stopSequences"] = request_table.stop, + ["temperature"] = request_table.temperature, + ["topK"] = request_table.top_k, + ["topP"] = request_table.top_p, + } +end + +local function to_bard_chat_openai(request_table, model_info, route_type) + if request_table then -- try-catch type mechanism + local new_r = {} + + if request_table.messages and #request_table.messages > 0 then + local system_prompt + + for i, v in ipairs(request_table.messages) do + + -- for 'system', we just concat them all into one Gemini instruction + if v.role and v.role == "system" then + system_prompt = system_prompt or buffer.new() + system_prompt:put(v.content or "") + else + -- for any other role, just construct the chat history as 'parts.text' type + new_r.contents = new_r.contents or {} + table_insert(new_r.contents, { + role = _OPENAI_ROLE_MAPPING[v.role or "user"], -- default to 'user' + parts = { + { + text = v.content or "" + }, + }, + }) + end + end + + ---- TODO for some reason this is broken? + ---- I think it's something to do with which "regional" endpoint of Gemini you hit... + -- if system_prompt then + -- new_r.systemInstruction = { + -- parts = { + -- { + -- text = system_prompt:get(), + -- }, + -- }, + -- } + -- end + ---- + + end + + new_r.generationConfig = to_bard_generation_config(request_table) + + kong.log.debug(cjson.encode(new_r)) + + return new_r, "application/json", nil + end + + local err = "empty request table received for transformation" + ngx.log(ngx.ERR, err) + return nil, nil, err +end + +local function from_bard_chat_openai(response, model_info, route_type) + local response, err = cjson.decode(response) + + if err then + local err_client = "failed to decode response from Gemini" + ngx.log(ngx.ERR, fmt("%s: %s", err_client, err)) + return nil, err_client + end + + -- messages/choices table is only 1 size, so don't need to static allocate + local messages = {} + messages.choices = {} + + if response.candidates + and #response.candidates > 0 + and response.candidates[1].content + and response.candidates[1].content.parts + and #response.candidates[1].content.parts > 0 + and response.candidates[1].content.parts[1].text then + + messages.choices[1] = { + index = 0, + message = { + role = "assistant", + content = response.candidates[1].content.parts[1].text, + }, + finish_reason = string_lower(response.candidates[1].finishReason), + } + messages.object = "chat.completion" + messages.model = model_info.name + + else -- probably a server fault or other unexpected response + local err = "no generation candidates received from Gemini, or max_tokens too short" + ngx.log(ngx.ERR, err) + return nil, err + end + + return cjson.encode(messages) +end + +local function to_bard_chat_bard(request_table, model_info, route_type) + return nil, nil, "bard to bard not yet implemented" +end + +local function from_bard_chat_bard(request_table, model_info, route_type) + return nil, nil, "bard to bard not yet implemented" +end + +local transformers_to = { + ["llm/v1/chat"] = to_bard_chat_openai, + ["gemini/v1/chat"] = to_gemini_chat_bard, +} + +local transformers_from = { + ["llm/v1/chat"] = from_bard_chat_openai, + ["gemini/v1/chat"] = from_gemini_chat_bard, +} + +function _M.from_format(response_string, model_info, route_type) + ngx.log(ngx.DEBUG, "converting from ", model_info.provider, "://", route_type, " type to kong") + + -- MUST return a string, to set as the response body + if not transformers_from[route_type] then + return nil, fmt("no transformer available from format %s://%s", model_info.provider, route_type) + end + + local ok, response_string, err = pcall(transformers_from[route_type], response_string, model_info, route_type) + if not ok or err then + return nil, fmt("transformation failed from type %s://%s: %s", + model_info.provider, + route_type, + err or "unexpected_error" + ) + end + + return response_string, nil +end + +function _M.to_format(request_table, model_info, route_type) + ngx.log(ngx.DEBUG, "converting from kong type to ", model_info.provider, "/", route_type) + + if route_type == "preserve" then + -- do nothing + return request_table, nil, nil + end + + if not transformers_to[route_type] then + return nil, nil, fmt("no transformer for %s://%s", model_info.provider, route_type) + end + + request_table = ai_shared.merge_config_defaults(request_table, model_info.options, model_info.route_type) + + local ok, response_object, content_type, err = pcall( + transformers_to[route_type], + request_table, + model_info + ) + if err or (not ok) then + return nil, nil, fmt("error transforming to %s://%s", model_info.provider, route_type) + end + + return response_object, content_type, nil +end + +function _M.subrequest(body, conf, http_opts, return_res_table) + -- use shared/standard subrequest routine + local body_string, err + + if type(body) == "table" then + body_string, err = cjson.encode(body) + if err then + return nil, nil, "failed to parse body to json: " .. err + end + elseif type(body) == "string" then + body_string = body + else + return nil, nil, "body must be table or string" + end + + -- may be overridden + local url = (conf.model.options and conf.model.options.upstream_url) + or fmt( + "%s%s", + ai_shared.upstream_url_format[DRIVER_NAME], + ai_shared.operation_map[DRIVER_NAME][conf.route_type].path + ) + + local method = ai_shared.operation_map[DRIVER_NAME][conf.route_type].method + + local headers = { + ["Accept"] = "application/json", + ["Content-Type"] = "application/json", + } + + if conf.auth and conf.auth.header_name then + headers[conf.auth.header_name] = conf.auth.header_value + end + + local res, err, httpc = ai_shared.http_request(url, body_string, method, headers, http_opts, return_res_table) + if err then + return nil, nil, "request to ai service failed: " .. err + end + + if return_res_table then + return res, res.status, nil, httpc + else + -- At this point, the entire request / response is complete and the connection + -- will be closed or back on the connection pool. + local status = res.status + local body = res.body + + if status > 299 then + return body, res.status, "status code " .. status + end + + return body, res.status, nil + end +end + +function _M.header_filter_hooks(body) + -- nothing to parse in header_filter phase +end + +function _M.post_request(conf) + if ai_shared.clear_response_headers[DRIVER_NAME] then + for i, v in ipairs(ai_shared.clear_response_headers[DRIVER_NAME]) do + kong.response.clear_header(v) + end + end +end + +function _M.pre_request(conf, body) + kong.service.request.set_header("Accept-Encoding", "gzip, identity") -- tell server not to send brotli + + return true, nil +end + +-- returns err or nil +function _M.configure_request(conf) + local parsed_url + + if (conf.model.options and conf.model.options.upstream_url) then + parsed_url = socket_url.parse(conf.model.options.upstream_url) + else + local path = conf.model.options + and conf.model.options.upstream_path + or ai_shared.operation_map[DRIVER_NAME][conf.route_type] + and fmt(ai_shared.operation_map[DRIVER_NAME][conf.route_type].path, conf.model.name) + or "/" + if not path then + return nil, fmt("operation %s is not supported for openai provider", conf.route_type) + end + + parsed_url = socket_url.parse(ai_shared.upstream_url_format[DRIVER_NAME]) + parsed_url.path = path + end + + -- if the path is read from a URL capture, ensure that it is valid + parsed_url.path = string_gsub(parsed_url.path, "^/*", "/") + + kong.service.request.set_path(parsed_url.path) + kong.service.request.set_scheme(parsed_url.scheme) + kong.service.set_target(parsed_url.host, (tonumber(parsed_url.port) or 443)) + + local auth_header_name = conf.auth and conf.auth.header_name + local auth_header_value = conf.auth and conf.auth.header_value + local auth_param_name = conf.auth and conf.auth.param_name + local auth_param_value = conf.auth and conf.auth.param_value + local auth_param_location = conf.auth and conf.auth.param_location + + if auth_header_name and auth_header_value then + kong.service.request.set_header(auth_header_name, auth_header_value) + end + + if auth_param_name and auth_param_value and auth_param_location == "query" then + local query_table = kong.request.get_query() + query_table[auth_param_name] = auth_param_value + kong.service.request.set_query(query_table) + end + + -- if auth_param_location is "form", it will have already been set in a global pre-request hook + return true, nil +end + +return _M diff --git a/kong/llm/drivers/shared.lua b/kong/llm/drivers/shared.lua index 9d62998c34cd..9babf77019ae 100644 --- a/kong/llm/drivers/shared.lua +++ b/kong/llm/drivers/shared.lua @@ -62,6 +62,7 @@ _M.upstream_url_format = { anthropic = "https://api.anthropic.com:443", cohere = "https://api.cohere.com:443", azure = "https://%s.openai.azure.com:443/openai/deployments/%s", + gemini = "https://generativelanguage.googleapis.com", } _M.operation_map = { @@ -105,6 +106,12 @@ _M.operation_map = { method = "POST", }, }, + gemini = { + ["llm/v1/chat"] = { + path = "/v1/models/%s:generateContent", -- /v1/models/gemini-pro:generateContent, + method = "POST", + }, + }, } _M.clear_response_headers = { @@ -120,6 +127,9 @@ _M.clear_response_headers = { mistral = { "Set-Cookie", }, + gemini = { + "Set-Cookie", + }, } --- diff --git a/kong/llm/init.lua b/kong/llm/init.lua index aaf3af08a790..64533e14e700 100644 --- a/kong/llm/init.lua +++ b/kong/llm/init.lua @@ -11,7 +11,6 @@ local _M = { } - do -- formats_compatible is a map of formats that are compatible with each other. local formats_compatible = { diff --git a/kong/llm/schemas/init.lua b/kong/llm/schemas/init.lua index 15ce1a2a1ef0..ec1586b5eeb2 100644 --- a/kong/llm/schemas/init.lua +++ b/kong/llm/schemas/init.lua @@ -123,7 +123,7 @@ local model_schema = { type = "string", description = "AI provider request format - Kong translates " .. "requests to and from the specified backend compatible formats.", required = true, - one_of = { "openai", "azure", "anthropic", "cohere", "mistral", "llama2" }}}, + one_of = { "openai", "azure", "anthropic", "cohere", "mistral", "llama2", "gemini" }}}, { name = { type = "string", description = "Model name to execute.", diff --git a/spec/03-plugins/38-ai-proxy/01-unit_spec.lua b/spec/03-plugins/38-ai-proxy/01-unit_spec.lua index c1dfadfb4aca..bd47caac286a 100644 --- a/spec/03-plugins/38-ai-proxy/01-unit_spec.lua +++ b/spec/03-plugins/38-ai-proxy/01-unit_spec.lua @@ -223,6 +223,20 @@ local FORMATS = { }, }, }, + gemini = { + ["llm/v1/chat"] = { + config = { + name = "gemini-pro", + provider = "gemini", + options = { + max_tokens = 8192, + temperature = 0.8, + top_k = 1, + top_p = 0.6, + }, + }, + }, + }, } local STREAMS = { diff --git a/spec/fixtures/ai-proxy/unit/expected-requests/gemini/llm-v1-chat.json b/spec/fixtures/ai-proxy/unit/expected-requests/gemini/llm-v1-chat.json new file mode 100644 index 000000000000..f236df678a4d --- /dev/null +++ b/spec/fixtures/ai-proxy/unit/expected-requests/gemini/llm-v1-chat.json @@ -0,0 +1,57 @@ +{ + "contents": [ + { + "role": "user", + "parts": [ + { + "text": "What is 1 + 2?" + } + ] + }, + { + "role": "model", + "parts": [ + { + "text": "The sum of 1 + 2 is 3. If you have any more math questions or if there's anything else I can help you with, feel free to ask!" + } + ] + }, + { + "role": "user", + "parts": [ + { + "text": "Multiply that by 2" + } + ] + }, + { + "role": "model", + "parts": [ + { + "text": "Certainly! If you multiply 3 by 2, the result is 6. If you have any more questions or if there's anything else I can help you with, feel free to ask!" + } + ] + }, + { + "role": "user", + "parts": [ + { + "text": "Why can't you divide by zero?" + } + ] + } + ], + "generationConfig": { + "temperature": 0.8, + "topK": 1, + "topP": 0.6, + "maxOutputTokens": 8192 + }, + "systemInstruction": { + "parts": [ + { + "text": "You are a mathematician." + } + ] + } +} \ No newline at end of file diff --git a/spec/fixtures/ai-proxy/unit/expected-responses/gemini/llm-v1-chat.json b/spec/fixtures/ai-proxy/unit/expected-responses/gemini/llm-v1-chat.json new file mode 100644 index 000000000000..90a1656d2a37 --- /dev/null +++ b/spec/fixtures/ai-proxy/unit/expected-responses/gemini/llm-v1-chat.json @@ -0,0 +1,14 @@ +{ + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "message": { + "content": "Ah, vous voulez savoir le double de ce résultat ? Eh bien, le double de 2 est **4**. \n", + "role": "assistant" + } + } + ], + "model": "gemini-pro", + "object": "chat.completion" +} \ No newline at end of file diff --git a/spec/fixtures/ai-proxy/unit/real-responses/gemini/llm-v1-chat.json b/spec/fixtures/ai-proxy/unit/real-responses/gemini/llm-v1-chat.json new file mode 100644 index 000000000000..80781b6eb72a --- /dev/null +++ b/spec/fixtures/ai-proxy/unit/real-responses/gemini/llm-v1-chat.json @@ -0,0 +1,34 @@ +{ + "candidates": [ + { + "content": { + "parts": [ + { + "text": "Ah, vous voulez savoir le double de ce résultat ? Eh bien, le double de 2 est **4**. \n" + } + ], + "role": "model" + }, + "finishReason": "STOP", + "index": 0, + "safetyRatings": [ + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HARASSMENT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "probability": "NEGLIGIBLE" + } + ] + } + ] + } \ No newline at end of file From f384a15a24f220267fda00f4411e1e43201c35a6 Mon Sep 17 00:00:00 2001 From: Jack Tysoe Date: Tue, 21 May 2024 01:14:56 +0100 Subject: [PATCH 02/20] fix gemini system prompt --- kong/llm/drivers/gemini.lua | 49 +++++++++++++++++++------------------ 1 file changed, 25 insertions(+), 24 deletions(-) diff --git a/kong/llm/drivers/gemini.lua b/kong/llm/drivers/gemini.lua index 5360690b6c63..cf674bc1fc2c 100644 --- a/kong/llm/drivers/gemini.lua +++ b/kong/llm/drivers/gemini.lua @@ -9,6 +9,7 @@ local string_gsub = string.gsub local buffer = require("string.buffer") local table_insert = table.insert local string_lower = string.lower +local string_sub = string.sub -- -- globals @@ -21,7 +22,7 @@ local _OPENAI_ROLE_MAPPING = { ["assistant"] = "model", } -local function to_bard_generation_config(request_table) +local function to_gemini_generation_config(request_table) return { ["maxOutputTokens"] = request_table.max_tokens, ["stopSequences"] = request_table.stop, @@ -31,7 +32,7 @@ local function to_bard_generation_config(request_table) } end -local function to_bard_chat_openai(request_table, model_info, route_type) +local function to_gemini_chat_openai(request_table, model_info, route_type) if request_table then -- try-catch type mechanism local new_r = {} @@ -58,22 +59,22 @@ local function to_bard_chat_openai(request_table, model_info, route_type) end end - ---- TODO for some reason this is broken? - ---- I think it's something to do with which "regional" endpoint of Gemini you hit... - -- if system_prompt then - -- new_r.systemInstruction = { - -- parts = { - -- { - -- text = system_prompt:get(), - -- }, - -- }, - -- } - -- end - ---- + -- This was only added in Gemini 1.5 + if system_prompt and model_info.name:sub(1, 10) == "gemini-1.0" then + return nil, nil, "system prompts aren't supported on gemini-1.0 models" + elseif system_prompt then + new_r.systemInstruction = { + parts = { + { + text = system_prompt:get(), + }, + }, + } + end end - new_r.generationConfig = to_bard_generation_config(request_table) + new_r.generationConfig = to_gemini_generation_config(request_table) kong.log.debug(cjson.encode(new_r)) @@ -85,7 +86,7 @@ local function to_bard_chat_openai(request_table, model_info, route_type) return nil, nil, err end -local function from_bard_chat_openai(response, model_info, route_type) +local function from_gemini_chat_openai(response, model_info, route_type) local response, err = cjson.decode(response) if err then @@ -125,22 +126,22 @@ local function from_bard_chat_openai(response, model_info, route_type) return cjson.encode(messages) end -local function to_bard_chat_bard(request_table, model_info, route_type) - return nil, nil, "bard to bard not yet implemented" +local function to_gemini_chat_gemini(request_table, model_info, route_type) + return nil, nil, "gemini to gemini not yet implemented" end -local function from_bard_chat_bard(request_table, model_info, route_type) - return nil, nil, "bard to bard not yet implemented" +local function from_gemini_chat_gemini(request_table, model_info, route_type) + return nil, nil, "gemini to gemini not yet implemented" end local transformers_to = { - ["llm/v1/chat"] = to_bard_chat_openai, - ["gemini/v1/chat"] = to_gemini_chat_bard, + ["llm/v1/chat"] = to_gemini_chat_openai, + ["gemini/v1/chat"] = to_gemini_chat_gemini, } local transformers_from = { - ["llm/v1/chat"] = from_bard_chat_openai, - ["gemini/v1/chat"] = from_gemini_chat_bard, + ["llm/v1/chat"] = from_gemini_chat_openai, + ["gemini/v1/chat"] = from_gemini_chat_gemini, } function _M.from_format(response_string, model_info, route_type) From a0b02968c5bd01fd85bd18f7b857481d8cf7c647 Mon Sep 17 00:00:00 2001 From: Jack Tysoe Date: Mon, 20 May 2024 14:59:13 +0100 Subject: [PATCH 03/20] stash --- Dockerfile | 8 ++++++ kong/llm/drivers/gemini.lua | 52 ++++++++++++++++++++++++++++++++++--- kong/llm/drivers/shared.lua | 2 +- 3 files changed, 57 insertions(+), 5 deletions(-) create mode 100644 Dockerfile diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 000000000000..5eaa3f4877aa --- /dev/null +++ b/Dockerfile @@ -0,0 +1,8 @@ +FROM kong:3.6.1 + +USER root + +COPY kong/plugins/ai-proxy/ /usr/local/share/lua/5.1/kong/plugins/ai-proxy/ +COPY kong/llm/ /usr/local/share/lua/5.1/kong/llm/ + +USER kong diff --git a/kong/llm/drivers/gemini.lua b/kong/llm/drivers/gemini.lua index cf674bc1fc2c..c09b0561c01c 100644 --- a/kong/llm/drivers/gemini.lua +++ b/kong/llm/drivers/gemini.lua @@ -81,9 +81,53 @@ local function to_gemini_chat_openai(request_table, model_info, route_type) return new_r, "application/json", nil end - local err = "empty request table received for transformation" - ngx.log(ngx.ERR, err) - return nil, nil, err + local new_r = {} + + if request_table.messages and #request_table.messages > 0 then + local system_prompt + + for i, v in ipairs(request_table.messages) do + + -- for 'system', we just concat them all into one Gemini instruction + if v.role and v.role == "system" then + system_prompt = system_prompt or buffer.new() + system_prompt:put(v.content or "") + else + -- for any other role, just construct the chat history as 'parts.text' type + new_r.contents = new_r.contents or {} + table_insert(new_r.contents, { + role = _OPENAI_ROLE_MAPPING[v.role or "user"], -- default to 'user' + parts = { + { + text = v.content or "" + }, + }, + }) + end + end + + -- only works for gemini 1.5+ + -- if system_prompt then + -- if string_sub(model_info.name, 1, 10) == "gemini-1.0" then + -- return nil, nil, "system prompts only work with gemini models 1.5 or later" + -- end + + -- new_r.systemInstruction = { + -- parts = { + -- { + -- text = system_prompt:get(), + -- }, + -- }, + -- } + -- end + -- + end + + kong.log.debug(cjson.encode(new_r)) + + new_r.generationConfig = to_gemini_generation_config(request_table) + + return new_r, "application/json", nil end local function from_gemini_chat_openai(response, model_info, route_type) @@ -184,7 +228,7 @@ function _M.to_format(request_table, model_info, route_type) model_info ) if err or (not ok) then - return nil, nil, fmt("error transforming to %s://%s", model_info.provider, route_type) + return nil, nil, fmt("error transforming to %s://%s: %s", model_info.provider, route_type, err) end return response_object, content_type, nil diff --git a/kong/llm/drivers/shared.lua b/kong/llm/drivers/shared.lua index 9babf77019ae..cc1ca7915a06 100644 --- a/kong/llm/drivers/shared.lua +++ b/kong/llm/drivers/shared.lua @@ -108,7 +108,7 @@ _M.operation_map = { }, gemini = { ["llm/v1/chat"] = { - path = "/v1/models/%s:generateContent", -- /v1/models/gemini-pro:generateContent, + path = "/v1/models/%s:generateContent", method = "POST", }, }, From f5d75bb1b2292862458636130598161b4b9100a7 Mon Sep 17 00:00:00 2001 From: Jack Tysoe Date: Tue, 21 May 2024 03:32:35 +0100 Subject: [PATCH 04/20] working public and vertex gemini chat text --- Dockerfile | 1 + kong-3.8.0-0.rockspec | 2 ++ kong/llm/drivers/gemini.lua | 53 +++++++++++++++++++++++++++---------- kong/llm/drivers/shared.lua | 45 ++++++++++++++++++------------- kong/llm/init.lua | 1 - kong/llm/schemas/init.lua | 33 ++++++++++++++++++++++- 6 files changed, 101 insertions(+), 34 deletions(-) diff --git a/Dockerfile b/Dockerfile index 5eaa3f4877aa..df43e39b2ca1 100644 --- a/Dockerfile +++ b/Dockerfile @@ -2,6 +2,7 @@ FROM kong:3.6.1 USER root +RUN luarocks install lua-resty-gcp COPY kong/plugins/ai-proxy/ /usr/local/share/lua/5.1/kong/plugins/ai-proxy/ COPY kong/llm/ /usr/local/share/lua/5.1/kong/llm/ diff --git a/kong-3.8.0-0.rockspec b/kong-3.8.0-0.rockspec index f0c4dbe8ce12..a0f1c786e878 100644 --- a/kong-3.8.0-0.rockspec +++ b/kong-3.8.0-0.rockspec @@ -43,6 +43,7 @@ dependencies = { "lpeg == 1.1.0", "lua-resty-ljsonschema == 1.1.6-2", "lua-resty-snappy == 1.0-1", + "lua-resty-gcp == 0.0.13-1", } build = { type = "builtin", @@ -606,6 +607,7 @@ build = { ["kong.llm.drivers.anthropic"] = "kong/llm/drivers/anthropic.lua", ["kong.llm.drivers.mistral"] = "kong/llm/drivers/mistral.lua", ["kong.llm.drivers.llama2"] = "kong/llm/drivers/llama2.lua", + ["kong.llm.auth.gcp] = "kong/llm/auth/gcp.lua", ["kong.plugins.ai-prompt-decorator.handler"] = "kong/plugins/ai-prompt-decorator/handler.lua", ["kong.plugins.ai-prompt-decorator.schema"] = "kong/plugins/ai-prompt-decorator/schema.lua", diff --git a/kong/llm/drivers/gemini.lua b/kong/llm/drivers/gemini.lua index c09b0561c01c..cc57f0b39b47 100644 --- a/kong/llm/drivers/gemini.lua +++ b/kong/llm/drivers/gemini.lua @@ -310,23 +310,43 @@ end -- returns err or nil function _M.configure_request(conf) local parsed_url - - if (conf.model.options and conf.model.options.upstream_url) then - parsed_url = socket_url.parse(conf.model.options.upstream_url) - else - local path = conf.model.options - and conf.model.options.upstream_path - or ai_shared.operation_map[DRIVER_NAME][conf.route_type] - and fmt(ai_shared.operation_map[DRIVER_NAME][conf.route_type].path, conf.model.name) - or "/" - if not path then - return nil, fmt("operation %s is not supported for openai provider", conf.route_type) + local operation = kong.ctx.shared.ai_proxy_streaming_mode and "streamGenerateContent" + or "generateContent" + local f_url = conf.model.options and conf.model.options.upstream_url + + if not f_url then -- upstream_url override is not set + -- check if this is "public" or "vertex" gemini deployment + if conf.model.options.gemini + and conf.model.options.gemini.api_endpoint + and conf.model.options.gemini.project_id + and conf.model.options.gemini.location_id + then + -- vertex mode + f_url = fmt(ai_shared.upstream_url_format["gemini_vertex"], + conf.model.options.gemini.api_endpoint) .. + fmt(ai_shared.operation_map["gemini_vertex"][conf.route_type].path, + conf.model.options.gemini.project_id, + conf.model.options.gemini.location_id, + conf.model.name, + operation) + else + -- public mode + f_url = ai_shared.upstream_url_format["gemini"] .. + fmt(ai_shared.operation_map["gemini"][conf.route_type].path, + conf.model.name, + operation) end + end + + parsed_url = socket_url.parse(f_url) - parsed_url = socket_url.parse(ai_shared.upstream_url_format[DRIVER_NAME]) - parsed_url.path = path + kong.log.inspect(parsed_url) + + if conf.model.options and conf.model.options.upstream_path then + -- upstream path override is set (or templated from request params) + parsed_url.path = conf.model.options.upstream_path end - + -- if the path is read from a URL capture, ensure that it is valid parsed_url.path = string_gsub(parsed_url.path, "^/*", "/") @@ -350,6 +370,11 @@ function _M.configure_request(conf) kong.service.request.set_query(query_table) end + ---- DEBUG REMOVE THIS + local auth = require("resty.gcp.request.credentials.accesstoken"):new(conf.auth.gcp_service_account_json) + kong.service.request.set_header("Authorization", "Bearer " .. auth.token) + ---- + -- if auth_param_location is "form", it will have already been set in a global pre-request hook return true, nil end diff --git a/kong/llm/drivers/shared.lua b/kong/llm/drivers/shared.lua index cc1ca7915a06..f2560ffea6b3 100644 --- a/kong/llm/drivers/shared.lua +++ b/kong/llm/drivers/shared.lua @@ -63,6 +63,7 @@ _M.upstream_url_format = { cohere = "https://api.cohere.com:443", azure = "https://%s.openai.azure.com:443/openai/deployments/%s", gemini = "https://generativelanguage.googleapis.com", + gemini_vertex = "https://%s", } _M.operation_map = { @@ -108,7 +109,13 @@ _M.operation_map = { }, gemini = { ["llm/v1/chat"] = { - path = "/v1/models/%s:generateContent", + path = "/v1beta/models/%s:%s", + method = "POST", + }, + }, + gemini_vertex = { + ["llm/v1/chat"] = { + path = "/v1/projects/%s/locations/%s/publishers/google/models/%s:%s", method = "POST", }, }, @@ -414,24 +421,26 @@ function _M.resolve_plugin_conf(kong_request, conf) -- handle all other options for k, v in pairs(conf.model.options or {}) do - local prop_m = string_match(v or "", '%$%((.-)%)') - if prop_m then - local splitted = split(prop_m, '.') - if #splitted ~= 2 then - return nil, "cannot parse expression for field '" .. v .. "'" - end - - -- find the request parameter, with the configured name - prop_m, err = _M.conf_from_request(kong_request, splitted[1], splitted[2]) - if err then - return nil, err - end - if not prop_m then - return nil, splitted[1] .. " key " .. splitted[2] .. " was not provided" - end + if type(v) == "string" then + local prop_m = string_match(v or "", '%$%((.-)%)') + if prop_m then + local splitted = split(prop_m, '.') + if #splitted ~= 2 then + return nil, "cannot parse expression for field '" .. v .. "'" + end + + -- find the request parameter, with the configured name + prop_m, err = _M.conf_from_request(kong_request, splitted[1], splitted[2]) + if err then + return nil, err + end + if not prop_m then + return nil, splitted[1] .. " key " .. splitted[2] .. " was not provided" + end - -- replace the value - conf_m.model.options[k] = prop_m + -- replace the value + conf_m.model.options[k] = prop_m + end end end diff --git a/kong/llm/init.lua b/kong/llm/init.lua index 64533e14e700..85802e54b9c7 100644 --- a/kong/llm/init.lua +++ b/kong/llm/init.lua @@ -10,7 +10,6 @@ local _M = { config_schema = require "kong.llm.schemas", } - do -- formats_compatible is a map of formats that are compatible with each other. local formats_compatible = { diff --git a/kong/llm/schemas/init.lua b/kong/llm/schemas/init.lua index ec1586b5eeb2..7aaaf19b38ef 100644 --- a/kong/llm/schemas/init.lua +++ b/kong/llm/schemas/init.lua @@ -2,6 +2,25 @@ local typedefs = require("kong.db.schema.typedefs") local fmt = string.format +local gemini_options_schema = { + type = "record", + required = false, + fields = { + { api_endpoint = { + type = "string", + description = "If running Gemini on Vertex, specify the regional API endpoint (hostname only).", + required = false }}, + { project_id = { + type = "string", + description = "If running Gemini on Vertex, specify the project ID.", + required = false }}, + { location_id = { + type = "string", + description = "If running Gemini on Vertex, specify the location ID.", + required = false }}, + }, +} + local auth_schema = { type = "record", @@ -34,11 +53,22 @@ local auth_schema = { description = "Specify whether the 'param_name' and 'param_value' options go in a query string, or the POST form/JSON body.", required = false, one_of = { "query", "body" } }}, + { gcp_use_service_account = { + type = "boolean", + description = "Use service account auth for GCP-based providers and models.", + required = false, + default = false }}, + { gcp_service_account_json = { + type = "string", + description = "Set this field to the full JSON of the GCP service account to authenticate, if required. " .. + "If null (and gcp_use_service_account is true), Kong will attempt to read from " .. + "environment variable `GCP_SERVICE_ACCOUNT`.", + required = false, + referenceable = true }}, } } - local model_options_schema = { description = "Key/value settings for the model", type = "record", @@ -110,6 +140,7 @@ local model_options_schema = { .. "used when e.g. using the 'preserve' route_type.", type = "string", required = false }}, + { gemini = gemini_options_schema }, } } From 49a6784e730d1fe2212f2660b8934f009d1c554e Mon Sep 17 00:00:00 2001 From: Jack Tysoe Date: Tue, 21 May 2024 15:58:06 +0100 Subject: [PATCH 05/20] stash --- kong/llm/drivers/gemini.lua | 11 ++++++----- kong/plugins/ai-proxy/handler.lua | 7 ------- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/kong/llm/drivers/gemini.lua b/kong/llm/drivers/gemini.lua index cc57f0b39b47..13eb16024071 100644 --- a/kong/llm/drivers/gemini.lua +++ b/kong/llm/drivers/gemini.lua @@ -316,7 +316,8 @@ function _M.configure_request(conf) if not f_url then -- upstream_url override is not set -- check if this is "public" or "vertex" gemini deployment - if conf.model.options.gemini + if conf.model.options + and conf.model.options.gemini and conf.model.options.gemini.api_endpoint and conf.model.options.gemini.project_id and conf.model.options.gemini.location_id @@ -370,10 +371,10 @@ function _M.configure_request(conf) kong.service.request.set_query(query_table) end - ---- DEBUG REMOVE THIS - local auth = require("resty.gcp.request.credentials.accesstoken"):new(conf.auth.gcp_service_account_json) - kong.service.request.set_header("Authorization", "Bearer " .. auth.token) - ---- + -- ---- DEBUG REMOVE THIS + -- local auth = require("resty.gcp.request.credentials.accesstoken"):new(conf.auth.gcp_service_account_json) + -- kong.service.request.set_header("Authorization", "Bearer " .. auth.token) + -- ---- -- if auth_param_location is "form", it will have already been set in a global pre-request hook return true, nil diff --git a/kong/plugins/ai-proxy/handler.lua b/kong/plugins/ai-proxy/handler.lua index 35e13fbe8d9e..17eea0667246 100644 --- a/kong/plugins/ai-proxy/handler.lua +++ b/kong/plugins/ai-proxy/handler.lua @@ -365,13 +365,6 @@ function _M:access(conf) end end - -- check the incoming format is the same as the configured LLM format - local compatible, err = llm.is_compatible(request_table, route_type) - if not compatible then - kong_ctx_shared.skip_response_transformer = true - return bad_request(err) - end - -- check if the user has asked for a stream, and/or if -- we are forcing all requests to be of streaming type if request_table and request_table.stream or From 32c9528e801a42e97e534c14ac2d3f4a721908bf Mon Sep 17 00:00:00 2001 From: Jack Tysoe Date: Tue, 21 May 2024 21:43:00 +0100 Subject: [PATCH 06/20] working static format for vertex; patched broken streaming termination --- kong-3.8.0-0.rockspec | 3 +- kong/llm/drivers/gemini.lua | 29 +++++++++++---- kong/llm/drivers/shared.lua | 7 +++- kong/plugins/ai-proxy/handler.lua | 61 ++++++++++++++++++++++++++++--- 4 files changed, 85 insertions(+), 15 deletions(-) diff --git a/kong-3.8.0-0.rockspec b/kong-3.8.0-0.rockspec index a0f1c786e878..bcf8b38cb059 100644 --- a/kong-3.8.0-0.rockspec +++ b/kong-3.8.0-0.rockspec @@ -35,6 +35,7 @@ dependencies = { "lua-messagepack == 0.5.4", "lua-resty-aws == 1.5.0", "lua-resty-openssl == 1.4.0", + "lua-resty-gcp == 0.0.13", "lua-resty-counter == 0.2.1", "lua-resty-ipmatcher == 0.6.1", "lua-resty-acme == 0.14.0", @@ -43,7 +44,6 @@ dependencies = { "lpeg == 1.1.0", "lua-resty-ljsonschema == 1.1.6-2", "lua-resty-snappy == 1.0-1", - "lua-resty-gcp == 0.0.13-1", } build = { type = "builtin", @@ -607,7 +607,6 @@ build = { ["kong.llm.drivers.anthropic"] = "kong/llm/drivers/anthropic.lua", ["kong.llm.drivers.mistral"] = "kong/llm/drivers/mistral.lua", ["kong.llm.drivers.llama2"] = "kong/llm/drivers/llama2.lua", - ["kong.llm.auth.gcp] = "kong/llm/auth/gcp.lua", ["kong.plugins.ai-prompt-decorator.handler"] = "kong/plugins/ai-prompt-decorator/handler.lua", ["kong.plugins.ai-prompt-decorator.schema"] = "kong/plugins/ai-prompt-decorator/schema.lua", diff --git a/kong/llm/drivers/gemini.lua b/kong/llm/drivers/gemini.lua index 13eb16024071..f21ad8494110 100644 --- a/kong/llm/drivers/gemini.lua +++ b/kong/llm/drivers/gemini.lua @@ -308,7 +308,7 @@ function _M.pre_request(conf, body) end -- returns err or nil -function _M.configure_request(conf) +function _M.configure_request(conf, identity_interface) local parsed_url local operation = kong.ctx.shared.ai_proxy_streaming_mode and "streamGenerateContent" or "generateContent" @@ -361,6 +361,7 @@ function _M.configure_request(conf) local auth_param_value = conf.auth and conf.auth.param_value local auth_param_location = conf.auth and conf.auth.param_location + -- DBO restrictions makes sure that only one of these auth blocks runs in one plugin config if auth_header_name and auth_header_value then kong.service.request.set_header(auth_header_name, auth_header_value) end @@ -370,14 +371,28 @@ function _M.configure_request(conf) query_table[auth_param_name] = auth_param_value kong.service.request.set_query(query_table) end + -- if auth_param_location is "form", it will have already been set in a global pre-request hook - -- ---- DEBUG REMOVE THIS - -- local auth = require("resty.gcp.request.credentials.accesstoken"):new(conf.auth.gcp_service_account_json) - -- kong.service.request.set_header("Authorization", "Bearer " .. auth.token) - -- ---- + -- if we're passed a GCP SDK, for cloud identity / SSO, use it appropriately + if identity_interface then + if identity_interface:needsRefresh() then + -- HACK: A bug in lua-resty-gcp tries to re-load the environment + -- variable every time, which fails in nginx + -- Create a whole new interface instead. + -- Memory leaks are mega unlikely because this should only + -- happen about once an hour, and the old one will be + -- cleaned up anyway. + local service_account_json = identity_interface.service_account_json + local identity_interface_new = identity_interface:new(service_account_json) + identity_interface.token = identity_interface_new.token + + kong.log.notice("gcp identity token for ", kong.plugin.get_id(), " has been refreshed") + end - -- if auth_param_location is "form", it will have already been set in a global pre-request hook - return true, nil + kong.service.request.set_header("Authorization", "Bearer " .. identity_interface.token) + end + + return true end return _M diff --git a/kong/llm/drivers/shared.lua b/kong/llm/drivers/shared.lua index f2560ffea6b3..d23edd9ab403 100644 --- a/kong/llm/drivers/shared.lua +++ b/kong/llm/drivers/shared.lua @@ -16,7 +16,7 @@ local split = require("kong.tools.string").split local cycle_aware_deep_copy = require("kong.tools.table").cycle_aware_deep_copy local function str_ltrim(s) -- remove leading whitespace from string. - return (s:gsub("^%s*", "")) + return type(s) == "string" and s:gsub("^%s*", "") end -- @@ -221,6 +221,11 @@ end function _M.frame_to_events(frame) local events = {} + if (not frame) or #frame < 1 then + return + end + + -- check if it's raw json and just return the split up data frame -- Cohere / Other flat-JSON format parser -- just return the split up data frame if (not kong or not kong.ctx.plugin.truncated_frame) and string.sub(str_ltrim(frame), 1, 1) == "{" then diff --git a/kong/plugins/ai-proxy/handler.lua b/kong/plugins/ai-proxy/handler.lua index 17eea0667246..dbdebefa5c56 100644 --- a/kong/plugins/ai-proxy/handler.lua +++ b/kong/plugins/ai-proxy/handler.lua @@ -6,6 +6,14 @@ local kong_meta = require("kong.meta") local buffer = require "string.buffer" local strip = require("kong.tools.utils").strip +-- cloud auth/sdk providers +local GCP_SERVICE_ACCOUNT do + GCP_SERVICE_ACCOUNT = os.getenv("GCP_SERVICE_ACCOUNT") +end + +local GCP = require("resty.gcp.request.credentials.accesstoken") +-- + local EMPTY = {} @@ -16,16 +24,37 @@ local _M = { } +local _KEYBASTION = setmetatable({}, { + __mode = "k", + __index = function(this_cache, plugin_config) + if plugin_config.model.provider == "gemini" and + plugin_config.auth and + plugin_config.auth.gcp_use_service_account then + + ngx.log(ngx.NOTICE, "loading gcp sdk for plugin ", kong.plugin.get_id()) + + local service_account_json = (plugin_config.auth and plugin_config.auth.gcp_service_account_json) or GCP_SERVICE_ACCOUNT + + local ok, gcp_auth = pcall(GCP.new, nil, service_account_json) + if ok and gcp_auth then + -- store our item for the next time we need it + gcp_auth.service_account_json = service_account_json + this_cache[plugin_config] = { interface = gcp_auth, error = nil } + return this_cache[plugin_config] + end + + return { interface = nil, error = "cloud-authentication with GCP failed" } + end + end, +}) + ---- Return a 400 response with a JSON body. This function is used to --- return errors to the client while also logging the error. local function bad_request(msg) kong.log.info(msg) return kong.response.exit(400, { error = { message = msg } }) end - -- get the token text from an event frame local function get_token_text(event_t) -- get: event_t.choices[1] @@ -68,6 +97,19 @@ local function handle_streaming_frame(conf) local events = ai_shared.frame_to_events(chunk) + if not events then + local response = 'data: {"error": true, "message": "empty transformer response"}' + + if is_gzip then + response = kong_utils.deflate_gzip(response) + end + + ngx.arg[1] = response + ngx.arg[2] = true + + return + end + for _, event in ipairs(events) do local formatted, _, metadata = ai_driver.from_format(event, conf.model, "stream/" .. conf.route_type) @@ -320,7 +362,7 @@ function _M:access(conf) if not request_table then if not string.find(content_type, "multipart/form-data", nil, true) then - return bad_request("content-type header does not match request body") + return bad_request("content-type header does not match request body, or bad JSON formatting") end multipart = true -- this may be a large file upload, so we have to proxy it directly @@ -424,8 +466,17 @@ function _M:access(conf) kong.service.request.set_body(parsed_request_body, content_type) end + -- get the provider's cached identity interface - nil may come back, which is fine + local identity_interface = _KEYBASTION[conf] + if identity_interface.error then + kong.ctx.shared.skip_response_transformer = true + kong.log.err("error authenticating with cloud-provider, ", identity_interface.error) + + return internal_server_error("LLM request failed before proxying") + end + -- now re-configure the request for this operation type - local ok, err = ai_driver.configure_request(conf_m) + local ok, err = ai_driver.configure_request(conf_m, identity_interface.interface) if not ok then kong_ctx_shared.skip_response_transformer = true kong.log.err("failed to configure request for AI service: ", err) From e4c3137fb72bc14991dcc1300d980fceeac5479e Mon Sep 17 00:00:00 2001 From: Jack Tysoe Date: Wed, 22 May 2024 00:40:54 +0100 Subject: [PATCH 07/20] finished gemini (text-only) support --- kong/llm/drivers/gemini.lua | 75 ++++++++++++++++++++++++++++--- kong/llm/drivers/shared.lua | 63 ++++++++++++++++---------- kong/plugins/ai-proxy/handler.lua | 26 +++++++---- 3 files changed, 126 insertions(+), 38 deletions(-) diff --git a/kong/llm/drivers/gemini.lua b/kong/llm/drivers/gemini.lua index f21ad8494110..eba9ed02afb8 100644 --- a/kong/llm/drivers/gemini.lua +++ b/kong/llm/drivers/gemini.lua @@ -32,6 +32,69 @@ local function to_gemini_generation_config(request_table) } end +local function handle_stream_event(event_t, model_info, route_type) + local metadata + + + -- discard empty frames, it should either be a random new line, or comment + if (not event_t.data) or (#event_t.data < 1) then + return + end + + local event, err = cjson.decode(event_t.data) + if err then + ngx.log(ngx.WARN, "failed to decode stream event frame from gemini: " .. err) + return nil, "failed to decode stream event frame from gemini", nil + end + + local new_event + local metadata + + if event.candidates and + #event.candidates > 0 then + + if event.candidates[1].content and + event.candidates[1].content.parts and + #event.candidates[1].content.parts > 0 and + event.candidates[1].content.parts[1].text then + + new_event = { + choices = { + [1] = { + delta = { + content = event.candidates[1].content.parts[1].text or "", + role = "assistant", + }, + index = 0, + }, + }, + } + end + + if event.candidates[1].finishReason then + metadata = metadata or {} + metadata.finished_reason = event.candidates[1].finishReason + new_event = "[DONE]" + end + end + + if event.usageMetadata then + metadata = metadata or {} + metadata.completion_tokens = event.usageMetadata.candidatesTokenCount or 0 + metadata.prompt_tokens = event.usageMetadata.promptTokenCount or 0 + end + + if new_event then + if new_event ~= "[DONE]" then + new_event = cjson.encode(new_event) + end + + return new_event, nil, metadata + else + return nil, nil, metadata -- caller code will handle "unrecognised" event types + end +end + local function to_gemini_chat_openai(request_table, model_info, route_type) if request_table then -- try-catch type mechanism local new_r = {} @@ -180,12 +243,11 @@ end local transformers_to = { ["llm/v1/chat"] = to_gemini_chat_openai, - ["gemini/v1/chat"] = to_gemini_chat_gemini, } local transformers_from = { ["llm/v1/chat"] = from_gemini_chat_openai, - ["gemini/v1/chat"] = from_gemini_chat_gemini, + ["stream/llm/v1/chat"] = handle_stream_event, } function _M.from_format(response_string, model_info, route_type) @@ -196,7 +258,7 @@ function _M.from_format(response_string, model_info, route_type) return nil, fmt("no transformer available from format %s://%s", model_info.provider, route_type) end - local ok, response_string, err = pcall(transformers_from[route_type], response_string, model_info, route_type) + local ok, response_string, err, metadata = pcall(transformers_from[route_type], response_string, model_info, route_type) if not ok or err then return nil, fmt("transformation failed from type %s://%s: %s", model_info.provider, @@ -205,7 +267,7 @@ function _M.from_format(response_string, model_info, route_type) ) end - return response_string, nil + return response_string, nil, metadata end function _M.to_format(request_table, model_info, route_type) @@ -302,7 +364,8 @@ function _M.post_request(conf) end function _M.pre_request(conf, body) - kong.service.request.set_header("Accept-Encoding", "gzip, identity") -- tell server not to send brotli + -- disable gzip for gemini because it breaks streaming + kong.service.request.set_header("Accept-Encoding", "identity") return true, nil end @@ -341,8 +404,6 @@ function _M.configure_request(conf, identity_interface) parsed_url = socket_url.parse(f_url) - kong.log.inspect(parsed_url) - if conf.model.options and conf.model.options.upstream_path then -- upstream path override is set (or templated from request params) parsed_url.path = conf.model.options.upstream_path diff --git a/kong/llm/drivers/shared.lua b/kong/llm/drivers/shared.lua index d23edd9ab403..a74f1352b9ca 100644 --- a/kong/llm/drivers/shared.lua +++ b/kong/llm/drivers/shared.lua @@ -55,6 +55,7 @@ _M.streaming_has_token_counts = { ["cohere"] = true, ["llama2"] = true, ["anthropic"] = true, + ["gemini"] = true, } _M.upstream_url_format = { @@ -216,12 +217,12 @@ end -- as if it were an SSE message. -- -- @param {string} frame input string to format into SSE events --- @param {string} delimiter delimeter (can be complex string) to split by +-- @param {boolean} raw_json sets application/json byte-parser mode -- @return {table} n number of split SSE messages, or empty table -function _M.frame_to_events(frame) +function _M.frame_to_events(frame, raw_json_mode) local events = {} - if (not frame) or #frame < 1 then + if (not frame) or (#frame < 1) or (type(frame)) ~= "string" then return end @@ -234,36 +235,52 @@ function _M.frame_to_events(frame) data = event, } end + + -- some new LLMs return the JSON object-by-object, + -- because that totally makes sense to parse?! + elseif raw_json_mode then + -- if this is the first frame, it will begin with array opener '[' + frame = (string.sub(str_ltrim(frame), 1, 1) == "[" and string.sub(str_ltrim(frame), 2)) or frame + + -- it may start with ',' which is the start of the new frame + frame = (string.sub(str_ltrim(frame), 1, 1) == "," and string.sub(str_ltrim(frame), 2)) or frame + + -- finally, it may end with the array terminator ']' indicating the finished stream + frame = (string.sub(str_ltrim(frame), -1) == "]" and string.sub(str_ltrim(frame), 1, -2)) or frame + + -- for multiple events that arrive in the same frame, split by top-level comma + for _, v in ipairs(split(frame, "\n,")) do + events[#events+1] = { data = v } + end + else -- standard SSE parser local event_lines = split(frame, "\n") local struct = { event = nil, id = nil, data = nil } - for i, dat in ipairs(event_lines) do - if #dat < 1 then - events[#events + 1] = struct - struct = { event = nil, id = nil, data = nil } - end + -- test for truncated chunk on the last line (no trailing \r\n\r\n) + if #dat > 0 and #event_lines == i then + ngx.log(ngx.DEBUG, "[ai-proxy] truncated sse frame head") + kong.ctx.plugin.truncated_frame = dat + break -- stop parsing immediately, server has done something wrong + end - -- test for truncated chunk on the last line (no trailing \r\n\r\n) - if #dat > 0 and #event_lines == i then - ngx.log(ngx.DEBUG, "[ai-proxy] truncated sse frame head") - kong.ctx.plugin.truncated_frame = dat - break -- stop parsing immediately, server has done something wrong - end + -- test for abnormal start-of-frame (truncation tail) + if kong and kong.ctx.plugin.truncated_frame then + -- this is the tail of a previous incomplete chunk + ngx.log(ngx.DEBUG, "[ai-proxy] truncated sse frame tail") + dat = fmt("%s%s", kong.ctx.plugin.truncated_frame, dat) + kong.ctx.plugin.truncated_frame = nil + end - -- test for abnormal start-of-frame (truncation tail) - if kong and kong.ctx.plugin.truncated_frame then - -- this is the tail of a previous incomplete chunk - ngx.log(ngx.DEBUG, "[ai-proxy] truncated sse frame tail") - dat = fmt("%s%s", kong.ctx.plugin.truncated_frame, dat) - kong.ctx.plugin.truncated_frame = nil - end + local s1, _ = str_find(dat, ":") -- find where the cut point is - local s1, _ = str_find(dat, ":") -- find where the cut point is + if s1 and s1 ~= 1 then + local field = str_sub(dat, 1, s1-1) -- returns "data" from data: hello world + local value = str_ltrim(str_sub(dat, s1+1)) -- returns "hello world" from data: hello world if s1 and s1 ~= 1 then - local field = str_sub(dat, 1, s1-1) -- returns "data" from data: hello world + local field = str_sub(dat, 1, s1-1) -- returns "data " from data: hello world local value = str_ltrim(str_sub(dat, s1+1)) -- returns "hello world" from data: hello world -- for now not checking if the value is already been set diff --git a/kong/plugins/ai-proxy/handler.lua b/kong/plugins/ai-proxy/handler.lua index dbdebefa5c56..cb9a16663ddb 100644 --- a/kong/plugins/ai-proxy/handler.lua +++ b/kong/plugins/ai-proxy/handler.lua @@ -24,6 +24,11 @@ local _M = { } +-- static messages +local ERROR_MSG = { error = { message = "" } } +local ERROR__NOT_SET = 'data: {"error": true, "message": "empty or unsupported transformer response"}' + + local _KEYBASTION = setmetatable({}, { __mode = "k", __index = function(this_cache, plugin_config) @@ -92,14 +97,17 @@ local function handle_streaming_frame(conf) -- because we have already 200 OK'd the client by now if (not finished) and (is_gzip) then - chunk = kong_utils.inflate_gzip(chunk) + chunk = kong_utils.inflate_gzip(ngx.arg[1]) end - local events = ai_shared.frame_to_events(chunk) + local events = ai_shared.frame_to_events(chunk, conf.model.provider == "gemini") if not events then - local response = 'data: {"error": true, "message": "empty transformer response"}' - + -- usually a not-supported-transformer or empty frames. + -- header_filter has already run, so all we can do is log it, + -- and then send the client a readable error in a single chunk + local response = ERROR__NOT_SET + if is_gzip then response = kong_utils.deflate_gzip(response) end @@ -419,8 +427,9 @@ function _M:access(conf) return bad_request("response streaming is not enabled for this LLM") end - -- store token cost estimate, on first pass - if not kong_ctx_plugin.ai_stream_prompt_tokens then + -- store token cost estimate, on first pass, if the + -- provider doesn't reply with a prompt token count + if (not kong.ctx.plugin.ai_stream_prompt_tokens) and (not ai_shared.streaming_has_token_counts[conf_m.model.provider]) then local prompt_tokens, err = ai_shared.calculate_cost(request_table or {}, {}, 1.8) if err then kong.log.err("unable to estimate request token cost: ", err) @@ -468,7 +477,7 @@ function _M:access(conf) -- get the provider's cached identity interface - nil may come back, which is fine local identity_interface = _KEYBASTION[conf] - if identity_interface.error then + if identity_interface and identity_interface.error then kong.ctx.shared.skip_response_transformer = true kong.log.err("error authenticating with cloud-provider, ", identity_interface.error) @@ -476,7 +485,8 @@ function _M:access(conf) end -- now re-configure the request for this operation type - local ok, err = ai_driver.configure_request(conf_m, identity_interface.interface) + local ok, err = ai_driver.configure_request(conf_m, + identity_interface and identity_interface.interface) if not ok then kong_ctx_shared.skip_response_transformer = true kong.log.err("failed to configure request for AI service: ", err) From d148a9256d1d2e2eeda6d28b73e6527cc203fefe Mon Sep 17 00:00:00 2001 From: Jack Tysoe Date: Wed, 22 May 2024 00:47:17 +0100 Subject: [PATCH 08/20] entity catch on gemini --- kong/llm/schemas/init.lua | 3 +++ 1 file changed, 3 insertions(+) diff --git a/kong/llm/schemas/init.lua b/kong/llm/schemas/init.lua index 7aaaf19b38ef..9dc68f16db8a 100644 --- a/kong/llm/schemas/init.lua +++ b/kong/llm/schemas/init.lua @@ -19,6 +19,9 @@ local gemini_options_schema = { description = "If running Gemini on Vertex, specify the location ID.", required = false }}, }, + entity_checks = { + { mutually_required = { "api_endpoint", "project_id", "location_id" }, }, + }, } From 3ef1f50750d7f4eeb36d803c833e721f05d213f3 Mon Sep 17 00:00:00 2001 From: Jack Tysoe Date: Fri, 28 Jun 2024 15:51:25 +0100 Subject: [PATCH 09/20] repair gemini split logic --- kong/llm/drivers/shared.lua | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/kong/llm/drivers/shared.lua b/kong/llm/drivers/shared.lua index a74f1352b9ca..70eeb64085fe 100644 --- a/kong/llm/drivers/shared.lua +++ b/kong/llm/drivers/shared.lua @@ -226,19 +226,9 @@ function _M.frame_to_events(frame, raw_json_mode) return end - -- check if it's raw json and just return the split up data frame - -- Cohere / Other flat-JSON format parser - -- just return the split up data frame - if (not kong or not kong.ctx.plugin.truncated_frame) and string.sub(str_ltrim(frame), 1, 1) == "{" then - for event in frame:gmatch("[^\r\n]+") do - events[#events + 1] = { - data = event, - } - end - -- some new LLMs return the JSON object-by-object, -- because that totally makes sense to parse?! - elseif raw_json_mode then + if raw_json_mode then -- if this is the first frame, it will begin with array opener '[' frame = (string.sub(str_ltrim(frame), 1, 1) == "[" and string.sub(str_ltrim(frame), 2)) or frame @@ -253,8 +243,18 @@ function _M.frame_to_events(frame, raw_json_mode) events[#events+1] = { data = v } end + -- check if it's raw json and just return the split up data frame + -- Cohere / Other flat-JSON format parser + -- just return the split up data frame + elseif (not kong or not kong.ctx.plugin.truncated_frame) and string.sub(str_ltrim(frame), 1, 1) == "{" then + for event in frame:gmatch("[^\r\n]+") do + events[#events + 1] = { + data = event, + } + end + + -- standard SSE parser else - -- standard SSE parser local event_lines = split(frame, "\n") local struct = { event = nil, id = nil, data = nil } From c207f2bad90d490d0afe27c837cd9823b5d7fe4a Mon Sep 17 00:00:00 2001 From: Jack Tysoe Date: Fri, 28 Jun 2024 16:38:05 +0100 Subject: [PATCH 10/20] fix(ai-proxy): gemini streaming transformer bug --- kong/llm/drivers/gemini.lua | 4 +--- kong/llm/drivers/shared.lua | 38 +++++++++++++++++++------------------ 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/kong/llm/drivers/gemini.lua b/kong/llm/drivers/gemini.lua index eba9ed02afb8..8344e990e33a 100644 --- a/kong/llm/drivers/gemini.lua +++ b/kong/llm/drivers/gemini.lua @@ -136,11 +136,9 @@ local function to_gemini_chat_openai(request_table, model_info, route_type) } end end - + new_r.generationConfig = to_gemini_generation_config(request_table) - kong.log.debug(cjson.encode(new_r)) - return new_r, "application/json", nil end diff --git a/kong/llm/drivers/shared.lua b/kong/llm/drivers/shared.lua index 70eeb64085fe..25d9d5773149 100644 --- a/kong/llm/drivers/shared.lua +++ b/kong/llm/drivers/shared.lua @@ -258,29 +258,31 @@ function _M.frame_to_events(frame, raw_json_mode) local event_lines = split(frame, "\n") local struct = { event = nil, id = nil, data = nil } - -- test for truncated chunk on the last line (no trailing \r\n\r\n) - if #dat > 0 and #event_lines == i then - ngx.log(ngx.DEBUG, "[ai-proxy] truncated sse frame head") - kong.ctx.plugin.truncated_frame = dat - break -- stop parsing immediately, server has done something wrong - end + for i, dat in ipairs(event_lines) do + if #dat < 1 then + events[#events + 1] = struct + struct = { event = nil, id = nil, data = nil } + end - -- test for abnormal start-of-frame (truncation tail) - if kong and kong.ctx.plugin.truncated_frame then - -- this is the tail of a previous incomplete chunk - ngx.log(ngx.DEBUG, "[ai-proxy] truncated sse frame tail") - dat = fmt("%s%s", kong.ctx.plugin.truncated_frame, dat) - kong.ctx.plugin.truncated_frame = nil - end + -- test for truncated chunk on the last line (no trailing \r\n\r\n) + if #dat > 0 and #event_lines == i then + ngx.log(ngx.DEBUG, "[ai-proxy] truncated sse frame head") + kong.ctx.plugin.truncated_frame = dat + break -- stop parsing immediately, server has done something wrong + end - local s1, _ = str_find(dat, ":") -- find where the cut point is + -- test for abnormal start-of-frame (truncation tail) + if kong and kong.ctx.plugin.truncated_frame then + -- this is the tail of a previous incomplete chunk + ngx.log(ngx.DEBUG, "[ai-proxy] truncated sse frame tail") + dat = fmt("%s%s", kong.ctx.plugin.truncated_frame, dat) + kong.ctx.plugin.truncated_frame = nil + end - if s1 and s1 ~= 1 then - local field = str_sub(dat, 1, s1-1) -- returns "data" from data: hello world - local value = str_ltrim(str_sub(dat, s1+1)) -- returns "hello world" from data: hello world + local s1, _ = str_find(dat, ":") -- find where the cut point is if s1 and s1 ~= 1 then - local field = str_sub(dat, 1, s1-1) -- returns "data " from data: hello world + local field = str_sub(dat, 1, s1-1) -- returns "data" from data: hello world local value = str_ltrim(str_sub(dat, s1+1)) -- returns "hello world" from data: hello world -- for now not checking if the value is already been set From dd303d6ac199ee94030fd30ef23feb46a0947502 Mon Sep 17 00:00:00 2001 From: Jack Tysoe Date: Tue, 21 May 2024 03:32:35 +0100 Subject: [PATCH 11/20] working public and vertex gemini chat text --- kong-3.8.0-0.rockspec | 2 ++ 1 file changed, 2 insertions(+) diff --git a/kong-3.8.0-0.rockspec b/kong-3.8.0-0.rockspec index bcf8b38cb059..7f86eef73561 100644 --- a/kong-3.8.0-0.rockspec +++ b/kong-3.8.0-0.rockspec @@ -44,6 +44,7 @@ dependencies = { "lpeg == 1.1.0", "lua-resty-ljsonschema == 1.1.6-2", "lua-resty-snappy == 1.0-1", + "lua-resty-gcp == 0.0.13-1", } build = { type = "builtin", @@ -607,6 +608,7 @@ build = { ["kong.llm.drivers.anthropic"] = "kong/llm/drivers/anthropic.lua", ["kong.llm.drivers.mistral"] = "kong/llm/drivers/mistral.lua", ["kong.llm.drivers.llama2"] = "kong/llm/drivers/llama2.lua", + ["kong.llm.auth.gcp] = "kong/llm/auth/gcp.lua", ["kong.plugins.ai-prompt-decorator.handler"] = "kong/plugins/ai-prompt-decorator/handler.lua", ["kong.plugins.ai-prompt-decorator.schema"] = "kong/plugins/ai-prompt-decorator/schema.lua", From 2f998b6704d37eadbf48c82df9a0a762db51f61e Mon Sep 17 00:00:00 2001 From: Jack Tysoe Date: Tue, 21 May 2024 21:43:00 +0100 Subject: [PATCH 12/20] working static format for vertex; patched broken streaming termination --- kong-3.8.0-0.rockspec | 2 -- kong/plugins/ai-proxy/handler.lua | 13 +++++++++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/kong-3.8.0-0.rockspec b/kong-3.8.0-0.rockspec index 7f86eef73561..bcf8b38cb059 100644 --- a/kong-3.8.0-0.rockspec +++ b/kong-3.8.0-0.rockspec @@ -44,7 +44,6 @@ dependencies = { "lpeg == 1.1.0", "lua-resty-ljsonschema == 1.1.6-2", "lua-resty-snappy == 1.0-1", - "lua-resty-gcp == 0.0.13-1", } build = { type = "builtin", @@ -608,7 +607,6 @@ build = { ["kong.llm.drivers.anthropic"] = "kong/llm/drivers/anthropic.lua", ["kong.llm.drivers.mistral"] = "kong/llm/drivers/mistral.lua", ["kong.llm.drivers.llama2"] = "kong/llm/drivers/llama2.lua", - ["kong.llm.auth.gcp] = "kong/llm/auth/gcp.lua", ["kong.plugins.ai-prompt-decorator.handler"] = "kong/plugins/ai-prompt-decorator/handler.lua", ["kong.plugins.ai-prompt-decorator.schema"] = "kong/plugins/ai-prompt-decorator/schema.lua", diff --git a/kong/plugins/ai-proxy/handler.lua b/kong/plugins/ai-proxy/handler.lua index cb9a16663ddb..68a437e72f44 100644 --- a/kong/plugins/ai-proxy/handler.lua +++ b/kong/plugins/ai-proxy/handler.lua @@ -118,6 +118,19 @@ local function handle_streaming_frame(conf) return end + if not events then + local response = 'data: {"error": true, "message": "empty transformer response"}' + + if is_gzip then + response = kong_utils.deflate_gzip(response) + end + + ngx.arg[1] = response + ngx.arg[2] = true + + return + end + for _, event in ipairs(events) do local formatted, _, metadata = ai_driver.from_format(event, conf.model, "stream/" .. conf.route_type) From 569054a8fe244e5a820f12ce6087020699373f23 Mon Sep 17 00:00:00 2001 From: Jack Tysoe Date: Wed, 22 May 2024 00:40:54 +0100 Subject: [PATCH 13/20] finished gemini (text-only) support --- kong/llm/drivers/shared.lua | 38 +++++++++++++++---------------- kong/plugins/ai-proxy/handler.lua | 7 ++++-- 2 files changed, 23 insertions(+), 22 deletions(-) diff --git a/kong/llm/drivers/shared.lua b/kong/llm/drivers/shared.lua index 25d9d5773149..70eeb64085fe 100644 --- a/kong/llm/drivers/shared.lua +++ b/kong/llm/drivers/shared.lua @@ -258,31 +258,29 @@ function _M.frame_to_events(frame, raw_json_mode) local event_lines = split(frame, "\n") local struct = { event = nil, id = nil, data = nil } - for i, dat in ipairs(event_lines) do - if #dat < 1 then - events[#events + 1] = struct - struct = { event = nil, id = nil, data = nil } - end + -- test for truncated chunk on the last line (no trailing \r\n\r\n) + if #dat > 0 and #event_lines == i then + ngx.log(ngx.DEBUG, "[ai-proxy] truncated sse frame head") + kong.ctx.plugin.truncated_frame = dat + break -- stop parsing immediately, server has done something wrong + end - -- test for truncated chunk on the last line (no trailing \r\n\r\n) - if #dat > 0 and #event_lines == i then - ngx.log(ngx.DEBUG, "[ai-proxy] truncated sse frame head") - kong.ctx.plugin.truncated_frame = dat - break -- stop parsing immediately, server has done something wrong - end + -- test for abnormal start-of-frame (truncation tail) + if kong and kong.ctx.plugin.truncated_frame then + -- this is the tail of a previous incomplete chunk + ngx.log(ngx.DEBUG, "[ai-proxy] truncated sse frame tail") + dat = fmt("%s%s", kong.ctx.plugin.truncated_frame, dat) + kong.ctx.plugin.truncated_frame = nil + end - -- test for abnormal start-of-frame (truncation tail) - if kong and kong.ctx.plugin.truncated_frame then - -- this is the tail of a previous incomplete chunk - ngx.log(ngx.DEBUG, "[ai-proxy] truncated sse frame tail") - dat = fmt("%s%s", kong.ctx.plugin.truncated_frame, dat) - kong.ctx.plugin.truncated_frame = nil - end + local s1, _ = str_find(dat, ":") -- find where the cut point is - local s1, _ = str_find(dat, ":") -- find where the cut point is + if s1 and s1 ~= 1 then + local field = str_sub(dat, 1, s1-1) -- returns "data" from data: hello world + local value = str_ltrim(str_sub(dat, s1+1)) -- returns "hello world" from data: hello world if s1 and s1 ~= 1 then - local field = str_sub(dat, 1, s1-1) -- returns "data" from data: hello world + local field = str_sub(dat, 1, s1-1) -- returns "data " from data: hello world local value = str_ltrim(str_sub(dat, s1+1)) -- returns "hello world" from data: hello world -- for now not checking if the value is already been set diff --git a/kong/plugins/ai-proxy/handler.lua b/kong/plugins/ai-proxy/handler.lua index 68a437e72f44..b9204ed403e8 100644 --- a/kong/plugins/ai-proxy/handler.lua +++ b/kong/plugins/ai-proxy/handler.lua @@ -119,8 +119,11 @@ local function handle_streaming_frame(conf) end if not events then - local response = 'data: {"error": true, "message": "empty transformer response"}' - + -- usually a not-supported-transformer or empty frames. + -- header_filter has already run, so all we can do is log it, + -- and then send the client a readable error in a single chunk + local response = ERROR__NOT_SET + if is_gzip then response = kong_utils.deflate_gzip(response) end From 5043361d20d25553792656e1578890a21f0546e6 Mon Sep 17 00:00:00 2001 From: Jack Tysoe Date: Fri, 28 Jun 2024 16:38:05 +0100 Subject: [PATCH 14/20] fix(ai-proxy): gemini streaming transformer bug --- kong/llm/drivers/shared.lua | 38 +++++++++++++++++++------------------ 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/kong/llm/drivers/shared.lua b/kong/llm/drivers/shared.lua index 70eeb64085fe..25d9d5773149 100644 --- a/kong/llm/drivers/shared.lua +++ b/kong/llm/drivers/shared.lua @@ -258,29 +258,31 @@ function _M.frame_to_events(frame, raw_json_mode) local event_lines = split(frame, "\n") local struct = { event = nil, id = nil, data = nil } - -- test for truncated chunk on the last line (no trailing \r\n\r\n) - if #dat > 0 and #event_lines == i then - ngx.log(ngx.DEBUG, "[ai-proxy] truncated sse frame head") - kong.ctx.plugin.truncated_frame = dat - break -- stop parsing immediately, server has done something wrong - end + for i, dat in ipairs(event_lines) do + if #dat < 1 then + events[#events + 1] = struct + struct = { event = nil, id = nil, data = nil } + end - -- test for abnormal start-of-frame (truncation tail) - if kong and kong.ctx.plugin.truncated_frame then - -- this is the tail of a previous incomplete chunk - ngx.log(ngx.DEBUG, "[ai-proxy] truncated sse frame tail") - dat = fmt("%s%s", kong.ctx.plugin.truncated_frame, dat) - kong.ctx.plugin.truncated_frame = nil - end + -- test for truncated chunk on the last line (no trailing \r\n\r\n) + if #dat > 0 and #event_lines == i then + ngx.log(ngx.DEBUG, "[ai-proxy] truncated sse frame head") + kong.ctx.plugin.truncated_frame = dat + break -- stop parsing immediately, server has done something wrong + end - local s1, _ = str_find(dat, ":") -- find where the cut point is + -- test for abnormal start-of-frame (truncation tail) + if kong and kong.ctx.plugin.truncated_frame then + -- this is the tail of a previous incomplete chunk + ngx.log(ngx.DEBUG, "[ai-proxy] truncated sse frame tail") + dat = fmt("%s%s", kong.ctx.plugin.truncated_frame, dat) + kong.ctx.plugin.truncated_frame = nil + end - if s1 and s1 ~= 1 then - local field = str_sub(dat, 1, s1-1) -- returns "data" from data: hello world - local value = str_ltrim(str_sub(dat, s1+1)) -- returns "hello world" from data: hello world + local s1, _ = str_find(dat, ":") -- find where the cut point is if s1 and s1 ~= 1 then - local field = str_sub(dat, 1, s1-1) -- returns "data " from data: hello world + local field = str_sub(dat, 1, s1-1) -- returns "data" from data: hello world local value = str_ltrim(str_sub(dat, s1+1)) -- returns "hello world" from data: hello world -- for now not checking if the value is already been set From 14721f850396ab13c960fbbf4ac4cc43906ba6f2 Mon Sep 17 00:00:00 2001 From: Jack Tysoe <91137069+tysoekong@users.noreply.github.com> Date: Thu, 4 Jul 2024 13:00:43 +0100 Subject: [PATCH 15/20] Delete Dockerfile --- Dockerfile | 9 --------- 1 file changed, 9 deletions(-) delete mode 100644 Dockerfile diff --git a/Dockerfile b/Dockerfile deleted file mode 100644 index df43e39b2ca1..000000000000 --- a/Dockerfile +++ /dev/null @@ -1,9 +0,0 @@ -FROM kong:3.6.1 - -USER root - -RUN luarocks install lua-resty-gcp -COPY kong/plugins/ai-proxy/ /usr/local/share/lua/5.1/kong/plugins/ai-proxy/ -COPY kong/llm/ /usr/local/share/lua/5.1/kong/llm/ - -USER kong From 23d95f236895e49f8aacce66f893a5e4e43c8e47 Mon Sep 17 00:00:00 2001 From: Jack Tysoe <91137069+tysoekong@users.noreply.github.com> Date: Thu, 4 Jul 2024 13:01:13 +0100 Subject: [PATCH 16/20] Update kong/plugins/ai-proxy/handler.lua Co-authored-by: Wangchong Zhou --- kong/plugins/ai-proxy/handler.lua | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/kong/plugins/ai-proxy/handler.lua b/kong/plugins/ai-proxy/handler.lua index b9204ed403e8..fbe32f84fcb7 100644 --- a/kong/plugins/ai-proxy/handler.lua +++ b/kong/plugins/ai-proxy/handler.lua @@ -100,7 +100,8 @@ local function handle_streaming_frame(conf) chunk = kong_utils.inflate_gzip(ngx.arg[1]) end - local events = ai_shared.frame_to_events(chunk, conf.model.provider == "gemini") + local is_raw_json = conf.model.provider == "gemini" + local events = ai_shared.frame_to_events(chunk, is_raw_json ) if not events then -- usually a not-supported-transformer or empty frames. From 9dea66fdf66b1fcd1c0132849e8cd3dd5015fd8c Mon Sep 17 00:00:00 2001 From: Jack Tysoe Date: Thu, 4 Jul 2024 13:24:26 +0100 Subject: [PATCH 17/20] fix(ai-proxy): gemini duplication --- kong/llm/drivers/gemini.lua | 78 ++++++++++++++++--------------------- 1 file changed, 33 insertions(+), 45 deletions(-) diff --git a/kong/llm/drivers/gemini.lua b/kong/llm/drivers/gemini.lua index 8344e990e33a..b62f7f7d98dc 100644 --- a/kong/llm/drivers/gemini.lua +++ b/kong/llm/drivers/gemini.lua @@ -32,6 +32,23 @@ local function to_gemini_generation_config(request_table) } end +local function is_response_content(content) + return content + and content.candidates + and #content.candidates > 0 + and content.candidates[1].content + and content.candidates[1].content.parts + and #content.candidates[1].content.parts > 0 + and content.candidates[1].content.parts[1].text +end + +local function is_response_finished(content) + return content + and content.candidates + and #content.candidates > 0 + and content.candidates[1].finishReason +end + local function handle_stream_event(event_t, model_info, route_type) local metadata @@ -50,32 +67,24 @@ local function handle_stream_event(event_t, model_info, route_type) local new_event local metadata - if event.candidates and - #event.candidates > 0 then - - if event.candidates[1].content and - event.candidates[1].content.parts and - #event.candidates[1].content.parts > 0 and - event.candidates[1].content.parts[1].text then - - new_event = { - choices = { - [1] = { - delta = { - content = event.candidates[1].content.parts[1].text or "", - role = "assistant", - }, - index = 0, + if is_response_content(event) then + new_event = { + choices = { + [1] = { + delta = { + content = event.candidates[1].content.parts[1].text or "", + role = "assistant", }, + index = 0, }, - } - end + }, + } + end - if event.candidates[1].finishReason then - metadata = metadata or {} - metadata.finished_reason = event.candidates[1].finishReason - new_event = "[DONE]" - end + if is_response_finished(event) then + metadata = metadata or {} + metadata.finished_reason = event.candidates[1].finishReason + new_event = "[DONE]" end if event.usageMetadata then @@ -166,26 +175,8 @@ local function to_gemini_chat_openai(request_table, model_info, route_type) }) end end - - -- only works for gemini 1.5+ - -- if system_prompt then - -- if string_sub(model_info.name, 1, 10) == "gemini-1.0" then - -- return nil, nil, "system prompts only work with gemini models 1.5 or later" - -- end - - -- new_r.systemInstruction = { - -- parts = { - -- { - -- text = system_prompt:get(), - -- }, - -- }, - -- } - -- end - -- end - kong.log.debug(cjson.encode(new_r)) - new_r.generationConfig = to_gemini_generation_config(request_table) return new_r, "application/json", nil @@ -206,10 +197,7 @@ local function from_gemini_chat_openai(response, model_info, route_type) if response.candidates and #response.candidates > 0 - and response.candidates[1].content - and response.candidates[1].content.parts - and #response.candidates[1].content.parts > 0 - and response.candidates[1].content.parts[1].text then + and is_response_content(response) then messages.choices[1] = { index = 0, From 65b0d8f6034d599dcde7343057b7583d5877df31 Mon Sep 17 00:00:00 2001 From: Jack Tysoe Date: Thu, 4 Jul 2024 20:21:19 +0100 Subject: [PATCH 18/20] gemini tests --- spec/03-plugins/38-ai-proxy/01-unit_spec.lua | 35 +++++++++++++++++++ .../complete-json/expected-output.json | 8 +++++ .../complete-json/input.bin | 2 ++ 3 files changed, 45 insertions(+) create mode 100644 spec/fixtures/ai-proxy/unit/streaming-chunk-formats/complete-json/expected-output.json create mode 100644 spec/fixtures/ai-proxy/unit/streaming-chunk-formats/complete-json/input.bin diff --git a/spec/03-plugins/38-ai-proxy/01-unit_spec.lua b/spec/03-plugins/38-ai-proxy/01-unit_spec.lua index bd47caac286a..5173739af8f1 100644 --- a/spec/03-plugins/38-ai-proxy/01-unit_spec.lua +++ b/spec/03-plugins/38-ai-proxy/01-unit_spec.lua @@ -661,4 +661,39 @@ describe(PLUGIN_NAME .. ": (unit)", function() end) + local function dump(o) + if type(o) == 'table' then + local s = '{ ' + for k,v in pairs(o) do + if type(k) ~= 'number' then k = '"'..k..'"' end + s = s .. '['..k..'] = ' .. dump(v) .. ',' + end + return s .. '} ' + else + return tostring(o) + end + end + + describe("streaming transformer tests", function() + + it("transforms truncated-json type", function() + + end) + + it("transforms complete-json type", function() + local input = pl_file.read(fmt("spec/fixtures/ai-proxy/unit/streaming-chunk-formats/complete-json/input.bin")) + local events = ai_shared.frame_to_events(input, false) -- not "truncated json mode" like Gemini + + local expected = pl_file.read(fmt("spec/fixtures/ai-proxy/unit/streaming-chunk-formats/complete-json/expected-output.json")) + local expected_events = cjson.decode(expected) + + assert.same(events, expected_events) + end) + + it("transforms text/event-stream type", function() + + end) + + end) + end) diff --git a/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/complete-json/expected-output.json b/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/complete-json/expected-output.json new file mode 100644 index 000000000000..b08549afbf4e --- /dev/null +++ b/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/complete-json/expected-output.json @@ -0,0 +1,8 @@ +[ + { + "data": "{\"is_finished\":false,\"event_type\":\"stream-start\",\"generation_id\":\"10f31c2f-1a4c-48cf-b500-dc8141a25ae5\"}" + }, + { + "data": "{\"is_finished\":false,\"event_type\":\"text-generation\",\"text\":\"2\"}" + } +] \ No newline at end of file diff --git a/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/complete-json/input.bin b/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/complete-json/input.bin new file mode 100644 index 000000000000..af13220a423a --- /dev/null +++ b/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/complete-json/input.bin @@ -0,0 +1,2 @@ +{"is_finished":false,"event_type":"stream-start","generation_id":"10f31c2f-1a4c-48cf-b500-dc8141a25ae5"} +{"is_finished":false,"event_type":"text-generation","text":"2"} \ No newline at end of file From 6c2cc3e11995bf8903f3d3af3747b89703486fd9 Mon Sep 17 00:00:00 2001 From: Jack Tysoe Date: Fri, 5 Jul 2024 00:43:42 +0100 Subject: [PATCH 19/20] feat(ai-proxy): stream format tests --- kong/llm/drivers/shared.lua | 5 +- spec/03-plugins/38-ai-proxy/01-unit_spec.lua | 28 +++- .../expected-output.json | 14 ++ .../partial-json-beginning/input.bin | 141 ++++++++++++++++++ .../partial-json-end/expected-output.json | 8 + .../partial-json-end/input.bin | 80 ++++++++++ .../text-event-stream/expected-output.json | 11 ++ .../text-event-stream/input.bin | 7 + 8 files changed, 290 insertions(+), 4 deletions(-) create mode 100644 spec/fixtures/ai-proxy/unit/streaming-chunk-formats/partial-json-beginning/expected-output.json create mode 100644 spec/fixtures/ai-proxy/unit/streaming-chunk-formats/partial-json-beginning/input.bin create mode 100644 spec/fixtures/ai-proxy/unit/streaming-chunk-formats/partial-json-end/expected-output.json create mode 100644 spec/fixtures/ai-proxy/unit/streaming-chunk-formats/partial-json-end/input.bin create mode 100644 spec/fixtures/ai-proxy/unit/streaming-chunk-formats/text-event-stream/expected-output.json create mode 100644 spec/fixtures/ai-proxy/unit/streaming-chunk-formats/text-event-stream/input.bin diff --git a/kong/llm/drivers/shared.lua b/kong/llm/drivers/shared.lua index 25d9d5773149..0e1d0d18a962 100644 --- a/kong/llm/drivers/shared.lua +++ b/kong/llm/drivers/shared.lua @@ -267,7 +267,10 @@ function _M.frame_to_events(frame, raw_json_mode) -- test for truncated chunk on the last line (no trailing \r\n\r\n) if #dat > 0 and #event_lines == i then ngx.log(ngx.DEBUG, "[ai-proxy] truncated sse frame head") - kong.ctx.plugin.truncated_frame = dat + if kong then + kong.ctx.plugin.truncated_frame = dat + end + break -- stop parsing immediately, server has done something wrong end diff --git a/spec/03-plugins/38-ai-proxy/01-unit_spec.lua b/spec/03-plugins/38-ai-proxy/01-unit_spec.lua index 5173739af8f1..22fb1e668e34 100644 --- a/spec/03-plugins/38-ai-proxy/01-unit_spec.lua +++ b/spec/03-plugins/38-ai-proxy/01-unit_spec.lua @@ -676,8 +676,24 @@ describe(PLUGIN_NAME .. ": (unit)", function() describe("streaming transformer tests", function() - it("transforms truncated-json type", function() - + it("transforms truncated-json type (beginning of stream)", function() + local input = pl_file.read(fmt("spec/fixtures/ai-proxy/unit/streaming-chunk-formats/partial-json-beginning/input.bin")) + local events = ai_shared.frame_to_events(input, true) + + local expected = pl_file.read(fmt("spec/fixtures/ai-proxy/unit/streaming-chunk-formats/partial-json-beginning/expected-output.json")) + local expected_events = cjson.decode(expected) + + assert.same(events, expected_events, true) + end) + + it("transforms truncated-json type (end of stream)", function() + local input = pl_file.read(fmt("spec/fixtures/ai-proxy/unit/streaming-chunk-formats/partial-json-end/input.bin")) + local events = ai_shared.frame_to_events(input, true) + + local expected = pl_file.read(fmt("spec/fixtures/ai-proxy/unit/streaming-chunk-formats/partial-json-end/expected-output.json")) + local expected_events = cjson.decode(expected) + + assert.same(events, expected_events, true) end) it("transforms complete-json type", function() @@ -691,7 +707,13 @@ describe(PLUGIN_NAME .. ": (unit)", function() end) it("transforms text/event-stream type", function() - + local input = pl_file.read(fmt("spec/fixtures/ai-proxy/unit/streaming-chunk-formats/text-event-stream/input.bin")) + local events = ai_shared.frame_to_events(input, false) -- not "truncated json mode" like Gemini + + local expected = pl_file.read(fmt("spec/fixtures/ai-proxy/unit/streaming-chunk-formats/text-event-stream/expected-output.json")) + local expected_events = cjson.decode(expected) + + assert.same(events, expected_events) end) end) diff --git a/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/partial-json-beginning/expected-output.json b/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/partial-json-beginning/expected-output.json new file mode 100644 index 000000000000..5f3b0afa51d4 --- /dev/null +++ b/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/partial-json-beginning/expected-output.json @@ -0,0 +1,14 @@ +[ + { + "data": "{\n \"candidates\": [\n {\n \"content\": {\n \"parts\": [\n {\n \"text\": \"The\"\n }\n ],\n \"role\": \"model\"\n },\n \"finishReason\": \"STOP\",\n \"index\": 0\n }\n ],\n \"usageMetadata\": {\n \"promptTokenCount\": 6,\n \"candidatesTokenCount\": 1,\n \"totalTokenCount\": 7\n }\n}" + }, + { + "data": "\n{\n \"candidates\": [\n {\n \"content\": {\n \"parts\": [\n {\n \"text\": \" theory of relativity is actually two theories by Albert Einstein: **special relativity** and\"\n }\n ],\n \"role\": \"model\"\n },\n \"finishReason\": \"STOP\",\n \"index\": 0,\n \"safetyRatings\": [\n {\n \"category\": \"HARM_CATEGORY_SEXUALLY_EXPLICIT\",\n \"probability\": \"NEGLIGIBLE\"\n },\n {\n \"category\": \"HARM_CATEGORY_HATE_SPEECH\",\n \"probability\": \"NEGLIGIBLE\"\n },\n {\n \"category\": \"HARM_CATEGORY_HARASSMENT\",\n \"probability\": \"NEGLIGIBLE\"\n },\n {\n \"category\": \"HARM_CATEGORY_DANGEROUS_CONTENT\",\n \"probability\": \"NEGLIGIBLE\"\n }\n ]\n }\n ],\n \"usageMetadata\": {\n \"promptTokenCount\": 6,\n \"candidatesTokenCount\": 17,\n \"totalTokenCount\": 23\n }\n}" + }, + { + "data": "\n{\n \"candidates\": [\n {\n \"content\": {\n \"parts\": [\n {\n \"text\": \" **general relativity**. Here's a simplified breakdown:\\n\\n**Special Relativity (\"\n }\n ],\n \"role\": \"model\"\n },\n \"finishReason\": \"STOP\",\n \"index\": 0,\n \"safetyRatings\": [\n {\n \"category\": \"HARM_CATEGORY_SEXUALLY_EXPLICIT\",\n \"probability\": \"NEGLIGIBLE\"\n },\n {\n \"category\": \"HARM_CATEGORY_HATE_SPEECH\",\n \"probability\": \"NEGLIGIBLE\"\n },\n {\n \"category\": \"HARM_CATEGORY_HARASSMENT\",\n \"probability\": \"NEGLIGIBLE\"\n },\n {\n \"category\": \"HARM_CATEGORY_DANGEROUS_CONTENT\",\n \"probability\": \"NEGLIGIBLE\"\n }\n ]\n }\n ],\n \"usageMetadata\": {\n \"promptTokenCount\": 6,\n \"candidatesTokenCount\": 33,\n \"totalTokenCount\": 39\n }\n}" + }, + { + "data": "\n{\n \"candidates\": [\n {\n \"content\": {\n \"parts\": [\n {\n \"text\": \"1905):**\\n\\n* **Focus:** The relationship between space and time.\\n* **Key ideas:**\\n * **Speed of light\"\n }\n ],\n \"role\": \"model\"\n },\n \"finishReason\": \"STOP\",\n \"index\": 0,\n \"safetyRatings\": [\n {\n \"category\": \"HARM_CATEGORY_SEXUALLY_EXPLICIT\",\n \"probability\": \"NEGLIGIBLE\"\n },\n {\n \"category\": \"HARM_CATEGORY_HATE_SPEECH\",\n \"probability\": \"NEGLIGIBLE\"\n },\n {\n \"category\": \"HARM_CATEGORY_HARASSMENT\",\n \"probability\": \"NEGLIGIBLE\"\n },\n {\n \"category\": \"HARM_CATEGORY_DANGEROUS_CONTENT\",\n \"probability\": \"NEGLIGIBLE\"\n }\n ]\n }\n ],\n \"usageMetadata\": {\n \"promptTokenCount\": 6,\n \"candidatesTokenCount\": 65,\n \"totalTokenCount\": 71\n }\n}\n" + } +] \ No newline at end of file diff --git a/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/partial-json-beginning/input.bin b/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/partial-json-beginning/input.bin new file mode 100644 index 000000000000..8cef2a01fa8d --- /dev/null +++ b/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/partial-json-beginning/input.bin @@ -0,0 +1,141 @@ +[{ + "candidates": [ + { + "content": { + "parts": [ + { + "text": "The" + } + ], + "role": "model" + }, + "finishReason": "STOP", + "index": 0 + } + ], + "usageMetadata": { + "promptTokenCount": 6, + "candidatesTokenCount": 1, + "totalTokenCount": 7 + } +} +, +{ + "candidates": [ + { + "content": { + "parts": [ + { + "text": " theory of relativity is actually two theories by Albert Einstein: **special relativity** and" + } + ], + "role": "model" + }, + "finishReason": "STOP", + "index": 0, + "safetyRatings": [ + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HARASSMENT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "probability": "NEGLIGIBLE" + } + ] + } + ], + "usageMetadata": { + "promptTokenCount": 6, + "candidatesTokenCount": 17, + "totalTokenCount": 23 + } +} +, +{ + "candidates": [ + { + "content": { + "parts": [ + { + "text": " **general relativity**. Here's a simplified breakdown:\n\n**Special Relativity (" + } + ], + "role": "model" + }, + "finishReason": "STOP", + "index": 0, + "safetyRatings": [ + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HARASSMENT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "probability": "NEGLIGIBLE" + } + ] + } + ], + "usageMetadata": { + "promptTokenCount": 6, + "candidatesTokenCount": 33, + "totalTokenCount": 39 + } +} +, +{ + "candidates": [ + { + "content": { + "parts": [ + { + "text": "1905):**\n\n* **Focus:** The relationship between space and time.\n* **Key ideas:**\n * **Speed of light" + } + ], + "role": "model" + }, + "finishReason": "STOP", + "index": 0, + "safetyRatings": [ + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HARASSMENT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "probability": "NEGLIGIBLE" + } + ] + } + ], + "usageMetadata": { + "promptTokenCount": 6, + "candidatesTokenCount": 65, + "totalTokenCount": 71 + } +} diff --git a/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/partial-json-end/expected-output.json b/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/partial-json-end/expected-output.json new file mode 100644 index 000000000000..ba6a64384d95 --- /dev/null +++ b/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/partial-json-end/expected-output.json @@ -0,0 +1,8 @@ +[ + { + "data": "{\n \"candidates\": [\n {\n \"content\": {\n \"parts\": [\n {\n \"text\": \" is constant:** No matter how fast you are moving, light always travels at the same speed (approximately 299,792,458\"\n }\n ],\n \"role\": \"model\"\n },\n \"finishReason\": \"STOP\",\n \"index\": 0,\n \"safetyRatings\": [\n {\n \"category\": \"HARM_CATEGORY_SEXUALLY_EXPLICIT\",\n \"probability\": \"NEGLIGIBLE\"\n },\n {\n \"category\": \"HARM_CATEGORY_HATE_SPEECH\",\n \"probability\": \"NEGLIGIBLE\"\n },\n {\n \"category\": \"HARM_CATEGORY_HARASSMENT\",\n \"probability\": \"NEGLIGIBLE\"\n },\n {\n \"category\": \"HARM_CATEGORY_DANGEROUS_CONTENT\",\n \"probability\": \"NEGLIGIBLE\"\n }\n ]\n }\n ],\n \"usageMetadata\": {\n \"promptTokenCount\": 6,\n \"candidatesTokenCount\": 97,\n \"totalTokenCount\": 103\n }\n}" + }, + { + "data": "\n{\n \"candidates\": [\n {\n \"content\": {\n \"parts\": [\n {\n \"text\": \" not a limit.\\n\\nIf you're interested in learning more about relativity, I encourage you to explore further resources online or in books. There are many excellent introductory materials available. \\n\"\n }\n ],\n \"role\": \"model\"\n },\n \"finishReason\": \"STOP\",\n \"index\": 0,\n \"safetyRatings\": [\n {\n \"category\": \"HARM_CATEGORY_SEXUALLY_EXPLICIT\",\n \"probability\": \"NEGLIGIBLE\"\n },\n {\n \"category\": \"HARM_CATEGORY_HATE_SPEECH\",\n \"probability\": \"NEGLIGIBLE\"\n },\n {\n \"category\": \"HARM_CATEGORY_HARASSMENT\",\n \"probability\": \"NEGLIGIBLE\"\n },\n {\n \"category\": \"HARM_CATEGORY_DANGEROUS_CONTENT\",\n \"probability\": \"NEGLIGIBLE\"\n }\n ]\n }\n ],\n \"usageMetadata\": {\n \"promptTokenCount\": 6,\n \"candidatesTokenCount\": 547,\n \"totalTokenCount\": 553\n }\n}\n" + } +] \ No newline at end of file diff --git a/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/partial-json-end/input.bin b/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/partial-json-end/input.bin new file mode 100644 index 000000000000..d6489e74d19d --- /dev/null +++ b/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/partial-json-end/input.bin @@ -0,0 +1,80 @@ +,{ + "candidates": [ + { + "content": { + "parts": [ + { + "text": " is constant:** No matter how fast you are moving, light always travels at the same speed (approximately 299,792,458" + } + ], + "role": "model" + }, + "finishReason": "STOP", + "index": 0, + "safetyRatings": [ + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HARASSMENT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "probability": "NEGLIGIBLE" + } + ] + } + ], + "usageMetadata": { + "promptTokenCount": 6, + "candidatesTokenCount": 97, + "totalTokenCount": 103 + } +} +, +{ + "candidates": [ + { + "content": { + "parts": [ + { + "text": " not a limit.\n\nIf you're interested in learning more about relativity, I encourage you to explore further resources online or in books. There are many excellent introductory materials available. \n" + } + ], + "role": "model" + }, + "finishReason": "STOP", + "index": 0, + "safetyRatings": [ + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HARASSMENT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "probability": "NEGLIGIBLE" + } + ] + } + ], + "usageMetadata": { + "promptTokenCount": 6, + "candidatesTokenCount": 547, + "totalTokenCount": 553 + } +} +] \ No newline at end of file diff --git a/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/text-event-stream/expected-output.json b/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/text-event-stream/expected-output.json new file mode 100644 index 000000000000..f515516c7ec8 --- /dev/null +++ b/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/text-event-stream/expected-output.json @@ -0,0 +1,11 @@ +[ + { + "data": "{ \"choices\": [ { \"delta\": { \"content\": \"\", \"role\": \"assistant\" }, \"finish_reason\": null, \"index\": 0, \"logprobs\": null } ], \"created\": 1720136012, \"id\": \"chatcmpl-9hQFArK1oMZcRwaIa86RGwrjVNmeY\", \"model\": \"gpt-4-0613\", \"object\": \"chat.completion.chunk\", \"system_fingerprint\": null}" + }, + { + "data": "{ \"choices\": [ { \"delta\": { \"content\": \"2\" }, \"finish_reason\": null, \"index\": 0, \"logprobs\": null } ], \"created\": 1720136012, \"id\": \"chatcmpl-9hQFArK1oMZcRwaIa86RGwrjVNmeY\", \"model\": \"gpt-4-0613\", \"object\": \"chat.completion.chunk\", \"system_fingerprint\": null}" + }, + { + "data": "{ \"choices\": [ { \"delta\": {}, \"finish_reason\": \"stop\", \"index\": 0, \"logprobs\": null } ], \"created\": 1720136012, \"id\": \"chatcmpl-9hQFArK1oMZcRwaIa86RGwrjVNmeY\", \"model\": \"gpt-4-0613\", \"object\": \"chat.completion.chunk\", \"system_fingerprint\": null}" + } +] \ No newline at end of file diff --git a/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/text-event-stream/input.bin b/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/text-event-stream/input.bin new file mode 100644 index 000000000000..efe2ad50c657 --- /dev/null +++ b/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/text-event-stream/input.bin @@ -0,0 +1,7 @@ +data: { "choices": [ { "delta": { "content": "", "role": "assistant" }, "finish_reason": null, "index": 0, "logprobs": null } ], "created": 1720136012, "id": "chatcmpl-9hQFArK1oMZcRwaIa86RGwrjVNmeY", "model": "gpt-4-0613", "object": "chat.completion.chunk", "system_fingerprint": null} + +data: { "choices": [ { "delta": { "content": "2" }, "finish_reason": null, "index": 0, "logprobs": null } ], "created": 1720136012, "id": "chatcmpl-9hQFArK1oMZcRwaIa86RGwrjVNmeY", "model": "gpt-4-0613", "object": "chat.completion.chunk", "system_fingerprint": null} + +data: { "choices": [ { "delta": {}, "finish_reason": "stop", "index": 0, "logprobs": null } ], "created": 1720136012, "id": "chatcmpl-9hQFArK1oMZcRwaIa86RGwrjVNmeY", "model": "gpt-4-0613", "object": "chat.completion.chunk", "system_fingerprint": null} + +data: [DONE] \ No newline at end of file From ae686a42488415c593e950cf0bf48422e0619578 Mon Sep 17 00:00:00 2001 From: Jack Tysoe Date: Fri, 5 Jul 2024 01:09:32 +0100 Subject: [PATCH 20/20] fix(ai-proxy): lint --- kong-3.8.0-0.rockspec | 2 ++ kong/llm/drivers/gemini.lua | 16 ++-------------- kong/plugins/ai-proxy/handler.lua | 4 +--- spec/03-plugins/38-ai-proxy/01-unit_spec.lua | 14 -------------- 4 files changed, 5 insertions(+), 31 deletions(-) diff --git a/kong-3.8.0-0.rockspec b/kong-3.8.0-0.rockspec index bcf8b38cb059..ce680566797a 100644 --- a/kong-3.8.0-0.rockspec +++ b/kong-3.8.0-0.rockspec @@ -608,6 +608,8 @@ build = { ["kong.llm.drivers.mistral"] = "kong/llm/drivers/mistral.lua", ["kong.llm.drivers.llama2"] = "kong/llm/drivers/llama2.lua", + ["kong.llm.drivers.gemini"] = "kong/llm/drivers/gemini.lua", + ["kong.plugins.ai-prompt-decorator.handler"] = "kong/plugins/ai-prompt-decorator/handler.lua", ["kong.plugins.ai-prompt-decorator.schema"] = "kong/plugins/ai-prompt-decorator/schema.lua", diff --git a/kong/llm/drivers/gemini.lua b/kong/llm/drivers/gemini.lua index b62f7f7d98dc..59296ee9160b 100644 --- a/kong/llm/drivers/gemini.lua +++ b/kong/llm/drivers/gemini.lua @@ -9,7 +9,6 @@ local string_gsub = string.gsub local buffer = require("string.buffer") local table_insert = table.insert local string_lower = string.lower -local string_sub = string.sub -- -- globals @@ -49,10 +48,7 @@ local function is_response_finished(content) and content.candidates[1].finishReason end -local function handle_stream_event(event_t, model_info, route_type) - local metadata - - +local function handle_stream_event(event_t, model_info, route_type) -- discard empty frames, it should either be a random new line, or comment if (not event_t.data) or (#event_t.data < 1) then return @@ -65,7 +61,7 @@ local function handle_stream_event(event_t, model_info, route_type) end local new_event - local metadata + local metadata = nil if is_response_content(event) then new_event = { @@ -219,14 +215,6 @@ local function from_gemini_chat_openai(response, model_info, route_type) return cjson.encode(messages) end -local function to_gemini_chat_gemini(request_table, model_info, route_type) - return nil, nil, "gemini to gemini not yet implemented" -end - -local function from_gemini_chat_gemini(request_table, model_info, route_type) - return nil, nil, "gemini to gemini not yet implemented" -end - local transformers_to = { ["llm/v1/chat"] = to_gemini_chat_openai, } diff --git a/kong/plugins/ai-proxy/handler.lua b/kong/plugins/ai-proxy/handler.lua index fbe32f84fcb7..8e661e89317e 100644 --- a/kong/plugins/ai-proxy/handler.lua +++ b/kong/plugins/ai-proxy/handler.lua @@ -25,7 +25,6 @@ local _M = { -- static messages -local ERROR_MSG = { error = { message = "" } } local ERROR__NOT_SET = 'data: {"error": true, "message": "empty or unsupported transformer response"}' @@ -497,8 +496,7 @@ function _M:access(conf) if identity_interface and identity_interface.error then kong.ctx.shared.skip_response_transformer = true kong.log.err("error authenticating with cloud-provider, ", identity_interface.error) - - return internal_server_error("LLM request failed before proxying") + return kong.response.exit(500, "LLM request failed before proxying") end -- now re-configure the request for this operation type diff --git a/spec/03-plugins/38-ai-proxy/01-unit_spec.lua b/spec/03-plugins/38-ai-proxy/01-unit_spec.lua index 22fb1e668e34..aeb42600d639 100644 --- a/spec/03-plugins/38-ai-proxy/01-unit_spec.lua +++ b/spec/03-plugins/38-ai-proxy/01-unit_spec.lua @@ -660,20 +660,6 @@ describe(PLUGIN_NAME .. ": (unit)", function() }, formatted) end) - - local function dump(o) - if type(o) == 'table' then - local s = '{ ' - for k,v in pairs(o) do - if type(k) ~= 'number' then k = '"'..k..'"' end - s = s .. '['..k..'] = ' .. dump(v) .. ',' - end - return s .. '} ' - else - return tostring(o) - end - end - describe("streaming transformer tests", function() it("transforms truncated-json type (beginning of stream)", function()