Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[backport -> release/3.6.x] feat(plugins): ai-prompt-guard-plugin #12337 #12427

Merged
merged 1 commit into from
Jan 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/labeler.yml
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@ plugins/ai-response-transformer:
- changed-files:
- any-glob-to-any-file: ['kong/plugins/ai-response-transformer/**/*', 'kong/llm/**/*']

plugins/ai-prompt-guard:
- changed-files:
- any-glob-to-any-file: kong/plugins/ai-prompt-guard/**/*

plugins/aws-lambda:
- changed-files:
- any-glob-to-any-file: kong/plugins/aws-lambda/**/*
Expand Down
3 changes: 3 additions & 0 deletions changelog/unreleased/kong/add-ai-prompt-guard-plugin.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
message: Introduced the new **AI Prompt Guard** which can allow and/or block LLM requests based on pattern matching.
type: feature
scope: Plugin
3 changes: 3 additions & 0 deletions kong-3.6.0-0.rockspec
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,9 @@ build = {
["kong.plugins.ai-prompt-template.schema"] = "kong/plugins/ai-prompt-template/schema.lua",
["kong.plugins.ai-prompt-template.templater"] = "kong/plugins/ai-prompt-template/templater.lua",

["kong.plugins.ai-prompt-guard.handler"] = "kong/plugins/ai-prompt-guard/handler.lua",
["kong.plugins.ai-prompt-guard.schema"] = "kong/plugins/ai-prompt-guard/schema.lua",

["kong.vaults.env"] = "kong/vaults/env/init.lua",
["kong.vaults.env.schema"] = "kong/vaults/env/schema.lua",

Expand Down
1 change: 1 addition & 0 deletions kong/constants.lua
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ local plugins = {
"ai-proxy",
"ai-prompt-decorator",
"ai-prompt-template",
"ai-prompt-guard",
"ai-request-transformer",
"ai-response-transformer",
}
Expand Down
112 changes: 112 additions & 0 deletions kong/plugins/ai-prompt-guard/handler.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
local _M = {}

-- imports
local kong_meta = require "kong.meta"
local buffer = require("string.buffer")
local ngx_re_find = ngx.re.find
--

_M.PRIORITY = 771
_M.VERSION = kong_meta.version

local function bad_request(msg, reveal_msg_to_client)
-- don't let users know 'ai-prompt-guard' is in use
kong.log.info(msg)
if not reveal_msg_to_client then
msg = "bad request"
end
return kong.response.exit(400, { error = { message = msg } })
end

function _M.execute(request, conf)
local user_prompt

-- concat all 'user' prompts into one string, if conversation history must be checked
if request.messages and not conf.allow_all_conversation_history then
local buf = buffer.new()

for _, v in ipairs(request.messages) do
if v.role == "user" then
buf:put(v.content)
end
end

user_prompt = buf:get()

elseif request.messages then
-- just take the trailing 'user' prompt
for _, v in ipairs(request.messages) do
if v.role == "user" then
user_prompt = v.content
end
end

elseif request.prompt then
user_prompt = request.prompt

else
return nil, "ai-prompt-guard only supports llm/v1/chat or llm/v1/completions prompts"
end

if not user_prompt then
return nil, "no 'prompt' or 'messages' received"
end

-- check the prompt for explcit ban patterns
if conf.deny_patterns and #conf.deny_patterns > 0 then
for _, v in ipairs(conf.deny_patterns) do
-- check each denylist; if prompt matches it, deny immediately
local m, _, err = ngx_re_find(user_prompt, v, "jo")
if err then
return nil, "bad regex execution for: " .. v

elseif m then
return nil, "prompt pattern is blocked"
end
end
end

-- if any allow_patterns specified, make sure the prompt matches one of them
if conf.allow_patterns and #conf.allow_patterns > 0 then
local valid = false

for _, v in ipairs(conf.allow_patterns) do
-- check each denylist; if prompt matches it, deny immediately
local m, _, err = ngx_re_find(user_prompt, v, "jo")

if err then
return nil, "bad regex execution for: " .. v

elseif m then
valid = true
break
end
end

if not valid then
return false, "prompt doesn't match any allowed pattern"
end
end

return true, nil
end

function _M:access(conf)
kong.service.request.enable_buffering()
kong.ctx.shared.ai_prompt_guarded = true -- future use

-- if plugin ordering was altered, receive the "decorated" request
local request, err = kong.request.get_body("application/json")

if err then
return bad_request("this LLM route only supports application/json requests", true)
end

-- run access handler
local ok, err = self.execute(request, conf)
if not ok then
return bad_request(err, false)
end
end

return _M
44 changes: 44 additions & 0 deletions kong/plugins/ai-prompt-guard/schema.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
local typedefs = require "kong.db.schema.typedefs"

return {
name = "ai-prompt-guard",
fields = {
{ protocols = typedefs.protocols_http },
{ config = {
type = "record",
fields = {
{ allow_patterns = {
description = "Array of valid patterns, or valid questions from the 'user' role in chat.",
type = "array",
default = {},
len_max = 10,
elements = {
type = "string",
len_min = 1,
len_max = 50,
}}},
{ deny_patterns = {
description = "Array of invalid patterns, or invalid questions from the 'user' role in chat.",
type = "array",
default = {},
len_max = 10,
elements = {
type = "string",
len_min = 1,
len_max = 50,
}}},
{ allow_all_conversation_history = {
description = "If true, will ignore all previous chat prompts from the conversation history.",
type = "boolean",
required = true,
default = false } },
}
}
}
},
entity_checks = {
{
at_least_one_of = { "config.allow_patterns", "config.deny_patterns" },
}
}
}
1 change: 1 addition & 0 deletions spec/01-unit/12-plugins_order_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ describe("Plugins", function()
"ai-request-transformer",
"ai-prompt-template",
"ai-prompt-decorator",
"ai-prompt-guard",
"ai-proxy",
"ai-response-transformer",
"aws-lambda",
Expand Down
80 changes: 80 additions & 0 deletions spec/03-plugins/42-ai-prompt-guard/00_config_spec.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
local PLUGIN_NAME = "ai-prompt-guard"


-- helper function to validate data against a schema
local validate do
local validate_entity = require("spec.helpers").validate_plugin_config_schema
local plugin_schema = require("kong.plugins." .. PLUGIN_NAME .. ".schema")

function validate(data)
return validate_entity(data, plugin_schema)
end
end

describe(PLUGIN_NAME .. ": (schema)", function()
it("won't allow both allow_patterns and deny_patterns to be unset", function()
local config = {
allow_all_conversation_history = true,
}

local ok, err = validate(config)

assert.is_falsy(ok)
assert.not_nil(err)
assert.equal("at least one of these fields must be non-empty: 'config.allow_patterns', 'config.deny_patterns'", err["@entity"][1])
end)

it("won't allow both allow_patterns and deny_patterns to be empty arrays", function()
local config = {
allow_all_conversation_history = true,
allow_patterns = {},
deny_patterns = {},
}

local ok, err = validate(config)

assert.is_falsy(ok)
assert.not_nil(err)
assert.equal("at least one of these fields must be non-empty: 'config.allow_patterns', 'config.deny_patterns'", err["@entity"][1])
end)

it("won't allow patterns that are too long", function()
local config = {
allow_all_conversation_history = true,
allow_patterns = {
[1] = "123456789012345678901234567890123456789012345678901" -- 51
},
}

local ok, err = validate(config)

assert.is_falsy(ok)
assert.not_nil(err)
assert.same({ config = {allow_patterns = { [1] = "length must be at most 50" }}}, err)
end)

it("won't allow too many array items", function()
local config = {
allow_all_conversation_history = true,
allow_patterns = {
[1] = "pattern",
[2] = "pattern",
[3] = "pattern",
[4] = "pattern",
[5] = "pattern",
[6] = "pattern",
[7] = "pattern",
[8] = "pattern",
[9] = "pattern",
[10] = "pattern",
[11] = "pattern",
},
}

local ok, err = validate(config)

assert.is_falsy(ok)
assert.not_nil(err)
assert.same({ config = {allow_patterns = "length must be at most 10" }}, err)
end)
end)
Loading
Loading