diff --git a/cmd/main.go b/cmd/main.go index 8b9603df..c2d8e446 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -8,7 +8,9 @@ import ( "fmt" "log" "net" + "os" "strings" + "time" "github.com/opiproject/gospdk/spdk" @@ -16,6 +18,7 @@ import ( "github.com/opiproject/opi-spdk-bridge/pkg/frontend" "github.com/opiproject/opi-spdk-bridge/pkg/kvm" "github.com/opiproject/opi-spdk-bridge/pkg/middleend" + "github.com/opiproject/opi-spdk-bridge/pkg/server" pb "github.com/opiproject/opi-api/storage/v1alpha1/gen/go" "google.golang.org/grpc" @@ -89,8 +92,12 @@ func main() { reflection.Register(s) - log.Printf("Server listening at %v", lis.Addr()) - if err := s.Serve(lis); err != nil { - log.Fatalf("failed to serve: %v", err) + wrapper := server.NewGRPCServerWrapper(2*time.Second, s, lis) + + wrapper.RunAsync() + if err := wrapper.Wait(); err != nil { + log.Printf("Server error: %v", err) + os.Exit(-1) } + log.Print("Server successfully stopped") } diff --git a/pkg/server/grpcserver.go b/pkg/server/grpcserver.go new file mode 100644 index 00000000..a4920916 --- /dev/null +++ b/pkg/server/grpcserver.go @@ -0,0 +1,111 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (C) 2023 Intel Corporation + +// Package server implements the server +package server + +import ( + "context" + "errors" + "log" + "net" + "os" + "os/signal" + "syscall" + "time" + + "google.golang.org/grpc" +) + +// GRPCServerWrapper wraps gRPC server to provide graceful shutdown capabilities +type GRPCServerWrapper struct { + waitSignal chan os.Signal + signalsToWait []os.Signal + + timeout time.Duration + + server *grpc.Server + listener net.Listener + waitServeComplete chan error + serve func(*grpc.Server, net.Listener) error +} + +func defaultServe(s *grpc.Server, l net.Listener) error { return s.Serve(l) } + +// NewGRPCServerWrapper creates a new instance of GRPCServerWrapper +func NewGRPCServerWrapper( + timeout time.Duration, server *grpc.Server, listener net.Listener, +) *GRPCServerWrapper { + if timeout == 0 { + log.Panicf("timeout cannot be zero") + } + + if server == nil { + log.Panicf("grpc server cannot be nil") + } + + if listener == nil { + log.Panic("listener cannot be nil") + } + + return &GRPCServerWrapper{ + waitSignal: make(chan os.Signal, 1), + signalsToWait: []os.Signal{syscall.SIGINT, syscall.SIGTERM}, + timeout: timeout, + server: server, + listener: listener, + waitServeComplete: make(chan error, 1), + serve: defaultServe, + } +} + +// RunAsync runs gRPC server +func (s *GRPCServerWrapper) RunAsync() { + go func() { + log.Printf("Server listening at %v", s.listener.Addr()) + s.waitServeComplete <- s.serve(s.server, s.listener) + }() +} + +// Wait waits for a signal and handles graceful completion +func (s *GRPCServerWrapper) Wait() error { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go func() { + signal.Notify(s.waitSignal, s.signalsToWait...) + select { + case sig := <-s.waitSignal: + log.Printf("Got signal: %v", sig) + log.Printf("Start graceful shutdown with timeout: %v", s.timeout) + time.AfterFunc(s.timeout, func() { cancel() }) + s.stopServer(ctx) + case <-ctx.Done(): + log.Println("Stop listening for a signal") + } + }() + + select { + case err := <-s.waitServeComplete: + return err + case <-ctx.Done(): + return errors.New("server stop timeout elapsed") + } +} + +func (s *GRPCServerWrapper) stopServer(ctx context.Context) { + log.Println("Stop server") + + stopped := make(chan struct{}, 1) + go func() { + s.server.GracefulStop() + close(stopped) + }() + + select { + case <-ctx.Done(): + log.Println("Server stop context done") + s.server.Stop() + case <-stopped: + log.Println("GracefulStop completed") + } +} diff --git a/pkg/server/grpcserver_test.go b/pkg/server/grpcserver_test.go new file mode 100644 index 00000000..0a229b18 --- /dev/null +++ b/pkg/server/grpcserver_test.go @@ -0,0 +1,227 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (C) 2023 Intel Corporation + +// Package server implements the server +package server + +import ( + "context" + "errors" + "log" + "net" + "os" + "sync" + "syscall" + "testing" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/test/bufconn" + + pb "github.com/opiproject/opi-api/storage/v1alpha1/gen/go" +) + +const timeout = 10 * time.Millisecond + +type TestServer struct { + pb.MiddleendEncryptionServiceServer + wait time.Duration + startedHandlingCall sync.WaitGroup +} + +func (b *TestServer) CreateEncryptedVolume(_ context.Context, _ *pb.CreateEncryptedVolumeRequest) (*pb.EncryptedVolume, error) { + b.startedHandlingCall.Done() + time.Sleep(b.wait) + return &pb.EncryptedVolume{}, nil +} + +type testEnv struct { + testServer *TestServer + client pb.MiddleendEncryptionServiceClient + conn *grpc.ClientConn + ln net.Listener + grpcServer *grpc.Server +} + +func (e *testEnv) Close() { + CloseGrpcConnection(e.conn) + CloseListener(e.ln) +} + +func createTestEnvironment(callTime time.Duration) *testEnv { + env := &testEnv{} + env.testServer = &TestServer{ + pb.UnimplementedMiddleendEncryptionServiceServer{}, + callTime, + sync.WaitGroup{}, + } + env.grpcServer = grpc.NewServer() + listener := bufconn.Listen(1024 * 1024) + env.ln = listener + pb.RegisterMiddleendEncryptionServiceServer(env.grpcServer, env.testServer) + + ctx := context.Background() + conn, err := grpc.DialContext(ctx, + "", + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) { + return listener.Dial() + })) + if err != nil { + log.Fatal(err) + } + env.conn = conn + env.client = pb.NewMiddleendEncryptionServiceClient(env.conn) + + return env +} + +func TestGRPCWrapperWait(t *testing.T) { + tests := map[string]struct { + callTime time.Duration + wantErr bool + serve func(*grpc.Server, net.Listener) error + }{ + "server stop timeout": { + callTime: timeout * 2, + wantErr: true, + }, + "successful server stop": { + callTime: timeout / 2, + wantErr: false, + }, + } + for testName, tt := range tests { + t.Run(testName, func(t *testing.T) { + testEnv := createTestEnvironment(tt.callTime) + defer testEnv.Close() + testEnv.testServer.startedHandlingCall.Add(1) + + serverWrapper := NewGRPCServerWrapper(timeout, testEnv.grpcServer, testEnv.ln) + // use rare signal in order not to catch a real interrupt + serverWrapper.signalsToWait = []os.Signal{syscall.SIGILL} + serverWrapper.RunAsync() + + var ( + clientResponse *pb.EncryptedVolume + clientErr error + ) + clientDone := sync.WaitGroup{} + clientDone.Add(1) + go func() { + clientResponse, clientErr = testEnv.client.CreateEncryptedVolume( + context.Background(), &pb.CreateEncryptedVolumeRequest{}) + clientDone.Done() + }() + testEnv.testServer.startedHandlingCall.Wait() + + serverWrapper.waitSignal <- os.Interrupt + waitErr := serverWrapper.Wait() + + if (waitErr != nil) != tt.wantErr { + t.Errorf("Expected elapsed: %v. received: %v", tt.wantErr, waitErr) + } + clientDone.Wait() + if (clientErr != nil) != tt.wantErr { + t.Errorf("Expected error %v, received: %v", tt.wantErr, clientErr) + } + if (clientResponse == nil) != tt.wantErr { + t.Errorf("Expected not nil response %v, received: %v", tt.wantErr, clientResponse) + } + }) + } + + t.Run("failed serve", func(t *testing.T) { + testEnv := createTestEnvironment(timeout) + defer testEnv.Close() + serverWrapper := NewGRPCServerWrapper(timeout, testEnv.grpcServer, testEnv.ln) + // use rare signal in order not to catch a real interrupt + serverWrapper.signalsToWait = []os.Signal{syscall.SIGILL} + stubErr := errors.New("some serve error") + serverWrapper.serve = func(s *grpc.Server, l net.Listener) error { return stubErr } + serverWrapper.RunAsync() + + waitErr := serverWrapper.Wait() + + if waitErr != stubErr { + t.Errorf("Expected error: %v, received: %v", stubErr, waitErr) + } + }) + + t.Run("failed serve after signal received", func(t *testing.T) { + testEnv := createTestEnvironment(timeout) + defer testEnv.Close() + serverWrapper := NewGRPCServerWrapper(timeout, testEnv.grpcServer, testEnv.ln) + // use rare signal in order not to catch a real interrupt + serverWrapper.signalsToWait = []os.Signal{syscall.SIGILL} + stubErr := errors.New("some serve error") + wg := sync.WaitGroup{} + wg.Add(1) + serverWrapper.serve = func(s *grpc.Server, l net.Listener) error { + wg.Wait() + return stubErr + } + serverWrapper.RunAsync() + go func() { + serverWrapper.waitSignal <- os.Interrupt + time.Sleep(timeout / 5) + wg.Done() + }() + + waitErr := serverWrapper.Wait() + + if waitErr != stubErr { + t.Errorf("Expected error: %v, received: %v", stubErr, waitErr) + } + }) +} + +func TestNewGRPCWrapper(t *testing.T) { + tests := map[string]struct { + timeout time.Duration + grpcServer *grpc.Server + listener net.Listener + wantPanic bool + }{ + "zero timeout": { + timeout: 0, + grpcServer: grpc.NewServer(), + listener: bufconn.Listen(32), + wantPanic: true, + }, + "nil grpc server": { + timeout: timeout, + grpcServer: nil, + listener: bufconn.Listen(32), + wantPanic: true, + }, + "nil listener": { + timeout: timeout, + grpcServer: grpc.NewServer(), + listener: nil, + wantPanic: true, + }, + "successful wrapper creation": { + timeout: timeout, + grpcServer: grpc.NewServer(), + listener: bufconn.Listen(32), + wantPanic: false, + }, + } + for testName, tt := range tests { + t.Run(testName, func(t *testing.T) { + defer func() { + r := recover() + if (r != nil) != tt.wantPanic { + t.Errorf("GRPCServerWrapper.Run() recover = %v, wantPanic = %v", r, tt.wantPanic) + } + }() + + wrapper := NewGRPCServerWrapper(tt.timeout, tt.grpcServer, tt.listener) + if !tt.wantPanic && wrapper == nil { + t.Error("Expect not nil wrapper") + } + }) + } +}