diff --git a/.github/labeler.yml b/.github/labeler.yml index d40e0799a35..2f6fe24f700 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -110,6 +110,10 @@ plugins/ai-response-transformer: - changed-files: - any-glob-to-any-file: ['kong/plugins/ai-response-transformer/**/*', 'kong/llm/**/*'] +plugins/ai-prompt-guard: +- changed-files: + - any-glob-to-any-file: kong/plugins/ai-prompt-guard/**/* + plugins/aws-lambda: - changed-files: - any-glob-to-any-file: kong/plugins/aws-lambda/**/* diff --git a/changelog/unreleased/kong/add-ai-prompt-guard-plugin.yml b/changelog/unreleased/kong/add-ai-prompt-guard-plugin.yml new file mode 100644 index 00000000000..dd0b8dbfb2e --- /dev/null +++ b/changelog/unreleased/kong/add-ai-prompt-guard-plugin.yml @@ -0,0 +1,3 @@ +message: Introduced the new **AI Prompt Guard** which can allow and/or block LLM requests based on pattern matching. +type: feature +scope: Plugin diff --git a/kong-3.6.0-0.rockspec b/kong-3.6.0-0.rockspec index c06a24019e3..c391df8f93b 100644 --- a/kong-3.6.0-0.rockspec +++ b/kong-3.6.0-0.rockspec @@ -587,6 +587,9 @@ build = { ["kong.plugins.ai-prompt-template.schema"] = "kong/plugins/ai-prompt-template/schema.lua", ["kong.plugins.ai-prompt-template.templater"] = "kong/plugins/ai-prompt-template/templater.lua", + ["kong.plugins.ai-prompt-guard.handler"] = "kong/plugins/ai-prompt-guard/handler.lua", + ["kong.plugins.ai-prompt-guard.schema"] = "kong/plugins/ai-prompt-guard/schema.lua", + ["kong.vaults.env"] = "kong/vaults/env/init.lua", ["kong.vaults.env.schema"] = "kong/vaults/env/schema.lua", diff --git a/kong/constants.lua b/kong/constants.lua index 25163735016..e94e555383e 100644 --- a/kong/constants.lua +++ b/kong/constants.lua @@ -39,6 +39,7 @@ local plugins = { "ai-proxy", "ai-prompt-decorator", "ai-prompt-template", + "ai-prompt-guard", "ai-request-transformer", "ai-response-transformer", } diff --git a/kong/plugins/ai-prompt-guard/handler.lua b/kong/plugins/ai-prompt-guard/handler.lua new file mode 100644 index 00000000000..50c64315f71 --- /dev/null +++ b/kong/plugins/ai-prompt-guard/handler.lua @@ -0,0 +1,112 @@ +local _M = {} + +-- imports +local kong_meta = require "kong.meta" +local buffer = require("string.buffer") +local ngx_re_find = ngx.re.find +-- + +_M.PRIORITY = 771 +_M.VERSION = kong_meta.version + +local function bad_request(msg, reveal_msg_to_client) + -- don't let users know 'ai-prompt-guard' is in use + kong.log.info(msg) + if not reveal_msg_to_client then + msg = "bad request" + end + return kong.response.exit(400, { error = { message = msg } }) +end + +function _M.execute(request, conf) + local user_prompt + + -- concat all 'user' prompts into one string, if conversation history must be checked + if request.messages and not conf.allow_all_conversation_history then + local buf = buffer.new() + + for _, v in ipairs(request.messages) do + if v.role == "user" then + buf:put(v.content) + end + end + + user_prompt = buf:get() + + elseif request.messages then + -- just take the trailing 'user' prompt + for _, v in ipairs(request.messages) do + if v.role == "user" then + user_prompt = v.content + end + end + + elseif request.prompt then + user_prompt = request.prompt + + else + return nil, "ai-prompt-guard only supports llm/v1/chat or llm/v1/completions prompts" + end + + if not user_prompt then + return nil, "no 'prompt' or 'messages' received" + end + + -- check the prompt for explcit ban patterns + if conf.deny_patterns and #conf.deny_patterns > 0 then + for _, v in ipairs(conf.deny_patterns) do + -- check each denylist; if prompt matches it, deny immediately + local m, _, err = ngx_re_find(user_prompt, v, "jo") + if err then + return nil, "bad regex execution for: " .. v + + elseif m then + return nil, "prompt pattern is blocked" + end + end + end + + -- if any allow_patterns specified, make sure the prompt matches one of them + if conf.allow_patterns and #conf.allow_patterns > 0 then + local valid = false + + for _, v in ipairs(conf.allow_patterns) do + -- check each denylist; if prompt matches it, deny immediately + local m, _, err = ngx_re_find(user_prompt, v, "jo") + + if err then + return nil, "bad regex execution for: " .. v + + elseif m then + valid = true + break + end + end + + if not valid then + return false, "prompt doesn't match any allowed pattern" + end + end + + return true, nil +end + +function _M:access(conf) + kong.service.request.enable_buffering() + kong.ctx.shared.ai_prompt_guarded = true -- future use + + -- if plugin ordering was altered, receive the "decorated" request + local request, err = kong.request.get_body("application/json") + + if err then + return bad_request("this LLM route only supports application/json requests", true) + end + + -- run access handler + local ok, err = self.execute(request, conf) + if not ok then + return bad_request(err, false) + end +end + +return _M diff --git a/kong/plugins/ai-prompt-guard/schema.lua b/kong/plugins/ai-prompt-guard/schema.lua new file mode 100644 index 00000000000..da4dd49eebc --- /dev/null +++ b/kong/plugins/ai-prompt-guard/schema.lua @@ -0,0 +1,44 @@ +local typedefs = require "kong.db.schema.typedefs" + +return { + name = "ai-prompt-guard", + fields = { + { protocols = typedefs.protocols_http }, + { config = { + type = "record", + fields = { + { allow_patterns = { + description = "Array of valid patterns, or valid questions from the 'user' role in chat.", + type = "array", + default = {}, + len_max = 10, + elements = { + type = "string", + len_min = 1, + len_max = 50, + }}}, + { deny_patterns = { + description = "Array of invalid patterns, or invalid questions from the 'user' role in chat.", + type = "array", + default = {}, + len_max = 10, + elements = { + type = "string", + len_min = 1, + len_max = 50, + }}}, + { allow_all_conversation_history = { + description = "If true, will ignore all previous chat prompts from the conversation history.", + type = "boolean", + required = true, + default = false } }, + } + } + } + }, + entity_checks = { + { + at_least_one_of = { "config.allow_patterns", "config.deny_patterns" }, + } + } +} diff --git a/spec/01-unit/12-plugins_order_spec.lua b/spec/01-unit/12-plugins_order_spec.lua index 8189d05e992..d897784255e 100644 --- a/spec/01-unit/12-plugins_order_spec.lua +++ b/spec/01-unit/12-plugins_order_spec.lua @@ -75,6 +75,7 @@ describe("Plugins", function() "ai-request-transformer", "ai-prompt-template", "ai-prompt-decorator", + "ai-prompt-guard", "ai-proxy", "ai-response-transformer", "aws-lambda", diff --git a/spec/03-plugins/42-ai-prompt-guard/00_config_spec.lua b/spec/03-plugins/42-ai-prompt-guard/00_config_spec.lua new file mode 100644 index 00000000000..ff8cc21669f --- /dev/null +++ b/spec/03-plugins/42-ai-prompt-guard/00_config_spec.lua @@ -0,0 +1,80 @@ +local PLUGIN_NAME = "ai-prompt-guard" + + +-- helper function to validate data against a schema +local validate do + local validate_entity = require("spec.helpers").validate_plugin_config_schema + local plugin_schema = require("kong.plugins." .. PLUGIN_NAME .. ".schema") + + function validate(data) + return validate_entity(data, plugin_schema) + end +end + +describe(PLUGIN_NAME .. ": (schema)", function() + it("won't allow both allow_patterns and deny_patterns to be unset", function() + local config = { + allow_all_conversation_history = true, + } + + local ok, err = validate(config) + + assert.is_falsy(ok) + assert.not_nil(err) + assert.equal("at least one of these fields must be non-empty: 'config.allow_patterns', 'config.deny_patterns'", err["@entity"][1]) + end) + + it("won't allow both allow_patterns and deny_patterns to be empty arrays", function() + local config = { + allow_all_conversation_history = true, + allow_patterns = {}, + deny_patterns = {}, + } + + local ok, err = validate(config) + + assert.is_falsy(ok) + assert.not_nil(err) + assert.equal("at least one of these fields must be non-empty: 'config.allow_patterns', 'config.deny_patterns'", err["@entity"][1]) + end) + + it("won't allow patterns that are too long", function() + local config = { + allow_all_conversation_history = true, + allow_patterns = { + [1] = "123456789012345678901234567890123456789012345678901" -- 51 + }, + } + + local ok, err = validate(config) + + assert.is_falsy(ok) + assert.not_nil(err) + assert.same({ config = {allow_patterns = { [1] = "length must be at most 50" }}}, err) + end) + + it("won't allow too many array items", function() + local config = { + allow_all_conversation_history = true, + allow_patterns = { + [1] = "pattern", + [2] = "pattern", + [3] = "pattern", + [4] = "pattern", + [5] = "pattern", + [6] = "pattern", + [7] = "pattern", + [8] = "pattern", + [9] = "pattern", + [10] = "pattern", + [11] = "pattern", + }, + } + + local ok, err = validate(config) + + assert.is_falsy(ok) + assert.not_nil(err) + assert.same({ config = {allow_patterns = "length must be at most 10" }}, err) + end) +end) diff --git a/spec/03-plugins/42-ai-prompt-guard/01_unit_spec.lua b/spec/03-plugins/42-ai-prompt-guard/01_unit_spec.lua new file mode 100644 index 00000000000..ac82622755c --- /dev/null +++ b/spec/03-plugins/42-ai-prompt-guard/01_unit_spec.lua @@ -0,0 +1,221 @@ +local PLUGIN_NAME = "ai-prompt-guard" +local access_handler = require("kong.plugins.ai-prompt-guard.handler") + + +local general_chat_request = { + messages = { + [1] = { + role = "system", + content = "You are a mathematician." + }, + [2] = { + role = "user", + content = "What is 1 + 1?" + }, + }, +} + +local general_chat_request_with_history = { + messages = { + [1] = { + role = "system", + content = "You are a mathematician." + }, + [2] = { + role = "user", + content = "What is 12 + 1?" + }, + [3] = { + role = "assistant", + content = "The answer is 13.", + }, + [4] = { + role = "user", + content = "Now double the previous answer.", + }, + }, +} + +local denied_chat_request = { + messages = { + [1] = { + role = "system", + content = "You are a mathematician." + }, + [2] = { + role = "user", + content = "What is 22 + 1?" + }, + }, +} + +local neither_allowed_nor_denied_chat_request = { + messages = { + [1] = { + role = "system", + content = "You are a mathematician." + }, + [2] = { + role = "user", + content = "What is 55 + 55?" + }, + }, +} + + +local general_completions_request = { + prompt = "You are a mathematician. What is 1 + 1?" +} + + +local denied_completions_request = { + prompt = "You are a mathematician. What is 22 + 1?" +} + +local neither_allowed_nor_denied_completions_request = { + prompt = "You are a mathematician. What is 55 + 55?" +} + +local allow_patterns_no_history = { + allow_patterns = { + [1] = ".*1 \\+ 1.*" + }, + allow_all_conversation_history = true, +} + +local allow_patterns_with_history = { + allow_patterns = { + [1] = ".*1 \\+ 1.*" + }, + allow_all_conversation_history = false, +} + +local deny_patterns_with_history = { + deny_patterns = { + [1] = ".*12 \\+ 1.*" + }, + allow_all_conversation_history = false, +} + +local deny_patterns_no_history = { + deny_patterns = { + [1] = ".*22 \\+ 1.*" + }, + allow_all_conversation_history = true, +} + +local both_patterns_no_history = { + allow_patterns = { + [1] = ".*1 \\+ 1.*" + }, + deny_patterns = { + [1] = ".*99 \\+ 99.*" + }, + allow_all_conversation_history = true, +} + +describe(PLUGIN_NAME .. ": (unit)", function() + + + describe("chat operations", function() + + it("allows request when only conf.allow_patterns is set", function() + local ok, err = access_handler.execute(general_chat_request, allow_patterns_no_history) + + assert.is_truthy(ok) + assert.is_nil(err) + end) + + it("allows request when only conf.deny_patterns is set, and pattern should not match", function() + local ok, err = access_handler.execute(general_chat_request, deny_patterns_no_history) + + assert.is_truthy(ok) + assert.is_nil(err) + end) + + it("denies request when only conf.allow_patterns is set, and pattern should not match", function() + local ok, err = access_handler.execute(denied_chat_request, allow_patterns_no_history) + + assert.is_falsy(ok) + assert.equal(err, "prompt doesn't match any allowed pattern") + end) + + it("denies request when only conf.deny_patterns is set, and pattern should match", function() + local ok, err = access_handler.execute(denied_chat_request, deny_patterns_no_history) + + assert.is_falsy(ok) + assert.equal(err, "prompt pattern is blocked") + end) + + it("allows request when both conf.allow_patterns and conf.deny_patterns are set, and pattern matches allow", function() + local ok, err = access_handler.execute(general_chat_request, both_patterns_no_history) + + assert.is_truthy(ok) + assert.is_nil(err) + end) + + it("denies request when both conf.allow_patterns and conf.deny_patterns are set, and pattern matches neither", function() + local ok, err = access_handler.execute(neither_allowed_nor_denied_chat_request, both_patterns_no_history) + + assert.is_falsy(ok) + assert.equal(err, "prompt doesn't match any allowed pattern") + end) + + it("denies request when only conf.allow_patterns is set and previous chat history should not match", function() + local ok, err = access_handler.execute(general_chat_request_with_history, allow_patterns_with_history) + + assert.is_falsy(ok) + assert.equal(err, "prompt doesn't match any allowed pattern") + end) + + it("denies request when only conf.deny_patterns is set and previous chat history should match", function() + local ok, err = access_handler.execute(general_chat_request_with_history, deny_patterns_with_history) + + assert.is_falsy(ok) + assert.equal(err, "prompt pattern is blocked") + end) + + end) + + + describe("completions operations", function() + + it("allows request when only conf.allow_patterns is set", function() + local ok, err = access_handler.execute(general_completions_request, allow_patterns_no_history) + + assert.is_truthy(ok) + assert.is_nil(err) + end) + + it("allows request when only conf.deny_patterns is set, and pattern should not match", function() + local ok, err = access_handler.execute(general_completions_request, deny_patterns_no_history) + + assert.is_truthy(ok) + assert.is_nil(err) + end) + + it("denies request when only conf.allow_patterns is set, and pattern should not match", function() + local ok, err = access_handler.execute(denied_completions_request, allow_patterns_no_history) + + assert.is_falsy(ok) + assert.equal(err, "prompt doesn't match any allowed pattern") + end) + + it("denies request when only conf.deny_patterns is set, and pattern should match", function() + local ok, err = access_handler.execute(denied_completions_request, deny_patterns_no_history) + + assert.is_falsy(ok) + assert.equal("prompt pattern is blocked", err) + end) + + it("denies request when both conf.allow_patterns and conf.deny_patterns are set, and pattern matches neither", function() + local ok, err = access_handler.execute(neither_allowed_nor_denied_completions_request, both_patterns_no_history) + + assert.is_falsy(ok) + assert.equal(err, "prompt doesn't match any allowed pattern") + end) + + end) + + +end) diff --git a/spec/03-plugins/42-ai-prompt-guard/02-integration_spec.lua b/spec/03-plugins/42-ai-prompt-guard/02-integration_spec.lua new file mode 100644 index 00000000000..d5ffdf8b535 --- /dev/null +++ b/spec/03-plugins/42-ai-prompt-guard/02-integration_spec.lua @@ -0,0 +1,428 @@ +local helpers = require "spec.helpers" + +local PLUGIN_NAME = "ai-prompt-guard" + + +for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then + describe(PLUGIN_NAME .. ": (access) [#" .. strategy .. "]", function() + local client + + lazy_setup(function() + + local bp = helpers.get_db_utils(strategy == "off" and "postgres" or strategy, nil, { PLUGIN_NAME }) + + -- both + local permit_history = bp.routes:insert({ + paths = { "~/permit-history$" }, + }) + + bp.plugins:insert { + name = PLUGIN_NAME, + route = { id = permit_history.id }, + config = { + allow_patterns = { + [1] = ".*cheddar.*", + [2] = ".*brie.*", + }, + deny_patterns = { + [1] = ".*leicester.*", + [2] = ".*edam.*", + }, + allow_all_conversation_history = true, + }, + } + + local block_history = bp.routes:insert({ + paths = { "~/block-history$" }, + }) + + bp.plugins:insert { + name = PLUGIN_NAME, + route = { id = block_history.id }, + config = { + allow_patterns = { + [1] = ".*cheddar.*", + [2] = ".*brie.*", + }, + deny_patterns = { + [1] = ".*leicester.*", + [2] = ".*edam.*", + }, + allow_all_conversation_history = false, + }, + } + -- + + -- allows only + local permit_history_allow_only = bp.routes:insert({ + paths = { "~/allow-only-permit-history$" }, + }) + + bp.plugins:insert { + name = PLUGIN_NAME, + route = { id = permit_history_allow_only.id }, + config = { + allow_patterns = { + [1] = ".*cheddar.*", + [2] = ".*brie.*", + }, + allow_all_conversation_history = true, + }, + } + + local block_history_allow_only = bp.routes:insert({ + paths = { "~/allow-only-block-history$" }, + }) + + bp.plugins:insert { + name = PLUGIN_NAME, + route = { id = block_history_allow_only.id }, + config = { + allow_patterns = { + [1] = ".*cheddar.*", + [2] = ".*brie.*", + }, + allow_all_conversation_history = false, + }, + } + -- + + -- denies only + local permit_history_deny_only = bp.routes:insert({ + paths = { "~/deny-only-permit-history$" }, + }) + + bp.plugins:insert { + name = PLUGIN_NAME, + route = { id = permit_history_deny_only.id }, + config = { + deny_patterns = { + [1] = ".*leicester.*", + [2] = ".*edam.*", + }, + allow_all_conversation_history = true, + }, + } + + local block_history_deny_only = bp.routes:insert({ + paths = { "~/deny-only-block-history$" }, + }) + + bp.plugins:insert { + name = PLUGIN_NAME, + route = { id = block_history_deny_only.id }, + config = { + deny_patterns = { + [1] = ".*leicester.*", + [2] = ".*edam.*", + }, + allow_all_conversation_history = false, + }, + } + -- + + assert(helpers.start_kong({ + database = strategy, + nginx_conf = "spec/fixtures/custom_nginx.template", + plugins = "bundled," .. PLUGIN_NAME, + declarative_config = strategy == "off" and helpers.make_yaml_file() or nil, + })) + end) + + lazy_teardown(function() + helpers.stop_kong(nil, true) + end) + + before_each(function() + client = helpers.proxy_client() + end) + + after_each(function() + if client then client:close() end + end) + + describe("request", function() + + -- both + it("allows message with 'allow' and 'deny' set, with history", function() + local r = client:get("/permit-history", { + headers = { + ["Content-Type"] = "application/json", + }, + body = [[ + { + "messages": [ + { + "role": "system", + "content": "You run a cheese shop." + }, + { + "role": "user", + "content": "I think that cheddar is the best cheese." + }, + { + "role": "assistant", + "content": "No, brie is the best cheese." + }, + { + "role": "user", + "content": "Why brie?" + } + ] + } + ]], + method = "POST", + }) + + -- the body is just an echo, don't need to test it + assert.res_status(200, r) + end) + + it("allows message with 'allow' and 'deny' set, without history", function() + local r = client:get("/block-history", { + headers = { + ["Content-Type"] = "application/json", + }, + body = [[ + { + "messages": [ + { + "role": "system", + "content": "You run a cheese shop." + }, + { + "role": "user", + "content": "I think that cheddar is the best cheese." + }, + { + "role": "assistant", + "content": "No, brie is the best cheese." + }, + { + "role": "user", + "content": "Why brie?" + } + ] + } + ]], + method = "POST", + }) + + assert.res_status(200, r) + end) + + it("blocks message with 'allow' and 'deny' set, with history", function() + local r = client:get("/permit-history", { + headers = { + ["Content-Type"] = "application/json", + }, + body = [[ + { + "messages": [ + { + "role": "system", + "content": "You run a cheese shop." + }, + { + "role": "user", + "content": "I think that cheddar or edam are the best cheeses." + }, + { + "role": "assistant", + "content": "No, brie is the best cheese." + }, + { + "role": "user", + "content": "Why?" + } + ] + } + ]], + method = "POST", + }) + + assert.res_status(400, r) + end) + -- + + -- allows only + it("allows message with 'allow' only set, with history", function() + local r = client:get("/allow-only-permit-history", { + headers = { + ["Content-Type"] = "application/json", + }, + body = [[ + { + "messages": [ + { + "role": "system", + "content": "You run a cheese shop." + }, + { + "role": "user", + "content": "I think that brie is the best cheese." + }, + { + "role": "assistant", + "content": "No, cheddar is the best cheese." + }, + { + "role": "user", + "content": "Why cheddar?" + } + ] + } + ]], + method = "POST", + }) + + assert.res_status(200, r) + end) + + it("allows message with 'allow' only set, without history", function() + local r = client:get("/allow-only-block-history", { + headers = { + ["Content-Type"] = "application/json", + }, + body = [[ + { + "messages": [ + { + "role": "system", + "content": "You run a cheese shop." + }, + { + "role": "user", + "content": "I think that brie is the best cheese." + }, + { + "role": "assistant", + "content": "No, cheddar is the best cheese." + }, + { + "role": "user", + "content": "Why cheddar?" + } + ] + } + ]], + method = "POST", + }) + + assert.res_status(200, r) + end) + + -- denies only + it("allows message with 'deny' only set, permit history", function() + local r = client:get("/deny-only-permit-history", { + headers = { + ["Content-Type"] = "application/json", + }, + + -- this will be permitted, because the BAD PHRASE is only in chat history, + -- which the developer "controls" + body = [[ + { + "messages": [ + { + "role": "system", + "content": "You run a cheese shop." + }, + { + "role": "user", + "content": "I think that leicester is the best cheese." + }, + { + "role": "assistant", + "content": "No, cheddar is the best cheese." + }, + { + "role": "user", + "content": "Why cheddar?" + } + ] + } + ]], + method = "POST", + }) + + assert.res_status(200, r) + end) + + it("blocks message with 'deny' only set, permit history", function() + local r = client:get("/deny-only-permit-history", { + headers = { + ["Content-Type"] = "application/json", + }, + + -- this will be blocks, because the BAD PHRASE is in the latest chat message, + -- which the user "controls" + body = [[ + { + "messages": [ + { + "role": "system", + "content": "You run a cheese shop." + }, + { + "role": "user", + "content": "I think that leicester is the best cheese." + }, + { + "role": "assistant", + "content": "No, edam is the best cheese." + }, + { + "role": "user", + "content": "Why edam?" + } + ] + } + ]], + method = "POST", + }) + + assert.res_status(400, r) + end) + + it("blocks message with 'deny' only set, scan history", function() + local r = client:get("/deny-only-block-history", { + headers = { + ["Content-Type"] = "application/json", + }, + + -- this will NOT be permitted, because the BAD PHRASE is in chat history, + -- as specified by the Kong admins + body = [[ + { + "messages": [ + { + "role": "system", + "content": "You run a cheese shop." + }, + { + "role": "user", + "content": "I think that leicester is the best cheese." + }, + { + "role": "assistant", + "content": "No, cheddar is the best cheese." + }, + { + "role": "user", + "content": "Why cheddar?" + } + ] + } + ]], + method = "POST", + }) + + assert.res_status(400, r) + end) + -- + + end) + end) + +end end