Skip to content

Commit

Permalink
fix(plugin): cleanup and fix error handling (#12912)
Browse files Browse the repository at this point in the history
* chore(ai-prompt-template): cleanup code improve error handling

* review comments
  • Loading branch information
Tieske authored May 8, 2024
1 parent 328097a commit 0b1705e
Show file tree
Hide file tree
Showing 4 changed files with 210 additions and 219 deletions.
107 changes: 43 additions & 64 deletions kong/plugins/ai-prompt-template/handler.lua
Original file line number Diff line number Diff line change
@@ -1,101 +1,80 @@
local _M = {}

local templater = require("kong.plugins.ai-prompt-template.templater")
local ipairs = ipairs
local type = type

-- imports
local kong_meta = require("kong.meta")
local templater = require("kong.plugins.ai-prompt-template.templater"):new()
local fmt = string.format
local parse_url = require("socket.url").parse
local byte = string.byte
local sub = string.sub
local type = type
local ipairs = ipairs
--

_M.PRIORITY = 773
_M.VERSION = kong_meta.version

local AIPromptTemplateHandler = {
PRIORITY = 773,
VERSION = require("kong.meta").version,
}


local log_entry_keys = {

local LOG_ENTRY_KEYS = {
REQUEST_BODY = "ai.payload.original_request",
}


-- reuse this table for error message response
local ERROR_MSG = { error = { message = "" } }


local function bad_request(msg)
kong.log.debug(msg)
ERROR_MSG.error.message = msg

return kong.response.exit(ngx.HTTP_BAD_REQUEST, ERROR_MSG)
return kong.response.exit(400, { error = { message = msg } })
end


local BRACE_START = byte("{")
local BRACE_END = byte("}")
local COLON = byte(":")
local SLASH = byte("/")


---- BORROWED FROM `kong.pdk.vault`
---
-- Checks if the passed in reference looks like a reference.
-- Checks if the passed in reference looks like a reference, and returns the template name.
-- Valid references start with '{template://' and end with '}'.
--
-- @local
-- @function is_reference
-- @tparam string reference reference to check
-- @treturn boolean `true` is the passed in reference looks like a reference, otherwise `false`
local function is_reference(reference)
return type(reference) == "string"
and byte(reference, 1) == BRACE_START
and byte(reference, -1) == BRACE_END
and byte(reference, 10) == COLON
and byte(reference, 11) == SLASH
and byte(reference, 12) == SLASH
and sub(reference, 2, 9) == "template"
-- @treturn string the reference template name or nil if it's not a reference
local function extract_template_name(reference)
if type(reference) ~= "string" then
return nil
end

if not (reference:sub(1, 12) == "{template://" and reference:sub(-1) == "}") then
return nil
end

return reference:sub(13, -2)
end


local function find_template(reference_string, templates)
local parts, err = parse_url(sub(reference_string, 2, -2))
if not parts then
return nil, fmt("template reference is not in format '{template://template_name}' (%s) [%s]", err, reference_string)
end

-- iterate templates to find it
--- Find a template by name in the list of templates.
-- @tparam string reference_name the name of the template to find
-- @tparam table templates the list of templates to search
-- @treturn string the template if found, or nil + error message if not found
local function find_template(reference_name, templates)
for _, v in ipairs(templates) do
if v.name == parts.host then
if v.name == reference_name then
return v, nil
end
end

return nil, fmt("could not find template name [%s]", parts.host)
return nil, "could not find template name [" .. reference_name .. "]"
end


function _M:access(conf)

function AIPromptTemplateHandler:access(conf)
kong.service.request.enable_buffering()
kong.ctx.shared.ai_prompt_templated = true

if conf.log_original_request then
kong.log.set_serialize_value(log_entry_keys.REQUEST_BODY, kong.request.get_raw_body())
kong.log.set_serialize_value(LOG_ENTRY_KEYS.REQUEST_BODY, kong.request.get_raw_body())
end

local request, err = kong.request.get_body("application/json")
if err then
local request = kong.request.get_body("application/json")
if type(request) ~= "table" then
return bad_request("this LLM route only supports application/json requests")
end

local messages = request.messages
local prompt = request.prompt

if (not messages) and (not prompt) then
return bad_request("this LLM route only supports llm/chat or llm/completions type requests")
end

if messages and prompt then
return bad_request("cannot run 'messages' and 'prompt' templates at the same time")
end
Expand All @@ -105,22 +84,22 @@ function _M:access(conf)
return bad_request("only 'llm/v1/chat' and 'llm/v1/completions' formats are supported for templating")
end

if not is_reference(reference) then
if not (conf.allow_untemplated_requests) then
return bad_request("this LLM route only supports templated requests")
local template_name = extract_template_name(reference)
if not template_name then
if conf.allow_untemplated_requests then
return -- not a reference, do nothing
end

-- not reference, do nothing
return
return bad_request("this LLM route only supports templated requests")
end

local requested_template, err = find_template(reference, conf.templates)
local requested_template, err = find_template(template_name, conf.templates)
if not requested_template then
return bad_request(err)
end

-- try to render the replacement request
local rendered_template, err = templater:render(requested_template, request.properties or {})
local rendered_template, err = templater.render(requested_template, request.properties or {})
if err then
return bad_request(err)
end
Expand All @@ -129,4 +108,4 @@ function _M:access(conf)
end


return _M
return AIPromptTemplateHandler
125 changes: 49 additions & 76 deletions kong/plugins/ai-prompt-template/templater.lua
Original file line number Diff line number Diff line change
@@ -1,93 +1,66 @@
local _S = {}

-- imports
local fmt = string.format
--

-- globals
local GSUB_REPLACE_PATTERN = "{{([%w_]+)}}"
--

local function backslash_replacement_function(c)
if c == "\n" then
return "\\n"
elseif c == "\r" then
return "\\r"
elseif c == "\t" then
return "\\t"
elseif c == "\b" then
return "\\b"
elseif c == "\f" then
return "\\f"
elseif c == '"' then
return '\\"'
elseif c == '\\' then
return '\\\\'
else
return string.format("\\u%04x", c:byte())
end
end
local cjson = require "cjson.safe"

local chars_to_be_escaped_in_JSON_string
= '['
.. '"' -- class sub-pattern to match a double quote
.. '%\\' -- class sub-pattern to match a backslash
.. '%z' -- class sub-pattern to match a null
.. '\001' .. '-' .. '\031' -- class sub-pattern to match control characters
.. ']'

-- borrowed from turbo-json
local function sanitize_parameter(s)
if type(s) ~= "string" or s == "" then
return nil, nil, "only string arguments are supported"
end
local _M = {}

-- check if someone is trying to inject JSON control characters to close the command
if s:sub(-1) == "," then
s = s:sub(1, -1)
end

return s:gsub(chars_to_be_escaped_in_JSON_string, backslash_replacement_function), nil
end

function _S:new(o)
local o = o or {}
setmetatable(o, self)
self.__index = self
--- Sanitize properties object.
-- Incoming user-provided JSON object may contain any kind of data.
-- @tparam table params the kv table to sanitize
-- @treturn[1] table the escaped values (without quotes)
-- @treturn[2] nil
-- @treturn[2] string error message
local function sanitize_properties(params)
local result = {}

return o
if type(params) ~= "table" then
return nil, "properties must be an object"
end

for k,v in pairs(params) do
if type(k) ~= "string" then
return nil, "properties must be an object"
end
if type(v) == "string" then
result[k] = cjson.encode(v):sub(2, -2) -- remove quotes
else
return nil, "property values must be a string, got " .. type(v)
end
end

return result
end


function _S:render(template, properties)
local sanitized_properties = {}
local err, _

for k, v in pairs(properties) do
sanitized_properties[k], _, err = sanitize_parameter(v)
if err then return nil, err end
end
do
local GSUB_REPLACE_PATTERN = "{{([%w_]+)}}"

local result = template.template:gsub(GSUB_REPLACE_PATTERN, sanitized_properties)
function _M.render(template, properties)
local sanitized_properties, err = sanitize_properties(properties)
if not sanitized_properties then
return nil, err
end

-- find any missing variables
local errors = {}
local error_string
for w in (result):gmatch(GSUB_REPLACE_PATTERN) do
errors[w] = true
end
local result = template.template:gsub(GSUB_REPLACE_PATTERN, sanitized_properties)

if next(errors) ~= nil then
for k, _ in pairs(errors) do
if not error_string then
error_string = fmt("missing template parameters: [%s]", k)
else
error_string = fmt("%s, [%s]", error_string, k)
-- find any missing variables
local errors = {}
local seen_before = {}
for w in result:gmatch(GSUB_REPLACE_PATTERN) do
if not seen_before[w] then
seen_before[w] = true
errors[#errors+1] = "[" .. w .. "]"
end
end
end

return result, error_string
if errors[1] then
return nil, "missing template parameters: " .. table.concat(errors, ", ")
end

return result
end
end

return _S

return _M
20 changes: 14 additions & 6 deletions spec/03-plugins/43-ai-prompt-template/01-unit_spec.lua
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
local PLUGIN_NAME = "ai-prompt-template"

-- imports
local templater = require("kong.plugins.ai-prompt-template.templater"):new()
--

local good_chat_template = {
template = [[
Expand Down Expand Up @@ -80,22 +77,33 @@ local good_prompt_template = {
}
local good_expected_prompt = "Make me a program to do fibonacci sequence in python."



describe(PLUGIN_NAME .. ": (unit)", function()

local templater

setup(function()
templater = require("kong.plugins.ai-prompt-template.templater")
end)


it("templates chat messages", function()
local rendered_template, err = templater:render(good_chat_template, templated_chat_request.parameters)
local rendered_template, err = templater.render(good_chat_template, templated_chat_request.parameters)
assert.is_nil(err)
assert.same(rendered_template, good_expected_chat)
end)


it("templates a prompt", function()
local rendered_template, err = templater:render(good_prompt_template, templated_prompt_request.parameters)
local rendered_template, err = templater.render(good_prompt_template, templated_prompt_request.parameters)
assert.is_nil(err)
assert.same(rendered_template, good_expected_prompt)
end)


it("prohibits json injection", function()
local rendered_template, err = templater:render(good_chat_template, templated_chat_request_inject_json.parameters)
local rendered_template, err = templater.render(good_chat_template, templated_chat_request_inject_json.parameters)
assert.is_nil(err)
assert.same(rendered_template, inject_json_expected_chat)
end)
Expand Down
Loading

1 comment on commit 0b1705e

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bazel Build

Docker image available kong/kong:0b1705ec0a4480596f6a1d2923a809d1f458a07b
Artifacts available https://github.com/Kong/kong/actions/runs/9000432206

Please sign in to comment.