Skip to content

Commit

Permalink
refactor(ai-prompt-guard): migrate ai-prompt-guard to new framework
Browse files Browse the repository at this point in the history
  • Loading branch information
fffonion committed Nov 20, 2024
1 parent bf0d5b6 commit 97234e0
Show file tree
Hide file tree
Showing 5 changed files with 163 additions and 144 deletions.
1 change: 1 addition & 0 deletions kong-3.9.0-0.rockspec
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,7 @@ build = {
["kong.plugins.ai-prompt-template.schema"] = "kong/plugins/ai-prompt-template/schema.lua",
["kong.plugins.ai-prompt-template.templater"] = "kong/plugins/ai-prompt-template/templater.lua",

["kong.plugins.ai-prompt-guard.filters.guard-prompt"] = "kong/plugins/ai-prompt-guard/filters/guard-prompt.lua",
["kong.plugins.ai-prompt-guard.handler"] = "kong/plugins/ai-prompt-guard/handler.lua",
["kong.plugins.ai-prompt-guard.schema"] = "kong/plugins/ai-prompt-guard/schema.lua",

Expand Down
151 changes: 151 additions & 0 deletions kong/plugins/ai-prompt-guard/filters/guard-prompt.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
-- This software is copyright Kong Inc. and its licensors.
-- Use of the software is subject to the agreement between your organization
-- and Kong Inc. If there is no such agreement, use is governed by and
-- subject to the terms of the Kong Master Software License Agreement found
-- at https://konghq.com/enterprisesoftwarelicense/.
-- [ END OF LICENSE 0867164ffc95e54f04670b5169c09574bdbd9bba ]

local buffer = require("string.buffer")
local ngx_re_find = ngx.re.find
local ai_plugin_ctx = require("kong.llm.plugin.ctx")

local _M = {
NAME = "guard-prompt",
STAGE = "REQ_TRANSFORMATION",
}

local FILTER_OUTPUT_SCHEMA = {
guarded = "boolean",
}

local _, set_ctx = ai_plugin_ctx.get_namespaced_accesors(_M.NAME, FILTER_OUTPUT_SCHEMA)

local EMPTY = {}


local function bad_request(msg)
kong.log.debug(msg)
return kong.response.exit(400, { error = { message = msg } })
end



local execute do
local bad_format_error = "ai-prompt-guard only supports llm/v1/chat or llm/v1/completions prompts"

-- Checks the prompt for the given patterns.
-- _Note_: if a regex fails, it returns a 500, and exits the request.
-- @tparam table request The deserialized JSON body of the request
-- @tparam table conf The plugin configuration
-- @treturn[1] table The decorated request (same table, content updated)
-- @treturn[2] nil
-- @treturn[2] string The error message
function execute(request, conf)
local collected_prompts
local messages = request.messages

-- concat all prompts into one string, if conversation history must be checked
if type(messages) == "table" then
local buf = buffer.new()
-- Note allow_all_conversation_history means ignores history
local just_pick_latest = conf.allow_all_conversation_history

-- iterate in reverse so we get the latest user prompt first
-- instead of the oldest one in history
for i=#messages, 1, -1 do
local v = messages[i]
if type(v.role) ~= "string" then
return nil, bad_format_error
end
if v.role == "user" or conf.match_all_roles then
if type(v.content) ~= "string" then
return nil, bad_format_error
end
buf:put(v.content)

if just_pick_latest then
break
end

buf:put(" ") -- put a seperator to avoid adhension of words
end
end

collected_prompts = buf:get()

elseif type(request.prompt) == "string" then
collected_prompts = request.prompt

else
return nil, bad_format_error
end

if not collected_prompts then
return nil, "no 'prompt' or 'messages' received"
end


-- check the prompt for explcit ban patterns
for _, v in ipairs(conf.deny_patterns or EMPTY) do
-- check each denylist; if prompt matches it, deny immediately
local m, _, err = ngx_re_find(collected_prompts, v, "jo")
if err then
-- regex failed, that's an error by the administrator
kong.log.err("bad regex pattern '", v ,"', failed to execute: ", err)
return kong.response.exit(500)

elseif m then
return nil, "prompt pattern is blocked"
end
end


if #(conf.allow_patterns or EMPTY) == 0 then
-- no allow_patterns, so we're good
return true
end

-- if any allow_patterns specified, make sure the prompt matches one of them
for _, v in ipairs(conf.allow_patterns or EMPTY) do
-- check each denylist; if prompt matches it, deny immediately
local m, _, err = ngx_re_find(collected_prompts, v, "jo")

if err then
-- regex failed, that's an error by the administrator
kong.log.err("bad regex pattern '", v ,"', failed to execute: ", err)
return kong.response.exit(500)

elseif m then
return true -- got a match so is allowed, exit early
end
end

return false, "prompt doesn't match any allowed pattern"
end
end

if _G._TEST then
-- only if we're testing export this function (using a different name!)
_M._execute = execute
end

function _M:run(conf)
-- if plugin ordering was altered, receive the "decorated" request
local request_body_table = ai_plugin_ctx.get_request_body_table_inuse()
if not request_body_table then
return bad_request("this LLM route only supports application/json requests")
end

-- run access handler
local ok, err = execute(request_body_table, conf)
if not ok then
kong.log.debug(err)
return bad_request("bad request") -- don't let users know 'ai-prompt-guard' is in use
end

set_ctx("guarded", true)

return true
end

return _M
147 changes: 7 additions & 140 deletions kong/plugins/ai-prompt-guard/handler.lua
Original file line number Diff line number Diff line change
@@ -1,144 +1,11 @@
local buffer = require("string.buffer")
local llm_state = require("kong.llm.state")
local ngx_re_find = ngx.re.find
local EMPTY = {}
local ai_plugin_base = require("kong.llm.plugin.base")

local NAME = "ai-prompt-guard"
local PRIORITY = 771

local AIPlugin = ai_plugin_base.define(NAME, PRIORITY)

local plugin = {
PRIORITY = 771,
VERSION = require("kong.meta").version
}
AIPlugin:enable(AIPlugin.register_filter(require("kong.llm.plugin.shared-filters.parse-request")))
AIPlugin:enable(AIPlugin.register_filter(require("kong.plugins." .. NAME .. ".filters.guard-prompt")))



local function bad_request(msg)
kong.log.debug(msg)
return kong.response.exit(400, { error = { message = msg } })
end



local execute do
local bad_format_error = "ai-prompt-guard only supports llm/v1/chat or llm/v1/completions prompts"

-- Checks the prompt for the given patterns.
-- _Note_: if a regex fails, it returns a 500, and exits the request.
-- @tparam table request The deserialized JSON body of the request
-- @tparam table conf The plugin configuration
-- @treturn[1] table The decorated request (same table, content updated)
-- @treturn[2] nil
-- @treturn[2] string The error message
function execute(request, conf)
local collected_prompts
local messages = request.messages

-- concat all prompts into one string, if conversation history must be checked
if type(messages) == "table" then
local buf = buffer.new()
-- Note allow_all_conversation_history means ignores history
local just_pick_latest = conf.allow_all_conversation_history

-- iterate in reverse so we get the latest user prompt first
-- instead of the oldest one in history
for i=#messages, 1, -1 do
local v = messages[i]
if type(v.role) ~= "string" then
return nil, bad_format_error
end
if v.role == "user" or conf.match_all_roles then
if type(v.content) ~= "string" then
return nil, bad_format_error
end
buf:put(v.content)

if just_pick_latest then
break
end

buf:put(" ") -- put a seperator to avoid adhension of words
end
end

collected_prompts = buf:get()

elseif type(request.prompt) == "string" then
collected_prompts = request.prompt

else
return nil, bad_format_error
end

if not collected_prompts then
return nil, "no 'prompt' or 'messages' received"
end


-- check the prompt for explcit ban patterns
for _, v in ipairs(conf.deny_patterns or EMPTY) do
-- check each denylist; if prompt matches it, deny immediately
local m, _, err = ngx_re_find(collected_prompts, v, "jo")
if err then
-- regex failed, that's an error by the administrator
kong.log.err("bad regex pattern '", v ,"', failed to execute: ", err)
return kong.response.exit(500)

elseif m then
return nil, "prompt pattern is blocked"
end
end


if #(conf.allow_patterns or EMPTY) == 0 then
-- no allow_patterns, so we're good
return true
end

-- if any allow_patterns specified, make sure the prompt matches one of them
for _, v in ipairs(conf.allow_patterns or EMPTY) do
-- check each denylist; if prompt matches it, deny immediately
local m, _, err = ngx_re_find(collected_prompts, v, "jo")

if err then
-- regex failed, that's an error by the administrator
kong.log.err("bad regex pattern '", v ,"', failed to execute: ", err)
return kong.response.exit(500)

elseif m then
return true -- got a match so is allowed, exit early
end
end

return false, "prompt doesn't match any allowed pattern"
end
end



function plugin:access(conf)
kong.service.request.enable_buffering()
llm_state.set_prompt_guarded() -- future use

-- if plugin ordering was altered, receive the "decorated" request
local request = kong.request.get_body("application/json", nil, conf.max_request_body_size)
if type(request) ~= "table" then
return bad_request("this LLM route only supports application/json requests")
end

-- run access handler
local ok, err = execute(request, conf)
if not ok then
kong.log.debug(err)
return bad_request("bad request") -- don't let users know 'ai-prompt-guard' is in use
end
end



if _G._TEST then
-- only if we're testing export this function (using a different name!)
plugin._execute = execute
end


return plugin
return AIPlugin:as_kong_plugin()
4 changes: 2 additions & 2 deletions spec/03-plugins/42-ai-prompt-guard/01-unit_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ describe(PLUGIN_NAME .. ": (unit)", function()

setup(function()
_G._TEST = true
package.loaded["kong.plugins.ai-prompt-guard.handler"] = nil
access_handler = require("kong.plugins.ai-prompt-guard.handler")
package.loaded["kong.plugins.ai-prompt-guard.filters.guard-prompt"] = nil
access_handler = require("kong.plugins.ai-prompt-guard.filters.guard-prompt")
end)

teardown(function()
Expand Down
4 changes: 2 additions & 2 deletions spec/03-plugins/42-ai-prompt-guard/02-integration_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ local helpers = require "spec.helpers"
local PLUGIN_NAME = "ai-prompt-guard"


for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
for _, strategy in helpers.all_strategies() do
describe(PLUGIN_NAME .. ": (access) [#" .. strategy .. "]", function()
local client

Expand Down Expand Up @@ -511,4 +511,4 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then

end)

end end
end

0 comments on commit 97234e0

Please sign in to comment.