From bd8e20bf5a8721088ed8fce0fceb17ab5d243617 Mon Sep 17 00:00:00 2001 From: Fan Shang Xiang Date: Thu, 21 Dec 2023 15:52:58 +0800 Subject: [PATCH] remove grpc wrapper --- pkg/azurefile/azurefile.go | 36 ++++++++-- pkg/azurefile/azurefile_test.go | 20 +++++- pkg/azurefileplugin/main.go | 4 +- pkg/csi-common/server.go | 122 -------------------------------- pkg/csi-common/server_test.go | 67 ------------------ pkg/csi-common/utils.go | 28 +++++++- pkg/csi-common/utils_test.go | 24 +++---- test/e2e/suite_test.go | 3 +- 8 files changed, 90 insertions(+), 214 deletions(-) delete mode 100644 pkg/csi-common/server.go delete mode 100644 pkg/csi-common/server_test.go diff --git a/pkg/azurefile/azurefile.go b/pkg/azurefile/azurefile.go index ea98dba181..c03a8bf1d7 100644 --- a/pkg/azurefile/azurefile.go +++ b/pkg/azurefile/azurefile.go @@ -20,6 +20,7 @@ import ( "bytes" "context" "encoding/binary" + "errors" "fmt" "net/url" "os/exec" @@ -33,11 +34,12 @@ import ( "github.com/container-storage-interface/spec/lib/go/csi" "github.com/pborman/uuid" "github.com/rubiojr/go-vhd/vhd" + "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" v1 "k8s.io/api/core/v1" - "k8s.io/apimachinery/pkg/api/errors" + apierrors "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/util/wait" "k8s.io/klog/v2" @@ -248,6 +250,7 @@ type Driver struct { printVolumeStatsCallLogs bool fileClient *azureFileClient mounter *mount.SafeFormatAndMount + server *grpc.Server // lock per volume attach (only for vhd disk feature) volLockMap *lockMap // only for nfs feature @@ -353,7 +356,7 @@ func NewDriver(options *DriverOptions) *Driver { } // Run driver initialization -func (d *Driver) Run(endpoint, kubeconfig string, testBool bool) { +func (d *Driver) Run(ctx context.Context, endpoint, kubeconfig string) error { versionMeta, err := GetVersionYAML(d.Name) if err != nil { klog.Fatalf("%v", err) @@ -407,10 +410,29 @@ func (d *Driver) Run(endpoint, kubeconfig string, testBool bool) { } d.AddNodeServiceCapabilities(nodeCap) - s := csicommon.NewNonBlockingGRPCServer() - // Driver d act as IdentityServer, ControllerServer and NodeServer - s.Start(endpoint, d, d, d, testBool) - s.Wait() + //setup grpc server + opts := []grpc.ServerOption{ + grpc.UnaryInterceptor(csicommon.LogGRPC), + } + server := grpc.NewServer(opts...) + csi.RegisterIdentityServer(server, d) + csi.RegisterControllerServer(server, d) + csi.RegisterNodeServer(server, d) + d.server = server + + listener, err := csicommon.ListenEndpoint(endpoint) + if err != nil { + klog.Fatalf("failed to listen endpoint: %v", err) + } + go func() { + <-ctx.Done() + d.server.GracefulStop() + }() + if err = d.server.Serve(listener); errors.Is(err, grpc.ErrServerStopped) { + klog.Infof("gRPC server stopped serving") + return nil + } + return err } // getFileShareQuota return (-1, nil) means file share does not exist @@ -1160,7 +1182,7 @@ func (d *Driver) SetAzureCredentials(ctx context.Context, accountName, accountKe Type: "Opaque", } _, err := d.cloud.KubeClient.CoreV1().Secrets(secretNamespace).Create(ctx, secret, metav1.CreateOptions{}) - if errors.IsAlreadyExists(err) { + if apierrors.IsAlreadyExists(err) { err = nil } if err != nil { diff --git a/pkg/azurefile/azurefile_test.go b/pkg/azurefile/azurefile_test.go index a81f945787..04dff01f65 100644 --- a/pkg/azurefile/azurefile_test.go +++ b/pkg/azurefile/azurefile_test.go @@ -26,6 +26,7 @@ import ( "reflect" "sort" "testing" + "time" "github.com/Azure/azure-sdk-for-go/services/storage/mgmt/2021-09-01/storage" azure2 "github.com/Azure/go-autorest/autorest/azure" @@ -1052,7 +1053,15 @@ func TestRun(t *testing.T) { os.Setenv(DefaultAzureCredentialFileEnv, fakeCredFile) d := NewFakeDriver() - d.Run("tcp://127.0.0.1:0", "", true) + ctx, cancelFn := context.WithCancel(context.Background()) + go func() { + time.Sleep(1 * time.Second) + cancelFn() + }() + if err := d.Run(ctx, "tcp://127.0.0.1:0", ""); err != nil { + t.Error(err.Error()) + } + }, }, { @@ -1077,9 +1086,16 @@ func TestRun(t *testing.T) { os.Setenv(DefaultAzureCredentialFileEnv, fakeCredFile) d := NewFakeDriver() + ctx, cancelFn := context.WithCancel(context.Background()) + go func() { + time.Sleep(1 * time.Second) + cancelFn() + }() d.cloud = &azure.Cloud{} d.NodeID = "" - d.Run("tcp://127.0.0.1:0", "", true) + if err := d.Run(ctx, "tcp://127.0.0.1:0", ""); err != nil { + t.Error(err.Error()) + } }, }, } diff --git a/pkg/azurefileplugin/main.go b/pkg/azurefileplugin/main.go index 94d59e4a78..729cfc7234 100644 --- a/pkg/azurefileplugin/main.go +++ b/pkg/azurefileplugin/main.go @@ -120,7 +120,9 @@ func handle() { if driver == nil { klog.Fatalln("Failed to initialize azurefile CSI Driver") } - driver.Run(*endpoint, *kubeconfig, false) + if err := driver.Run(context.Background(), *endpoint, *kubeconfig); err != nil { + klog.Fatalln(err) + } } func exportMetrics() { diff --git a/pkg/csi-common/server.go b/pkg/csi-common/server.go deleted file mode 100644 index 5d3209a15f..0000000000 --- a/pkg/csi-common/server.go +++ /dev/null @@ -1,122 +0,0 @@ -/* -Copyright 2017 The Kubernetes Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package csicommon - -import ( - "net" - "os" - "runtime" - "sync" - "time" - - "google.golang.org/grpc" - "k8s.io/klog/v2" - - "github.com/container-storage-interface/spec/lib/go/csi" -) - -// Defines Non blocking GRPC server interfaces -type NonBlockingGRPCServer interface { - // Start services at the endpoint - Start(endpoint string, ids csi.IdentityServer, cs csi.ControllerServer, ns csi.NodeServer, testMode bool) - // Waits for the service to stop - Wait() - // Stops the service gracefully - Stop() - // Stops the service forcefully - ForceStop() -} - -func NewNonBlockingGRPCServer() NonBlockingGRPCServer { - return &nonBlockingGRPCServer{} -} - -// NonBlocking server -type nonBlockingGRPCServer struct { - wg sync.WaitGroup - server *grpc.Server -} - -func (s *nonBlockingGRPCServer) Start(endpoint string, ids csi.IdentityServer, cs csi.ControllerServer, ns csi.NodeServer, testMode bool) { - s.wg.Add(1) - go s.serve(endpoint, ids, cs, ns, testMode) -} - -func (s *nonBlockingGRPCServer) Wait() { - s.wg.Wait() -} - -func (s *nonBlockingGRPCServer) Stop() { - s.server.GracefulStop() -} - -func (s *nonBlockingGRPCServer) ForceStop() { - s.server.Stop() -} - -func (s *nonBlockingGRPCServer) serve(endpoint string, ids csi.IdentityServer, cs csi.ControllerServer, ns csi.NodeServer, testMode bool) { - - proto, addr, err := ParseEndpoint(endpoint) - if err != nil { - klog.Fatal(err.Error()) - } - - if proto == "unix" { - if runtime.GOOS != "windows" { - addr = "/" + addr - } - if err := os.Remove(addr); err != nil && !os.IsNotExist(err) { - klog.Fatalf("Failed to remove %s, error: %s", addr, err.Error()) - } - } - - listener, err := net.Listen(proto, addr) - if err != nil { - klog.Fatalf("Failed to listen: %v", err) - } - - opts := []grpc.ServerOption{ - grpc.UnaryInterceptor(logGRPC), - } - server := grpc.NewServer(opts...) - s.server = server - - if ids != nil { - csi.RegisterIdentityServer(server, ids) - } - if cs != nil { - csi.RegisterControllerServer(server, cs) - } - if ns != nil { - csi.RegisterNodeServer(server, ns) - } - // Used to stop the server while running tests - if testMode { - s.wg.Done() - go func() { - // make sure Serve() is called - s.wg.Wait() - time.Sleep(time.Millisecond * 1000) - s.server.GracefulStop() - }() - } - - klog.Infof("Listening for connections on address: %#v", listener.Addr()) - if err := server.Serve(listener); err != nil { - klog.Errorf("Listening for connections on address: %#v, error: %v", listener.Addr(), err) - } -} diff --git a/pkg/csi-common/server_test.go b/pkg/csi-common/server_test.go deleted file mode 100644 index 18e07712c4..0000000000 --- a/pkg/csi-common/server_test.go +++ /dev/null @@ -1,67 +0,0 @@ -/* -Copyright 2017 The Kubernetes Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package csicommon - -import ( - "sync" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "google.golang.org/grpc" -) - -func TestNewNonBlockingGRPCServer(t *testing.T) { - s := NewNonBlockingGRPCServer() - assert.NotNil(t, s) -} - -func TestStart(_ *testing.T) { - s := NewNonBlockingGRPCServer() - // sleep a while to avoid race condition in unit test - time.Sleep(time.Millisecond * 500) - s.Start("tcp://127.0.0.1:0", nil, nil, nil, true) - time.Sleep(time.Millisecond * 500) -} - -func TestServe(_ *testing.T) { - s := nonBlockingGRPCServer{} - s.server = grpc.NewServer() - s.wg = sync.WaitGroup{} - //need to add one here as the actual also requires one. - s.wg.Add(1) - s.serve("tcp://127.0.0.1:0", nil, nil, nil, true) -} - -func TestWait(_ *testing.T) { - s := nonBlockingGRPCServer{} - s.server = grpc.NewServer() - s.wg = sync.WaitGroup{} - s.Wait() -} - -func TestStop(_ *testing.T) { - s := nonBlockingGRPCServer{} - s.server = grpc.NewServer() - s.Stop() -} - -func TestForceStop(_ *testing.T) { - s := nonBlockingGRPCServer{} - s.server = grpc.NewServer() - s.ForceStop() -} diff --git a/pkg/csi-common/utils.go b/pkg/csi-common/utils.go index dde6a64adc..b46146dabd 100644 --- a/pkg/csi-common/utils.go +++ b/pkg/csi-common/utils.go @@ -18,6 +18,9 @@ package csicommon import ( "fmt" + "net" + "os" + "runtime" "strings" "golang.org/x/net/context" @@ -28,7 +31,7 @@ import ( "github.com/kubernetes-csi/csi-lib-utils/protosanitizer" ) -func ParseEndpoint(ep string) (string, string, error) { +func parseEndpoint(ep string) (string, string, error) { if strings.HasPrefix(strings.ToLower(ep), "unix://") || strings.HasPrefix(strings.ToLower(ep), "tcp://") { s := strings.SplitN(ep, "://", 2) if s[1] != "" { @@ -37,6 +40,27 @@ func ParseEndpoint(ep string) (string, string, error) { } return "", "", fmt.Errorf("Invalid endpoint: %v", ep) } +func ListenEndpoint(endpoint string) (net.Listener, error) { + proto, addr, err := parseEndpoint(endpoint) + if err != nil { + klog.Fatal(err.Error()) + } + + if proto == "unix" { + if runtime.GOOS != "windows" { + addr = "/" + addr + } + if err := os.Remove(addr); err != nil && !os.IsNotExist(err) { + klog.Fatalf("Failed to remove %s, error: %s", addr, err.Error()) + } + } + + listener, err := net.Listen(proto, addr) + if err != nil { + klog.Fatalf("Failed to listen: %v", err) + } + return listener, err +} func NewVolumeCapabilityAccessMode(mode csi.VolumeCapability_AccessMode_Mode) *csi.VolumeCapability_AccessMode { return &csi.VolumeCapability_AccessMode{Mode: mode} @@ -71,7 +95,7 @@ func getLogLevel(method string) int32 { return 2 } -func logGRPC(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { +func LogGRPC(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { level := klog.Level(getLogLevel(info.FullMethod)) klog.V(level).Infof("GRPC call: %s", info.FullMethod) klog.V(level).Infof("GRPC request: %s", protosanitizer.StripSecrets(req)) diff --git a/pkg/csi-common/utils_test.go b/pkg/csi-common/utils_test.go index 811b05f272..0bbde2de8e 100644 --- a/pkg/csi-common/utils_test.go +++ b/pkg/csi-common/utils_test.go @@ -31,53 +31,53 @@ import ( func TestParseEndpoint(t *testing.T) { //Valid unix domain socket endpoint - sockType, addr, err := ParseEndpoint("unix://fake.sock") + sockType, addr, err := parseEndpoint("unix://fake.sock") assert.NoError(t, err) assert.Equal(t, sockType, "unix") assert.Equal(t, addr, "fake.sock") - sockType, addr, err = ParseEndpoint("unix:///fakedir/fakedir/fake.sock") + sockType, addr, err = parseEndpoint("unix:///fakedir/fakedir/fake.sock") assert.NoError(t, err) assert.Equal(t, sockType, "unix") assert.Equal(t, addr, "/fakedir/fakedir/fake.sock") //Valid unix domain socket with uppercase - sockType, addr, err = ParseEndpoint("UNIX://fake.sock") + sockType, addr, err = parseEndpoint("UNIX://fake.sock") assert.NoError(t, err) assert.Equal(t, sockType, "UNIX") assert.Equal(t, addr, "fake.sock") //Valid TCP endpoint with ip - sockType, addr, err = ParseEndpoint("tcp://127.0.0.1:80") + sockType, addr, err = parseEndpoint("tcp://127.0.0.1:80") assert.NoError(t, err) assert.Equal(t, sockType, "tcp") assert.Equal(t, addr, "127.0.0.1:80") //Valid TCP endpoint with uppercase - sockType, addr, err = ParseEndpoint("TCP://127.0.0.1:80") + sockType, addr, err = parseEndpoint("TCP://127.0.0.1:80") assert.NoError(t, err) assert.Equal(t, sockType, "TCP") assert.Equal(t, addr, "127.0.0.1:80") //Valid TCP endpoint with hostname - sockType, addr, err = ParseEndpoint("tcp://fakehost:80") + sockType, addr, err = parseEndpoint("tcp://fakehost:80") assert.NoError(t, err) assert.Equal(t, sockType, "tcp") assert.Equal(t, addr, "fakehost:80") - _, _, err = ParseEndpoint("unix:/fake.sock/") + _, _, err = parseEndpoint("unix:/fake.sock/") assert.NotNil(t, err) - _, _, err = ParseEndpoint("fake.sock") + _, _, err = parseEndpoint("fake.sock") assert.NotNil(t, err) - _, _, err = ParseEndpoint("unix://") + _, _, err = parseEndpoint("unix://") assert.NotNil(t, err) - _, _, err = ParseEndpoint("://") + _, _, err = parseEndpoint("://") assert.NotNil(t, err) - _, _, err = ParseEndpoint("") + _, _, err = parseEndpoint("") assert.NotNil(t, err) } @@ -132,7 +132,7 @@ func TestLogGRPC(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { // EXECUTE - _, _ = logGRPC(context.Background(), test.req, &info, handler) + _, _ = LogGRPC(context.Background(), test.req, &info, handler) klog.Flush() // ASSERT diff --git a/test/e2e/suite_test.go b/test/e2e/suite_test.go index f7ccf7bf1e..5eec4d781b 100644 --- a/test/e2e/suite_test.go +++ b/test/e2e/suite_test.go @@ -150,7 +150,8 @@ var _ = ginkgo.BeforeSuite(func(ctx ginkgo.SpecContext) { azurefileDriver = azurefile.NewDriver(&driverOptions) go func() { os.Setenv("AZURE_CREDENTIAL_FILE", credentials.TempAzureCredentialFilePath) - azurefileDriver.Run(fmt.Sprintf("unix:///tmp/csi-%s.sock", uuid.NewUUID().String()), kubeconfig, false) + err := azurefileDriver.Run(context.Background(), fmt.Sprintf("unix:///tmp/csi-%s.sock", uuid.NewUUID().String()), kubeconfig) + gomega.Expect(err).NotTo(gomega.HaveOccurred()) }() } })