From 194734cb97ec64975b909cd7c68e541a71d6f75c Mon Sep 17 00:00:00 2001 From: Qi Date: Mon, 17 Jun 2024 13:30:31 +0800 Subject: [PATCH] refactor(pdk): simplify private RL pdk --- kong/pdk/private/rate_limiting.lua | 268 ++---------------- kong/plugins/response-ratelimiting/access.lua | 10 +- .../response-ratelimiting/header_filter.lua | 17 +- t/01-pdk/16-rl-ctx.t | 200 +++++++++++++ 4 files changed, 238 insertions(+), 257 deletions(-) create mode 100644 t/01-pdk/16-rl-ctx.t diff --git a/kong/pdk/private/rate_limiting.lua b/kong/pdk/private/rate_limiting.lua index c730d8594266..5d3ae0d696bb 100644 --- a/kong/pdk/private/rate_limiting.lua +++ b/kong/pdk/private/rate_limiting.lua @@ -1,5 +1,5 @@ local table_new = require("table.new") -local buffer = require("string.buffer") +local table_clear = require("table.clear") local type = type local pairs = pairs @@ -7,80 +7,8 @@ local assert = assert local tostring = tostring local resp_header = ngx.header -local tablex_keys = require("pl.tablex").keys - -local RL_LIMIT = "RateLimit-Limit" -local RL_REMAINING = "RateLimit-Remaining" -local RL_RESET = "RateLimit-Reset" -local RETRY_AFTER = "Retry-After" - - -- determine the number of pre-allocated fields at runtime local max_fields_n = 4 -local buf = buffer.new(64) - -local LIMIT_BY = { - second = { - limit = "X-RateLimit-Limit-Second", - remain = "X-RateLimit-Remaining-Second", - limit_segment_0 = "X-", - limit_segment_1 = "RateLimit-Limit-", - limit_segment_3 = "-Second", - remain_segment_0 = "X-", - remain_segment_1 = "RateLimit-Remaining-", - remain_segment_3 = "-Second", - }, - minute = { - limit = "X-RateLimit-Limit-Minute", - remain = "X-RateLimit-Remaining-Minute", - limit_segment_0 = "X-", - limit_segment_1 = "RateLimit-Limit-", - limit_segment_3 = "-Minute", - remain_segment_0 = "X-", - remain_segment_1 = "RateLimit-Remaining-", - remain_segment_3 = "-Minute", - }, - hour = { - limit = "X-RateLimit-Limit-Hour", - remain = "X-RateLimit-Remaining-Hour", - limit_segment_0 = "X-", - limit_segment_1 = "RateLimit-Limit-", - limit_segment_3 = "-Hour", - remain_segment_0 = "X-", - remain_segment_1 = "RateLimit-Remaining-", - remain_segment_3 = "-Hour", - }, - day = { - limit = "X-RateLimit-Limit-Day", - remain = "X-RateLimit-Remaining-Day", - limit_segment_0 = "X-", - limit_segment_1 = "RateLimit-Limit-", - limit_segment_3 = "-Day", - remain_segment_0 = "X-", - remain_segment_1 = "RateLimit-Remaining-", - remain_segment_3 = "-Day", - }, - month = { - limit = "X-RateLimit-Limit-Month", - remain = "X-RateLimit-Remaining-Month", - limit_segment_0 = "X-", - limit_segment_1 = "RateLimit-Limit-", - limit_segment_3 = "-Month", - remain_segment_0 = "X-", - remain_segment_1 = "RateLimit-Remaining-", - remain_segment_3 = "-Month", - }, - year = { - limit = "X-RateLimit-Limit-Year", - remain = "X-RateLimit-Remaining-Year", - limit_segment_0 = "X-", - limit_segment_1 = "RateLimit-Limit-", - limit_segment_3 = "-Year", - remain_segment_0 = "X-", - remain_segment_1 = "RateLimit-Remaining-", - remain_segment_3 = "-Year", - }, -} local _M = {} @@ -114,201 +42,39 @@ local function _get_or_create_rl_ctx(ngx_ctx) end -function _M.set_basic_limit(ngx_ctx, limit, remaining, reset) - local rl_ctx = _get_or_create_rl_ctx(ngx_ctx or ngx.ctx) - - assert( - type(limit) == "number", - "arg #2 `limit` for `set_basic_limit` must be a number" - ) - assert( - type(remaining) == "number", - "arg #3 `remaining` for `set_basic_limit` must be a number" - ) +function _M.store_response_header(ngx_ctx, key, value) assert( - type(reset) == "number", - "arg #4 `reset` for `set_basic_limit` must be a number" + type(key) == "string", + "arg #2 `key` for function `store_response_header` must be a string" ) - rl_ctx[RL_LIMIT] = limit - rl_ctx[RL_REMAINING] = remaining - rl_ctx[RL_RESET] = reset -end - -function _M.set_retry_after(ngx_ctx, reset) - local rl_ctx = _get_or_create_rl_ctx(ngx_ctx or ngx.ctx) - - assert( - type(reset) == "number", - "arg #2 `reset` for `set_retry_after` must be a number" - ) - - rl_ctx[RETRY_AFTER] = reset -end - -function _M.set_limit_by(ngx_ctx, limit_by, limit, remaining) - local rl_ctx = _get_or_create_rl_ctx(ngx_ctx or ngx.ctx) - - assert( - type(limit_by) == "string", - "arg #2 `limit_by` for `set_limit_by` must be a string" - ) - assert( - type(limit) == "number", - "arg #3 `limit` for `set_limit_by` must be a number" - ) + local value_type = type(value) assert( - type(remaining) == "number", - "arg #4 `remaining` for `set_limit_by` must be a number" + value_type == "string" or value_type == "number", + "arg #3 `value` for function `store_response_header` must be a string or a number" ) - limit_by = LIMIT_BY[limit_by] - assert(limit_by, "invalid limit_by") - - rl_ctx[limit_by.limit] = limit - rl_ctx[limit_by.remain] = remaining + local rl_ctx = _get_or_create_rl_ctx(ngx_ctx) + rl_ctx[key] = value end -function _M.set_limit_by_with_identifier(ngx_ctx, limit_by, limit, remaining, id_seg_1, id_seg_2) - local rl_ctx = _get_or_create_rl_ctx(ngx_ctx or ngx.ctx) +function _M.get_stored_response_header(ngx_ctx, key) assert( - type(limit_by) == "string", - "arg #2 `limit_by` for `set_limit_by_with_identifier` must be a string" - ) - assert( - type(limit) == "number", - "arg #3 `limit` for `set_limit_by_with_identifier` must be a number" - ) - assert( - type(remaining) == "number", - "arg #4 `remaining` for `set_limit_by_with_identifier` must be a number" - ) - - local id_seg_1_typ = type(id_seg_1) - local id_seg_2_typ = type(id_seg_2) - assert( - id_seg_1_typ == "nil" or id_seg_1_typ == "string", - "arg #5 `id_seg_1` for `set_limit_by_with_identifier` must be a string or nil" - ) - assert( - id_seg_2_typ == "nil" or id_seg_2_typ == "string", - "arg #6 `id_seg_2` for `set_limit_by_with_identifier` must be a string or nil" + type(key) == "string", + "arg #2 `key` for function `get_stored_response_header` must be a string" ) - limit_by = LIMIT_BY[limit_by] - if not limit_by then - local valid_limit_bys = tablex_keys(LIMIT_BY) - local msg = string.format( - "arg #2 `limit_by` for `set_limit_by_with_identifier` must be one of: %s", - table.concat(valid_limit_bys, ", ") - ) - error(msg) + if not _has_rl_ctx(ngx_ctx) then + return nil end - id_seg_1 = id_seg_1 or "" - id_seg_2 = id_seg_2 or "" - - -- construct the key like X--RateLimit-Limit-- - local limit_key = buf:reset():put( - limit_by.limit_segment_0, - id_seg_1, - limit_by.limit_segment_1, - id_seg_2, - limit_by.limit_segment_3 - ):get() - - -- construct the key like X--RateLimit-Remaining-- - local remain_key = buf:reset():put( - limit_by.remain_segment_0, - id_seg_1, - limit_by.remain_segment_1, - id_seg_2, - limit_by.remain_segment_3 - ):get() - - rl_ctx[limit_key] = limit - rl_ctx[remain_key] = remaining -end - -function _M.get_basic_limit(ngx_ctx) - local rl_ctx = _get_rl_ctx(ngx_ctx or ngx.ctx) - return rl_ctx[RL_LIMIT], rl_ctx[RL_REMAINING], rl_ctx[RL_RESET] -end - -function _M.get_retry_after(ngx_ctx) - local rl_ctx = _get_rl_ctx(ngx_ctx or ngx.ctx) - return rl_ctx[RETRY_AFTER] -end - -function _M.get_limit_by(ngx_ctx, limit_by) - local rl_ctx = _get_rl_ctx(ngx_ctx or ngx.ctx) - - assert( - type(limit_by) == "string", - "arg #2 `limit_by` for `get_limit_by` must be a string" - ) - - limit_by = LIMIT_BY[limit_by] - assert(limit_by, "invalid limit_by") - - return rl_ctx[limit_by.limit], rl_ctx[limit_by.remain] + local rl_ctx = _get_rl_ctx(ngx_ctx) + return rl_ctx[key] end -function _M.get_limit_by_with_identifier(ngx_ctx, limit_by, id_seg_1, id_seg_2) - local rl_ctx = _get_rl_ctx(ngx_ctx or ngx.ctx) - - assert( - type(limit_by) == "string", - "arg #2 `limit_by` for `get_limit_by_with_identifier` must be a string" - ) - - local id_seg_1_typ = type(id_seg_1) - local id_seg_2_typ = type(id_seg_2) - assert( - id_seg_1_typ == "nil" or id_seg_1_typ == "string", - "arg #3 `id_seg_1` for `get_limit_by_with_identifier` must be a string or nil" - ) - assert( - id_seg_2_typ == "nil" or id_seg_2_typ == "string", - "arg #4 `id_seg_2` for `get_limit_by_with_identifier` must be a string or nil" - ) - - limit_by = LIMIT_BY[limit_by] - if not limit_by then - local valid_limit_bys = tablex_keys(LIMIT_BY) - local msg = string.format( - "arg #2 `limit_by` for `get_limit_by_with_identifier` must be one of: %s", - table.concat(valid_limit_bys, ", ") - ) - error(msg) - end - - id_seg_1 = id_seg_1 or "" - id_seg_2 = id_seg_2 or "" - - -- construct the key like X--RateLimit-Limit-- - local limit_key = buf:reset():put( - limit_by.limit_segment_0, - id_seg_1, - limit_by.limit_segment_1, - id_seg_2, - limit_by.limit_segment_3 - ):get() - - -- construct the key like X--RateLimit-Remaining-- - local remain_key = buf:reset():put( - limit_by.remain_segment_0, - id_seg_1, - limit_by.remain_segment_1, - id_seg_2, - limit_by.remain_segment_3 - ):get() - - return rl_ctx[limit_key], rl_ctx[remain_key] -end -function _M.set_response_headers(ngx_ctx) +function _M.apply_response_headers(ngx_ctx) if not _has_rl_ctx(ngx_ctx) then return end diff --git a/kong/plugins/response-ratelimiting/access.lua b/kong/plugins/response-ratelimiting/access.lua index 00078502c936..bba16660015e 100644 --- a/kong/plugins/response-ratelimiting/access.lua +++ b/kong/plugins/response-ratelimiting/access.lua @@ -1,5 +1,6 @@ local policies = require "kong.plugins.response-ratelimiting.policies" local timestamp = require "kong.tools.timestamp" +local pdk_private_rl = require "kong.pdk.private.rate_limiting" local kong = kong @@ -9,6 +10,10 @@ local error = error local tostring = tostring +local pdk_rl_store_response_header = pdk_private_rl.store_response_header +local pdk_rl_apply_response_headers = pdk_private_rl.apply_response_headers + + local EMPTY = {} local HTTP_TOO_MANY_REQUESTS = 429 local RATELIMIT_REMAINING = "X-RateLimit-Remaining" @@ -84,6 +89,7 @@ function _M.execute(conf) end -- Append usage headers to the upstream request. Also checks "block_on_first_violation". + local ngx_ctx = ngx.ctx for k in pairs(conf.limits) do local remaining for _, lv in pairs(usage[k]) do @@ -97,9 +103,11 @@ function _M.execute(conf) end end - kong.service.request.set_header(RATELIMIT_REMAINING .. "-" .. k, remaining) + pdk_rl_store_response_header(ngx_ctx, RATELIMIT_REMAINING .. "-" .. k, remaining) end + pdk_rl_apply_response_headers(ngx_ctx) + kong.ctx.plugin.usage = usage -- For later use end diff --git a/kong/plugins/response-ratelimiting/header_filter.lua b/kong/plugins/response-ratelimiting/header_filter.lua index e45c0ee5b480..65885b627c51 100644 --- a/kong/plugins/response-ratelimiting/header_filter.lua +++ b/kong/plugins/response-ratelimiting/header_filter.lua @@ -13,8 +13,12 @@ local math_max = math.max local strip = kong_string.strip local split = kong_string.split -local pdk_rl_set_response_headers = pdk_private_rl.set_response_headers -local pdk_rl_set_limit_by_with_identifier = pdk_private_rl.set_limit_by_with_identifier +local pdk_rl_store_response_header = pdk_private_rl.store_response_header +local pdk_rl_apply_response_headers = pdk_private_rl.apply_response_headers + + +local RATELIMIT_LIMIT = "X-RateLimit-Limit" +local RATELIMIT_REMAINING = "X-RateLimit-Remaining" local function parse_header(header_value, limits) @@ -68,8 +72,12 @@ function _M.execute(conf) for limit_name in pairs(usage) do for period_name, lv in pairs(usage[limit_name]) do if not conf.hide_client_headers then + local limit_hdr = RATELIMIT_LIMIT .. "-" .. limit_name .. "-" .. period_name + local remain_hdr = RATELIMIT_REMAINING .. "-" .. limit_name .. "-" .. period_name local remain = math_max(0, lv.remaining - (increments[limit_name] and increments[limit_name] or 0)) - pdk_rl_set_limit_by_with_identifier(ngx_ctx, period_name, lv.limit, remain, nil, limit_name) + + pdk_rl_store_response_header(ngx_ctx, limit_hdr, lv.limit) + pdk_rl_store_response_header(ngx_ctx, remain_hdr, remain) end if increments[limit_name] and increments[limit_name] > 0 and lv.remaining <= 0 then @@ -78,8 +86,7 @@ function _M.execute(conf) end end - -- Set rate-limiting response headers - pdk_rl_set_response_headers(ngx_ctx) + pdk_rl_apply_response_headers(ngx_ctx) kong.response.clear_header(conf.header_name) diff --git a/t/01-pdk/16-rl-ctx.t b/t/01-pdk/16-rl-ctx.t new file mode 100644 index 000000000000..dc94d1221627 --- /dev/null +++ b/t/01-pdk/16-rl-ctx.t @@ -0,0 +1,200 @@ +use strict; +use warnings FATAL => 'all'; +use Test::Nginx::Socket::Lua; +do "./t/Util.pm"; + +plan tests => repeat_each() * (blocks() * 4) - 1; + +run_tests(); + +__DATA__ + +=== TEST 1: should work in rewrite phase +--- http_config eval: $t::Util::HttpConfig +--- config + location = /t { + rewrite_by_lua_block { + local pdk_rl = require("kong.pdk.private.rate_limiting") + pdk_rl.store_response_header(ngx.ctx, "X-1", 1) + pdk_rl.store_response_header(ngx.ctx, "X-2", 2) + + local value = pdk_rl.get_stored_response_header(ngx.ctx, "X-1") + assert(value == 1, "unexpected value: " .. value) + + value = pdk_rl.get_stored_response_header(ngx.ctx, "X-2") + assert(value == 2, "unexpected value: " .. value) + + pdk_rl.apply_response_headers(ngx.ctx) + } + + content_by_lua_block { + ngx.say("ok") + } + + log_by_lua_block { + local pdk_rl = require("kong.pdk.private.rate_limiting") + + local value = pdk_rl.get_stored_response_header(ngx.ctx, "X-1") + assert(value == 1, "unexpected value: " .. value) + + value = pdk_rl.get_stored_response_header(ngx.ctx, "X-2") + assert(value == 2, "unexpected value: " .. value) + } + } +--- request +GET /t +--- response_headers +X-1: 1 +X-2: 2 +--- no_error_log +[error] + + + +=== TEST 2: should work in access phase +--- http_config eval: $t::Util::HttpConfig +--- config + location = /t { + access_by_lua_block { + local pdk_rl = require("kong.pdk.private.rate_limiting") + pdk_rl.store_response_header(ngx.ctx, "X-1", 1) + pdk_rl.store_response_header(ngx.ctx, "X-2", 2) + + local value = pdk_rl.get_stored_response_header(ngx.ctx, "X-1") + assert(value == 1, "unexpected value: " .. value) + + value = pdk_rl.get_stored_response_header(ngx.ctx, "X-2") + assert(value == 2, "unexpected value: " .. value) + + pdk_rl.apply_response_headers(ngx.ctx) + } + + content_by_lua_block { + ngx.say("ok") + } + + log_by_lua_block { + local pdk_rl = require("kong.pdk.private.rate_limiting") + + local value = pdk_rl.get_stored_response_header(ngx.ctx, "X-1") + assert(value == 1, "unexpected value: " .. value) + + value = pdk_rl.get_stored_response_header(ngx.ctx, "X-2") + assert(value == 2, "unexpected value: " .. value) + } + } +--- request +GET /t +--- response_headers +X-1: 1 +X-2: 2 +--- no_error_log +[error] + + +=== TEST 3: should work in header_filter phase +--- http_config eval: $t::Util::HttpConfig +--- config + location = /t { + header_filter_by_lua_block { + local pdk_rl = require("kong.pdk.private.rate_limiting") + pdk_rl.store_response_header(ngx.ctx, "X-1", 1) + pdk_rl.store_response_header(ngx.ctx, "X-2", 2) + + local value = pdk_rl.get_stored_response_header(ngx.ctx, "X-1") + assert(value == 1, "unexpected value: " .. value) + + value = pdk_rl.get_stored_response_header(ngx.ctx, "X-2") + assert(value == 2, "unexpected value: " .. value) + + pdk_rl.apply_response_headers(ngx.ctx) + } + + content_by_lua_block { + ngx.say("ok") + } + + log_by_lua_block { + local pdk_rl = require("kong.pdk.private.rate_limiting") + + local value = pdk_rl.get_stored_response_header(ngx.ctx, "X-1") + assert(value == 1, "unexpected value: " .. value) + + value = pdk_rl.get_stored_response_header(ngx.ctx, "X-2") + assert(value == 2, "unexpected value: " .. value) + } + } +--- request +GET /t +--- response_headers +X-1: 1 +X-2: 2 +--- no_error_log +[error] + + + +=== TEST 4: should not accept invalid arguments +--- http_config eval: $t::Util::HttpConfig +--- config + location = /t { + rewrite_by_lua_block { + local pdk_rl = require("kong.pdk.private.rate_limiting") + local ok, err + + ok, err = pcall(pdk_rl.store_response_header, ngx.ctx, nil, 1) + assert(not ok, "pcall should fail") + assert( + err:find("arg #2 `key` for function `store_response_header` must be a string", nil, true), + "unexpected error message: " .. err + ) + for _k, v in ipairs({ 1, true, {}, function() end, ngx.null }) do + ok, err = pcall(pdk_rl.store_response_header, ngx.ctx, v, 1) + assert(not ok, "pcall should fail") + assert( + err:find("arg #2 `key` for function `store_response_header` must be a string", nil, true), + "unexpected error message: " .. err + ) + end + + ok, err = pcall(pdk_rl.store_response_header, ngx.ctx, "X-1", nil) + assert(not ok, "pcall should fail") + assert( + err:find("arg #3 `value` for function `store_response_header` must be a string or a number", nil, true), + "unexpected error message: " .. err + ) + for _k, v in ipairs({ true, {}, function() end, ngx.null }) do + ok, err = pcall(pdk_rl.store_response_header, ngx.ctx, "X-1", v) + assert(not ok, "pcall should fail") + assert( + err:find("arg #3 `value` for function `store_response_header` must be a string or a number", nil, true), + "unexpected error message: " .. err + ) + end + + ok, err = pcall(pdk_rl.get_stored_response_header, ngx.ctx, nil) + assert(not ok, "pcall should fail") + assert( + err:find("arg #2 `key` for function `get_stored_response_header` must be a string", nil, true), + "unexpected error message: " .. err + ) + for _k, v in ipairs({ 1, true, {}, function() end, ngx.null }) do + ok, err = pcall(pdk_rl.get_stored_response_header, ngx.ctx, v) + assert(not ok, "pcall should fail") + assert( + err:find("arg #2 `key` for function `get_stored_response_header` must be a string", nil, true), + "unexpected error message: " .. err + ) + end + } + + content_by_lua_block { + ngx.print("ok") + } + } +--- request +GET /t +--- response_body eval +"ok" +--- no_error_log +[error] \ No newline at end of file