Skip to content

Commit

Permalink
fix(pdk): function store_response_header should accept nil as `va…
Browse files Browse the repository at this point in the history
…lue`
  • Loading branch information
ADD-SP committed Jun 18, 2024
1 parent 19ee6a9 commit f77ad1e
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 26 deletions.
49 changes: 35 additions & 14 deletions kong/pdk/private/rate_limiting.lua
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,34 @@ local max_fields_n = 4
local _M = {}


local function _validate_key(key, arg_n, func_name)
local typ = type(key)
if typ ~= "string" then
local msg = string.format(
"arg #%d `key` for function `%s` must be a string, got %s",
arg_n,
func_name,
typ
)
error(msg, 3)
end
end


local function _validate_value(value, arg_n, func_name)
local typ = type(value)
if typ ~= "number" and typ ~= "nil" and typ ~= "string" then
local msg = string.format(
"arg #%d `value` for function `%s` must be a string, number or nil, got %s",
arg_n,
func_name,
typ
)
error(msg, 3)
end
end


local function _has_rl_ctx(ngx_ctx)
return ngx_ctx.__rate_limiting_context__ ~= nil
end
Expand Down Expand Up @@ -42,27 +70,20 @@ end


function _M.store_response_header(ngx_ctx, key, value)
assert(
type(key) == "string",
"arg #2 `key` for function `store_response_header` must be a string"
)

local value_type = type(value)
assert(
value_type == "string" or value_type == "number",
"arg #3 `value` for function `store_response_header` must be a string or a number"
)
_validate_key(key, 2, "store_response_header")
_validate_value(value, 3, "store_response_header")

local rl_ctx = _get_or_create_rl_ctx(ngx_ctx)
rl_ctx[key] = value
end


function _M.get_stored_response_header(ngx_ctx, key)
assert(
type(key) == "string",
"arg #2 `key` for function `get_stored_response_header` must be a string"
)
_validate_key(key, 2, "get_stored_response_header")

if not _has_rl_ctx(ngx_ctx) then
return nil
end

if not _has_rl_ctx(ngx_ctx) then
return nil
Expand Down
48 changes: 36 additions & 12 deletions t/01-pdk/16-rl-ctx.t
Original file line number Diff line number Diff line change
Expand Up @@ -140,49 +140,73 @@ X-2: 2
location = /t {
rewrite_by_lua_block {
local pdk_rl = require("kong.pdk.private.rate_limiting")
local ok, err
local ok, err, errmsg
ok, err = pcall(pdk_rl.store_response_header, ngx.ctx, nil, 1)
assert(not ok, "pcall should fail")
errmsg = string.format(
"arg #%d `key` for function `%s` must be a string, got %s",
2,
"store_response_header",
type(nil)
)
assert(
err:find("arg #2 `key` for function `store_response_header` must be a string", nil, true),
err:find(errmsg, nil, true),
"unexpected error message: " .. err
)
for _k, v in ipairs({ 1, true, {}, function() end, ngx.null }) do
ok, err = pcall(pdk_rl.store_response_header, ngx.ctx, v, 1)
assert(not ok, "pcall should fail")
errmsg = string.format(
"arg #%d `key` for function `%s` must be a string, got %s",
2,
"store_response_header",
type(v)
)
assert(
err:find("arg #2 `key` for function `store_response_header` must be a string", nil, true),
err:find(errmsg, nil, true),
"unexpected error message: " .. err
)
end
ok, err = pcall(pdk_rl.store_response_header, ngx.ctx, "X-1", nil)
assert(not ok, "pcall should fail")
assert(
err:find("arg #3 `value` for function `store_response_header` must be a string or a number", nil, true),
"unexpected error message: " .. err
)
for _k, v in ipairs({ true, {}, function() end, ngx.null }) do
ok, err = pcall(pdk_rl.store_response_header, ngx.ctx, "X-1", v)
assert(not ok, "pcall should fail")
errmsg = string.format(
"arg #%d `value` for function `%s` must be a string, number or nil, got %s",
3,
"store_response_header",
type(v)
)
assert(
err:find("arg #3 `value` for function `store_response_header` must be a string or a number", nil, true),
err:find(errmsg, nil, true),
"unexpected error message: " .. err
)
end
ok, err = pcall(pdk_rl.get_stored_response_header, ngx.ctx, nil)
assert(not ok, "pcall should fail")
errmsg = string.format(
"arg #%d `key` for function `%s` must be a string, got %s",
2,
"get_stored_response_header",
type(nil)
)
assert(
err:find("arg #2 `key` for function `get_stored_response_header` must be a string", nil, true),
err:find(errmsg, nil, true),
"unexpected error message: " .. err
)
for _k, v in ipairs({ 1, true, {}, function() end, ngx.null }) do
ok, err = pcall(pdk_rl.get_stored_response_header, ngx.ctx, v)
assert(not ok, "pcall should fail")
errmsg = string.format(
"arg #%d `key` for function `%s` must be a string, got %s",
2,
"get_stored_response_header",
type(v)
)
assert(
err:find("arg #2 `key` for function `get_stored_response_header` must be a string", nil, true),
err:find(errmsg, nil, true),
"unexpected error message: " .. err
)
end
Expand Down

0 comments on commit f77ad1e

Please sign in to comment.