diff --git a/kong/pdk/private/rate_limiting.lua b/kong/pdk/private/rate_limiting.lua index 1710372a71df..a267c0a20652 100644 --- a/kong/pdk/private/rate_limiting.lua +++ b/kong/pdk/private/rate_limiting.lua @@ -12,6 +12,34 @@ local max_fields_n = 4 local _M = {} +local function _validate_key(key, arg_n, func_name) + local typ = type(key) + if typ ~= "string" then + local msg = string.format( + "arg #%d `key` for function `%s` must be a string, got %s", + arg_n, + func_name, + typ + ) + error(msg, 3) + end +end + + +local function _validate_value(value, arg_n, func_name) + local typ = type(value) + if typ ~= "number" and typ ~= "string" then + local msg = string.format( + "arg #%d `value` for function `%s` must be a string or a number, got %s", + arg_n, + func_name, + typ + ) + error(msg, 3) + end +end + + local function _has_rl_ctx(ngx_ctx) return ngx_ctx.__rate_limiting_context__ ~= nil end @@ -42,16 +70,8 @@ end function _M.store_response_header(ngx_ctx, key, value) - assert( - type(key) == "string", - "arg #2 `key` for function `store_response_header` must be a string" - ) - - local value_type = type(value) - assert( - value_type == "string" or value_type == "number", - "arg #3 `value` for function `store_response_header` must be a string or a number" - ) + _validate_key(key, 2, "store_response_header") + _validate_value(value, 3, "store_response_header") local rl_ctx = _get_or_create_rl_ctx(ngx_ctx) rl_ctx[key] = value @@ -59,10 +79,11 @@ end function _M.get_stored_response_header(ngx_ctx, key) - assert( - type(key) == "string", - "arg #2 `key` for function `get_stored_response_header` must be a string" - ) + _validate_key(key, 2, "get_stored_response_header") + + if not _has_rl_ctx(ngx_ctx) then + return nil + end if not _has_rl_ctx(ngx_ctx) then return nil diff --git a/t/01-pdk/16-rl-ctx.t b/t/01-pdk/16-rl-ctx.t index 2c233fd9c01f..66519c9a19c3 100644 --- a/t/01-pdk/16-rl-ctx.t +++ b/t/01-pdk/16-rl-ctx.t @@ -140,49 +140,85 @@ X-2: 2 location = /t { rewrite_by_lua_block { local pdk_rl = require("kong.pdk.private.rate_limiting") - local ok, err + local ok, err, errmsg ok, err = pcall(pdk_rl.store_response_header, ngx.ctx, nil, 1) assert(not ok, "pcall should fail") + errmsg = string.format( + "arg #%d `key` for function `%s` must be a string, got %s", + 2, + "store_response_header", + type(nil) + ) assert( - err:find("arg #2 `key` for function `store_response_header` must be a string", nil, true), + err:find(errmsg, nil, true), "unexpected error message: " .. err ) for _k, v in ipairs({ 1, true, {}, function() end, ngx.null }) do ok, err = pcall(pdk_rl.store_response_header, ngx.ctx, v, 1) assert(not ok, "pcall should fail") + errmsg = string.format( + "arg #%d `key` for function `%s` must be a string, got %s", + 2, + "store_response_header", + type(v) + ) assert( - err:find("arg #2 `key` for function `store_response_header` must be a string", nil, true), + err:find(errmsg, nil, true), "unexpected error message: " .. err ) end ok, err = pcall(pdk_rl.store_response_header, ngx.ctx, "X-1", nil) assert(not ok, "pcall should fail") + errmsg = string.format( + "arg #%d `value` for function `%s` must be a string or a number, got %s", + 3, + "store_response_header", + type(nil) + ) assert( - err:find("arg #3 `value` for function `store_response_header` must be a string or a number", nil, true), + err:find(errmsg, nil, true), "unexpected error message: " .. err ) for _k, v in ipairs({ true, {}, function() end, ngx.null }) do ok, err = pcall(pdk_rl.store_response_header, ngx.ctx, "X-1", v) assert(not ok, "pcall should fail") + errmsg = string.format( + "arg #%d `value` for function `%s` must be a string or a number, got %s", + 3, + "store_response_header", + type(v) + ) assert( - err:find("arg #3 `value` for function `store_response_header` must be a string or a number", nil, true), + err:find(errmsg, nil, true), "unexpected error message: " .. err ) end ok, err = pcall(pdk_rl.get_stored_response_header, ngx.ctx, nil) assert(not ok, "pcall should fail") + errmsg = string.format( + "arg #%d `key` for function `%s` must be a string, got %s", + 2, + "get_stored_response_header", + type(nil) + ) assert( - err:find("arg #2 `key` for function `get_stored_response_header` must be a string", nil, true), + err:find(errmsg, nil, true), "unexpected error message: " .. err ) for _k, v in ipairs({ 1, true, {}, function() end, ngx.null }) do ok, err = pcall(pdk_rl.get_stored_response_header, ngx.ctx, v) assert(not ok, "pcall should fail") + errmsg = string.format( + "arg #%d `key` for function `%s` must be a string, got %s", + 2, + "get_stored_response_header", + type(v) + ) assert( - err:find("arg #2 `key` for function `get_stored_response_header` must be a string", nil, true), + err:find(errmsg, nil, true), "unexpected error message: " .. err ) end