Skip to content

Commit

Permalink
feat(shutdown): add panic handling
Browse files Browse the repository at this point in the history
use simplified handling while waiting for
golang/go#53757

Signed-off-by: Artsiom Koltun <[email protected]>
  • Loading branch information
artek-koltun committed Oct 30, 2023
1 parent 7609b47 commit dbad6e2
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 14 deletions.
30 changes: 28 additions & 2 deletions pkg/utils/shutdown.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package utils
import (
"context"
"errors"
"fmt"
"log"
"net"
"net/http"
Expand Down Expand Up @@ -141,7 +142,7 @@ func (s *ShutdownHandler) RunAndWait() error {
for i := range s.serves {
fn := s.serves[i]
s.eg.Go(func() error {
return fn()
return wrapServeFuncPanic(fn)()
})
}

Expand All @@ -163,7 +164,8 @@ func (s *ShutdownHandler) RunAndWait() error {
for i := len(s.shutdowns) - 1; i >= 0; i-- {
timeoutCtx, cancel := context.WithTimeout(context.Background(), s.timeoutPerShutdown)
defer cancel()
err = errors.Join(err, s.shutdowns[i](timeoutCtx))
shutdownFn := wrapShutdownFuncPanic(s.shutdowns[i])
err = errors.Join(err, shutdownFn(timeoutCtx))
}

return err
Expand All @@ -172,6 +174,30 @@ func (s *ShutdownHandler) RunAndWait() error {
return s.eg.Wait()
}

func wrapServeFuncPanic(fn ServeFunc) ServeFunc {
return func() (err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("was panic for serve function, recovered value: %v", r)
}
}()
err = fn()
return err
}
}

func wrapShutdownFuncPanic(fn ShutdownFunc) ShutdownFunc {
return func(ctx context.Context) (err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("was panic for shutdown function, recovered value: %v", r)
}
}()
err = fn(ctx)
return err
}
}

func runWithCtx(ctx context.Context, fn func() error) error {
var err error

Expand Down
69 changes: 57 additions & 12 deletions pkg/utils/shutdown_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package utils
import (
"context"
"errors"
"log"
"os"
"sync"
"testing"
Expand All @@ -26,19 +27,24 @@ func newServeShutdownPair(
serveIDs, shutdownIDs *[]int,
mu *sync.Mutex,
serveErr, shutdownErr error,
servePanic bool, shutdownPanic bool,
) *serveShutdownPair {
shutdownTrigger := make(chan struct{}, 1)
s := &serveShutdownPair{
shutdownTrigger: shutdownTrigger,
serve: func() error {
if serveErr == nil {
<-shutdownTrigger
}

mu.Lock()
*serveIDs = append(*serveIDs, fnID)
mu.Unlock()

if servePanic {
log.Panic("Panic!")
}

if serveErr == nil {
<-shutdownTrigger
}

return serveErr
},
shutdown: func(ctx context.Context) error {
Expand All @@ -48,38 +54,74 @@ func newServeShutdownPair(

shutdownTrigger <- struct{}{}

if shutdownPanic {
log.Panic("Panic!")
}

return shutdownErr
},
}

return s
}

func errString(err error) string {
if err == nil {
return ""
}

return err.Error()
}

func TestRunAndWait(t *testing.T) {
stubErr := errors.New("stub error")
tests := map[string]struct {
giveServeErr error
giveServePanic bool
giveShutdownErr error
giveShutdownPanic bool
stoppedByInterrupt bool
wantErr error
wantErr string
}{
"all services successfully completed": {
giveServeErr: nil,
giveServePanic: false,
giveShutdownErr: nil,
giveShutdownPanic: false,
stoppedByInterrupt: true,
wantErr: nil,
wantErr: "",
},
"serve failed": {
giveServeErr: stubErr,
giveServePanic: false,
giveShutdownErr: nil,
stoppedByInterrupt: false,
wantErr: stubErr,
giveShutdownPanic: false,
wantErr: stubErr.Error(),
},
"shutdown failed": {
giveServeErr: nil,
giveServePanic: false,
giveShutdownErr: stubErr,
giveShutdownPanic: false,
stoppedByInterrupt: true,
wantErr: stubErr.Error(),
},
"serve panic": {
giveServeErr: nil,
giveServePanic: true,
giveShutdownErr: nil,
giveShutdownPanic: false,
stoppedByInterrupt: false,
wantErr: "was panic for serve function, recovered value: Panic!",
},
"shutdown panic": {
giveServeErr: nil,
giveServePanic: false,
giveShutdownErr: nil,
giveShutdownPanic: true,
stoppedByInterrupt: true,
wantErr: stubErr,
wantErr: "was panic for shutdown function, recovered value: Panic!",
},
}
for testName, tt := range tests {
Expand All @@ -89,9 +131,12 @@ func TestRunAndWait(t *testing.T) {
serveFnIDs := &[]int{}
shutdownFnIDs := &[]int{}
mu := sync.Mutex{}
s0 := newServeShutdownPair(0, serveFnIDs, shutdownFnIDs, &mu, nil, nil)
s1 := newServeShutdownPair(1, serveFnIDs, shutdownFnIDs, &mu, tt.giveServeErr, tt.giveShutdownErr)
s2 := newServeShutdownPair(2, serveFnIDs, shutdownFnIDs, &mu, nil, nil)
s0 := newServeShutdownPair(0, serveFnIDs, shutdownFnIDs, &mu, nil, nil, false, false)
s1 := newServeShutdownPair(1, serveFnIDs, shutdownFnIDs, &mu,
tt.giveServeErr, tt.giveShutdownErr,
tt.giveServePanic, tt.giveShutdownPanic,
)
s2 := newServeShutdownPair(2, serveFnIDs, shutdownFnIDs, &mu, nil, nil, false, false)

sh.AddServe(s0.serve, s0.shutdown)
sh.AddServe(s1.serve, s1.shutdown)
Expand All @@ -103,7 +148,7 @@ func TestRunAndWait(t *testing.T) {

err := sh.RunAndWait()

if !errors.Is(err, tt.wantErr) {
if errString(err) != tt.wantErr {
t.Errorf("Expected error: %v, received: %v", tt.wantErr, err)
}

Expand Down

0 comments on commit dbad6e2

Please sign in to comment.