From aa5797941e6078300f9b2b408b350b911a3bf1ca Mon Sep 17 00:00:00 2001 From: owl Date: Wed, 14 Aug 2024 16:48:02 +0800 Subject: [PATCH] feat: fix code --- .../06-mistral_integration_spec.lua | 135 ++++++++++++++++++ .../07-llama2_integration_spec.lua | 94 ++++++++++++ 2 files changed, 229 insertions(+) diff --git a/spec/03-plugins/38-ai-proxy/06-mistral_integration_spec.lua b/spec/03-plugins/38-ai-proxy/06-mistral_integration_spec.lua index 26bc21acf999..7134fd21a54a 100644 --- a/spec/03-plugins/38-ai-proxy/06-mistral_integration_spec.lua +++ b/spec/03-plugins/38-ai-proxy/06-mistral_integration_spec.lua @@ -111,6 +111,35 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, }, } + + local chat_good_no_allow_override = assert(bp.routes:insert { + service = empty_service, + protocols = { "http" }, + strip_path = true, + paths = { "/mistral/llm/v1/chat/good-no-allow-override" } + }) + bp.plugins:insert { + name = PLUGIN_NAME, + route = { id = chat_good_no_allow_override.id }, + config = { + route_type = "llm/v1/chat", + auth = { + header_name = "Authorization", + header_value = "Bearer mistral-key", + allow_auth_override = false, + }, + model = { + name = "mistralai/Mistral-7B-Instruct-v0.1-instruct", + provider = "mistral", + options = { + max_tokens = 256, + temperature = 1.0, + mistral_format = "openai", + upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/v1/chat/completions", + }, + }, + }, + } -- -- 200 chat bad upstream response with one option @@ -347,6 +376,112 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then role = "assistant", }, json.choices[1].message) end) + + it("good request with client right auth", function() + local r = client:get("/mistral/llm/v1/chat/good", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + ["Authorization"] = "Bearer mistral-key", + + }, + body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good.json"), + }) + + -- validate that the request succeeded, response status 200 + local body = assert.res_status(200 , r) + local json = cjson.decode(body) + + -- check this is in the 'kong' response format + assert.equals(json.id, "chatcmpl-8T6YwgvjQVVnGbJ2w8hpOA17SeNy2") + assert.equals(json.model, "mistralai/Mistral-7B-Instruct-v0.1-instruct") + assert.equals(json.object, "chat.completion") + assert.equals(r.headers["X-Kong-LLM-Model"], "mistral/mistralai/Mistral-7B-Instruct-v0.1-instruct") + + assert.is_table(json.choices) + assert.is_table(json.choices[1].message) + assert.same({ + content = "The sum of 1 + 1 is 2.", + role = "assistant", + }, json.choices[1].message) + end) + + it("good request with client wrong auth", function() + local r = client:get("/mistral/llm/v1/chat/good", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + ["Authorization"] = "Bearer wrong", + }, + body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good.json"), + }) + + -- validate that the request succeeded, response status 200 + + local body = assert.res_status(401 , r) + local json = cjson.decode(body) + + -- check this is in the 'kong' response format + assert.is_truthy(json.error) + assert.equals(json.error.code, "invalid_api_key") + end) + + it("good request with client right auth and no allow_auth_override", function() + local r = client:get("/mistral/llm/v1/chat/good-no-allow-override", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + ["Authorization"] = "Bearer mistral-key", + + }, + body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good.json"), + }) + + -- validate that the request succeeded, response status 200 + local body = assert.res_status(200 , r) + local json = cjson.decode(body) + + -- check this is in the 'kong' response format + assert.equals(json.id, "chatcmpl-8T6YwgvjQVVnGbJ2w8hpOA17SeNy2") + assert.equals(json.model, "mistralai/Mistral-7B-Instruct-v0.1-instruct") + assert.equals(json.object, "chat.completion") + assert.equals(r.headers["X-Kong-LLM-Model"], "mistral/mistralai/Mistral-7B-Instruct-v0.1-instruct") + + assert.is_table(json.choices) + assert.is_table(json.choices[1].message) + assert.same({ + content = "The sum of 1 + 1 is 2.", + role = "assistant", + }, json.choices[1].message) + end) + + it("good request with client wrong auth and no allow_auth_override", function() + local r = client:get("/mistral/llm/v1/chat/good-no-allow-override", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + ["Authorization"] = "Bearer wrong", + }, + body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good.json"), + }) + + -- validate that the request succeeded, response status 200 + local body = assert.res_status(200 , r) + local json = cjson.decode(body) + + -- check this is in the 'kong' response format + assert.equals(json.id, "chatcmpl-8T6YwgvjQVVnGbJ2w8hpOA17SeNy2") + assert.equals(json.model, "mistralai/Mistral-7B-Instruct-v0.1-instruct") + assert.equals(json.object, "chat.completion") + assert.equals(r.headers["X-Kong-LLM-Model"], "mistral/mistralai/Mistral-7B-Instruct-v0.1-instruct") + + assert.is_table(json.choices) + assert.is_table(json.choices[1].message) + assert.same({ + content = "The sum of 1 + 1 is 2.", + role = "assistant", + }, json.choices[1].message) + end) end) describe("mistral llm/v1/completions", function() diff --git a/spec/03-plugins/38-ai-proxy/07-llama2_integration_spec.lua b/spec/03-plugins/38-ai-proxy/07-llama2_integration_spec.lua index 778804d4af6c..0060ddaf4fb2 100644 --- a/spec/03-plugins/38-ai-proxy/07-llama2_integration_spec.lua +++ b/spec/03-plugins/38-ai-proxy/07-llama2_integration_spec.lua @@ -141,6 +141,35 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, }, } + + local chat_good_no_allow_override = assert(bp.routes:insert { + service = empty_service, + protocols = { "http" }, + strip_path = true, + paths = { "/raw/llm/v1/completions-no-allow-override" } + }) + bp.plugins:insert { + name = PLUGIN_NAME, + route = { id = chat_good_no_allow_override.id }, + config = { + route_type = "llm/v1/completions", + auth = { + header_name = "Authorization", + header_value = "Bearer llama2-key", + allow_auth_override = false, + }, + model = { + name = "llama-2-7b-chat-hf", + provider = "llama2", + options = { + max_tokens = 256, + temperature = 1.0, + llama2_format = "raw", + upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/raw/llm/v1/completions", + }, + }, + }, + } -- -- start kong @@ -198,6 +227,71 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then assert.equals(json.choices[1].text, "Is a well known font.") end) + + it("runs good request in completions format with client right auth", function() + local r = client:get("/raw/llm/v1/completions", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + ["Authorization"] = "Bearer llama2-key" + }, + body = pl_file.read("spec/fixtures/ai-proxy/llama2/raw/requests/good-completions.json"), + }) + + local body = assert.res_status(200, r) + local json = cjson.decode(body) + + assert.equals(json.choices[1].text, "Is a well known font.") + end) + + it("runs good request in completions format with client wrong auth", function() + local r = client:get("/raw/llm/v1/completions", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + ["Authorization"] = "Bearer wrong" + }, + body = pl_file.read("spec/fixtures/ai-proxy/llama2/raw/requests/good-completions.json"), + }) + + local body = assert.res_status(401, r) + local json = cjson.decode(body) + + assert.equals(json.error, "Model requires a Pro subscription.") + end) + + it("runs good request in completions format with client right auth and no allow_auth_override", function() + local r = client:get("/raw/llm/v1/completions-no-allow-override", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + ["Authorization"] = "Bearer llama2-key" + }, + body = pl_file.read("spec/fixtures/ai-proxy/llama2/raw/requests/good-completions.json"), + }) + + local body = assert.res_status(200, r) + local json = cjson.decode(body) + + assert.equals(json.choices[1].text, "Is a well known font.") + end) + + it("runs good request in completions format with client wrong auth and no allow_auth_override", function() + local r = client:get("/raw/llm/v1/completions-no-allow-override", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + ["Authorization"] = "Bearer wrong" + }, + body = pl_file.read("spec/fixtures/ai-proxy/llama2/raw/requests/good-completions.json"), + }) + + local body = assert.res_status(200, r) + local json = cjson.decode(body) + + assert.equals(json.choices[1].text, "Is a well known font.") + end) + end) describe("one-shot request", function()