From 13d642bb46e56d8e0791f2d1b7a6762da571dfff Mon Sep 17 00:00:00 2001 From: Diaa Sami Date: Wed, 1 Nov 2023 14:26:38 +0100 Subject: [PATCH] pkg/splunk_logger: handle potentially dangling goroutine --- pkg/splunk_logger/splunk_hook.go | 5 +- pkg/splunk_logger/splunk_logger.go | 13 +++-- pkg/splunk_logger/splunk_logger_test.go | 65 +++++++++++++++++++++++++ 3 files changed, 78 insertions(+), 5 deletions(-) create mode 100644 pkg/splunk_logger/splunk_logger_test.go diff --git a/pkg/splunk_logger/splunk_hook.go b/pkg/splunk_logger/splunk_hook.go index 32dd20b672..4c0f0702a5 100644 --- a/pkg/splunk_logger/splunk_hook.go +++ b/pkg/splunk_logger/splunk_hook.go @@ -1,6 +1,7 @@ package logger import ( + "context" "fmt" "os" @@ -11,7 +12,7 @@ type SplunkHook struct { sl *SplunkLogger } -func NewSplunkHook(host, port, token, source string) (*SplunkHook, error) { +func NewSplunkHook(context context.Context, host, port, token, source string) (*SplunkHook, error) { url := fmt.Sprintf("https://%s:%s/services/collector/event", host, port) hostname, err := os.Hostname() if err != nil { @@ -19,7 +20,7 @@ func NewSplunkHook(host, port, token, source string) (*SplunkHook, error) { } return &SplunkHook{ - sl: NewSplunkLogger(url, token, source, hostname), + sl: NewSplunkLogger(context, url, token, source, hostname), }, nil } diff --git a/pkg/splunk_logger/splunk_logger.go b/pkg/splunk_logger/splunk_logger.go index f124782eb4..1ec68a4e14 100644 --- a/pkg/splunk_logger/splunk_logger.go +++ b/pkg/splunk_logger/splunk_logger.go @@ -2,6 +2,7 @@ package logger import ( "bytes" + "context" "encoding/json" "fmt" "net/http" @@ -40,7 +41,7 @@ type SplunkEvent struct { Host string `json:"host"` } -func NewSplunkLogger(url, token, source, hostname string) *SplunkLogger { +func NewSplunkLogger(context context.Context, url, token, source, hostname string) *SplunkLogger { sl := &SplunkLogger{ client: retryablehttp.NewClient().StandardClient(), url: url, @@ -52,15 +53,21 @@ func NewSplunkLogger(url, token, source, hostname string) *SplunkLogger { ticker := time.NewTicker(time.Second * SendFrequency) sl.payloads = make(chan *SplunkPayload, PayloadsChannelSize) - go sl.flushPayloads(ticker.C) + go sl.flushPayloads(context, ticker.C) return sl } -func (sl *SplunkLogger) flushPayloads(ticker <-chan time.Time) { +func (sl *SplunkLogger) flushPayloads(context context.Context, ticker <-chan time.Time) { var payloads []*SplunkPayload for { select { + case <-context.Done(): + err := sl.SendPayloads(payloads) + if err != nil { + fmt.Fprintf(os.Stderr, "Splunk logger unable to send payloads: %v", err) + } + return case p := <-sl.payloads: if p != nil { payloads = append(payloads, p) diff --git a/pkg/splunk_logger/splunk_logger_test.go b/pkg/splunk_logger/splunk_logger_test.go new file mode 100644 index 0000000000..fab4dee071 --- /dev/null +++ b/pkg/splunk_logger/splunk_logger_test.go @@ -0,0 +1,65 @@ +package logger + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestSplunkLogger(t *testing.T) { + ch := make(chan bool) + time.AfterFunc(time.Second*10, func() { + ch <- false + }) + count := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // make sure the logger retries requests + if count == 0 { + count += 1 + w.WriteHeader(http.StatusInternalServerError) + return + } + require.Equal(t, "Splunk", r.Header.Get("Authorization")) + require.Equal(t, "application/json", r.Header.Get("Content-Type")) + var sp SplunkPayload + err := json.NewDecoder(r.Body).Decode(&sp) + require.NoError(t, err) + require.Equal(t, "test-host", sp.Host) + require.Equal(t, "test-host", sp.Event.Host) + require.Equal(t, "image-builder", sp.Event.Ident) + require.Equal(t, "message", sp.Event.Message) + ch <- true + })) + sl := NewSplunkLogger(context.Background(), srv.URL, "", "image-builder", "test-host") + require.NoError(t, sl.LogWithTime(time.Now(), "message")) + require.True(t, <-ch) +} + +func TestSplunkLoggerContext(t *testing.T) { + ch := make(chan bool) + time.AfterFunc(time.Second*10, func() { + ch <- false + }) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "Splunk", r.Header.Get("Authorization")) + require.Equal(t, "application/json", r.Header.Get("Content-Type")) + var sp SplunkPayload + err := json.NewDecoder(r.Body).Decode(&sp) + require.NoError(t, err) + require.Equal(t, "test-host", sp.Host) + require.Equal(t, "test-host", sp.Event.Host) + require.Equal(t, "image-builder", sp.Event.Ident) + require.Equal(t, "message", sp.Event.Message) + ch <- true + })) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*1) + defer cancel() + sl := NewSplunkLogger(ctx, srv.URL, "", "image-builder", "test-host") + require.NoError(t, sl.LogWithTime(time.Now(), "message")) + require.True(t, <-ch) +}