Skip to content

Commit

Permalink
fix: prompt decorator crashes AI Gateway when trying to decorate "pre…
Browse files Browse the repository at this point in the history
…pend"
  • Loading branch information
tysoekong authored and ProBrian committed Dec 13, 2024
1 parent b20e847 commit bbafa05
Show file tree
Hide file tree
Showing 4 changed files with 211 additions and 63 deletions.
5 changes: 5 additions & 0 deletions kong/llm/plugin/ctx.lua
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,11 @@ local EMPTY_REQUEST_T = _M.immutable_table({})

function _M.get_request_body_table_inuse()
local request_body_table

if _M.has_namespace("decorate-prompt") then -- has ai-prompt-decorator and others in future
request_body_table = _M.get_namespaced_ctx("decorate-prompt", "request_body_table")
end

if _M.has_namespace("normalize-request") then -- has ai-proxy/ai-proxy-advanced
request_body_table = _M.get_namespaced_ctx("normalize-request", "request_body_table")
end
Expand Down
11 changes: 9 additions & 2 deletions kong/llm/plugin/shared-filters/normalize-request.lua
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,14 @@ local function validate_and_transform(conf)
local model_t = conf_m.model
local model_provider = conf.model.provider -- use the one from conf, not the merged one to avoid potential security risk

local request_table = ai_plugin_ctx.get_namespaced_ctx("parse-request", "request_body_table")
local request_table
if ai_plugin_ctx.has_namespace("decorate-prompt") and
ai_plugin_ctx.get_namespaced_ctx("decorate-prompt", "decorated") then
request_table = ai_plugin_ctx.get_namespaced_ctx("decorate-prompt", "request_body_table")
else
request_table = ai_plugin_ctx.get_namespaced_ctx("parse-request", "request_body_table")
end

if not request_table then
return bail(400, "content-type header does not match request body, or bad JSON formatting")
end
Expand Down Expand Up @@ -219,4 +226,4 @@ function _M:run(conf)
return true
end

return _M
return _M
17 changes: 10 additions & 7 deletions kong/plugins/ai-prompt-decorator/filters/decorate-prompt.lua
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,17 @@
-- [ END OF LICENSE 0867164ffc95e54f04670b5169c09574bdbd9bba ]

local new_tab = require("table.new")
local deep_copy = require("kong.tools.table").deep_copy
local ai_plugin_ctx = require("kong.llm.plugin.ctx")
local cycle_aware_deep_copy = require("kong.tools.table").cycle_aware_deep_copy

local _M = {
NAME = "decorate-prompt",
STAGE = "REQ_TRANSFORMATION",
}
}

local FILTER_OUTPUT_SCHEMA = {
decorated = "boolean",
request_body_table = "table",
}

local _, set_ctx = ai_plugin_ctx.get_namespaced_accesors(_M.NAME, FILTER_OUTPUT_SCHEMA)
Expand All @@ -24,7 +25,7 @@ local EMPTY = {}


local function bad_request(msg)
kong.log.debug(msg)
kong.log.info(msg)
return kong.response.exit(400, { error = { message = msg } })
end

Expand All @@ -37,9 +38,6 @@ local function execute(request, conf)
local prepend = conf.prompts.prepend or EMPTY
local append = conf.prompts.append or EMPTY

-- ensure we don't modify the original request
request = deep_copy(request)

local old_messages = request.messages
local new_messages = new_tab(#append + #prepend + #old_messages, 0)
request.messages = new_messages
Expand Down Expand Up @@ -81,9 +79,14 @@ function _M:run(conf)
return bad_request("this LLM route only supports llm/chat type requests")
end

kong.service.request.set_body(execute(request_body_table, conf), "application/json")
-- Deep copy to avoid modifying the immutable table.
-- Re-assign it to trigger GC of the old one and save memory.
request_body_table = execute(cycle_aware_deep_copy(request_body_table), conf)

kong.service.request.set_body(request_body_table, "application/json") -- legacy

set_ctx("decorated", true)
set_ctx("request_body_table", request_body_table)

return true
end
Expand Down
241 changes: 187 additions & 54 deletions spec/03-plugins/41-ai-prompt-decorator/02-integration_spec.lua
Original file line number Diff line number Diff line change
@@ -1,23 +1,45 @@
local helpers = require "spec.helpers"
local helpers = require("spec.helpers")
local cjson = require("cjson")


local PLUGIN_NAME = "ai-prompt-decorator"


for _, strategy in helpers.all_strategies() do
local openai_flat_chat = {
messages = {
{
role = "user",
content = "I think that cheddar is the best cheese.",
},
{
role = "assistant",
content = "No, brie is the best cheese.",
},
{
role = "user",
content = "Why brie?",
},
},
}


for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
describe(PLUGIN_NAME .. ": (access) [#" .. strategy .. "]", function()
local client

lazy_setup(function()

local bp = helpers.get_db_utils(strategy == "off" and "postgres" or strategy, nil, { PLUGIN_NAME })
local bp = helpers.get_db_utils(strategy == "off" and "postgres" or strategy, nil, { PLUGIN_NAME, "ctx-checker-last", "ctx-checker" })


local route1 = bp.routes:insert({
hosts = { "test1.com" },
-- echo route, we don't need a mock AI here
local prepend = bp.routes:insert({
hosts = { "prepend.decorate.local" },
})

bp.plugins:insert {
name = PLUGIN_NAME,
route = { id = route1.id },
route = { id = prepend.id },
config = {
prompts = {
prepend = {
Expand All @@ -30,6 +52,28 @@ for _, strategy in helpers.all_strategies() do
content = "Prepend text 2 here.",
},
},
},
},
}

bp.plugins:insert {
name = "ctx-checker-last",
route = { id = prepend.id },
config = {
ctx_check_field = "ai_namespaced_ctx",
}
}


local append = bp.routes:insert({
hosts = { "append.decorate.local" },
})

bp.plugins:insert {
name = PLUGIN_NAME,
route = { id = append.id },
config = {
prompts = {
append = {
[1] = {
role = "assistant",
Expand All @@ -44,72 +88,161 @@ for _, strategy in helpers.all_strategies() do
},
}

bp.plugins:insert {
name = "ctx-checker-last",
route = { id = append.id },
config = {
ctx_check_field = "ai_namespaced_ctx",
}
}

local both = bp.routes:insert({
hosts = { "both.decorate.local" },
})


bp.plugins:insert {
name = PLUGIN_NAME,
route = { id = both.id },
config = {
prompts = {
prepend = {
[1] = {
role = "system",
content = "Prepend text 1 here.",
},
[2] = {
role = "assistant",
content = "Prepend text 2 here.",
},
},
append = {
[1] = {
role = "assistant",
content = "Append text 3 here.",
},
[2] = {
role = "user",
content = "Append text 4 here.",
},
},
},
},
}

bp.plugins:insert {
name = "ctx-checker-last",
route = { id = both.id },
config = {
ctx_check_field = "ai_namespaced_ctx",
}
}


assert(helpers.start_kong({
database = strategy,
nginx_conf = "spec/fixtures/custom_nginx.template",
plugins = "bundled," .. PLUGIN_NAME,
plugins = "bundled,ctx-checker-last,ctx-checker," .. PLUGIN_NAME,
declarative_config = strategy == "off" and helpers.make_yaml_file() or nil,
}))
end)


lazy_teardown(function()
helpers.stop_kong()
helpers.stop_kong(nil, true)
end)


before_each(function()
client = helpers.proxy_client()
end)


after_each(function()
if client then client:close() end
end)



it("blocks a non-chat message", function()
local r = client:get("/request", {
headers = {
host = "test1.com",
["Content-Type"] = "application/json",
},
body = [[
{
"anything": [
{
"random": "data"
}
]
}]],
method = "POST",
})

assert.response(r).has.status(400)
local json = assert.response(r).has.jsonbody()
assert.same({ error = { message = "this LLM route only supports llm/chat type requests" }}, json)
end)


it("blocks an empty messages array", function()
local r = client:get("/request", {
headers = {
host = "test1.com",
["Content-Type"] = "application/json",
},
body = [[
{
"messages": []
}]],
method = "POST",
})

assert.response(r).has.status(400)
local json = assert.response(r).has.jsonbody()
assert.same({ error = { message = "this LLM route only supports llm/chat type requests" }}, json)
describe("request", function()
it("modifies the LLM chat request - prepend", function()
local r = client:get("/", {
headers = {
host = "prepend.decorate.local",
["Content-Type"] = "application/json"
},
body = cjson.encode(openai_flat_chat),
})

-- get the REQUEST body, that left Kong for the upstream, using the echo system
assert.response(r).has.status(200)
local request = assert.response(r).has.jsonbody()
request = cjson.decode(request.post_data.text)

assert.same({ content = "Prepend text 1 here.", role = "system" }, request.messages[1])
assert.same({ content = "Prepend text 2 here.", role = "system" }, request.messages[2])

-- check ngx.ctx was set properly for later AI chain filters
local ctx = assert.response(r).has.header("ctx-checker-last-ai-namespaced-ctx")
ctx = ngx.unescape_uri(ctx)
assert.match_re(ctx, [[.*decorate-prompt.*]])
assert.match_re(ctx, [[.*decorated = true.*]])
assert.match_re(ctx, [[.*Prepend text 1 here.*]])
assert.match_re(ctx, [[.*Prepend text 2 here.*]])
end)

it("modifies the LLM chat request - append", function()
local r = client:get("/", {
headers = {
host = "append.decorate.local",
["Content-Type"] = "application/json"
},
body = cjson.encode(openai_flat_chat),
})

-- get the REQUEST body, that left Kong for the upstream, using the echo system
assert.response(r).has.status(200)
local request = assert.response(r).has.jsonbody()
request = cjson.decode(request.post_data.text)

assert.same({ content = "Append text 1 here.", role = "assistant" }, request.messages[#request.messages-1])
assert.same({ content = "Append text 2 here.", role = "user" }, request.messages[#request.messages])

-- check ngx.ctx was set properly for later AI chain filters
local ctx = assert.response(r).has.header("ctx-checker-last-ai-namespaced-ctx")
ctx = ngx.unescape_uri(ctx)
assert.match_re(ctx, [[.*decorate-prompt.*]])
assert.match_re(ctx, [[.*decorated = true.*]])
assert.match_re(ctx, [[.*Append text 1 here.*]])
assert.match_re(ctx, [[.*Append text 2 here.*]])
end)


it("modifies the LLM chat request - both", function()
local r = client:get("/", {
headers = {
host = "both.decorate.local",
["Content-Type"] = "application/json"
},
body = cjson.encode(openai_flat_chat),
})

-- get the REQUEST body, that left Kong for the upstream, using the echo system
assert.response(r).has.status(200)
local request = assert.response(r).has.jsonbody()
request = cjson.decode(request.post_data.text)

assert.same({ content = "Prepend text 1 here.", role = "system" }, request.messages[1])
assert.same({ content = "Prepend text 2 here.", role = "assistant" }, request.messages[2])
assert.same({ content = "Append text 3 here.", role = "assistant" }, request.messages[#request.messages-1])
assert.same({ content = "Append text 4 here.", role = "user" }, request.messages[#request.messages])

-- check ngx.ctx was set properly for later AI chain filters
local ctx = assert.response(r).has.header("ctx-checker-last-ai-namespaced-ctx")
ctx = ngx.unescape_uri(ctx)
assert.match_re(ctx, [[.*decorate-prompt.*]])
assert.match_re(ctx, [[.*decorated = true.*]])
assert.match_re(ctx, [[.*Prepend text 1 here.*]])
assert.match_re(ctx, [[.*Prepend text 2 here.*]])
assert.match_re(ctx, [[.*Append text 3 here.*]])
assert.match_re(ctx, [[.*Append text 4 here.*]])
end)
end)

end)

end
end end

0 comments on commit bbafa05

Please sign in to comment.