From 700b3b07c2124b759fd403c509d58bf5c4da4c7a Mon Sep 17 00:00:00 2001 From: Jun Ouyang Date: Thu, 15 Aug 2024 11:06:29 +0800 Subject: [PATCH] feat(llm): add auth.allow_override option for llm auth functionality (#13493) Add `allow_override` option to allow overriding the upstream model auth parameter or header from the caller's request. --- .../kong/ai-proxy-add-allow-override-opt.yml | 4 + kong/clustering/compat/removed_fields.lua | 1 + kong/llm/drivers/anthropic.lua | 11 +- kong/llm/drivers/azure.lua | 10 +- kong/llm/drivers/cohere.lua | 11 +- kong/llm/drivers/llama2.lua | 11 +- kong/llm/drivers/mistral.lua | 11 +- kong/llm/drivers/openai.lua | 11 +- kong/llm/drivers/shared.lua | 4 +- kong/llm/schemas/init.lua | 10 + .../09-hybrid_mode/09-config-compat_spec.lua | 5 + .../03-plugins/38-ai-proxy/00-config_spec.lua | 55 +++ .../02-openai_integration_spec.lua | 355 ++++++++++++++++++ .../03-anthropic_integration_spec.lua | 128 +++++++ .../04-cohere_integration_spec.lua | 123 ++++++ .../38-ai-proxy/05-azure_integration_spec.lua | 129 +++++++ .../06-mistral_integration_spec.lua | 135 +++++++ .../07-llama2_integration_spec.lua | 94 +++++ 18 files changed, 1090 insertions(+), 18 deletions(-) create mode 100644 changelog/unreleased/kong/ai-proxy-add-allow-override-opt.yml diff --git a/changelog/unreleased/kong/ai-proxy-add-allow-override-opt.yml b/changelog/unreleased/kong/ai-proxy-add-allow-override-opt.yml new file mode 100644 index 000000000000..798dffc5e59a --- /dev/null +++ b/changelog/unreleased/kong/ai-proxy-add-allow-override-opt.yml @@ -0,0 +1,4 @@ +message: | + **AI-proxy-plugin**: Add `allow_auth_override` option to allow overriding the upstream model auth parameter or header from the caller's request. +scope: Plugin +type: feature diff --git a/kong/clustering/compat/removed_fields.lua b/kong/clustering/compat/removed_fields.lua index d37db9a4172d..3493266c6221 100644 --- a/kong/clustering/compat/removed_fields.lua +++ b/kong/clustering/compat/removed_fields.lua @@ -175,6 +175,7 @@ return { "model.options.bedrock", "auth.aws_access_key_id", "auth.aws_secret_access_key", + "auth.allow_auth_override", "model_name_header", }, ai_prompt_decorator = { diff --git a/kong/llm/drivers/anthropic.lua b/kong/llm/drivers/anthropic.lua index 332f21878098..c942dbcbbe44 100644 --- a/kong/llm/drivers/anthropic.lua +++ b/kong/llm/drivers/anthropic.lua @@ -472,13 +472,18 @@ function _M.configure_request(conf) local auth_param_location = conf.auth and conf.auth.param_location if auth_header_name and auth_header_value then - kong.service.request.set_header(auth_header_name, auth_header_value) + local exist_value = kong.request.get_header(auth_header_name) + if exist_value == nil or not conf.auth.allow_auth_override then + kong.service.request.set_header(auth_header_name, auth_header_value) + end end if auth_param_name and auth_param_value and auth_param_location == "query" then local query_table = kong.request.get_query() - query_table[auth_param_name] = auth_param_value - kong.service.request.set_query(query_table) + if query_table[auth_param_name] == nil or not conf.auth.allow_auth_override then + query_table[auth_param_name] = auth_param_value + kong.service.request.set_query(query_table) + end end -- if auth_param_location is "form", it will have already been set in a pre-request hook diff --git a/kong/llm/drivers/azure.lua b/kong/llm/drivers/azure.lua index 343904ffad24..b88bffbfd1d1 100644 --- a/kong/llm/drivers/azure.lua +++ b/kong/llm/drivers/azure.lua @@ -131,9 +131,13 @@ function _M.configure_request(conf) local auth_param_location = conf.auth and conf.auth.param_location if auth_header_name and auth_header_value then - kong.service.request.set_header(auth_header_name, auth_header_value) + local exist_value = kong.request.get_header(auth_header_name) + if exist_value == nil or not conf.auth.allow_auth_override then + kong.service.request.set_header(auth_header_name, auth_header_value) + end end + local query_table = kong.request.get_query() -- technically min supported version @@ -141,7 +145,9 @@ function _M.configure_request(conf) or (conf.model.options and conf.model.options.azure_api_version) if auth_param_name and auth_param_value and auth_param_location == "query" then - query_table[auth_param_name] = auth_param_value + if query_table[auth_param_name] == nil or not conf.auth.allow_auth_override then + query_table[auth_param_name] = auth_param_value + end end kong.service.request.set_query(query_table) diff --git a/kong/llm/drivers/cohere.lua b/kong/llm/drivers/cohere.lua index d25764164086..89151608caac 100644 --- a/kong/llm/drivers/cohere.lua +++ b/kong/llm/drivers/cohere.lua @@ -480,13 +480,18 @@ function _M.configure_request(conf) local auth_param_location = conf.auth and conf.auth.param_location if auth_header_name and auth_header_value then - kong.service.request.set_header(auth_header_name, auth_header_value) + local exist_value = kong.request.get_header(auth_header_name) + if exist_value == nil or not conf.auth.allow_auth_override then + kong.service.request.set_header(auth_header_name, auth_header_value) + end end if auth_param_name and auth_param_value and auth_param_location == "query" then local query_table = kong.request.get_query() - query_table[auth_param_name] = auth_param_value - kong.service.request.set_query(query_table) + if query_table[auth_param_name] == nil or not conf.auth.allow_auth_override then + query_table[auth_param_name] = auth_param_value + kong.service.request.set_query(query_table) + end end -- if auth_param_location is "form", it will have already been set in a pre-request hook diff --git a/kong/llm/drivers/llama2.lua b/kong/llm/drivers/llama2.lua index 0526453f8a52..446e7295e70b 100644 --- a/kong/llm/drivers/llama2.lua +++ b/kong/llm/drivers/llama2.lua @@ -277,13 +277,18 @@ function _M.configure_request(conf) local auth_param_location = conf.auth and conf.auth.param_location if auth_header_name and auth_header_value then - kong.service.request.set_header(auth_header_name, auth_header_value) + local exist_value = kong.request.get_header(auth_header_name) + if exist_value == nil or not conf.auth.allow_auth_override then + kong.service.request.set_header(auth_header_name, auth_header_value) + end end if auth_param_name and auth_param_value and auth_param_location == "query" then local query_table = kong.request.get_query() - query_table[auth_param_name] = auth_param_value - kong.service.request.set_query(query_table) + if query_table[auth_param_name] == nil or not conf.auth.allow_auth_override then + query_table[auth_param_name] = auth_param_value + kong.service.request.set_query(query_table) + end end -- if auth_param_location is "form", it will have already been set in a pre-request hook diff --git a/kong/llm/drivers/mistral.lua b/kong/llm/drivers/mistral.lua index d1d2303b6919..8ae85b3a513b 100644 --- a/kong/llm/drivers/mistral.lua +++ b/kong/llm/drivers/mistral.lua @@ -172,13 +172,18 @@ function _M.configure_request(conf) local auth_param_location = conf.auth and conf.auth.param_location if auth_header_name and auth_header_value then - kong.service.request.set_header(auth_header_name, auth_header_value) + local exist_value = kong.request.get_header(auth_header_name) + if exist_value == nil or not conf.auth.allow_auth_override then + kong.service.request.set_header(auth_header_name, auth_header_value) + end end if auth_param_name and auth_param_value and auth_param_location == "query" then local query_table = kong.request.get_query() - query_table[auth_param_name] = auth_param_value - kong.service.request.set_query(query_table) + if query_table[auth_param_name] == nil or not conf.auth.allow_auth_override then + query_table[auth_param_name] = auth_param_value + kong.service.request.set_query(query_table) + end end -- if auth_param_location is "form", it will have already been set in a pre-request hook diff --git a/kong/llm/drivers/openai.lua b/kong/llm/drivers/openai.lua index 52df1910586e..b77cd1aafc37 100644 --- a/kong/llm/drivers/openai.lua +++ b/kong/llm/drivers/openai.lua @@ -213,13 +213,18 @@ function _M.configure_request(conf) local auth_param_location = conf.auth and conf.auth.param_location if auth_header_name and auth_header_value then - kong.service.request.set_header(auth_header_name, auth_header_value) + local exist_value = kong.request.get_header(auth_header_name) + if exist_value == nil or not conf.auth.allow_auth_override then + kong.service.request.set_header(auth_header_name, auth_header_value) + end end if auth_param_name and auth_param_value and auth_param_location == "query" then local query_table = kong.request.get_query() - query_table[auth_param_name] = auth_param_value - kong.service.request.set_query(query_table) + if query_table[auth_param_name] == nil or not conf.auth.allow_auth_override then + query_table[auth_param_name] = auth_param_value + kong.service.request.set_query(query_table) + end end -- if auth_param_location is "form", it will have already been set in a global pre-request hook diff --git a/kong/llm/drivers/shared.lua b/kong/llm/drivers/shared.lua index b4fa0c854246..02ee704bd67c 100644 --- a/kong/llm/drivers/shared.lua +++ b/kong/llm/drivers/shared.lua @@ -528,7 +528,9 @@ function _M.pre_request(conf, request_table) local auth_param_location = conf.auth and conf.auth.param_location if auth_param_name and auth_param_value and auth_param_location == "body" and request_table then - request_table[auth_param_name] = auth_param_value + if request_table[auth_param_name] == nil or not conf.auth.allow_auth_override then + request_table[auth_param_name] = auth_param_value + end end -- retrieve the plugin name diff --git a/kong/llm/schemas/init.lua b/kong/llm/schemas/init.lua index 0fcc3a058a31..c4cf0e302baf 100644 --- a/kong/llm/schemas/init.lua +++ b/kong/llm/schemas/init.lua @@ -97,6 +97,11 @@ local auth_schema = { required = false, encrypted = true, referenceable = true }}, + { allow_auth_override = { + type = "boolean", + description = "If enabled, the authorization header or parameter can be overridden in the request by the value configured in the plugin.", + required = false, + default = true }}, } } @@ -237,6 +242,11 @@ return { { logging = logging_schema }, }, entity_checks = { + { conditional = { if_field = "model.provider", + if_match = { one_of = { "bedrock", "gemini" } }, + then_field = "auth.allow_auth_override", + then_match = { eq = false }, + then_err = "bedrock and gemini only support auth.allow_auth_override = false" }}, { mutually_required = { "auth.header_name", "auth.header_value" }, }, { mutually_required = { "auth.param_name", "auth.param_value", "auth.param_location" }, }, diff --git a/spec/02-integration/09-hybrid_mode/09-config-compat_spec.lua b/spec/02-integration/09-hybrid_mode/09-config-compat_spec.lua index 808f4cd5ade3..e75628a8cb00 100644 --- a/spec/02-integration/09-hybrid_mode/09-config-compat_spec.lua +++ b/spec/02-integration/09-hybrid_mode/09-config-compat_spec.lua @@ -495,6 +495,7 @@ describe("CP/DP config compat transformations #" .. strategy, function() header_value = "value", gcp_service_account_json = '{"service": "account"}', gcp_use_service_account = true, + allow_auth_override = false, }, model = { name = "any-model-name", @@ -526,6 +527,7 @@ describe("CP/DP config compat transformations #" .. strategy, function() -- gemini fields expected.config.auth.gcp_service_account_json = nil expected.config.auth.gcp_use_service_account = nil + expected.config.auth.allow_auth_override = nil expected.config.model.options.gemini = nil -- bedrock fields @@ -562,6 +564,7 @@ describe("CP/DP config compat transformations #" .. strategy, function() header_value = "value", gcp_service_account_json = '{"service": "account"}', gcp_use_service_account = true, + allow_auth_override = false, }, model = { name = "any-model-name", @@ -625,6 +628,7 @@ describe("CP/DP config compat transformations #" .. strategy, function() header_value = "value", gcp_service_account_json = '{"service": "account"}', gcp_use_service_account = true, + allow_auth_override = false, }, model = { name = "any-model-name", @@ -720,6 +724,7 @@ describe("CP/DP config compat transformations #" .. strategy, function() -- bedrock fields expected.config.auth.aws_access_key_id = nil expected.config.auth.aws_secret_access_key = nil + expected.config.auth.allow_auth_override = nil expected.config.model.options.bedrock = nil do_assert(uuid(), "3.7.0", expected) diff --git a/spec/03-plugins/38-ai-proxy/00-config_spec.lua b/spec/03-plugins/38-ai-proxy/00-config_spec.lua index 34cddb74d619..3aa61ef5d460 100644 --- a/spec/03-plugins/38-ai-proxy/00-config_spec.lua +++ b/spec/03-plugins/38-ai-proxy/00-config_spec.lua @@ -289,4 +289,59 @@ describe(PLUGIN_NAME .. ": (schema)", function() assert.is_truthy(ok) end) + it("bedrock model can not support ath.allowed_auth_override", function() + local config = { + route_type = "llm/v1/chat", + auth = { + param_name = "apikey", + param_value = "key", + param_location = "query", + header_name = "Authorization", + header_value = "Bearer token", + allow_auth_override = true, + }, + model = { + name = "bedrock", + provider = "bedrock", + options = { + max_tokens = 256, + temperature = 1.0, + upstream_url = "http://nowhere", + }, + }, + } + + local ok, err = validate(config) + + assert.is_falsy(ok) + assert.is_truthy(err) + end) + + it("gemini model can not support ath.allowed_auth_override", function() + local config = { + route_type = "llm/v1/chat", + auth = { + param_name = "apikey", + param_value = "key", + param_location = "query", + header_name = "Authorization", + header_value = "Bearer token", + allow_auth_override = true, + }, + model = { + name = "gemini", + provider = "gemini", + options = { + max_tokens = 256, + temperature = 1.0, + upstream_url = "http://nowhere", + }, + }, + } + + local ok, err = validate(config) + + assert.is_falsy(ok) + assert.is_truthy(err) + end) end) diff --git a/spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua b/spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua index b1b772bfbeda..e716a5f0e386 100644 --- a/spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua +++ b/spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua @@ -268,6 +268,41 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then path = FILE_LOG_PATH_STATS_ONLY, }, } + + -- 200 chat good with one option + local chat_good_no_allow_override = assert(bp.routes:insert { + service = empty_service, + protocols = { "http" }, + strip_path = true, + paths = { "/openai/llm/v1/chat/good-no-allow-override" } + }) + bp.plugins:insert { + name = PLUGIN_NAME, + route = { id = chat_good_no_allow_override.id }, + config = { + route_type = "llm/v1/chat", + logging = { + log_payloads = false, + log_statistics = true, + }, + auth = { + header_name = "Authorization", + header_value = "Bearer openai-key", + allow_auth_override = false, + }, + model = { + name = "gpt-3.5-turbo", + provider = "openai", + options = { + max_tokens = 256, + temperature = 1.0, + upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/llm/v1/chat/good", + input_cost = 10.0, + output_cost = 10.0, + }, + }, + }, + } -- -- 200 chat good with statistics disabled @@ -436,6 +471,37 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, }, } + + -- 200 completions good using query param key with no allow override + local completions_good_one_query_param_no_allow_override = assert(bp.routes:insert { + service = empty_service, + protocols = { "http" }, + strip_path = true, + paths = { "/openai/llm/v1/completions/query-param-auth-no-allow-override" } + }) + bp.plugins:insert { + name = PLUGIN_NAME, + route = { id = completions_good_one_query_param_no_allow_override.id }, + config = { + route_type = "llm/v1/completions", + auth = { + param_name = "apikey", + param_value = "openai-key", + param_location = "query", + allow_auth_override = false, + }, + model = { + name = "gpt-3.5-turbo-instruct", + provider = "openai", + options = { + max_tokens = 256, + temperature = 1.0, + upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/llm/v1/completions/good" + }, + }, + }, + } + -- -- 200 embeddings (preserve route mode) good @@ -534,6 +600,37 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, }, } + + -- 200 completions good using post body key + local completions_good_post_body_key_no_allow_override = assert(bp.routes:insert { + service = empty_service, + protocols = { "http" }, + strip_path = true, + paths = { "/openai/llm/v1/completions/post-body-auth-no-allow-override" } + }) + bp.plugins:insert { + name = PLUGIN_NAME, + route = { id = completions_good_post_body_key_no_allow_override.id }, + config = { + route_type = "llm/v1/completions", + auth = { + param_name = "apikey", + param_value = "openai-key", + param_location = "body", + allow_auth_override = false, + }, + model = { + name = "gpt-3.5-turbo-instruct", + provider = "openai", + options = { + max_tokens = 256, + temperature = 1.0, + upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/llm/v1/completions/good" + }, + }, + }, + } + -- -- 401 unauthorized @@ -838,6 +935,64 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then assert.is_truthy(json.error) assert.equals(json.error.code, "invalid_api_key") end) + + it("unauthorized request with client header auth", function() + local r = client:get("/openai/llm/v1/chat/good", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + ["Authorization"] = "Bearer wrong", + }, + body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good.json"), + }) + + local body = assert.res_status(401 , r) + local json = cjson.decode(body) + + -- check this is in the 'kong' response format + assert.is_truthy(json.error) + assert.equals(json.error.code, "invalid_api_key") + end) + + it("authorized request with client header auth", function() + local r = client:get("/openai/llm/v1/chat/good", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + ["Authorization"] = "Bearer openai-key", + }, + body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good.json"), + }) + + assert.res_status(200 , r) + end) + + it("authorized request with client right header auth with no allow_auth_override", function() + local r = client:get("/openai/llm/v1/chat/good-no-allow-override", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + ["Authorization"] = "Bearer openai-key", + }, + body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good.json"), + }) + + assert.res_status(200 , r) + end) + + it("authorized request with wrong client header auth with no allow_auth_override", function() + local r = client:get("/openai/llm/v1/chat/good-no-allow-override", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + ["Authorization"] = "Bearer wrong", + }, + body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good.json"), + }) + + assert.res_status(200 , r) + end) + end) describe("openai llm/v1/chat", function() @@ -1013,6 +1168,89 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then assert.same("\n\nI am a language model AI created by OpenAI. I can answer questions", json.choices[1].text) end) + it("works with query param auth with client right auth parm", function() + local r = client:get("/openai/llm/v1/completions/query-param-auth?apikey=openai-key", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + }, + body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-completions/requests/good.json"), + }) + + -- validate that the request succeeded, response status 200 + local body = assert.res_status(200 , r) + local json = cjson.decode(body) + + -- check this is in the 'kong' response format + assert.equals("cmpl-8TBeaJVQIhE9kHEJbk1RnKzgFxIqN", json.id) + assert.equals("gpt-3.5-turbo-instruct", json.model) + assert.equals("text_completion", json.object) + assert.is_table(json.choices) + assert.is_table(json.choices[1]) + assert.same("\n\nI am a language model AI created by OpenAI. I can answer questions", json.choices[1].text) + end) + + it("works with query param auth with client wrong auth parm", function() + local r = client:get("/openai/llm/v1/completions/query-param-auth?apikey=wrong", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + }, + body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-completions/requests/good.json"), + }) + + local body = assert.res_status(401 , r) + local json = cjson.decode(body) + + -- check this is in the 'kong' response format + assert.is_truthy(json.error) + assert.equals(json.error.code, "invalid_api_key") + end) + + it("works with query param auth with client right auth parm with no allow-override", function() + local r = client:get("/openai/llm/v1/completions/query-param-auth-no-allow-override?apikey=openai-key", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + }, + body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-completions/requests/good.json"), + }) + + -- validate that the request succeeded, response status 200 + local body = assert.res_status(200 , r) + local json = cjson.decode(body) + + -- check this is in the 'kong' response format + assert.equals("cmpl-8TBeaJVQIhE9kHEJbk1RnKzgFxIqN", json.id) + assert.equals("gpt-3.5-turbo-instruct", json.model) + assert.equals("text_completion", json.object) + assert.is_table(json.choices) + assert.is_table(json.choices[1]) + assert.same("\n\nI am a language model AI created by OpenAI. I can answer questions", json.choices[1].text) + end) + + it("works with query param auth with client wrong auth parm with no allow-override", function() + local r = client:get("/openai/llm/v1/completions/query-param-auth-no-allow-override?apikey=wrong", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + }, + body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-completions/requests/good.json"), + }) + + -- validate that the request succeeded, response status 200 + local body = assert.res_status(200 , r) + local json = cjson.decode(body) + + -- check this is in the 'kong' response format + assert.equals("cmpl-8TBeaJVQIhE9kHEJbk1RnKzgFxIqN", json.id) + assert.equals("gpt-3.5-turbo-instruct", json.model) + assert.equals("text_completion", json.object) + assert.is_table(json.choices) + assert.is_table(json.choices[1]) + assert.same("\n\nI am a language model AI created by OpenAI. I can answer questions", json.choices[1].text) + end) + it("works with post body auth", function() local r = client:get("/openai/llm/v1/completions/post-body-auth", { headers = { @@ -1034,6 +1272,123 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then assert.is_table(json.choices[1]) assert.same("\n\nI am a language model AI created by OpenAI. I can answer questions", json.choices[1].text) end) + + it("works with post body auth", function() + local r = client:get("/openai/llm/v1/completions/post-body-auth", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + }, + body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-completions/requests/good.json"), + }) + + -- validate that the request succeeded, response status 200 + local body = assert.res_status(200 , r) + local json = cjson.decode(body) + + -- check this is in the 'kong' response format + assert.equals("cmpl-8TBeaJVQIhE9kHEJbk1RnKzgFxIqN", json.id) + assert.equals("gpt-3.5-turbo-instruct", json.model) + assert.equals("text_completion", json.object) + assert.is_table(json.choices) + assert.is_table(json.choices[1]) + assert.same("\n\nI am a language model AI created by OpenAI. I can answer questions", json.choices[1].text) + end) + + it("works with post body auth with client right auth body", function() + local good_body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-completions/requests/good.json") + local body = cjson.decode(good_body) + body.apikey = "openai-key" + local r = client:get("/openai/llm/v1/completions/post-body-auth", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + }, + body = cjson.encode(body), + }) + + -- validate that the request succeeded, response status 200 + local body = assert.res_status(200 , r) + local json = cjson.decode(body) + + -- check this is in the 'kong' response format + assert.equals("cmpl-8TBeaJVQIhE9kHEJbk1RnKzgFxIqN", json.id) + assert.equals("gpt-3.5-turbo-instruct", json.model) + assert.equals("text_completion", json.object) + assert.is_table(json.choices) + assert.is_table(json.choices[1]) + assert.same("\n\nI am a language model AI created by OpenAI. I can answer questions", json.choices[1].text) + end) + + it("works with post body auth with client wrong auth body", function() + local good_body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-completions/requests/good.json") + local body = cjson.decode(good_body) + body.apikey = "wrong" + local r = client:get("/openai/llm/v1/completions/post-body-auth", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + }, + body = cjson.encode(body), + }) + + local body = assert.res_status(401 , r) + local json = cjson.decode(body) + + -- check this is in the 'kong' response format + assert.is_truthy(json.error) + assert.equals(json.error.code, "invalid_api_key") + end) + + it("works with post body auth with client right auth body and no allow_auth_override", function() + local good_body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-completions/requests/good.json") + local body = cjson.decode(good_body) + body.apikey = "openai-key" + local r = client:get("/openai/llm/v1/completions/post-body-auth-no-allow-override", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + }, + body = cjson.encode(body), + }) + + -- validate that the request succeeded, response status 200 + local body = assert.res_status(200 , r) + local json = cjson.decode(body) + + -- check this is in the 'kong' response format + assert.equals("cmpl-8TBeaJVQIhE9kHEJbk1RnKzgFxIqN", json.id) + assert.equals("gpt-3.5-turbo-instruct", json.model) + assert.equals("text_completion", json.object) + assert.is_table(json.choices) + assert.is_table(json.choices[1]) + assert.same("\n\nI am a language model AI created by OpenAI. I can answer questions", json.choices[1].text) + end) + + it("works with post body auth with client wrong auth body and no allow_auth_override", function() + local good_body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-completions/requests/good.json") + local body = cjson.decode(good_body) + body.apikey = "wrong" + local r = client:get("/openai/llm/v1/completions/post-body-auth-no-allow-override", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + }, + body = cjson.encode(body), + }) + + -- validate that the request succeeded, response status 200 + local body = assert.res_status(200 , r) + local json = cjson.decode(body) + + -- check this is in the 'kong' response format + assert.equals("cmpl-8TBeaJVQIhE9kHEJbk1RnKzgFxIqN", json.id) + assert.equals("gpt-3.5-turbo-instruct", json.model) + assert.equals("text_completion", json.object) + assert.is_table(json.choices) + assert.is_table(json.choices[1]) + assert.same("\n\nI am a language model AI created by OpenAI. I can answer questions", json.choices[1].text) + end) end) describe("one-shot request", function() diff --git a/spec/03-plugins/38-ai-proxy/03-anthropic_integration_spec.lua b/spec/03-plugins/38-ai-proxy/03-anthropic_integration_spec.lua index dd52f7e066a9..6d87425054e8 100644 --- a/spec/03-plugins/38-ai-proxy/03-anthropic_integration_spec.lua +++ b/spec/03-plugins/38-ai-proxy/03-anthropic_integration_spec.lua @@ -218,6 +218,35 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, }, } + + local chat_good_no_allow_override = assert(bp.routes:insert { + service = empty_service, + protocols = { "http" }, + strip_path = true, + paths = { "/anthropic/llm/v1/chat/good-no-allow-override" } + }) + bp.plugins:insert { + name = PLUGIN_NAME, + route = { id = chat_good_no_allow_override.id }, + config = { + route_type = "llm/v1/chat", + auth = { + header_name = "x-api-key", + header_value = "anthropic-key", + allow_auth_override = false, + }, + model = { + name = "claude-2.1", + provider = "anthropic", + options = { + max_tokens = 256, + temperature = 1.0, + upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/llm/v1/chat/good", + anthropic_version = "2023-06-01", + }, + }, + }, + } -- -- 200 chat bad upstream response with one option @@ -551,6 +580,105 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, json.choices[1].message) end) + it("good request with client right header auth", function() + local r = client:get("/anthropic/llm/v1/chat/good", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + ["x-api-key"] = "anthropic-key", + }, + body = pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-chat/requests/good.json"), + }) + + local body = assert.res_status(200 , r) + local json = cjson.decode(body) + + -- check this is in the 'kong' response format + -- assert.equals(json.id, "chatcmpl-8T6YwgvjQVVnGbJ2w8hpOA17SeNy2") + assert.equals(json.model, "claude-2.1") + assert.equals(json.object, "chat.content") + assert.equals(r.headers["X-Kong-LLM-Model"], "anthropic/claude-2.1") + + assert.is_table(json.choices) + assert.is_table(json.choices[1].message) + assert.same({ + content = "The sum of 1 + 1 is 2.", + role = "assistant", + }, json.choices[1].message) + end) + + it("good request with client wrong header auth", function() + local r = client:get("/anthropic/llm/v1/chat/good", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + ["x-api-key"] = "wrong", + }, + body = pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-chat/requests/good.json"), + }) + + local body = assert.res_status(401 , r) + local json = cjson.decode(body) + + -- check this is in the 'kong' response format + assert.is_truthy(json.error) + assert.equals(json.error.type, "authentication_error") + end) + + it("good request with client right header auth and no allow_auth_override", function() + local r = client:get("/anthropic/llm/v1/chat/good-no-allow-override", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + ["x-api-key"] = "anthropic-key", + }, + body = pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-chat/requests/good.json"), + }) + + local body = assert.res_status(200 , r) + local json = cjson.decode(body) + + -- check this is in the 'kong' response format + -- assert.equals(json.id, "chatcmpl-8T6YwgvjQVVnGbJ2w8hpOA17SeNy2") + assert.equals(json.model, "claude-2.1") + assert.equals(json.object, "chat.content") + assert.equals(r.headers["X-Kong-LLM-Model"], "anthropic/claude-2.1") + + assert.is_table(json.choices) + assert.is_table(json.choices[1].message) + assert.same({ + content = "The sum of 1 + 1 is 2.", + role = "assistant", + }, json.choices[1].message) + end) + + it("good request with client wrong header auth and no allow_auth_override", function() + local r = client:get("/anthropic/llm/v1/chat/good-no-allow-override", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + ["x-api-key"] = "wrong", + }, + body = pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-chat/requests/good.json"), + }) + + local body = assert.res_status(200 , r) + local json = cjson.decode(body) + + -- check this is in the 'kong' response format + -- assert.equals(json.id, "chatcmpl-8T6YwgvjQVVnGbJ2w8hpOA17SeNy2") + assert.equals(json.model, "claude-2.1") + assert.equals(json.object, "chat.content") + assert.equals(r.headers["X-Kong-LLM-Model"], "anthropic/claude-2.1") + + assert.is_table(json.choices) + assert.is_table(json.choices[1].message) + assert.same({ + content = "The sum of 1 + 1 is 2.", + role = "assistant", + }, json.choices[1].message) + end) + it("bad upstream response", function() local r = client:get("/anthropic/llm/v1/chat/bad_upstream_response", { headers = { diff --git a/spec/03-plugins/38-ai-proxy/04-cohere_integration_spec.lua b/spec/03-plugins/38-ai-proxy/04-cohere_integration_spec.lua index eb52249c8fb4..d3d0f55a9ce9 100644 --- a/spec/03-plugins/38-ai-proxy/04-cohere_integration_spec.lua +++ b/spec/03-plugins/38-ai-proxy/04-cohere_integration_spec.lua @@ -166,6 +166,33 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, }, } + local chat_good_no_allow_override = assert(bp.routes:insert { + service = empty_service, + protocols = { "http" }, + strip_path = true, + paths = { "/cohere/llm/v1/chat/good-no-allow-override" } + }) + bp.plugins:insert { + name = PLUGIN_NAME, + route = { id = chat_good_no_allow_override.id }, + config = { + route_type = "llm/v1/chat", + auth = { + header_name = "Authorization", + header_value = "Bearer cohere-key", + allow_auth_override = false, + }, + model = { + name = "command", + provider = "cohere", + options = { + max_tokens = 256, + temperature = 1.0, + upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/llm/v1/chat/good", + }, + }, + }, + } -- -- 200 chat bad upstream response with one option @@ -426,6 +453,102 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, json.choices[1].message) end) + it("good request with right client auth", function() + local r = client:get("/cohere/llm/v1/chat/good", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + ["Authorization"] = "Bearer cohere-key", + }, + body = pl_file.read("spec/fixtures/ai-proxy/cohere/llm-v1-chat/requests/good.json"), + }) + + local body = assert.res_status(200 , r) + local json = cjson.decode(body) + + -- check this is in the 'kong' response format + assert.equals(json.model, "command") + assert.equals(json.object, "chat.completion") + assert.equals(r.headers["X-Kong-LLM-Model"], "cohere/command") + + assert.is_table(json.choices) + assert.is_table(json.choices[1].message) + assert.same({ + content = "The sum of 1 + 1 is 2.", + role = "assistant", + }, json.choices[1].message) + end) + + it("good request with wrong client auth", function() + local r = client:get("/cohere/llm/v1/chat/good", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + ["Authorization"] = "Bearer wrong", + }, + body = pl_file.read("spec/fixtures/ai-proxy/cohere/llm-v1-chat/requests/good.json"), + }) + + local body = assert.res_status(401 , r) + local json = cjson.decode(body) + + -- check this is in the 'kong' response format + assert.is_truthy(json.message) + assert.equals(json.message, "invalid api token") + end) + + it("good request with right client auth and no allow_auth_override", function() + local r = client:get("/cohere/llm/v1/chat/good-no-allow-override", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + ["Authorization"] = "Bearer cohere-key", + }, + body = pl_file.read("spec/fixtures/ai-proxy/cohere/llm-v1-chat/requests/good.json"), + }) + + local body = assert.res_status(200 , r) + local json = cjson.decode(body) + + -- check this is in the 'kong' response format + assert.equals(json.model, "command") + assert.equals(json.object, "chat.completion") + assert.equals(r.headers["X-Kong-LLM-Model"], "cohere/command") + + assert.is_table(json.choices) + assert.is_table(json.choices[1].message) + assert.same({ + content = "The sum of 1 + 1 is 2.", + role = "assistant", + }, json.choices[1].message) + end) + + it("good request with wrong client auth and no allow_auth_override", function() + local r = client:get("/cohere/llm/v1/chat/good-no-allow-override", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + ["Authorization"] = "Bearer wrong", + }, + body = pl_file.read("spec/fixtures/ai-proxy/cohere/llm-v1-chat/requests/good.json"), + }) + + local body = assert.res_status(200 , r) + local json = cjson.decode(body) + + -- check this is in the 'kong' response format + assert.equals(json.model, "command") + assert.equals(json.object, "chat.completion") + assert.equals(r.headers["X-Kong-LLM-Model"], "cohere/command") + + assert.is_table(json.choices) + assert.is_table(json.choices[1].message) + assert.same({ + content = "The sum of 1 + 1 is 2.", + role = "assistant", + }, json.choices[1].message) + end) + it("bad upstream response", function() local r = client:get("/cohere/llm/v1/chat/bad_upstream_response", { headers = { diff --git a/spec/03-plugins/38-ai-proxy/05-azure_integration_spec.lua b/spec/03-plugins/38-ai-proxy/05-azure_integration_spec.lua index 82385720efcf..baa6a618389d 100644 --- a/spec/03-plugins/38-ai-proxy/05-azure_integration_spec.lua +++ b/spec/03-plugins/38-ai-proxy/05-azure_integration_spec.lua @@ -168,6 +168,36 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, }, } + + local chat_good_no_allow_override = assert(bp.routes:insert { + service = empty_service, + protocols = { "http" }, + strip_path = true, + paths = { "/azure/llm/v1/chat/good-no-allow-override" } + }) + bp.plugins:insert { + name = PLUGIN_NAME, + route = { id = chat_good_no_allow_override.id }, + config = { + route_type = "llm/v1/chat", + auth = { + header_name = "api-key", + header_value = "azure-key", + allow_auth_override = false, + }, + model = { + name = "gpt-3.5-turbo", + provider = "azure", + options = { + max_tokens = 256, + temperature = 1.0, + upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/llm/v1/chat/good", + azure_instance = "001-kong-t", + azure_deployment_id = "gpt-3.5-custom", + }, + }, + }, + } -- -- 200 chat bad upstream response with one option @@ -442,6 +472,105 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, json.choices[1].message) end) + it("good request with client right auth", function() + local r = client:get("/azure/llm/v1/chat/good", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + ["api-key"] = "azure-key", + }, + body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good.json"), + }) + + -- validate that the request succeeded, response status 200 + local body = assert.res_status(200 , r) + local json = cjson.decode(body) + + -- check this is in the 'kong' response format + assert.equals(json.id, "chatcmpl-8T6YwgvjQVVnGbJ2w8hpOA17SeNy2") + assert.equals(json.model, "gpt-3.5-turbo-0613") + assert.equals(json.object, "chat.completion") + + assert.is_table(json.choices) + assert.is_table(json.choices[1].message) + assert.same({ + content = "The sum of 1 + 1 is 2.", + role = "assistant", + }, json.choices[1].message) + end) + + it("good request with client wrong auth", function() + local r = client:get("/azure/llm/v1/chat/good", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + ["api-key"] = "wrong", + }, + body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good.json"), + }) + + local body = assert.res_status(401 , r) + local json = cjson.decode(body) + + -- check this is in the 'kong' response format + assert.is_truthy(json.error) + assert.equals(json.error.code, "invalid_api_key") + end) + + it("good request with client right auth and no allow_auth_override", function() + local r = client:get("/azure/llm/v1/chat/good-no-allow-override", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + ["api-key"] = "azure-key", + }, + body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good.json"), + }) + + -- validate that the request succeeded, response status 200 + local body = assert.res_status(200 , r) + local json = cjson.decode(body) + + -- check this is in the 'kong' response format + assert.equals(json.id, "chatcmpl-8T6YwgvjQVVnGbJ2w8hpOA17SeNy2") + assert.equals(json.model, "gpt-3.5-turbo-0613") + assert.equals(json.object, "chat.completion") + + assert.is_table(json.choices) + assert.is_table(json.choices[1].message) + assert.same({ + content = "The sum of 1 + 1 is 2.", + role = "assistant", + }, json.choices[1].message) + end) + + it("good request with client wrong auth and no allow_auth_override", function() + local r = client:get("/azure/llm/v1/chat/good-no-allow-override", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + ["api-key"] = "wrong", + }, + body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good.json"), + }) + + -- validate that the request succeeded, response status 200 + local body = assert.res_status(200 , r) + local json = cjson.decode(body) + + -- check this is in the 'kong' response format + assert.equals(json.id, "chatcmpl-8T6YwgvjQVVnGbJ2w8hpOA17SeNy2") + assert.equals(json.model, "gpt-3.5-turbo-0613") + assert.equals(json.object, "chat.completion") + + assert.is_table(json.choices) + assert.is_table(json.choices[1].message) + assert.same({ + content = "The sum of 1 + 1 is 2.", + role = "assistant", + }, json.choices[1].message) + end) + it("bad upstream response", function() local r = client:get("/azure/llm/v1/chat/bad_upstream_response", { headers = { diff --git a/spec/03-plugins/38-ai-proxy/06-mistral_integration_spec.lua b/spec/03-plugins/38-ai-proxy/06-mistral_integration_spec.lua index 26bc21acf999..7134fd21a54a 100644 --- a/spec/03-plugins/38-ai-proxy/06-mistral_integration_spec.lua +++ b/spec/03-plugins/38-ai-proxy/06-mistral_integration_spec.lua @@ -111,6 +111,35 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, }, } + + local chat_good_no_allow_override = assert(bp.routes:insert { + service = empty_service, + protocols = { "http" }, + strip_path = true, + paths = { "/mistral/llm/v1/chat/good-no-allow-override" } + }) + bp.plugins:insert { + name = PLUGIN_NAME, + route = { id = chat_good_no_allow_override.id }, + config = { + route_type = "llm/v1/chat", + auth = { + header_name = "Authorization", + header_value = "Bearer mistral-key", + allow_auth_override = false, + }, + model = { + name = "mistralai/Mistral-7B-Instruct-v0.1-instruct", + provider = "mistral", + options = { + max_tokens = 256, + temperature = 1.0, + mistral_format = "openai", + upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/v1/chat/completions", + }, + }, + }, + } -- -- 200 chat bad upstream response with one option @@ -347,6 +376,112 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then role = "assistant", }, json.choices[1].message) end) + + it("good request with client right auth", function() + local r = client:get("/mistral/llm/v1/chat/good", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + ["Authorization"] = "Bearer mistral-key", + + }, + body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good.json"), + }) + + -- validate that the request succeeded, response status 200 + local body = assert.res_status(200 , r) + local json = cjson.decode(body) + + -- check this is in the 'kong' response format + assert.equals(json.id, "chatcmpl-8T6YwgvjQVVnGbJ2w8hpOA17SeNy2") + assert.equals(json.model, "mistralai/Mistral-7B-Instruct-v0.1-instruct") + assert.equals(json.object, "chat.completion") + assert.equals(r.headers["X-Kong-LLM-Model"], "mistral/mistralai/Mistral-7B-Instruct-v0.1-instruct") + + assert.is_table(json.choices) + assert.is_table(json.choices[1].message) + assert.same({ + content = "The sum of 1 + 1 is 2.", + role = "assistant", + }, json.choices[1].message) + end) + + it("good request with client wrong auth", function() + local r = client:get("/mistral/llm/v1/chat/good", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + ["Authorization"] = "Bearer wrong", + }, + body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good.json"), + }) + + -- validate that the request succeeded, response status 200 + + local body = assert.res_status(401 , r) + local json = cjson.decode(body) + + -- check this is in the 'kong' response format + assert.is_truthy(json.error) + assert.equals(json.error.code, "invalid_api_key") + end) + + it("good request with client right auth and no allow_auth_override", function() + local r = client:get("/mistral/llm/v1/chat/good-no-allow-override", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + ["Authorization"] = "Bearer mistral-key", + + }, + body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good.json"), + }) + + -- validate that the request succeeded, response status 200 + local body = assert.res_status(200 , r) + local json = cjson.decode(body) + + -- check this is in the 'kong' response format + assert.equals(json.id, "chatcmpl-8T6YwgvjQVVnGbJ2w8hpOA17SeNy2") + assert.equals(json.model, "mistralai/Mistral-7B-Instruct-v0.1-instruct") + assert.equals(json.object, "chat.completion") + assert.equals(r.headers["X-Kong-LLM-Model"], "mistral/mistralai/Mistral-7B-Instruct-v0.1-instruct") + + assert.is_table(json.choices) + assert.is_table(json.choices[1].message) + assert.same({ + content = "The sum of 1 + 1 is 2.", + role = "assistant", + }, json.choices[1].message) + end) + + it("good request with client wrong auth and no allow_auth_override", function() + local r = client:get("/mistral/llm/v1/chat/good-no-allow-override", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + ["Authorization"] = "Bearer wrong", + }, + body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good.json"), + }) + + -- validate that the request succeeded, response status 200 + local body = assert.res_status(200 , r) + local json = cjson.decode(body) + + -- check this is in the 'kong' response format + assert.equals(json.id, "chatcmpl-8T6YwgvjQVVnGbJ2w8hpOA17SeNy2") + assert.equals(json.model, "mistralai/Mistral-7B-Instruct-v0.1-instruct") + assert.equals(json.object, "chat.completion") + assert.equals(r.headers["X-Kong-LLM-Model"], "mistral/mistralai/Mistral-7B-Instruct-v0.1-instruct") + + assert.is_table(json.choices) + assert.is_table(json.choices[1].message) + assert.same({ + content = "The sum of 1 + 1 is 2.", + role = "assistant", + }, json.choices[1].message) + end) end) describe("mistral llm/v1/completions", function() diff --git a/spec/03-plugins/38-ai-proxy/07-llama2_integration_spec.lua b/spec/03-plugins/38-ai-proxy/07-llama2_integration_spec.lua index 778804d4af6c..0060ddaf4fb2 100644 --- a/spec/03-plugins/38-ai-proxy/07-llama2_integration_spec.lua +++ b/spec/03-plugins/38-ai-proxy/07-llama2_integration_spec.lua @@ -141,6 +141,35 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, }, } + + local chat_good_no_allow_override = assert(bp.routes:insert { + service = empty_service, + protocols = { "http" }, + strip_path = true, + paths = { "/raw/llm/v1/completions-no-allow-override" } + }) + bp.plugins:insert { + name = PLUGIN_NAME, + route = { id = chat_good_no_allow_override.id }, + config = { + route_type = "llm/v1/completions", + auth = { + header_name = "Authorization", + header_value = "Bearer llama2-key", + allow_auth_override = false, + }, + model = { + name = "llama-2-7b-chat-hf", + provider = "llama2", + options = { + max_tokens = 256, + temperature = 1.0, + llama2_format = "raw", + upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/raw/llm/v1/completions", + }, + }, + }, + } -- -- start kong @@ -198,6 +227,71 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then assert.equals(json.choices[1].text, "Is a well known font.") end) + + it("runs good request in completions format with client right auth", function() + local r = client:get("/raw/llm/v1/completions", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + ["Authorization"] = "Bearer llama2-key" + }, + body = pl_file.read("spec/fixtures/ai-proxy/llama2/raw/requests/good-completions.json"), + }) + + local body = assert.res_status(200, r) + local json = cjson.decode(body) + + assert.equals(json.choices[1].text, "Is a well known font.") + end) + + it("runs good request in completions format with client wrong auth", function() + local r = client:get("/raw/llm/v1/completions", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + ["Authorization"] = "Bearer wrong" + }, + body = pl_file.read("spec/fixtures/ai-proxy/llama2/raw/requests/good-completions.json"), + }) + + local body = assert.res_status(401, r) + local json = cjson.decode(body) + + assert.equals(json.error, "Model requires a Pro subscription.") + end) + + it("runs good request in completions format with client right auth and no allow_auth_override", function() + local r = client:get("/raw/llm/v1/completions-no-allow-override", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + ["Authorization"] = "Bearer llama2-key" + }, + body = pl_file.read("spec/fixtures/ai-proxy/llama2/raw/requests/good-completions.json"), + }) + + local body = assert.res_status(200, r) + local json = cjson.decode(body) + + assert.equals(json.choices[1].text, "Is a well known font.") + end) + + it("runs good request in completions format with client wrong auth and no allow_auth_override", function() + local r = client:get("/raw/llm/v1/completions-no-allow-override", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + ["Authorization"] = "Bearer wrong" + }, + body = pl_file.read("spec/fixtures/ai-proxy/llama2/raw/requests/good-completions.json"), + }) + + local body = assert.res_status(200, r) + local json = cjson.decode(body) + + assert.equals(json.choices[1].text, "Is a well known font.") + end) + end) describe("one-shot request", function()