diff --git a/pkg/component/controller/leaderelector/leasepool.go b/pkg/component/controller/leaderelector/leasepool.go index 65a2de09874e..d510b49e31db 100644 --- a/pkg/component/controller/leaderelector/leasepool.go +++ b/pkg/component/controller/leaderelector/leasepool.go @@ -18,6 +18,7 @@ package leaderelector import ( "context" + "errors" "fmt" "sync" @@ -75,44 +76,25 @@ func (l *LeasePool) Start(context.Context) error { return err } - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancelCause(context.Background()) var wg sync.WaitGroup wg.Add(2) go func() { defer wg.Done(); client.Run(ctx, l.status.Set) }() - go func() { defer wg.Done(); l.invokeCallbacks(ctx.Done()) }() + go func() { defer wg.Done(); l.invokeCallbacks(ctx) }() - l.stop = func() { cancel(); wg.Wait() } + l.stop = func() { cancel(errors.New("lease pool is stopping")); wg.Wait() } return nil } -func (l *LeasePool) invokeCallbacks(done <-chan struct{}) { - var lastStatus leaderelection.Status - - for { - status, statusChanged := l.status.Peek() - - if status != lastStatus { - lastStatus = status - if status == leaderelection.StatusLeading { - l.log.Info("acquired leader lease") - runCallbacks(l.acquiredLeaseCallbacks) - } else { - l.log.Info("lost leader lease") - runCallbacks(l.lostLeaseCallbacks) - } - } - - select { - case <-statusChanged: - case <-done: - l.log.Info("Lease pool is stopping") - if status == leaderelection.StatusLeading { - runCallbacks(l.lostLeaseCallbacks) - } - return - } - } +func (l *LeasePool) invokeCallbacks(ctx context.Context) { + RunLeaderTasks(ctx, l.GetLeaderStatus, func(ctx context.Context) { + l.log.Info("acquired leader lease") + runCallbacks(l.acquiredLeaseCallbacks) + <-ctx.Done() + l.log.Infof("lost leader lease (%v)", context.Cause(ctx)) + runCallbacks(l.lostLeaseCallbacks) + }) } func runCallbacks(callbacks []func()) { @@ -148,3 +130,36 @@ func (l *LeasePool) IsLeader() bool { status, _ := l.GetLeaderStatus() return status == leaderelection.StatusLeading } + +// Indicates that the previously gained lead has been lost. +var ErrLostLead = errors.New("lost the lead") + +// Runs the provided tasks function when the lead is taken. It continuously +// monitors the leader election status using the provided peek function. When +// the lead is taken, the tasks function is called with a context that is +// cancelled either when the lead has been lost or ctx is done. After the tasks +// function returns, the process is repeated until ctx is done. +func RunLeaderTasks(ctx context.Context, peek func() (leaderelection.Status, <-chan struct{}), tasks func(context.Context)) { + for { + status, statusChanged := peek() + + if status == leaderelection.StatusLeading { + ctx, cancel := context.WithCancelCause(ctx) + go func() { + select { + case <-statusChanged: + cancel(ErrLostLead) + case <-ctx.Done(): + } + }() + + tasks(ctx) + } + + select { + case <-statusChanged: + case <-ctx.Done(): + return + } + } +}