From fcc3e7bb973e34d3b86c39c6007399c7faddb5ec Mon Sep 17 00:00:00 2001 From: owl Date: Wed, 14 Aug 2024 15:35:40 +0800 Subject: [PATCH] feat: fix code --- kong/llm/drivers/shared.lua | 4 +- .../02-openai_integration_spec.lua | 355 ++++++++++++++++++ 2 files changed, 358 insertions(+), 1 deletion(-) diff --git a/kong/llm/drivers/shared.lua b/kong/llm/drivers/shared.lua index b4fa0c8542464..02ee704bd67ca 100644 --- a/kong/llm/drivers/shared.lua +++ b/kong/llm/drivers/shared.lua @@ -528,7 +528,9 @@ function _M.pre_request(conf, request_table) local auth_param_location = conf.auth and conf.auth.param_location if auth_param_name and auth_param_value and auth_param_location == "body" and request_table then - request_table[auth_param_name] = auth_param_value + if request_table[auth_param_name] == nil or not conf.auth.allow_auth_override then + request_table[auth_param_name] = auth_param_value + end end -- retrieve the plugin name diff --git a/spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua b/spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua index b1b772bfbedae..e716a5f0e3866 100644 --- a/spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua +++ b/spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua @@ -268,6 +268,41 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then path = FILE_LOG_PATH_STATS_ONLY, }, } + + -- 200 chat good with one option + local chat_good_no_allow_override = assert(bp.routes:insert { + service = empty_service, + protocols = { "http" }, + strip_path = true, + paths = { "/openai/llm/v1/chat/good-no-allow-override" } + }) + bp.plugins:insert { + name = PLUGIN_NAME, + route = { id = chat_good_no_allow_override.id }, + config = { + route_type = "llm/v1/chat", + logging = { + log_payloads = false, + log_statistics = true, + }, + auth = { + header_name = "Authorization", + header_value = "Bearer openai-key", + allow_auth_override = false, + }, + model = { + name = "gpt-3.5-turbo", + provider = "openai", + options = { + max_tokens = 256, + temperature = 1.0, + upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/llm/v1/chat/good", + input_cost = 10.0, + output_cost = 10.0, + }, + }, + }, + } -- -- 200 chat good with statistics disabled @@ -436,6 +471,37 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, }, } + + -- 200 completions good using query param key with no allow override + local completions_good_one_query_param_no_allow_override = assert(bp.routes:insert { + service = empty_service, + protocols = { "http" }, + strip_path = true, + paths = { "/openai/llm/v1/completions/query-param-auth-no-allow-override" } + }) + bp.plugins:insert { + name = PLUGIN_NAME, + route = { id = completions_good_one_query_param_no_allow_override.id }, + config = { + route_type = "llm/v1/completions", + auth = { + param_name = "apikey", + param_value = "openai-key", + param_location = "query", + allow_auth_override = false, + }, + model = { + name = "gpt-3.5-turbo-instruct", + provider = "openai", + options = { + max_tokens = 256, + temperature = 1.0, + upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/llm/v1/completions/good" + }, + }, + }, + } + -- -- 200 embeddings (preserve route mode) good @@ -534,6 +600,37 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, }, } + + -- 200 completions good using post body key + local completions_good_post_body_key_no_allow_override = assert(bp.routes:insert { + service = empty_service, + protocols = { "http" }, + strip_path = true, + paths = { "/openai/llm/v1/completions/post-body-auth-no-allow-override" } + }) + bp.plugins:insert { + name = PLUGIN_NAME, + route = { id = completions_good_post_body_key_no_allow_override.id }, + config = { + route_type = "llm/v1/completions", + auth = { + param_name = "apikey", + param_value = "openai-key", + param_location = "body", + allow_auth_override = false, + }, + model = { + name = "gpt-3.5-turbo-instruct", + provider = "openai", + options = { + max_tokens = 256, + temperature = 1.0, + upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/llm/v1/completions/good" + }, + }, + }, + } + -- -- 401 unauthorized @@ -838,6 +935,64 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then assert.is_truthy(json.error) assert.equals(json.error.code, "invalid_api_key") end) + + it("unauthorized request with client header auth", function() + local r = client:get("/openai/llm/v1/chat/good", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + ["Authorization"] = "Bearer wrong", + }, + body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good.json"), + }) + + local body = assert.res_status(401 , r) + local json = cjson.decode(body) + + -- check this is in the 'kong' response format + assert.is_truthy(json.error) + assert.equals(json.error.code, "invalid_api_key") + end) + + it("authorized request with client header auth", function() + local r = client:get("/openai/llm/v1/chat/good", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + ["Authorization"] = "Bearer openai-key", + }, + body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good.json"), + }) + + assert.res_status(200 , r) + end) + + it("authorized request with client right header auth with no allow_auth_override", function() + local r = client:get("/openai/llm/v1/chat/good-no-allow-override", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + ["Authorization"] = "Bearer openai-key", + }, + body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good.json"), + }) + + assert.res_status(200 , r) + end) + + it("authorized request with wrong client header auth with no allow_auth_override", function() + local r = client:get("/openai/llm/v1/chat/good-no-allow-override", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + ["Authorization"] = "Bearer wrong", + }, + body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good.json"), + }) + + assert.res_status(200 , r) + end) + end) describe("openai llm/v1/chat", function() @@ -1013,6 +1168,89 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then assert.same("\n\nI am a language model AI created by OpenAI. I can answer questions", json.choices[1].text) end) + it("works with query param auth with client right auth parm", function() + local r = client:get("/openai/llm/v1/completions/query-param-auth?apikey=openai-key", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + }, + body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-completions/requests/good.json"), + }) + + -- validate that the request succeeded, response status 200 + local body = assert.res_status(200 , r) + local json = cjson.decode(body) + + -- check this is in the 'kong' response format + assert.equals("cmpl-8TBeaJVQIhE9kHEJbk1RnKzgFxIqN", json.id) + assert.equals("gpt-3.5-turbo-instruct", json.model) + assert.equals("text_completion", json.object) + assert.is_table(json.choices) + assert.is_table(json.choices[1]) + assert.same("\n\nI am a language model AI created by OpenAI. I can answer questions", json.choices[1].text) + end) + + it("works with query param auth with client wrong auth parm", function() + local r = client:get("/openai/llm/v1/completions/query-param-auth?apikey=wrong", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + }, + body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-completions/requests/good.json"), + }) + + local body = assert.res_status(401 , r) + local json = cjson.decode(body) + + -- check this is in the 'kong' response format + assert.is_truthy(json.error) + assert.equals(json.error.code, "invalid_api_key") + end) + + it("works with query param auth with client right auth parm with no allow-override", function() + local r = client:get("/openai/llm/v1/completions/query-param-auth-no-allow-override?apikey=openai-key", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + }, + body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-completions/requests/good.json"), + }) + + -- validate that the request succeeded, response status 200 + local body = assert.res_status(200 , r) + local json = cjson.decode(body) + + -- check this is in the 'kong' response format + assert.equals("cmpl-8TBeaJVQIhE9kHEJbk1RnKzgFxIqN", json.id) + assert.equals("gpt-3.5-turbo-instruct", json.model) + assert.equals("text_completion", json.object) + assert.is_table(json.choices) + assert.is_table(json.choices[1]) + assert.same("\n\nI am a language model AI created by OpenAI. I can answer questions", json.choices[1].text) + end) + + it("works with query param auth with client wrong auth parm with no allow-override", function() + local r = client:get("/openai/llm/v1/completions/query-param-auth-no-allow-override?apikey=wrong", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + }, + body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-completions/requests/good.json"), + }) + + -- validate that the request succeeded, response status 200 + local body = assert.res_status(200 , r) + local json = cjson.decode(body) + + -- check this is in the 'kong' response format + assert.equals("cmpl-8TBeaJVQIhE9kHEJbk1RnKzgFxIqN", json.id) + assert.equals("gpt-3.5-turbo-instruct", json.model) + assert.equals("text_completion", json.object) + assert.is_table(json.choices) + assert.is_table(json.choices[1]) + assert.same("\n\nI am a language model AI created by OpenAI. I can answer questions", json.choices[1].text) + end) + it("works with post body auth", function() local r = client:get("/openai/llm/v1/completions/post-body-auth", { headers = { @@ -1034,6 +1272,123 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then assert.is_table(json.choices[1]) assert.same("\n\nI am a language model AI created by OpenAI. I can answer questions", json.choices[1].text) end) + + it("works with post body auth", function() + local r = client:get("/openai/llm/v1/completions/post-body-auth", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + }, + body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-completions/requests/good.json"), + }) + + -- validate that the request succeeded, response status 200 + local body = assert.res_status(200 , r) + local json = cjson.decode(body) + + -- check this is in the 'kong' response format + assert.equals("cmpl-8TBeaJVQIhE9kHEJbk1RnKzgFxIqN", json.id) + assert.equals("gpt-3.5-turbo-instruct", json.model) + assert.equals("text_completion", json.object) + assert.is_table(json.choices) + assert.is_table(json.choices[1]) + assert.same("\n\nI am a language model AI created by OpenAI. I can answer questions", json.choices[1].text) + end) + + it("works with post body auth with client right auth body", function() + local good_body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-completions/requests/good.json") + local body = cjson.decode(good_body) + body.apikey = "openai-key" + local r = client:get("/openai/llm/v1/completions/post-body-auth", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + }, + body = cjson.encode(body), + }) + + -- validate that the request succeeded, response status 200 + local body = assert.res_status(200 , r) + local json = cjson.decode(body) + + -- check this is in the 'kong' response format + assert.equals("cmpl-8TBeaJVQIhE9kHEJbk1RnKzgFxIqN", json.id) + assert.equals("gpt-3.5-turbo-instruct", json.model) + assert.equals("text_completion", json.object) + assert.is_table(json.choices) + assert.is_table(json.choices[1]) + assert.same("\n\nI am a language model AI created by OpenAI. I can answer questions", json.choices[1].text) + end) + + it("works with post body auth with client wrong auth body", function() + local good_body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-completions/requests/good.json") + local body = cjson.decode(good_body) + body.apikey = "wrong" + local r = client:get("/openai/llm/v1/completions/post-body-auth", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + }, + body = cjson.encode(body), + }) + + local body = assert.res_status(401 , r) + local json = cjson.decode(body) + + -- check this is in the 'kong' response format + assert.is_truthy(json.error) + assert.equals(json.error.code, "invalid_api_key") + end) + + it("works with post body auth with client right auth body and no allow_auth_override", function() + local good_body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-completions/requests/good.json") + local body = cjson.decode(good_body) + body.apikey = "openai-key" + local r = client:get("/openai/llm/v1/completions/post-body-auth-no-allow-override", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + }, + body = cjson.encode(body), + }) + + -- validate that the request succeeded, response status 200 + local body = assert.res_status(200 , r) + local json = cjson.decode(body) + + -- check this is in the 'kong' response format + assert.equals("cmpl-8TBeaJVQIhE9kHEJbk1RnKzgFxIqN", json.id) + assert.equals("gpt-3.5-turbo-instruct", json.model) + assert.equals("text_completion", json.object) + assert.is_table(json.choices) + assert.is_table(json.choices[1]) + assert.same("\n\nI am a language model AI created by OpenAI. I can answer questions", json.choices[1].text) + end) + + it("works with post body auth with client wrong auth body and no allow_auth_override", function() + local good_body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-completions/requests/good.json") + local body = cjson.decode(good_body) + body.apikey = "wrong" + local r = client:get("/openai/llm/v1/completions/post-body-auth-no-allow-override", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + }, + body = cjson.encode(body), + }) + + -- validate that the request succeeded, response status 200 + local body = assert.res_status(200 , r) + local json = cjson.decode(body) + + -- check this is in the 'kong' response format + assert.equals("cmpl-8TBeaJVQIhE9kHEJbk1RnKzgFxIqN", json.id) + assert.equals("gpt-3.5-turbo-instruct", json.model) + assert.equals("text_completion", json.object) + assert.is_table(json.choices) + assert.is_table(json.choices[1]) + assert.same("\n\nI am a language model AI created by OpenAI. I can answer questions", json.choices[1].text) + end) end) describe("one-shot request", function()