diff --git a/kong/db/dao/init.lua b/kong/db/dao/init.lua index 561dd2644973..1c26fd849d4c 100644 --- a/kong/db/dao/init.lua +++ b/kong/db/dao/init.lua @@ -15,6 +15,7 @@ local new_tab = require "table.new" local DAO_MAX_TTL = require("kong.constants").DATABASE.DAO_MAX_TTL local get_request_id = require("kong.tracing.request_id").get local is_valid_uuid = require("kong.tools.uuid").is_valid_uuid +local deep_copy = require("kong.tools.table").deep_copy local setmetatable = setmetatable local tostring = tostring @@ -312,6 +313,12 @@ local function validate_options_value(self, options) end end + if options.skip_ttl ~= nil then + if type(options.skip_ttl) ~= "boolean" then + errors.skip_ttl = "must be a boolean" + end + end + if next(errors) then return nil, errors end @@ -948,14 +955,9 @@ local function generate_foreign_key_methods(schema) return nil, err, err_t end - -- Must have fully hydrated entity (including workspace id) for hooks to notify on - local show_ws_id = { show_ws_id = true } - if options ~= nil then - for k, v in pairs(options) do - show_ws_id[k] = v - end - end - local entity, err, err_t = self["select_by_" .. name](self, unique_value, show_ws_id) + local select_options = deep_copy(options or {}) + select_options["show_ws_id"] = true + local entity, err, err_t = self["select_by_" .. name](self, unique_value, select_options) if err then return nil, err, err_t end @@ -964,7 +966,7 @@ local function generate_foreign_key_methods(schema) return true end - local cascade_entries = find_cascade_delete_entities(self, entity, show_ws_id) + local cascade_entries = find_cascade_delete_entities(self, entity, select_options) local ok, err_t = run_hook("dao:delete_by:pre", entity, @@ -1354,14 +1356,9 @@ function DAO:delete(pk_or_entity, options) return nil, tostring(err_t), err_t end - -- Must have fully hydrated entity (including workspace id) for hooks to notify on - local show_ws_id = { show_ws_id = true } - if options ~= nil then - for k, v in pairs(options) do - show_ws_id[k] = v - end - end - local entity, err, err_t = self:select(primary_key, show_ws_id) + local select_options = deep_copy(options or {}) + select_options["show_ws_id"] = true + local entity, err, err_t = self:select(primary_key, select_options) if err then return nil, err, err_t end @@ -1378,7 +1375,7 @@ function DAO:delete(pk_or_entity, options) end end - local cascade_entries = find_cascade_delete_entities(self, primary_key, show_ws_id) + local cascade_entries = find_cascade_delete_entities(self, primary_key, select_options) local ws_id = entity.ws_id local _ diff --git a/kong/db/strategies/postgres/init.lua b/kong/db/strategies/postgres/init.lua index 14034d13cb7c..addd22076f73 100644 --- a/kong/db/strategies/postgres/init.lua +++ b/kong/db/strategies/postgres/init.lua @@ -432,7 +432,7 @@ local function execute(strategy, statement_name, attributes, options) local is_update = options and options.update local has_ttl = strategy.schema.ttl - + local skip_ttl = options and options.skip_ttl if has_ws_id then assert(ws_id == nil or type(ws_id) == "string") argv[0] = escape_literal(connector, ws_id, "ws_id") @@ -441,7 +441,7 @@ local function execute(strategy, statement_name, attributes, options) for i = 1, argc do local name = argn[i] local value - if has_ttl and name == "ttl" then + if has_ttl and name == "ttl" and not skip_ttl then value = (options and options.ttl) and get_ttl_value(strategy, attributes, options) @@ -707,7 +707,12 @@ end function _mt:select(primary_key, options) - local res, err = execute(self, "select", self.collapse(primary_key), options) + local statement_name = "select" + if self.schema.ttl and options and options.skip_ttl then + statement_name = "select_skip_ttl" + end + + local res, err = execute(self, statement_name, self.collapse(primary_key), options) if res then local row = res[1] if row then @@ -723,6 +728,11 @@ end function _mt:select_by_field(field_name, unique_value, options) local statement_name = "select_by_" .. field_name + + if self.schema.ttl and options and options.skip_ttl then + statement_name = statement_name .. "_skip_ttl" + end + local filter = { [field_name] = unique_value, } @@ -826,7 +836,11 @@ end function _mt:delete(primary_key, options) - local res, err = execute(self, "delete", self.collapse(primary_key), options) + local statement_name = "delete" + if self.schema.ttl and options and options.skip_ttl then + statement_name = "delete_skip_ttl" + end + local res, err = execute(self, statement_name, self.collapse(primary_key), options) if res then if res.affected_rows == 0 then return nil, nil @@ -841,6 +855,9 @@ end function _mt:delete_by_field(field_name, unique_value, options) local statement_name = "delete_by_" .. field_name + if self.schema.ttl and options and options.skip_ttl then + statement_name = statement_name .. "_skip_ttl" + end local filter = { [field_name] = unique_value, } @@ -1319,6 +1336,19 @@ function _M.new(connector, schema, errors) } }) + add_statement("delete_skip_ttl", { + operation = "write", + argn = primary_key_names, + argv = primary_key_args, + code = { + "DELETE\n", + " FROM ", table_name_escaped, "\n", + where_clause( + " WHERE ", "(" .. pk_escaped .. ") = (" .. primary_key_placeholders .. ")", + ws_id_select_where), ";" + } + }) + add_statement("select", { operation = "read", expr = select_expressions, @@ -1335,6 +1365,21 @@ function _M.new(connector, schema, errors) } }) + add_statement("select_skip_ttl", { + operation = "read", + expr = select_expressions, + argn = primary_key_names, + argv = primary_key_args, + code = { + "SELECT ", select_expressions, "\n", + " FROM ", table_name_escaped, "\n", + where_clause( + " WHERE ", "(" .. pk_escaped .. ") = (" .. primary_key_placeholders .. ")", + ws_id_select_where), + " LIMIT 1;" + } + }) + add_statement_for_export("page_first", { operation = "read", argn = { LIMIT }, @@ -1685,6 +1730,20 @@ function _M.new(connector, schema, errors) }, }) + add_statement("select_by_" .. field_name .. "_skip_ttl", { + operation = "read", + argn = single_names, + argv = single_args, + code = { + "SELECT ", select_expressions, "\n", + " FROM ", table_name_escaped, "\n", + where_clause( + " WHERE ", unique_escaped .. " = $1", + ws_id_select_where), + " LIMIT 1;" + }, + }) + local update_by_args_names = {} for _, update_name in ipairs(update_names) do insert(update_by_args_names, update_name) @@ -1740,6 +1799,19 @@ function _M.new(connector, schema, errors) ws_id_select_where), ";" } }) + + add_statement("delete_by_" .. field_name .. "_skip_ttl", { + operation = "write", + argn = single_names, + argv = single_args, + code = { + "DELETE\n", + " FROM ", table_name_escaped, "\n", + where_clause( + " WHERE ", unique_escaped .. " = $1", + ws_id_select_where), ";" + } + }) end end diff --git a/spec/02-integration/03-db/14-dao_spec.lua b/spec/02-integration/03-db/14-dao_spec.lua index 6708f58aaf8e..8c388ba196ce 100644 --- a/spec/02-integration/03-db/14-dao_spec.lua +++ b/spec/02-integration/03-db/14-dao_spec.lua @@ -23,6 +23,7 @@ for _, strategy in helpers.all_strategies() do "services", "consumers", "acls", + "keyauth_credentials", }) _G.kong.db = db @@ -105,6 +106,7 @@ for _, strategy in helpers.all_strategies() do db.consumers:truncate() db.plugins:truncate() db.services:truncate() + db.keyauth_credentials:truncate() end) it("select_by_cache_key()", function() @@ -192,6 +194,36 @@ for _, strategy in helpers.all_strategies() do assert.same(new_plugin_config.config.redis.host, read_plugin.config.redis.host) assert.same(new_plugin_config.config.redis.host, read_plugin.config.redis_host) -- legacy field is included end) + + it("keyauth_credentials can be deleted or selected before run ttl cleanup in background timer", function() + local key = uuid() + local original_keyauth_credentials = bp.keyauth_credentials:insert({ + consumer = { id = consumer.id }, + key = key, + }, { ttl = 5 }) + + -- wait for 5 seconds. + ngx.sleep(5) + + -- select or delete keyauth_credentials after ttl expired. + local expired_keyauth_credentials + helpers.wait_until(function() + expired_keyauth_credentials = kong.db.keyauth_credentials:select_by_key(key) + return not expired_keyauth_credentials + end, 1) + assert.is_nil(expired_keyauth_credentials) + kong.db.keyauth_credentials:delete_by_key(key) + + -- select or delete keyauth_credentials with skip_ttl=true after ttl expired. + expired_keyauth_credentials = kong.db.keyauth_credentials:select_by_key(key, { skip_ttl = true }) + assert.not_nil(expired_keyauth_credentials) + assert.same(expired_keyauth_credentials.id, original_keyauth_credentials.id) + kong.db.keyauth_credentials:delete_by_key(key, { skip_ttl = true }) + + -- check again + expired_keyauth_credentials = kong.db.keyauth_credentials:select_by_key(key, { skip_ttl = true }) + assert.is_nil(expired_keyauth_credentials) + end) end) end