Skip to content

Commit

Permalink
Allow power event watcher interrupt to be called multiple times
Browse files Browse the repository at this point in the history
  • Loading branch information
RebeccaMahany committed Oct 11, 2023
1 parent af5086f commit 9d658e9
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 1 deletion.
10 changes: 9 additions & 1 deletion pkg/windows/powereventwatcher/power_event_watcher_other.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ import (
)

type noOpPowerEventWatcher struct {
interrupt chan struct{}
interrupt chan struct{}
interrupted bool
}

func New(_ types.Knapsack, _ log.Logger) (*noOpPowerEventWatcher, error) {
Expand All @@ -24,5 +25,12 @@ func (n *noOpPowerEventWatcher) Execute() error {
}

func (n *noOpPowerEventWatcher) Interrupt(_ error) {
// Only perform shutdown tasks on first call to interrupt -- no need to repeat on potential extra calls.
if n.interrupted {
return
}

n.interrupted = true

n.interrupt <- struct{}{}
}
53 changes: 53 additions & 0 deletions pkg/windows/powereventwatcher/power_event_watcher_other_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
//go:build !windows
// +build !windows

package powereventwatcher

import (
"errors"
"testing"
"time"

"github.com/go-kit/kit/log"
typesmocks "github.com/kolide/launcher/pkg/agent/types/mocks"
"github.com/stretchr/testify/require"
)

func TestInterrupt_Multiple(t *testing.T) {
t.Parallel()

p, err := New(typesmocks.NewKnapsack(t), log.NewNopLogger())
require.NoError(t, err)

// Start and then interrupt
go p.Execute()
p.Interrupt(errors.New("test error"))

// Confirm we can call Interrupt multiple times without blocking
interruptComplete := make(chan struct{})
expectedInterrupts := 3
for i := 0; i < expectedInterrupts; i += 1 {
go func() {
p.Interrupt(nil)
interruptComplete <- struct{}{}
}()
}

receivedInterrupts := 0
for {
if receivedInterrupts >= expectedInterrupts {
break
}

select {
case <-interruptComplete:
receivedInterrupts += 1
continue
case <-time.After(5 * time.Second):
t.Errorf("could not call interrupt multiple times and return within 5 seconds -- received %d interrupts before timeout", receivedInterrupts)
t.FailNow()
}
}

require.Equal(t, expectedInterrupts, receivedInterrupts)
}
8 changes: 8 additions & 0 deletions pkg/windows/powereventwatcher/power_event_watcher_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ type (
unsubscribeProcedure *syscall.LazyProc
renderEventLogProcedure *syscall.LazyProc
interrupt chan struct{}
interrupted bool
}
)

Expand Down Expand Up @@ -102,6 +103,13 @@ func (p *powerEventWatcher) Execute() error {
}

func (p *powerEventWatcher) Interrupt(_ error) {
// Only perform shutdown tasks on first call to interrupt -- no need to repeat on potential extra calls.
if p.interrupted {
return
}

p.interrupted = true

// EvtClose: https://learn.microsoft.com/en-us/windows/win32/api/winevt/nf-winevt-evtclose
ret, _, err := p.unsubscribeProcedure.Call(p.subscriptionHandle)
level.Debug(p.logger).Log("msg", "unsubscribed from power events", "ret", fmt.Sprintf("%+v", ret), "last_err", err)
Expand Down

0 comments on commit 9d658e9

Please sign in to comment.