Skip to content

Commit

Permalink
feat: graceful server shutdown
Browse files Browse the repository at this point in the history
Signed-off-by: Artsiom Koltun <[email protected]>
  • Loading branch information
artek-koltun committed Aug 28, 2023
1 parent 325cfd9 commit c717377
Show file tree
Hide file tree
Showing 3 changed files with 348 additions and 3 deletions.
13 changes: 10 additions & 3 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,17 @@ import (
"fmt"
"log"
"net"
"os"
"strings"
"time"

"github.com/opiproject/gospdk/spdk"

"github.com/opiproject/opi-spdk-bridge/pkg/backend"
"github.com/opiproject/opi-spdk-bridge/pkg/frontend"
"github.com/opiproject/opi-spdk-bridge/pkg/kvm"
"github.com/opiproject/opi-spdk-bridge/pkg/middleend"
"github.com/opiproject/opi-spdk-bridge/pkg/server"

pb "github.com/opiproject/opi-api/storage/v1alpha1/gen/go"
"google.golang.org/grpc"
Expand Down Expand Up @@ -89,8 +92,12 @@ func main() {

reflection.Register(s)

log.Printf("Server listening at %v", lis.Addr())
if err := s.Serve(lis); err != nil {
log.Fatalf("failed to serve: %v", err)
wrapper := server.NewGRPCServerWrapper(2*time.Second, s, lis)

wrapper.RunAsync()
if err := wrapper.Wait(); err != nil {
log.Printf("Server error: %v", err)
os.Exit(-1)
}
log.Print("Server successfully stopped")
}
111 changes: 111 additions & 0 deletions pkg/server/grpcserver.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
// SPDX-License-Identifier: Apache-2.0
// Copyright (C) 2023 Intel Corporation

// Package server implements the server
package server

import (
"context"
"errors"
"log"
"net"
"os"
"os/signal"
"syscall"
"time"

"google.golang.org/grpc"
)

// GRPCServerWrapper wraps gRPC server to provide graceful shutdown capabilities
type GRPCServerWrapper struct {
waitSignal chan os.Signal
signalsToWait []os.Signal

timeout time.Duration

server *grpc.Server
listener net.Listener
waitServeComplete chan error
serve func(*grpc.Server, net.Listener) error
}

func defaultServe(s *grpc.Server, l net.Listener) error { return s.Serve(l) }

// NewGRPCServerWrapper creates a new instance of GRPCServerWrapper
func NewGRPCServerWrapper(
timeout time.Duration, server *grpc.Server, listener net.Listener,
) *GRPCServerWrapper {
if timeout == 0 {
log.Panicf("timeout cannot be zero")
}

if server == nil {
log.Panicf("grpc server cannot be nil")
}

if listener == nil {
log.Panic("listener cannot be nil")
}

return &GRPCServerWrapper{
waitSignal: make(chan os.Signal, 1),
signalsToWait: []os.Signal{syscall.SIGINT, syscall.SIGTERM},
timeout: timeout,
server: server,
listener: listener,
waitServeComplete: make(chan error, 1),
serve: defaultServe,
}
}

// RunAsync runs gRPC server
func (s *GRPCServerWrapper) RunAsync() {
go func() {
log.Printf("Server listening at %v", s.listener.Addr())
s.waitServeComplete <- s.serve(s.server, s.listener)
}()
}

// Wait waits for a signal and handles graceful completion
func (s *GRPCServerWrapper) Wait() error {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() {
signal.Notify(s.waitSignal, s.signalsToWait...)
select {
case sig := <-s.waitSignal:
log.Printf("Got signal: %v", sig)
log.Printf("Start graceful shutdown with timeout: %v", s.timeout)
time.AfterFunc(s.timeout, func() { cancel() })
s.stopServer(ctx)
case <-ctx.Done():
log.Println("Stop listening for a signal")
}
}()

select {
case err := <-s.waitServeComplete:
return err
case <-ctx.Done():
return errors.New("server stop timeout elapsed")
}
}

func (s *GRPCServerWrapper) stopServer(ctx context.Context) {
log.Println("Stop server")

stopped := make(chan struct{}, 1)
go func() {
s.server.GracefulStop()
close(stopped)
}()

select {
case <-ctx.Done():
log.Println("Server stop context done")
s.server.Stop()
case <-stopped:
log.Println("GracefulStop completed")
}
}
227 changes: 227 additions & 0 deletions pkg/server/grpcserver_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
// SPDX-License-Identifier: Apache-2.0
// Copyright (C) 2023 Intel Corporation

// Package server implements the server
package server

import (
"context"
"errors"
"log"
"net"
"os"
"sync"
"syscall"
"testing"
"time"

"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/test/bufconn"

pb "github.com/opiproject/opi-api/storage/v1alpha1/gen/go"
)

const timeout = 50 * time.Millisecond

type TestServer struct {
pb.MiddleendEncryptionServiceServer
wait time.Duration
startedHandlingCall sync.WaitGroup
}

func (b *TestServer) CreateEncryptedVolume(_ context.Context, _ *pb.CreateEncryptedVolumeRequest) (*pb.EncryptedVolume, error) {
b.startedHandlingCall.Done()
time.Sleep(b.wait)
return &pb.EncryptedVolume{}, nil
}

type testEnv struct {
testServer *TestServer
client pb.MiddleendEncryptionServiceClient
conn *grpc.ClientConn
ln net.Listener
grpcServer *grpc.Server
}

func (e *testEnv) Close() {
CloseGrpcConnection(e.conn)
CloseListener(e.ln)
}

func createTestEnvironment(callTime time.Duration) *testEnv {
env := &testEnv{}
env.testServer = &TestServer{
pb.UnimplementedMiddleendEncryptionServiceServer{},
callTime,
sync.WaitGroup{},
}
env.grpcServer = grpc.NewServer()
listener := bufconn.Listen(1024 * 1024)
env.ln = listener
pb.RegisterMiddleendEncryptionServiceServer(env.grpcServer, env.testServer)

ctx := context.Background()
conn, err := grpc.DialContext(ctx,
"",
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) {
return listener.Dial()
}))
if err != nil {
log.Fatal(err)
}
env.conn = conn
env.client = pb.NewMiddleendEncryptionServiceClient(env.conn)

return env
}

func TestGRPCWrapperWait(t *testing.T) {
tests := map[string]struct {
callTime time.Duration
wantErr bool
serve func(*grpc.Server, net.Listener) error
}{
"server stop timeout": {
callTime: timeout * 2,
wantErr: true,
},
"successful server stop": {
callTime: timeout / 10,
wantErr: false,
},
}
for testName, tt := range tests {
t.Run(testName, func(t *testing.T) {
testEnv := createTestEnvironment(tt.callTime)
defer testEnv.Close()
testEnv.testServer.startedHandlingCall.Add(1)

serverWrapper := NewGRPCServerWrapper(timeout, testEnv.grpcServer, testEnv.ln)
// use rare signal in order not to catch a real interrupt
serverWrapper.signalsToWait = []os.Signal{syscall.SIGILL}
serverWrapper.RunAsync()

var (
clientResponse *pb.EncryptedVolume
clientErr error
)
clientDone := sync.WaitGroup{}
clientDone.Add(1)
go func() {
clientResponse, clientErr = testEnv.client.CreateEncryptedVolume(
context.Background(), &pb.CreateEncryptedVolumeRequest{})
clientDone.Done()
}()
testEnv.testServer.startedHandlingCall.Wait()

serverWrapper.waitSignal <- os.Interrupt
waitErr := serverWrapper.Wait()

if (waitErr != nil) != tt.wantErr {
t.Errorf("Expected elapsed: %v. received: %v", tt.wantErr, waitErr)
}
clientDone.Wait()
if (clientErr != nil) != tt.wantErr {
t.Errorf("Expected error %v, received: %v", tt.wantErr, clientErr)
}
if (clientResponse == nil) != tt.wantErr {
t.Errorf("Expected not nil response %v, received: %v", tt.wantErr, clientResponse)
}
})
}

t.Run("failed serve", func(t *testing.T) {
testEnv := createTestEnvironment(timeout)
defer testEnv.Close()
serverWrapper := NewGRPCServerWrapper(timeout, testEnv.grpcServer, testEnv.ln)
// use rare signal in order not to catch a real interrupt
serverWrapper.signalsToWait = []os.Signal{syscall.SIGILL}
stubErr := errors.New("some serve error")
serverWrapper.serve = func(s *grpc.Server, l net.Listener) error { return stubErr }
serverWrapper.RunAsync()

waitErr := serverWrapper.Wait()

if waitErr != stubErr {
t.Errorf("Expected error: %v, received: %v", stubErr, waitErr)
}
})

t.Run("failed serve after signal received", func(t *testing.T) {
testEnv := createTestEnvironment(timeout)
defer testEnv.Close()
serverWrapper := NewGRPCServerWrapper(timeout, testEnv.grpcServer, testEnv.ln)
// use rare signal in order not to catch a real interrupt
serverWrapper.signalsToWait = []os.Signal{syscall.SIGILL}
stubErr := errors.New("some serve error")
wg := sync.WaitGroup{}
wg.Add(1)
serverWrapper.serve = func(s *grpc.Server, l net.Listener) error {
wg.Wait()
return stubErr
}
serverWrapper.RunAsync()
go func() {
serverWrapper.waitSignal <- os.Interrupt
time.Sleep(timeout / 10)
wg.Done()
}()

waitErr := serverWrapper.Wait()

if waitErr != stubErr {
t.Errorf("Expected error: %v, received: %v", stubErr, waitErr)
}
})
}

func TestNewGRPCWrapper(t *testing.T) {
tests := map[string]struct {
timeout time.Duration
grpcServer *grpc.Server
listener net.Listener
wantPanic bool
}{
"zero timeout": {
timeout: 0,
grpcServer: grpc.NewServer(),
listener: bufconn.Listen(32),
wantPanic: true,
},
"nil grpc server": {
timeout: timeout,
grpcServer: nil,
listener: bufconn.Listen(32),
wantPanic: true,
},
"nil listener": {
timeout: timeout,
grpcServer: grpc.NewServer(),
listener: nil,
wantPanic: true,
},
"successful wrapper creation": {
timeout: timeout,
grpcServer: grpc.NewServer(),
listener: bufconn.Listen(32),
wantPanic: false,
},
}
for testName, tt := range tests {
t.Run(testName, func(t *testing.T) {
defer func() {
r := recover()
if (r != nil) != tt.wantPanic {
t.Errorf("GRPCServerWrapper.Run() recover = %v, wantPanic = %v", r, tt.wantPanic)
}
}()

wrapper := NewGRPCServerWrapper(tt.timeout, tt.grpcServer, tt.listener)
if !tt.wantPanic && wrapper == nil {
t.Error("Expect not nil wrapper")
}
})
}
}

0 comments on commit c717377

Please sign in to comment.