-
Notifications
You must be signed in to change notification settings - Fork 4.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(plugins): ai-prompt-guard-plugin (#12337)
* feat(plugins): ai-prompt-guard-plugin * fix(ai-prompt-guard): fixes from code review * Update kong/plugins/ai-prompt-guard/schema.lua Co-authored-by: Vinicius Mignot <[email protected]> --------- Co-authored-by: Jack Tysoe <[email protected]> Co-authored-by: Vinicius Mignot <[email protected]>
- Loading branch information
1 parent
3ef9235
commit 93a1887
Showing
10 changed files
with
897 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
message: Introduced the new **AI Prompt Guard** which can allow and/or block LLM requests based on pattern matching. | ||
type: feature | ||
scope: Plugin |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
local _M = {} | ||
|
||
-- imports | ||
local kong_meta = require "kong.meta" | ||
local buffer = require("string.buffer") | ||
local ngx_re_find = ngx.re.find | ||
-- | ||
|
||
_M.PRIORITY = 771 | ||
_M.VERSION = kong_meta.version | ||
|
||
local function bad_request(msg, reveal_msg_to_client) | ||
-- don't let users know 'ai-prompt-guard' is in use | ||
kong.log.info(msg) | ||
if not reveal_msg_to_client then | ||
msg = "bad request" | ||
end | ||
return kong.response.exit(400, { error = { message = msg } }) | ||
end | ||
|
||
function _M.execute(request, conf) | ||
local user_prompt | ||
|
||
-- concat all 'user' prompts into one string, if conversation history must be checked | ||
if request.messages and not conf.allow_all_conversation_history then | ||
local buf = buffer.new() | ||
|
||
for _, v in ipairs(request.messages) do | ||
if v.role == "user" then | ||
buf:put(v.content) | ||
end | ||
end | ||
|
||
user_prompt = buf:get() | ||
|
||
elseif request.messages then | ||
-- just take the trailing 'user' prompt | ||
for _, v in ipairs(request.messages) do | ||
if v.role == "user" then | ||
user_prompt = v.content | ||
end | ||
end | ||
|
||
elseif request.prompt then | ||
user_prompt = request.prompt | ||
|
||
else | ||
return nil, "ai-prompt-guard only supports llm/v1/chat or llm/v1/completions prompts" | ||
end | ||
|
||
if not user_prompt then | ||
return nil, "no 'prompt' or 'messages' received" | ||
end | ||
|
||
-- check the prompt for explcit ban patterns | ||
if conf.deny_patterns and #conf.deny_patterns > 0 then | ||
for _, v in ipairs(conf.deny_patterns) do | ||
-- check each denylist; if prompt matches it, deny immediately | ||
local m, _, err = ngx_re_find(user_prompt, v, "jo") | ||
if err then | ||
return nil, "bad regex execution for: " .. v | ||
|
||
elseif m then | ||
return nil, "prompt pattern is blocked" | ||
end | ||
end | ||
end | ||
|
||
-- if any allow_patterns specified, make sure the prompt matches one of them | ||
if conf.allow_patterns and #conf.allow_patterns > 0 then | ||
local valid = false | ||
|
||
for _, v in ipairs(conf.allow_patterns) do | ||
-- check each denylist; if prompt matches it, deny immediately | ||
local m, _, err = ngx_re_find(user_prompt, v, "jo") | ||
|
||
if err then | ||
return nil, "bad regex execution for: " .. v | ||
|
||
elseif m then | ||
valid = true | ||
break | ||
end | ||
end | ||
|
||
if not valid then | ||
return false, "prompt doesn't match any allowed pattern" | ||
end | ||
end | ||
|
||
return true, nil | ||
end | ||
|
||
function _M:access(conf) | ||
kong.service.request.enable_buffering() | ||
kong.ctx.shared.ai_prompt_guarded = true -- future use | ||
|
||
-- if plugin ordering was altered, receive the "decorated" request | ||
local request, err = kong.request.get_body("application/json") | ||
|
||
if err then | ||
return bad_request("this LLM route only supports application/json requests", true) | ||
end | ||
|
||
-- run access handler | ||
local ok, err = self.execute(request, conf) | ||
if not ok then | ||
return bad_request(err, false) | ||
end | ||
end | ||
|
||
return _M |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
local typedefs = require "kong.db.schema.typedefs" | ||
|
||
return { | ||
name = "ai-prompt-guard", | ||
fields = { | ||
{ protocols = typedefs.protocols_http }, | ||
{ config = { | ||
type = "record", | ||
fields = { | ||
{ allow_patterns = { | ||
description = "Array of valid patterns, or valid questions from the 'user' role in chat.", | ||
type = "array", | ||
default = {}, | ||
len_max = 10, | ||
elements = { | ||
type = "string", | ||
len_min = 1, | ||
len_max = 50, | ||
}}}, | ||
{ deny_patterns = { | ||
description = "Array of invalid patterns, or invalid questions from the 'user' role in chat.", | ||
type = "array", | ||
default = {}, | ||
len_max = 10, | ||
elements = { | ||
type = "string", | ||
len_min = 1, | ||
len_max = 50, | ||
}}}, | ||
{ allow_all_conversation_history = { | ||
description = "If true, will ignore all previous chat prompts from the conversation history.", | ||
type = "boolean", | ||
required = true, | ||
default = false } }, | ||
} | ||
} | ||
} | ||
}, | ||
entity_checks = { | ||
{ | ||
at_least_one_of = { "config.allow_patterns", "config.deny_patterns" }, | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
local PLUGIN_NAME = "ai-prompt-guard" | ||
|
||
|
||
-- helper function to validate data against a schema | ||
local validate do | ||
local validate_entity = require("spec.helpers").validate_plugin_config_schema | ||
local plugin_schema = require("kong.plugins." .. PLUGIN_NAME .. ".schema") | ||
|
||
function validate(data) | ||
return validate_entity(data, plugin_schema) | ||
end | ||
end | ||
|
||
describe(PLUGIN_NAME .. ": (schema)", function() | ||
it("won't allow both allow_patterns and deny_patterns to be unset", function() | ||
local config = { | ||
allow_all_conversation_history = true, | ||
} | ||
|
||
local ok, err = validate(config) | ||
|
||
assert.is_falsy(ok) | ||
assert.not_nil(err) | ||
assert.equal("at least one of these fields must be non-empty: 'config.allow_patterns', 'config.deny_patterns'", err["@entity"][1]) | ||
end) | ||
|
||
it("won't allow both allow_patterns and deny_patterns to be empty arrays", function() | ||
local config = { | ||
allow_all_conversation_history = true, | ||
allow_patterns = {}, | ||
deny_patterns = {}, | ||
} | ||
|
||
local ok, err = validate(config) | ||
|
||
assert.is_falsy(ok) | ||
assert.not_nil(err) | ||
assert.equal("at least one of these fields must be non-empty: 'config.allow_patterns', 'config.deny_patterns'", err["@entity"][1]) | ||
end) | ||
|
||
it("won't allow patterns that are too long", function() | ||
local config = { | ||
allow_all_conversation_history = true, | ||
allow_patterns = { | ||
[1] = "123456789012345678901234567890123456789012345678901" -- 51 | ||
}, | ||
} | ||
|
||
local ok, err = validate(config) | ||
|
||
assert.is_falsy(ok) | ||
assert.not_nil(err) | ||
assert.same({ config = {allow_patterns = { [1] = "length must be at most 50" }}}, err) | ||
end) | ||
|
||
it("won't allow too many array items", function() | ||
local config = { | ||
allow_all_conversation_history = true, | ||
allow_patterns = { | ||
[1] = "pattern", | ||
[2] = "pattern", | ||
[3] = "pattern", | ||
[4] = "pattern", | ||
[5] = "pattern", | ||
[6] = "pattern", | ||
[7] = "pattern", | ||
[8] = "pattern", | ||
[9] = "pattern", | ||
[10] = "pattern", | ||
[11] = "pattern", | ||
}, | ||
} | ||
|
||
local ok, err = validate(config) | ||
|
||
assert.is_falsy(ok) | ||
assert.not_nil(err) | ||
assert.same({ config = {allow_patterns = "length must be at most 10" }}, err) | ||
end) | ||
end) |
Oops, something went wrong.