diff --git a/kong/llm/drivers/anthropic.lua b/kong/llm/drivers/anthropic.lua index 77c9f363f9b6..7161804cd935 100644 --- a/kong/llm/drivers/anthropic.lua +++ b/kong/llm/drivers/anthropic.lua @@ -71,7 +71,7 @@ local function to_claude_prompt(req) return kong_messages_to_claude_prompt(req.messages) end - + return nil, "request is missing .prompt and .messages commands" end @@ -328,7 +328,7 @@ function _M.from_format(response_string, model_info, route_type) if not transform then return nil, fmt("no transformer available from format %s://%s", model_info.provider, route_type) end - + local ok, response_string, err, metadata = pcall(transform, response_string, model_info, route_type) if not ok or err then return nil, fmt("transformation failed from type %s://%s: %s", diff --git a/kong/llm/drivers/azure.lua b/kong/llm/drivers/azure.lua index a0ba1741a861..343904ffad24 100644 --- a/kong/llm/drivers/azure.lua +++ b/kong/llm/drivers/azure.lua @@ -139,7 +139,7 @@ function _M.configure_request(conf) -- technically min supported version query_table["api-version"] = kong.request.get_query_arg("api-version") or (conf.model.options and conf.model.options.azure_api_version) - + if auth_param_name and auth_param_value and auth_param_location == "query" then query_table[auth_param_name] = auth_param_value end diff --git a/kong/llm/drivers/bedrock.lua b/kong/llm/drivers/bedrock.lua index 372a57fa8276..a32ad5120e6e 100644 --- a/kong/llm/drivers/bedrock.lua +++ b/kong/llm/drivers/bedrock.lua @@ -266,7 +266,7 @@ function _M.from_format(response_string, model_info, route_type) if not transformers_from[route_type] then return nil, fmt("no transformer available from format %s://%s", model_info.provider, route_type) end - + local ok, response_string, err, metadata = pcall(transformers_from[route_type], response_string, model_info, route_type) if not ok or err then return nil, fmt("transformation failed from type %s://%s: %s", diff --git a/kong/llm/drivers/cohere.lua b/kong/llm/drivers/cohere.lua index 1aafc9405b0c..28c8c64eaac0 100644 --- a/kong/llm/drivers/cohere.lua +++ b/kong/llm/drivers/cohere.lua @@ -21,7 +21,7 @@ local _CHAT_ROLES = { local function handle_stream_event(event_t, model_info, route_type) local metadata - + -- discard empty frames, it should either be a random new line, or comment if (not event_t.data) or (#event_t.data < 1) then return @@ -31,12 +31,12 @@ local function handle_stream_event(event_t, model_info, route_type) if err then return nil, "failed to decode event frame from cohere: " .. err, nil end - + local new_event - + if event.event_type == "stream-start" then kong.ctx.plugin.ai_proxy_cohere_stream_id = event.generation_id - + -- ignore the rest of this one new_event = { choices = { @@ -52,7 +52,7 @@ local function handle_stream_event(event_t, model_info, route_type) model = model_info.name, object = "chat.completion.chunk", } - + elseif event.event_type == "text-generation" then -- this is a token if route_type == "stream/llm/v1/chat" then @@ -137,19 +137,19 @@ end local function handle_json_inference_event(request_table, model) request_table.temperature = request_table.temperature request_table.max_tokens = request_table.max_tokens - + request_table.p = request_table.top_p request_table.k = request_table.top_k - + request_table.top_p = nil request_table.top_k = nil - + request_table.model = model.name or request_table.model request_table.stream = request_table.stream or false -- explicitly set this - + if request_table.prompt and request_table.messages then return kong.response.exit(400, "cannot run a 'prompt' and a history of 'messages' at the same time - refer to schema") - + elseif request_table.messages then -- we have to move all BUT THE LAST message into "chat_history" array -- and move the LAST message (from 'user') into "message" string @@ -164,26 +164,26 @@ local function handle_json_inference_event(request_table, model) else role = _CHAT_ROLES.user end - + chat_history[i] = { role = role, message = v.content, } end end - + request_table.chat_history = chat_history end - + request_table.message = request_table.messages[#request_table.messages].content request_table.messages = nil - + elseif request_table.prompt then request_table.prompt = request_table.prompt request_table.messages = nil request_table.message = nil end - + return request_table, "application/json", nil end @@ -202,7 +202,7 @@ local transformers_from = { -- messages/choices table is only 1 size, so don't need to static allocate local messages = {} messages.choices = {} - + if response_table.prompt and response_table.generations then -- this is a "co.generate" for i, v in ipairs(response_table.generations) do @@ -215,7 +215,7 @@ local transformers_from = { messages.object = "text_completion" messages.model = model_info.name messages.id = response_table.id - + local stats = { completion_tokens = response_table.meta and response_table.meta.billed_units @@ -230,10 +230,10 @@ local transformers_from = { and (response_table.meta.billed_units.output_tokens + response_table.meta.billed_units.input_tokens), } messages.usage = stats - + elseif response_table.text then -- this is a "co.chat" - + messages.choices[1] = { index = 0, message = { @@ -245,7 +245,7 @@ local transformers_from = { messages.object = "chat.completion" messages.model = model_info.name messages.id = response_table.generation_id - + local stats = { completion_tokens = response_table.meta and response_table.meta.billed_units @@ -260,10 +260,10 @@ local transformers_from = { and (response_table.meta.billed_units.output_tokens + response_table.meta.billed_units.input_tokens), } messages.usage = stats - + else -- probably a fault return nil, "'text' or 'generations' missing from cohere response body" - + end return cjson.encode(messages) @@ -314,17 +314,17 @@ local transformers_from = { prompt.object = "chat.completion" prompt.model = model_info.name prompt.id = response_table.generation_id - + local stats = { completion_tokens = response_table.token_count and response_table.token_count.response_tokens, prompt_tokens = response_table.token_count and response_table.token_count.prompt_tokens, total_tokens = response_table.token_count and response_table.token_count.total_tokens, } prompt.usage = stats - + else -- probably a fault return nil, "'text' or 'generations' missing from cohere response body" - + end return cjson.encode(prompt) @@ -465,7 +465,7 @@ function _M.configure_request(conf) and ai_shared.operation_map[DRIVER_NAME][conf.route_type].path or "/" end - + -- if the path is read from a URL capture, ensure that it is valid parsed_url.path = string_gsub(parsed_url.path, "^/*", "/") diff --git a/kong/llm/drivers/gemini.lua b/kong/llm/drivers/gemini.lua index 57ca7127ef29..0a68a0af8e10 100644 --- a/kong/llm/drivers/gemini.lua +++ b/kong/llm/drivers/gemini.lua @@ -123,7 +123,7 @@ local function to_gemini_chat_openai(request_table, model_info, route_type) } end end - + new_r.generationConfig = to_gemini_generation_config(request_table) return new_r, "application/json", nil @@ -222,7 +222,7 @@ function _M.from_format(response_string, model_info, route_type) if not transformers_from[route_type] then return nil, fmt("no transformer available from format %s://%s", model_info.provider, route_type) end - + local ok, response_string, err, metadata = pcall(transformers_from[route_type], response_string, model_info, route_type) if not ok or err then return nil, fmt("transformation failed from type %s://%s: %s", diff --git a/kong/llm/drivers/llama2.lua b/kong/llm/drivers/llama2.lua index 02b65818bd3a..0526453f8a52 100644 --- a/kong/llm/drivers/llama2.lua +++ b/kong/llm/drivers/llama2.lua @@ -113,7 +113,7 @@ local function to_raw(request_table, model) messages.parameters.top_k = request_table.top_k messages.parameters.temperature = request_table.temperature messages.parameters.stream = request_table.stream or false -- explicitly set this - + if request_table.prompt and request_table.messages then return kong.response.exit(400, "cannot run raw 'prompt' and chat history 'messages' requests at the same time - refer to schema") diff --git a/kong/llm/drivers/openai.lua b/kong/llm/drivers/openai.lua index ca1e37d91831..331ad3b5e7ea 100644 --- a/kong/llm/drivers/openai.lua +++ b/kong/llm/drivers/openai.lua @@ -72,7 +72,7 @@ function _M.from_format(response_string, model_info, route_type) if not transformers_from[route_type] then return nil, fmt("no transformer available from format %s://%s", model_info.provider, route_type) end - + local ok, response_string, err = pcall(transformers_from[route_type], response_string, model_info) if not ok or err then return nil, fmt("transformation failed from type %s://%s: %s", @@ -203,7 +203,7 @@ function _M.configure_request(conf) parsed_url = socket_url.parse(ai_shared.upstream_url_format[DRIVER_NAME]) parsed_url.path = path end - + -- if the path is read from a URL capture, ensure that it is valid parsed_url.path = string_gsub(parsed_url.path, "^/*", "/") diff --git a/kong/llm/drivers/shared.lua b/kong/llm/drivers/shared.lua index 8004b9384df1..15d9ce7e62f2 100644 --- a/kong/llm/drivers/shared.lua +++ b/kong/llm/drivers/shared.lua @@ -259,7 +259,7 @@ function _M.frame_to_events(frame, provider) -- it may start with ',' which is the start of the new frame frame = (string.sub(str_ltrim(frame), 1, 1) == "," and string.sub(str_ltrim(frame), 2)) or frame - + -- it may end with the array terminator ']' indicating the finished stream if string.sub(str_rtrim(frame), -1) == "]" then frame = string.sub(str_rtrim(frame), 1, -2) @@ -446,7 +446,7 @@ function _M.from_ollama(response_string, model_info, route_type) end end - + if output and output ~= _M._CONST.SSE_TERMINATOR then output, err = cjson.encode(output) end diff --git a/spec/02-integration/22-ai_plugins/01-reports_spec.lua b/spec/02-integration/22-ai_plugins/01-reports_spec.lua index 78c98c03153f..9c4858e7127a 100644 --- a/spec/02-integration/22-ai_plugins/01-reports_spec.lua +++ b/spec/02-integration/22-ai_plugins/01-reports_spec.lua @@ -38,32 +38,32 @@ for _, strategy in helpers.each_strategy() do local fixtures = { http_mock = {}, } - + fixtures.http_mock.openai = [[ server { server_name openai; listen ]]..MOCK_PORT..[[; - + default_type 'application/json'; - - + + location = "/llm/v1/chat/good" { content_by_lua_block { local pl_file = require "pl.file" local json = require("cjson.safe") - + ngx.req.read_body() local body, err = ngx.req.get_body_data() body, err = json.decode(body) - + local token = ngx.req.get_headers()["authorization"] local token_query = ngx.req.get_uri_args()["apikey"] - + if token == "Bearer openai-key" or token_query == "openai-key" or body.apikey == "openai-key" then ngx.req.read_body() local body, err = ngx.req.get_body_data() body, err = json.decode(body) - + if err or (body.messages == ngx.null) then ngx.status = 400 ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/responses/bad_request.json")) diff --git a/spec/03-plugins/38-ai-proxy/00-config_spec.lua b/spec/03-plugins/38-ai-proxy/00-config_spec.lua index bbd495918bbb..0a15f131b46b 100644 --- a/spec/03-plugins/38-ai-proxy/00-config_spec.lua +++ b/spec/03-plugins/38-ai-proxy/00-config_spec.lua @@ -84,7 +84,7 @@ describe(PLUGIN_NAME .. ": (schema)", function() end local ok, err = validate(config) - + assert.is_truthy(ok) assert.is_falsy(err) end) @@ -220,7 +220,7 @@ describe(PLUGIN_NAME .. ": (schema)", function() } local ok, err = validate(config) - + assert.equal(err["config"]["@entity"][1], "must set one of 'auth.header_name', 'auth.param_name', " .. "and its respective options, when provider is not self-hosted") assert.is_falsy(ok) @@ -244,7 +244,7 @@ describe(PLUGIN_NAME .. ": (schema)", function() } local ok, err = validate(config) - + assert.equals(err["config"]["@entity"][1], "all or none of these fields must be set: 'auth.header_name', 'auth.header_value'") assert.is_falsy(ok) end) @@ -268,7 +268,7 @@ describe(PLUGIN_NAME .. ": (schema)", function() } local ok, err = validate(config) - + assert.is_falsy(err) assert.is_truthy(ok) end) @@ -317,7 +317,7 @@ describe(PLUGIN_NAME .. ": (schema)", function() } local ok, err = validate(config) - + assert.is_falsy(err) assert.is_truthy(ok) end) @@ -344,7 +344,7 @@ describe(PLUGIN_NAME .. ": (schema)", function() } local ok, err = validate(config) - + assert.is_falsy(err) assert.is_truthy(ok) end) diff --git a/spec/03-plugins/38-ai-proxy/01-unit_spec.lua b/spec/03-plugins/38-ai-proxy/01-unit_spec.lua index 009f079195d0..a73f12a409b8 100644 --- a/spec/03-plugins/38-ai-proxy/01-unit_spec.lua +++ b/spec/03-plugins/38-ai-proxy/01-unit_spec.lua @@ -532,7 +532,7 @@ describe(PLUGIN_NAME .. ": (unit)", function() local expected_request_json = pl_file.read(filename) local expected_request_table, err = cjson.decode(expected_request_json) assert.is_nil(err) - + -- compare the tables assert.same(expected_request_table, actual_request_table) end) @@ -547,7 +547,7 @@ describe(PLUGIN_NAME .. ": (unit)", function() local filename if l.config.provider == "llama2" then filename = fmt("spec/fixtures/ai-proxy/unit/real-responses/%s/%s/%s.json", l.config.provider, l.config.options.llama2_format, pl_replace(k, "/", "-")) - + elseif l.config.provider == "mistral" then filename = fmt("spec/fixtures/ai-proxy/unit/real-responses/%s/%s/%s.json", l.config.provider, l.config.options.mistral_format, pl_replace(k, "/", "-")) @@ -604,7 +604,7 @@ describe(PLUGIN_NAME .. ": (unit)", function() it("converts to provider request format correctly", function() -- load the real provider frame from file local real_stream_frame = pl_file.read(fmt("spec/fixtures/ai-proxy/unit/real-stream-frames/%s/%s.txt", config.provider, pl_replace(format_name, "/", "-"))) - + -- use the shared function to produce an SSE format object local real_transformed_frame, err = ai_shared.frame_to_events(real_stream_frame) assert.is_nil(err) @@ -628,7 +628,7 @@ describe(PLUGIN_NAME .. ": (unit)", function() -- generic tests it("throws correct error when format is not supported", function() local driver = require("kong.llm.drivers.mistral") -- one-shot, random example of provider with only prompt support - + local model_config = { route_type = "llm/v1/chatnopenotsupported", name = "mistral-tiny", @@ -651,7 +651,7 @@ describe(PLUGIN_NAME .. ": (unit)", function() assert.equal(err, "no transformer available to format mistral://llm/v1/chatnopenotsupported/ollama") end) - + it("produces a correct default config merge", function() local formatted, err = ai_shared.merge_config_defaults( SAMPLE_LLM_V1_CHAT_WITH_SOME_OPTS, @@ -675,7 +675,7 @@ describe(PLUGIN_NAME .. ": (unit)", function() end) describe("streaming transformer tests", function() - + it("transforms truncated-json type (beginning of stream)", function() local input = pl_file.read(fmt("spec/fixtures/ai-proxy/unit/streaming-chunk-formats/partial-json-beginning/input.bin")) local events = ai_shared.frame_to_events(input, "gemini") @@ -695,7 +695,7 @@ describe(PLUGIN_NAME .. ": (unit)", function() assert.same(events, expected_events, true) end) - + it("transforms complete-json type", function() local input = pl_file.read(fmt("spec/fixtures/ai-proxy/unit/streaming-chunk-formats/complete-json/input.bin")) local events = ai_shared.frame_to_events(input, "cohere") -- not "truncated json mode" like Gemini diff --git a/spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua b/spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua index 0b5468e2e88d..e963d908e324 100644 --- a/spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua +++ b/spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua @@ -68,14 +68,14 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then local fixtures = { http_mock = {}, } - + fixtures.http_mock.openai = [[ server { server_name openai; listen ]]..MOCK_PORT..[[; - + default_type 'application/json'; - + location = "/llm/v1/chat/good" { content_by_lua_block { @@ -93,7 +93,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then ngx.req.read_body() local body, err = ngx.req.get_body_data() body, err = json.decode(body) - + if err or (body.messages == ngx.null) then ngx.status = 400 ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/responses/bad_request.json")) @@ -118,7 +118,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then ngx.req.read_body() local body, err = ngx.req.get_body_data() body, err = json.decode(body) - + if err or (body.messages == ngx.null) then ngx.status = 400 ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/responses/bad_request.json")) @@ -136,7 +136,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then location = "/llm/v1/chat/bad_request" { content_by_lua_block { local pl_file = require "pl.file" - + ngx.status = 400 ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/responses/bad_request.json")) } @@ -145,7 +145,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then location = "/llm/v1/chat/internal_server_error" { content_by_lua_block { local pl_file = require "pl.file" - + ngx.status = 500 ngx.header["content-type"] = "text/html" ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/responses/internal_server_error.html")) @@ -166,7 +166,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then local token_query = ngx.req.get_uri_args()["apikey"] if token == "Bearer openai-key" or token_query == "openai-key" or body.apikey == "openai-key" then - + if err or (body.messages == ngx.null) then ngx.status = 400 ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-completions/responses/bad_request.json")) @@ -184,7 +184,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then location = "/llm/v1/completions/bad_request" { content_by_lua_block { local pl_file = require "pl.file" - + ngx.status = 400 ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-completions/responses/bad_request.json")) } @@ -664,7 +664,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then declarative_config = strategy == "off" and helpers.make_yaml_file() or nil, }, nil, nil, fixtures)) end) - + lazy_teardown(function() helpers.stop_kong() end) @@ -692,7 +692,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good.json"), }) - + -- validate that the request succeeded, response status 200 local body = assert.res_status(200 , r) local json = cjson.decode(body) @@ -739,7 +739,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good.json"), }) - + -- validate that the request succeeded, response status 200 local body = assert.res_status(200 , r) local json = cjson.decode(body) @@ -817,7 +817,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good.json"), }) - + local body = assert.res_status(500 , r) assert.is_not_nil(body) end) @@ -830,7 +830,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good.json"), }) - + local body = assert.res_status(401 , r) local json = cjson.decode(body) @@ -849,7 +849,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good.json"), }) - + -- validate that the request succeeded, response status 200 local body = assert.res_status(200 , r) local json = cjson.decode(body) @@ -937,7 +937,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/bad_request.json"), }) - + local body = assert.res_status(400 , r) local json = cjson.decode(body) @@ -956,7 +956,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-completions/requests/good.json"), }) - + -- validate that the request succeeded, response status 200 local body = assert.res_status(200 , r) local json = cjson.decode(body) @@ -979,7 +979,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-completions/requests/bad_request.json"), }) - + local body = assert.res_status(400 , r) local json = cjson.decode(body) @@ -1038,7 +1038,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then describe("one-shot request", function() it("success", function() local ai_driver = require("kong.llm.drivers.openai") - + local plugin_conf = { route_type = "llm/v1/chat", auth = { @@ -1054,7 +1054,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, }, } - + local request = { messages = { [1] = { @@ -1067,15 +1067,15 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then } } } - + -- convert it to the specified driver format local ai_request = ai_driver.to_format(request, plugin_conf.model, "llm/v1/chat") - + -- send it to the ai service local ai_response, status_code, err = ai_driver.subrequest(ai_request, plugin_conf, {}, false) assert.is_nil(err) assert.equal(200, status_code) - + -- parse and convert the response local ai_response, _, err = ai_driver.from_format(ai_response, plugin_conf.model, plugin_conf.route_type) assert.is_nil(err) @@ -1092,7 +1092,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then it("404", function() local ai_driver = require("kong.llm.drivers.openai") - + local plugin_conf = { route_type = "llm/v1/chat", auth = { @@ -1108,7 +1108,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, }, } - + local request = { messages = { [1] = { @@ -1121,7 +1121,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then } } } - + -- convert it to the specified driver format local ai_request = ai_driver.to_format(request, plugin_conf.model, "llm/v1/chat") diff --git a/spec/03-plugins/38-ai-proxy/03-anthropic_integration_spec.lua b/spec/03-plugins/38-ai-proxy/03-anthropic_integration_spec.lua index c3cdc525c61e..78f990fe6161 100644 --- a/spec/03-plugins/38-ai-proxy/03-anthropic_integration_spec.lua +++ b/spec/03-plugins/38-ai-proxy/03-anthropic_integration_spec.lua @@ -17,14 +17,14 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then local fixtures = { http_mock = {}, } - + fixtures.http_mock.anthropic = [[ server { server_name anthropic; listen ]]..MOCK_PORT..[[; - + default_type 'application/json'; - + location = "/llm/v1/chat/good" { content_by_lua_block { @@ -36,7 +36,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then ngx.req.read_body() local body, err = ngx.req.get_body_data() body, err = json.decode(body) - + if err or (not body.messages) then ngx.status = 400 ngx.print(pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-chat/responses/bad_request.json")) @@ -61,7 +61,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then ngx.req.read_body() local body, err = ngx.req.get_body_data() body, err = json.decode(body) - + if err or (not body.messages) then ngx.status = 400 ngx.print(pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-chat/responses/bad_request.json")) @@ -129,7 +129,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then location = "/llm/v1/chat/bad_request" { content_by_lua_block { local pl_file = require "pl.file" - + ngx.status = 400 ngx.print(pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-chat/responses/bad_request.json")) } @@ -138,7 +138,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then location = "/llm/v1/chat/internal_server_error" { content_by_lua_block { local pl_file = require "pl.file" - + ngx.status = 500 ngx.header["content-type"] = "text/html" ngx.print(pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-chat/responses/internal_server_error.html")) @@ -156,7 +156,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then ngx.req.read_body() local body, err = ngx.req.get_body_data() body, err = json.decode(body) - + if err or (not body.prompt) then ngx.status = 400 ngx.print(pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-completions/responses/bad_request.json")) @@ -174,7 +174,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then location = "/llm/v1/completions/bad_request" { content_by_lua_block { local pl_file = require "pl.file" - + ngx.status = 400 ngx.print(pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-completions/responses/bad_request.json")) } @@ -501,7 +501,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-chat/requests/good.json"), }) - + local body = assert.res_status(500 , r) assert.is_not_nil(body) end) @@ -514,7 +514,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-chat/requests/good.json"), }) - + local body = assert.res_status(401 , r) local json = cjson.decode(body) @@ -558,7 +558,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-chat/requests/good.json"), }) - + -- check we got internal server error local body = assert.res_status(500 , r) local json = cjson.decode(body) @@ -573,7 +573,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-chat/requests/bad_request.json"), }) - + local body = assert.res_status(400 , r) local json = cjson.decode(body) @@ -619,7 +619,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-completions/requests/good.json"), }) - + local body = assert.res_status(200 , r) local json = cjson.decode(body) @@ -640,7 +640,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-completions/requests/bad_request.json"), }) - + local body = assert.res_status(400 , r) local json = cjson.decode(body) diff --git a/spec/03-plugins/38-ai-proxy/04-cohere_integration_spec.lua b/spec/03-plugins/38-ai-proxy/04-cohere_integration_spec.lua index 721cf97566e7..548db5e59be1 100644 --- a/spec/03-plugins/38-ai-proxy/04-cohere_integration_spec.lua +++ b/spec/03-plugins/38-ai-proxy/04-cohere_integration_spec.lua @@ -16,14 +16,14 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then local fixtures = { http_mock = {}, } - + fixtures.http_mock.cohere = [[ server { server_name cohere; listen ]]..MOCK_PORT..[[; - + default_type 'application/json'; - + location = "/llm/v1/chat/good" { content_by_lua_block { @@ -78,7 +78,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then location = "/llm/v1/chat/bad_request" { content_by_lua_block { local pl_file = require "pl.file" - + ngx.status = 400 ngx.print(pl_file.read("spec/fixtures/ai-proxy/cohere/llm-v1-chat/responses/bad_request.json")) } @@ -87,7 +87,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then location = "/llm/v1/chat/internal_server_error" { content_by_lua_block { local pl_file = require "pl.file" - + ngx.status = 500 ngx.header["content-type"] = "text/html" ngx.print(pl_file.read("spec/fixtures/ai-proxy/cohere/llm-v1-chat/responses/internal_server_error.html")) @@ -105,7 +105,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then ngx.req.read_body() local body, err = ngx.req.get_body_data() body, err = json.decode(body) - + if err or (not body.prompt) then ngx.status = 400 ngx.print(pl_file.read("spec/fixtures/ai-proxy/cohere/llm-v1-completions/responses/bad_request.json")) @@ -123,7 +123,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then location = "/llm/v1/completions/bad_request" { content_by_lua_block { local pl_file = require "pl.file" - + ngx.status = 400 ngx.print(pl_file.read("spec/fixtures/ai-proxy/cohere/llm-v1-completions/responses/bad_request.json")) } @@ -356,7 +356,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then declarative_config = strategy == "off" and helpers.make_yaml_file() or nil, }, nil, nil, fixtures)) end) - + lazy_teardown(function() helpers.stop_kong() end) @@ -378,7 +378,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = pl_file.read("spec/fixtures/ai-proxy/cohere/llm-v1-chat/requests/good.json"), }) - + local body = assert.res_status(500 , r) assert.is_not_nil(body) end) @@ -391,7 +391,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = pl_file.read("spec/fixtures/ai-proxy/cohere/llm-v1-chat/requests/good.json"), }) - + local body = assert.res_status(401 , r) local json = cjson.decode(body) @@ -409,7 +409,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = pl_file.read("spec/fixtures/ai-proxy/cohere/llm-v1-chat/requests/good.json"), }) - + local body = assert.res_status(200 , r) local json = cjson.decode(body) @@ -433,7 +433,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = pl_file.read("spec/fixtures/ai-proxy/cohere/llm-v1-chat/requests/good.json"), }) - + -- check we got internal server error local body = assert.res_status(500 , r) local json = cjson.decode(body) @@ -448,7 +448,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = pl_file.read("spec/fixtures/ai-proxy/cohere/llm-v1-chat/requests/bad_request.json"), }) - + local body = assert.res_status(400 , r) local json = cjson.decode(body) @@ -466,7 +466,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = pl_file.read("spec/fixtures/ai-proxy/cohere/llm-v1-completions/requests/good.json"), }) - + local body = assert.res_status(200 , r) local json = cjson.decode(body) @@ -487,7 +487,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = pl_file.read("spec/fixtures/ai-proxy/cohere/llm-v1-completions/requests/bad_request.json"), }) - + local body = assert.res_status(400 , r) local json = cjson.decode(body) diff --git a/spec/03-plugins/38-ai-proxy/05-azure_integration_spec.lua b/spec/03-plugins/38-ai-proxy/05-azure_integration_spec.lua index a8efe9b21a1e..d76d0c4ac50d 100644 --- a/spec/03-plugins/38-ai-proxy/05-azure_integration_spec.lua +++ b/spec/03-plugins/38-ai-proxy/05-azure_integration_spec.lua @@ -16,14 +16,14 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then local fixtures = { http_mock = {}, } - + fixtures.http_mock.azure = [[ server { server_name azure; listen ]]..MOCK_PORT..[[; - + default_type 'application/json'; - + location = "/llm/v1/chat/good" { content_by_lua_block { @@ -35,7 +35,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then ngx.req.read_body() local body, err = ngx.req.get_body_data() body, err = json.decode(body) - + if err or (body.messages == ngx.null) then ngx.status = 400 ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/responses/bad_request.json")) @@ -60,7 +60,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then ngx.req.read_body() local body, err = ngx.req.get_body_data() body, err = json.decode(body) - + if err or (body.messages == ngx.null) then ngx.status = 400 ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/responses/bad_request.json")) @@ -78,7 +78,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then location = "/llm/v1/chat/bad_request" { content_by_lua_block { local pl_file = require "pl.file" - + ngx.status = 400 ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/responses/bad_request.json")) } @@ -87,7 +87,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then location = "/llm/v1/chat/internal_server_error" { content_by_lua_block { local pl_file = require "pl.file" - + ngx.status = 500 ngx.header["content-type"] = "text/html" ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/responses/internal_server_error.html")) @@ -105,7 +105,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then ngx.req.read_body() local body, err = ngx.req.get_body_data() body, err = json.decode(body) - + if err or (body.messages == ngx.null) then ngx.status = 400 ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-completions/responses/bad_request.json")) @@ -123,7 +123,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then location = "/llm/v1/completions/bad_request" { content_by_lua_block { local pl_file = require "pl.file" - + ngx.status = 400 ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-completions/responses/bad_request.json")) } @@ -370,7 +370,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then declarative_config = strategy == "off" and helpers.make_yaml_file() or nil, }, nil, nil, fixtures)) end) - + lazy_teardown(function() helpers.stop_kong() end) @@ -392,7 +392,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good.json"), }) - + local body = assert.res_status(500 , r) assert.is_not_nil(body) end) @@ -405,7 +405,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good.json"), }) - + local body = assert.res_status(401 , r) local json = cjson.decode(body) @@ -424,7 +424,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good.json"), }) - + -- validate that the request succeeded, response status 200 local body = assert.res_status(200 , r) local json = cjson.decode(body) @@ -450,7 +450,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good.json"), }) - + -- check we got internal server error local body = assert.res_status(500 , r) local json = cjson.decode(body) @@ -465,7 +465,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/bad_request.json"), }) - + local body = assert.res_status(400 , r) local json = cjson.decode(body) @@ -484,7 +484,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-completions/requests/good.json"), }) - + -- validate that the request succeeded, response status 200 local body = assert.res_status(200 , r) local json = cjson.decode(body) @@ -507,7 +507,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-completions/requests/bad_request.json"), }) - + local body = assert.res_status(400 , r) local json = cjson.decode(body) diff --git a/spec/03-plugins/38-ai-proxy/06-mistral_integration_spec.lua b/spec/03-plugins/38-ai-proxy/06-mistral_integration_spec.lua index 3c711cd83b44..94058750ff1d 100644 --- a/spec/03-plugins/38-ai-proxy/06-mistral_integration_spec.lua +++ b/spec/03-plugins/38-ai-proxy/06-mistral_integration_spec.lua @@ -16,12 +16,12 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then local fixtures = { http_mock = {}, } - + fixtures.http_mock.mistral = [[ server { server_name mistral; listen ]]..MOCK_PORT..[[; - + default_type 'application/json'; location = "/v1/chat/completions" { @@ -34,7 +34,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then ngx.req.read_body() local body, err = ngx.req.get_body_data() body, err = json.decode(body) - + if err or (body.messages == ngx.null) then ngx.status = 400 ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/responses/bad_request.json")) @@ -59,7 +59,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then ngx.req.read_body() local body, err = ngx.req.get_body_data() body, err = json.decode(body) - + if err or (body.prompt == ngx.null) then ngx.status = 400 ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-completions/responses/bad_request.json")) @@ -307,7 +307,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then declarative_config = strategy == "off" and helpers.make_yaml_file() or nil, }, nil, nil, fixtures)) end) - + lazy_teardown(function() helpers.stop_kong() end) @@ -329,7 +329,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good.json"), }) - + -- validate that the request succeeded, response status 200 local body = assert.res_status(200 , r) local json = cjson.decode(body) @@ -357,7 +357,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-completions/requests/good.json"), }) - + -- validate that the request succeeded, response status 200 local body = assert.res_status(200 , r) local json = cjson.decode(body) diff --git a/spec/03-plugins/38-ai-proxy/07-llama2_integration_spec.lua b/spec/03-plugins/38-ai-proxy/07-llama2_integration_spec.lua index aa74ef9fd5ba..778804d4af6c 100644 --- a/spec/03-plugins/38-ai-proxy/07-llama2_integration_spec.lua +++ b/spec/03-plugins/38-ai-proxy/07-llama2_integration_spec.lua @@ -16,12 +16,12 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then local fixtures = { http_mock = {}, } - + fixtures.http_mock.llama2 = [[ server { server_name llama2; listen ]]..MOCK_PORT..[[; - + default_type 'application/json'; location = "/raw/llm/v1/chat" { @@ -155,7 +155,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then declarative_config = strategy == "off" and helpers.make_yaml_file() or nil, }, nil, nil, fixtures)) end) - + lazy_teardown(function() helpers.stop_kong() end) @@ -177,7 +177,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = pl_file.read("spec/fixtures/ai-proxy/llama2/raw/requests/good-chat.json"), }) - + local body = assert.res_status(200, r) local json = cjson.decode(body) @@ -192,7 +192,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = pl_file.read("spec/fixtures/ai-proxy/llama2/raw/requests/good-completions.json"), }) - + local body = assert.res_status(200, r) local json = cjson.decode(body) @@ -203,7 +203,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then describe("one-shot request", function() it("success", function() local ai_driver = require("kong.llm.drivers.llama2") - + local plugin_conf = { route_type = "llm/v1/chat", auth = { @@ -220,7 +220,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, }, } - + local request = { messages = { [1] = { @@ -260,7 +260,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then it("404", function() local ai_driver = require("kong.llm.drivers.llama2") - + local plugin_conf = { route_type = "llm/v1/chat", auth = { @@ -303,7 +303,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then it("401", function() local ai_driver = require("kong.llm.drivers.llama2") - + local plugin_conf = { route_type = "llm/v1/chat", auth = { diff --git a/spec/03-plugins/38-ai-proxy/08-encoding_integration_spec.lua b/spec/03-plugins/38-ai-proxy/08-encoding_integration_spec.lua index 0cc63ba41ba4..0e9801a2923e 100644 --- a/spec/03-plugins/38-ai-proxy/08-encoding_integration_spec.lua +++ b/spec/03-plugins/38-ai-proxy/08-encoding_integration_spec.lua @@ -152,7 +152,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then local token = ngx.req.get_headers()["authorization"] local token_query = ngx.req.get_uri_args()["apikey"] - + if token == "Bearer openai-key" or token_query == "openai-key" or body.apikey == "openai-key" then ngx.req.read_body() local body, err = ngx.req.get_body_data() @@ -235,7 +235,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then declarative_config = strategy == "off" and helpers.make_yaml_file() or nil, }, nil, nil, fixtures)) end) - + lazy_teardown(function() helpers.stop_kong() end) diff --git a/spec/03-plugins/38-ai-proxy/09-streaming_integration_spec.lua b/spec/03-plugins/38-ai-proxy/09-streaming_integration_spec.lua index 0d78e57b7786..24707b8039eb 100644 --- a/spec/03-plugins/38-ai-proxy/09-streaming_integration_spec.lua +++ b/spec/03-plugins/38-ai-proxy/09-streaming_integration_spec.lua @@ -494,7 +494,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then declarative_config = strategy == "off" and helpers.make_yaml_file() or nil, }, nil, nil, fixtures)) end) - + lazy_teardown(function() helpers.stop_kong() end) @@ -691,7 +691,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then end end until not buffer - + assert.equal(#events, 17) assert.equal(buf:tostring(), "1 + 1 = 2. This is the most basic example of addition.") end) @@ -753,7 +753,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then end end until not buffer - + assert.equal(#events, 8) assert.equal(buf:tostring(), "1 + 1 = 2") end) diff --git a/spec/03-plugins/39-ai-request-transformer/02-integration_spec.lua b/spec/03-plugins/39-ai-request-transformer/02-integration_spec.lua index 0b051da1479b..9598bab7f56f 100644 --- a/spec/03-plugins/39-ai-request-transformer/02-integration_spec.lua +++ b/spec/03-plugins/39-ai-request-transformer/02-integration_spec.lua @@ -158,7 +158,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then server { server_name llm; listen ]]..MOCK_PORT..[[; - + default_type 'application/json'; location ~/flat { @@ -171,7 +171,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then location = "/badrequest" { content_by_lua_block { local pl_file = require "pl.file" - + ngx.status = 400 ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/responses/bad_request.json")) } @@ -180,7 +180,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then location = "/internalservererror" { content_by_lua_block { local pl_file = require "pl.file" - + ngx.status = 500 ngx.header["content-type"] = "text/html" ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/responses/internal_server_error.html")) @@ -248,7 +248,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then declarative_config = strategy == "off" and helpers.make_yaml_file() or nil, }, nil, nil, fixtures)) end) - + lazy_teardown(function() helpers.stop_kong() end) @@ -270,7 +270,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = REQUEST_BODY, }) - + local body = assert.res_status(200 , r) local body_table, err = cjson.decode(body) diff --git a/spec/03-plugins/42-ai-prompt-guard/01-unit_spec.lua b/spec/03-plugins/42-ai-prompt-guard/01-unit_spec.lua index eab961081e65..b2c11519b056 100644 --- a/spec/03-plugins/42-ai-prompt-guard/01-unit_spec.lua +++ b/spec/03-plugins/42-ai-prompt-guard/01-unit_spec.lua @@ -18,7 +18,7 @@ local function create_request(typ) if typ ~= "chat" and typ ~= "completions" then error("type must be one of 'chat' or 'completions'", 2) end - + return setmetatable({ messages = messages, type = typ,