Skip to content

Commit

Permalink
feat(plugins): ai-prompt-guard-plugin (#12337)
Browse files Browse the repository at this point in the history
* feat(plugins): ai-prompt-guard-plugin

* fix(ai-prompt-guard): fixes from code review

* Update kong/plugins/ai-prompt-guard/schema.lua

Co-authored-by: Vinicius Mignot <[email protected]>

---------

Co-authored-by: Jack Tysoe <[email protected]>
Co-authored-by: Vinicius Mignot <[email protected]>
  • Loading branch information
3 people authored Jan 25, 2024
1 parent 3ef9235 commit 93a1887
Show file tree
Hide file tree
Showing 10 changed files with 897 additions and 0 deletions.
4 changes: 4 additions & 0 deletions .github/labeler.yml
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@ plugins/ai-response-transformer:
- changed-files:
- any-glob-to-any-file: ['kong/plugins/ai-response-transformer/**/*', 'kong/llm/**/*']

plugins/ai-prompt-guard:
- changed-files:
- any-glob-to-any-file: kong/plugins/ai-prompt-guard/**/*

plugins/aws-lambda:
- changed-files:
- any-glob-to-any-file: kong/plugins/aws-lambda/**/*
Expand Down
3 changes: 3 additions & 0 deletions changelog/unreleased/kong/add-ai-prompt-guard-plugin.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
message: Introduced the new **AI Prompt Guard** which can allow and/or block LLM requests based on pattern matching.
type: feature
scope: Plugin
3 changes: 3 additions & 0 deletions kong-3.6.0-0.rockspec
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,9 @@ build = {
["kong.plugins.ai-prompt-template.schema"] = "kong/plugins/ai-prompt-template/schema.lua",
["kong.plugins.ai-prompt-template.templater"] = "kong/plugins/ai-prompt-template/templater.lua",

["kong.plugins.ai-prompt-guard.handler"] = "kong/plugins/ai-prompt-guard/handler.lua",
["kong.plugins.ai-prompt-guard.schema"] = "kong/plugins/ai-prompt-guard/schema.lua",

["kong.vaults.env"] = "kong/vaults/env/init.lua",
["kong.vaults.env.schema"] = "kong/vaults/env/schema.lua",

Expand Down
1 change: 1 addition & 0 deletions kong/constants.lua
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ local plugins = {
"ai-proxy",
"ai-prompt-decorator",
"ai-prompt-template",
"ai-prompt-guard",
"ai-request-transformer",
"ai-response-transformer",
}
Expand Down
112 changes: 112 additions & 0 deletions kong/plugins/ai-prompt-guard/handler.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
local _M = {}

-- imports
local kong_meta = require "kong.meta"
local buffer = require("string.buffer")
local ngx_re_find = ngx.re.find
--

_M.PRIORITY = 771
_M.VERSION = kong_meta.version

local function bad_request(msg, reveal_msg_to_client)
-- don't let users know 'ai-prompt-guard' is in use
kong.log.info(msg)
if not reveal_msg_to_client then
msg = "bad request"
end
return kong.response.exit(400, { error = { message = msg } })
end

function _M.execute(request, conf)
local user_prompt

-- concat all 'user' prompts into one string, if conversation history must be checked
if request.messages and not conf.allow_all_conversation_history then
local buf = buffer.new()

for _, v in ipairs(request.messages) do
if v.role == "user" then
buf:put(v.content)
end
end

user_prompt = buf:get()

elseif request.messages then
-- just take the trailing 'user' prompt
for _, v in ipairs(request.messages) do
if v.role == "user" then
user_prompt = v.content
end
end

elseif request.prompt then
user_prompt = request.prompt

else
return nil, "ai-prompt-guard only supports llm/v1/chat or llm/v1/completions prompts"
end

if not user_prompt then
return nil, "no 'prompt' or 'messages' received"
end

-- check the prompt for explcit ban patterns
if conf.deny_patterns and #conf.deny_patterns > 0 then
for _, v in ipairs(conf.deny_patterns) do
-- check each denylist; if prompt matches it, deny immediately
local m, _, err = ngx_re_find(user_prompt, v, "jo")
if err then
return nil, "bad regex execution for: " .. v

elseif m then
return nil, "prompt pattern is blocked"
end
end
end

-- if any allow_patterns specified, make sure the prompt matches one of them
if conf.allow_patterns and #conf.allow_patterns > 0 then
local valid = false

for _, v in ipairs(conf.allow_patterns) do
-- check each denylist; if prompt matches it, deny immediately
local m, _, err = ngx_re_find(user_prompt, v, "jo")

if err then
return nil, "bad regex execution for: " .. v

elseif m then
valid = true
break
end
end

if not valid then
return false, "prompt doesn't match any allowed pattern"
end
end

return true, nil
end

function _M:access(conf)
kong.service.request.enable_buffering()
kong.ctx.shared.ai_prompt_guarded = true -- future use

-- if plugin ordering was altered, receive the "decorated" request
local request, err = kong.request.get_body("application/json")

if err then
return bad_request("this LLM route only supports application/json requests", true)
end

-- run access handler
local ok, err = self.execute(request, conf)
if not ok then
return bad_request(err, false)
end
end

return _M
44 changes: 44 additions & 0 deletions kong/plugins/ai-prompt-guard/schema.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
local typedefs = require "kong.db.schema.typedefs"

return {
name = "ai-prompt-guard",
fields = {
{ protocols = typedefs.protocols_http },
{ config = {
type = "record",
fields = {
{ allow_patterns = {
description = "Array of valid patterns, or valid questions from the 'user' role in chat.",
type = "array",
default = {},
len_max = 10,
elements = {
type = "string",
len_min = 1,
len_max = 50,
}}},
{ deny_patterns = {
description = "Array of invalid patterns, or invalid questions from the 'user' role in chat.",
type = "array",
default = {},
len_max = 10,
elements = {
type = "string",
len_min = 1,
len_max = 50,
}}},
{ allow_all_conversation_history = {
description = "If true, will ignore all previous chat prompts from the conversation history.",
type = "boolean",
required = true,
default = false } },
}
}
}
},
entity_checks = {
{
at_least_one_of = { "config.allow_patterns", "config.deny_patterns" },
}
}
}
1 change: 1 addition & 0 deletions spec/01-unit/12-plugins_order_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ describe("Plugins", function()
"ai-request-transformer",
"ai-prompt-template",
"ai-prompt-decorator",
"ai-prompt-guard",
"ai-proxy",
"ai-response-transformer",
"aws-lambda",
Expand Down
80 changes: 80 additions & 0 deletions spec/03-plugins/42-ai-prompt-guard/00_config_spec.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
local PLUGIN_NAME = "ai-prompt-guard"


-- helper function to validate data against a schema
local validate do
local validate_entity = require("spec.helpers").validate_plugin_config_schema
local plugin_schema = require("kong.plugins." .. PLUGIN_NAME .. ".schema")

function validate(data)
return validate_entity(data, plugin_schema)
end
end

describe(PLUGIN_NAME .. ": (schema)", function()
it("won't allow both allow_patterns and deny_patterns to be unset", function()
local config = {
allow_all_conversation_history = true,
}

local ok, err = validate(config)

assert.is_falsy(ok)
assert.not_nil(err)
assert.equal("at least one of these fields must be non-empty: 'config.allow_patterns', 'config.deny_patterns'", err["@entity"][1])
end)

it("won't allow both allow_patterns and deny_patterns to be empty arrays", function()
local config = {
allow_all_conversation_history = true,
allow_patterns = {},
deny_patterns = {},
}

local ok, err = validate(config)

assert.is_falsy(ok)
assert.not_nil(err)
assert.equal("at least one of these fields must be non-empty: 'config.allow_patterns', 'config.deny_patterns'", err["@entity"][1])
end)

it("won't allow patterns that are too long", function()
local config = {
allow_all_conversation_history = true,
allow_patterns = {
[1] = "123456789012345678901234567890123456789012345678901" -- 51
},
}

local ok, err = validate(config)

assert.is_falsy(ok)
assert.not_nil(err)
assert.same({ config = {allow_patterns = { [1] = "length must be at most 50" }}}, err)
end)

it("won't allow too many array items", function()
local config = {
allow_all_conversation_history = true,
allow_patterns = {
[1] = "pattern",
[2] = "pattern",
[3] = "pattern",
[4] = "pattern",
[5] = "pattern",
[6] = "pattern",
[7] = "pattern",
[8] = "pattern",
[9] = "pattern",
[10] = "pattern",
[11] = "pattern",
},
}

local ok, err = validate(config)

assert.is_falsy(ok)
assert.not_nil(err)
assert.same({ config = {allow_patterns = "length must be at most 10" }}, err)
end)
end)
Loading

0 comments on commit 93a1887

Please sign in to comment.