Skip to content

Commit

Permalink
feat: fix code
Browse files Browse the repository at this point in the history
  • Loading branch information
oowl committed Aug 14, 2024
1 parent 1ad5b8a commit f81c6f5
Show file tree
Hide file tree
Showing 2 changed files with 229 additions and 0 deletions.
135 changes: 135 additions & 0 deletions spec/03-plugins/38-ai-proxy/06-mistral_integration_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,35 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
},
},
}

local chat_good_no_allow_override = assert(bp.routes:insert {
service = empty_service,
protocols = { "http" },
strip_path = true,
paths = { "/mistral/llm/v1/chat/good-no-allow-override" }
})
bp.plugins:insert {
name = PLUGIN_NAME,
route = { id = chat_good_no_allow_override.id },
config = {
route_type = "llm/v1/chat",
auth = {
header_name = "Authorization",
header_value = "Bearer mistral-key",
allow_auth_override = false,
},
model = {
name = "mistralai/Mistral-7B-Instruct-v0.1-instruct",
provider = "mistral",
options = {
max_tokens = 256,
temperature = 1.0,
mistral_format = "openai",
upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/v1/chat/completions",
},
},
},
}
--

-- 200 chat bad upstream response with one option
Expand Down Expand Up @@ -347,6 +376,112 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
role = "assistant",
}, json.choices[1].message)
end)

it("good request with client right auth", function()
local r = client:get("/mistral/llm/v1/chat/good", {
headers = {
["content-type"] = "application/json",
["accept"] = "application/json",
["Authorization"] = "Bearer mistral-key",

},
body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good.json"),
})

-- validate that the request succeeded, response status 200
local body = assert.res_status(200 , r)
local json = cjson.decode(body)

-- check this is in the 'kong' response format
assert.equals(json.id, "chatcmpl-8T6YwgvjQVVnGbJ2w8hpOA17SeNy2")
assert.equals(json.model, "mistralai/Mistral-7B-Instruct-v0.1-instruct")
assert.equals(json.object, "chat.completion")
assert.equals(r.headers["X-Kong-LLM-Model"], "mistral/mistralai/Mistral-7B-Instruct-v0.1-instruct")

assert.is_table(json.choices)
assert.is_table(json.choices[1].message)
assert.same({
content = "The sum of 1 + 1 is 2.",
role = "assistant",
}, json.choices[1].message)
end)

it("good request with client wrong auth", function()
local r = client:get("/mistral/llm/v1/chat/good", {
headers = {
["content-type"] = "application/json",
["accept"] = "application/json",
["Authorization"] = "Bearer wrong",
},
body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good.json"),
})

-- validate that the request succeeded, response status 200

local body = assert.res_status(401 , r)
local json = cjson.decode(body)

-- check this is in the 'kong' response format
assert.is_truthy(json.error)
assert.equals(json.error.code, "invalid_api_key")
end)

it("good request with client right auth and no allow_auth_override", function()
local r = client:get("/mistral/llm/v1/chat/good-no-allow-override", {
headers = {
["content-type"] = "application/json",
["accept"] = "application/json",
["Authorization"] = "Bearer mistral-key",

},
body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good.json"),
})

-- validate that the request succeeded, response status 200
local body = assert.res_status(200 , r)
local json = cjson.decode(body)

-- check this is in the 'kong' response format
assert.equals(json.id, "chatcmpl-8T6YwgvjQVVnGbJ2w8hpOA17SeNy2")
assert.equals(json.model, "mistralai/Mistral-7B-Instruct-v0.1-instruct")
assert.equals(json.object, "chat.completion")
assert.equals(r.headers["X-Kong-LLM-Model"], "mistral/mistralai/Mistral-7B-Instruct-v0.1-instruct")

assert.is_table(json.choices)
assert.is_table(json.choices[1].message)
assert.same({
content = "The sum of 1 + 1 is 2.",
role = "assistant",
}, json.choices[1].message)
end)

it("good request with client wrong auth and no allow_auth_override", function()
local r = client:get("/mistral/llm/v1/chat/good-no-allow-override", {
headers = {
["content-type"] = "application/json",
["accept"] = "application/json",
["Authorization"] = "Bearer wrong",
},
body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good.json"),
})

-- validate that the request succeeded, response status 200
local body = assert.res_status(200 , r)
local json = cjson.decode(body)

-- check this is in the 'kong' response format
assert.equals(json.id, "chatcmpl-8T6YwgvjQVVnGbJ2w8hpOA17SeNy2")
assert.equals(json.model, "mistralai/Mistral-7B-Instruct-v0.1-instruct")
assert.equals(json.object, "chat.completion")
assert.equals(r.headers["X-Kong-LLM-Model"], "mistral/mistralai/Mistral-7B-Instruct-v0.1-instruct")

assert.is_table(json.choices)
assert.is_table(json.choices[1].message)
assert.same({
content = "The sum of 1 + 1 is 2.",
role = "assistant",
}, json.choices[1].message)
end)
end)

describe("mistral llm/v1/completions", function()
Expand Down
94 changes: 94 additions & 0 deletions spec/03-plugins/38-ai-proxy/07-llama2_integration_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,35 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
},
},
}

local chat_good_no_allow_override = assert(bp.routes:insert {
service = empty_service,
protocols = { "http" },
strip_path = true,
paths = { "/raw/llm/v1/completions-no-allow-override" }
})
bp.plugins:insert {
name = PLUGIN_NAME,
route = { id = chat_good_no_allow_override.id },
config = {
route_type = "llm/v1/completions",
auth = {
header_name = "Authorization",
header_value = "Bearer llama2-key",
allow_auth_override = false,
},
model = {
name = "llama-2-7b-chat-hf",
provider = "llama2",
options = {
max_tokens = 256,
temperature = 1.0,
llama2_format = "raw",
upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/raw/llm/v1/completions",
},
},
},
}
--

-- start kong
Expand Down Expand Up @@ -198,6 +227,71 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then

assert.equals(json.choices[1].text, "Is a well known font.")
end)

it("runs good request in completions format with client right auth", function()
local r = client:get("/raw/llm/v1/completions", {
headers = {
["content-type"] = "application/json",
["accept"] = "application/json",
["Authorization"] = "Bearer llama2-key"
},
body = pl_file.read("spec/fixtures/ai-proxy/llama2/raw/requests/good-completions.json"),
})

local body = assert.res_status(200, r)
local json = cjson.decode(body)

assert.equals(json.choices[1].text, "Is a well known font.")
end)

it("runs good request in completions format with client wrong auth", function()
local r = client:get("/raw/llm/v1/completions", {
headers = {
["content-type"] = "application/json",
["accept"] = "application/json",
["Authorization"] = "Bearer wrong"
},
body = pl_file.read("spec/fixtures/ai-proxy/llama2/raw/requests/good-completions.json"),
})

local body = assert.res_status(401, r)
local json = cjson.decode(body)

assert.equals(json.error, "Model requires a Pro subscription.")
end)

it("runs good request in completions format with client right auth and no allow_auth_override", function()
local r = client:get("/raw/llm/v1/completions-no-allow-override", {
headers = {
["content-type"] = "application/json",
["accept"] = "application/json",
["Authorization"] = "Bearer llama2-key"
},
body = pl_file.read("spec/fixtures/ai-proxy/llama2/raw/requests/good-completions.json"),
})

local body = assert.res_status(200, r)
local json = cjson.decode(body)

assert.equals(json.choices[1].text, "Is a well known font.")
end)

it("runs good request in completions format with client wrong auth and no allow_auth_override", function()
local r = client:get("/raw/llm/v1/completions-no-allow-override", {
headers = {
["content-type"] = "application/json",
["accept"] = "application/json",
["Authorization"] = "Bearer wrong"
},
body = pl_file.read("spec/fixtures/ai-proxy/llama2/raw/requests/good-completions.json"),
})

local body = assert.res_status(200, r)
local json = cjson.decode(body)

assert.equals(json.choices[1].text, "Is a well known font.")
end)

end)

describe("one-shot request", function()
Expand Down

0 comments on commit f81c6f5

Please sign in to comment.