Skip to content

Commit

Permalink
fix(ai-templater): improved error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
tysoekong committed Jan 24, 2024
1 parent a545f53 commit 746c22b
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 20 deletions.
33 changes: 17 additions & 16 deletions kong/plugins/ai-prompt-template/handler.lua
Original file line number Diff line number Diff line change
Expand Up @@ -49,23 +49,19 @@ local function is_reference(reference)
end

local function find_template(reference_string, templates)
if is_reference(reference_string) then
local parts, err = parse_url(sub(reference_string, 2, -2))
if not parts then
return nil, fmt("template reference is not in format '{template://template_name}' (%s) [%s]", err, reference_string)
end
local parts, err = parse_url(sub(reference_string, 2, -2))
if not parts then
return nil, fmt("template reference is not in format '{template://template_name}' (%s) [%s]", err, reference_string)
end

-- iterate templates to find it
for i, v in ipairs(templates) do
if v.name == parts.host then
return v, nil
end
-- iterate templates to find it
for i, v in ipairs(templates) do
if v.name == parts.host then
return v, nil
end

return nil, "could not find template name [" .. parts.host .. "]"
end

return nil, "'messages' template reference should be a single string, in format '{template://template_name}'"
return nil, fmt("could not find template name [%s]", parts.host)
end

function _M:access(conf)
Expand Down Expand Up @@ -100,17 +96,22 @@ function _M:access(conf)
return bad_request("only 'llm/v1/chat' and 'llm/v1/completions' formats are supported for templating")
end

local requested_template, err = find_template(reference, conf.templates)
if err and (not conf.allow_untemplated_requests) then bad_request(err) end
if is_reference(reference) then
local requested_template, err = find_template(reference, conf.templates)
if not requested_template then
return bad_request(err)
end

if not err then
-- try to render the replacement request
local rendered_template, err = templater:render(requested_template, request.properties or {})
if err then
return bad_request(err)
end

kong.service.request.set_raw_body(rendered_template)

elseif not (conf.allow_untemplated_requests) then
return bad_request("this LLM route only supports templated requests")
end
end

Expand Down
36 changes: 32 additions & 4 deletions spec/03-plugins/43-ai-prompt-template/02-integration_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,14 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
}
]],
},
[2] = {
name = "developer-completions",
template = [[
{
"prompt": "You are a {{language}} programming expert. Make me a {{program}} program."
}
]],
},
},
},
}
Expand Down Expand Up @@ -211,7 +219,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
local body = assert.res_status(400, r)
local json = cjson.decode(body)

assert.same(json, { error = { message = "'messages' template reference should be a single string, in format '{template://template_name}'" }})
assert.same(json, { error = { message = "this LLM route only supports templated requests" }})
end)

it("doesn't block when 'allow_untemplated_requests' is ON", function()
Expand All @@ -232,7 +240,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
]],
method = "POST",
})

local body = assert.res_status(200, r)
local json = cjson.decode(body)

Expand All @@ -256,13 +264,33 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
]],
method = "POST",
})

local body = assert.res_status(400, r)
local json = cjson.decode(body)

assert.same(json, { error = { message = "could not find template name [developer-doesnt-exist]" }} )
end)

it("still errors with a not found template when 'allow_untemplated_requests' is ON", function()
local r = client:get("/request", {
headers = {
host = "test1.com",
["Content-Type"] = "application/json",
},
body = [[
{
"messages": "{template://not_found}"
}
]],
method = "POST",
})

local body = assert.res_status(400, r)
local json = cjson.decode(body)

assert.same(json, { error = { message = "could not find template name [not_found]" }} )
end)

it("errors with missing template parameter", function()
local r = client:get("/request", {
headers = {
Expand All @@ -279,7 +307,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
]],
method = "POST",
})

local body = assert.res_status(400, r)
local json = cjson.decode(body)

Expand Down

0 comments on commit 746c22b

Please sign in to comment.