Skip to content

Commit

Permalink
refactor(concurrency): consistent node-level locks
Browse files Browse the repository at this point in the history
Several places in the gateway need a node-level lock, some of them used
slightly different implementations. This refactor brings consistency in
the ways we do node-level locking by using the same implementation
(concurrency.with_worker_mutex) everywhere.
  • Loading branch information
samugi committed Aug 13, 2024
1 parent cfd997f commit 5e74999
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 76 deletions.
92 changes: 29 additions & 63 deletions kong/cluster_events/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ local timer_at = ngx.timer.at
local ngx_update_time = ngx.update_time

local knode = kong and kong.node or require "kong.pdk.node".new()
local concurrency = require "kong.concurrency"

local POLL_INTERVAL_LOCK_KEY = "cluster_events:poll_interval"
local POLL_RUNNING_LOCK_KEY = "cluster_events:poll_running"
Expand Down Expand Up @@ -326,80 +327,45 @@ if ngx_debug then
end


local function get_lock(self)
-- check if a poll is not currently running, to ensure we don't start
-- another poll while a worker is still stuck in its own polling (in
-- case it is being slow)
-- we still add an exptime to this lock in case something goes horribly
-- wrong, to ensure other workers can poll new events
-- a poll cannot take more than max(poll_interval * 5, 10) -- 10s min
local ok, err = self.shm:safe_add(POLL_RUNNING_LOCK_KEY, true,
max(self.poll_interval * 5, 10))
if not ok then
if err ~= "exists" then
log(ERR, "failed to acquire poll_running lock: ", err)
end
-- else
-- log(DEBUG, "failed to acquire poll_running lock: ",
-- "a worker still holds the lock")

return false
end

if self.poll_interval > 0.001 then
-- check if interval of `poll_interval` has elapsed already, to ensure
-- we do not run the poll when a previous poll was quickly executed, but
-- another worker got the timer trigger a bit too late.
ok, err = self.shm:safe_add(POLL_INTERVAL_LOCK_KEY, true,
self.poll_interval - 0.001)
if not ok then
if err ~= "exists" then
log(ERR, "failed to acquire poll_interval lock: ", err)
end
-- else
-- log(DEBUG, "failed to acquire poll_interval lock: ",
-- "not enough time elapsed since last poll")

self.shm:delete(POLL_RUNNING_LOCK_KEY)

return false
end
end

return true
end


poll_handler = function(premature, self)
if premature or not self.polling then
-- set self.polling to false to stop a polling loop
return
end

if not get_lock(self) then
local ok, err = timer_at(self.poll_interval, poll_handler, self)
if not ok then
log(CRIT, "failed to start recurring polling timer: ", err)
-- check if a poll is not currently running, to ensure we don't start
-- another poll while a worker is still stuck in its own polling (in
-- case it is being slow)
-- we still add an exptime to this lock in case something goes horribly
-- wrong, to ensure other workers can poll new events
-- a poll cannot take more than max(poll_interval * 5, 10) -- 10s min
local ok, err = concurrency.with_worker_mutex({
name = POLL_RUNNING_LOCK_KEY,
timeout = 0,
exptime = max(self.poll_interval * 5, 10),
}, function()
if self.poll_interval > 0.001 then
-- check if interval of `poll_interval` has elapsed already, to ensure
-- we do not run the poll when a previous poll was quickly executed, but
-- another worker got the timer trigger a bit too late.
return concurrency.with_worker_mutex({
name = POLL_INTERVAL_LOCK_KEY,
timeout = 0,
exptime = self.poll_interval - 0.001,
}, function()
return poll(self)
end)
end

return
end
return poll(self)
end)

-- single worker

local pok, perr, err = pcall(poll, self)
if not pok then
log(ERR, "poll() threw an error: ", perr)

elseif not perr then
log(ERR, "failed to poll: ", err)
if not ok and err ~= "timeout" then
log(ERR, err)
end

-- unlock

self.shm:delete(POLL_RUNNING_LOCK_KEY)

local ok, err = timer_at(self.poll_interval, poll_handler, self)
-- schedule next polling timer
ok, err = timer_at(self.poll_interval, poll_handler, self)
if not ok then
log(CRIT, "failed to start recurring polling timer: ", err)
end
Expand Down
3 changes: 2 additions & 1 deletion kong/concurrency.lua
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ function concurrency.with_worker_mutex(opts, fn)
local elapsed, err = rlock:lock(opts_name)
if not elapsed then
if err == "timeout" then
return nil, err
local ttl = rlock.dict and rlock.dict:ttl(opts_name)
return nil, err, ttl
end
return nil, "failed to acquire worker lock: " .. err
end
Expand Down
23 changes: 12 additions & 11 deletions kong/db/declarative/import.lua
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ local constants = require("kong.constants")
local workspaces = require("kong.workspaces")
local cycle_aware_deep_copy = require("kong.tools.table").cycle_aware_deep_copy
local declarative_config = require("kong.db.schema.others.declarative_config")
local concurrency = require("kong.concurrency")


local yield = require("kong.tools.yield").yield
Expand Down Expand Up @@ -571,25 +572,25 @@ do
local DECLARATIVE_RETRY_TTL_MAX = 10
local DECLARATIVE_LOCK_KEY = "declarative:lock"

-- make sure no matter which path it exits, we released the lock.
load_into_cache_with_events = function(entities, meta, hash, hashes)
local kong_shm = ngx.shared.kong
local ok, err, ttl = concurrency.with_worker_mutex({
name = DECLARATIVE_LOCK_KEY,
timeout = 0,
exptime = DECLARATIVE_LOCK_TTL,
}, function()
return load_into_cache_with_events_no_lock(entities, meta, hash, hashes)
end)

local ok, err = kong_shm:add(DECLARATIVE_LOCK_KEY, 0, DECLARATIVE_LOCK_TTL)
if not ok then
if err == "exists" then
local ttl = min(kong_shm:ttl(DECLARATIVE_LOCK_KEY), DECLARATIVE_RETRY_TTL_MAX)
return nil, "busy", ttl
if err == "timeout" then
ttl = ttl or DECLARATIVE_RETRY_TTL_MAX
local retry_after = min(ttl, DECLARATIVE_RETRY_TTL_MAX)
return nil, "busy", retry_after
end

kong_shm:delete(DECLARATIVE_LOCK_KEY)
return nil, err
end

ok, err = load_into_cache_with_events_no_lock(entities, meta, hash, hashes)

kong_shm:delete(DECLARATIVE_LOCK_KEY)

return ok, err
end
end
Expand Down
19 changes: 18 additions & 1 deletion spec/02-integration/06-invalidations/01-cluster_events_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,28 @@ _G.ngx.config.debug = true

local helpers = require "spec.helpers"
local kong_cluster_events = require "kong.cluster_events"
local match = require "luassert.match"


for _, strategy in helpers.each_strategy() do
describe("cluster_events with db [#" .. strategy .. "]", function()
local db
local db, log_spy, orig_ngx_log

lazy_setup(function()
local _
_, db = helpers.get_db_utils(strategy, {})

orig_ngx_log = ngx.log
local logged = { level = function() end }
log_spy = spy.on(logged, "level")
_G.ngx.log = function(l) logged.level(l) end -- luacheck: ignore
end)

lazy_teardown(function()
local cluster_events = assert(kong_cluster_events.new { db = db })
cluster_events.strategy:truncate_events()

_G.ngx.log = orig_ngx_log -- luacheck: ignore
end)

before_each(function()
Expand Down Expand Up @@ -121,6 +129,7 @@ for _, strategy in helpers.each_strategy() do

assert(cluster_events_1:poll())
assert.spy(spy_func).was_called(3)
assert.spy(log_spy).was_not_called_with(match.is_not.gt(ngx.ERR))
end)

it("broadcasts data to subscribers", function()
Expand All @@ -144,6 +153,7 @@ for _, strategy in helpers.each_strategy() do
assert(cluster_events_1:poll())
assert.spy(spy_func).was_called(1)
assert.spy(spy_func).was_called_with("hello world")
assert.spy(log_spy).was_not_called_with(match.is_not.gt(ngx.ERR))
end)

it("does not broadcast events on the same node", function()
Expand All @@ -165,6 +175,7 @@ for _, strategy in helpers.each_strategy() do

assert(cluster_events_1:poll())
assert.spy(spy_func).was_not_called()
assert.spy(log_spy).was_not_called_with(match.is_not.gt(ngx.ERR))
end)

it("starts interval polling when subscribing", function()
Expand Down Expand Up @@ -199,6 +210,7 @@ for _, strategy in helpers.each_strategy() do
helpers.wait_until(function()
return called == 2
end, 10)
assert.spy(log_spy).was_not_called_with(match.is_not.gt(ngx.ERR))
end)

it("applies a poll_offset to lookback potentially missed events", function()
Expand Down Expand Up @@ -240,6 +252,7 @@ for _, strategy in helpers.each_strategy() do

assert(cluster_events_1:poll())
assert.spy(spy_func).was_called(2) -- not called again this time
assert.spy(log_spy).was_not_called_with(match.is_not.gt(ngx.ERR))
end)

it("handles more than <PAGE_SIZE> events at once", function()
Expand All @@ -263,6 +276,7 @@ for _, strategy in helpers.each_strategy() do

assert(cluster_events_1:poll())
assert.spy(spy_func).was_called(201)
assert.spy(log_spy).was_not_called_with(match.is_not.gt(ngx.ERR))
end)

it("runs callbacks in protected mode", function()
Expand All @@ -285,6 +299,7 @@ for _, strategy in helpers.each_strategy() do
assert.has_no_error(function()
cluster_events_1:poll()
end)
assert.spy(log_spy).was_not_called_with(match.is_not.gt(ngx.ERR))
end)

it("broadcasts an event with a delay", function()
Expand Down Expand Up @@ -319,6 +334,7 @@ for _, strategy in helpers.each_strategy() do
assert(cluster_events_1:poll())
return pcall(assert.spy(spy_func).was_called, 1) -- called
end, 1) -- note that we have already waited for `delay` seconds
assert.spy(log_spy).was_not_called_with(match.is_not.gt(ngx.ERR))
end)

it("broadcasts an event with a polling delay for subscribers", function()
Expand Down Expand Up @@ -356,6 +372,7 @@ for _, strategy in helpers.each_strategy() do
assert(cluster_events_1:poll())
return pcall(assert.spy(spy_func).was_called, 1) -- called
end, 1) -- note that we have already waited for `delay` seconds
assert.spy(log_spy).was_not_called_with(match.is_not.gt(ngx.ERR))
end)
end)
end)
Expand Down
16 changes: 16 additions & 0 deletions spec/helpers.lua
Original file line number Diff line number Diff line change
Expand Up @@ -2832,6 +2832,22 @@ luassert:register("assertion", "gt", is_gt,
"assertion.gt.negative",
"assertion.gt.positive")



---
-- Matcher to ensure a value is greater than a base value.
-- @function is_gt_matcher
-- @param base the base value to compare against
-- @param value the value that must be greater than the base value
local function is_gt_matcher(state, arguments)
local expected = arguments[1]
return function(value)
return value > expected
end
end
luassert:register("matcher", "gt", is_gt_matcher)


--- Generic modifier "certificate".
-- Will set a "certificate" value in the assertion state, so following
-- assertions will operate on the value set.
Expand Down

0 comments on commit 5e74999

Please sign in to comment.