Skip to content

Commit

Permalink
Merge pull request kubernetes-sigs#1634 from MartinForReal/shafan/grpc
Browse files Browse the repository at this point in the history
Refactor: remove grpc wrapper
  • Loading branch information
andyzhangx authored Dec 22, 2023
2 parents 6d52a19 + bd8e20b commit 0406086
Show file tree
Hide file tree
Showing 8 changed files with 90 additions and 214 deletions.
36 changes: 29 additions & 7 deletions pkg/azurefile/azurefile.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"bytes"
"context"
"encoding/binary"
"errors"
"fmt"
"net/url"
"os/exec"
Expand All @@ -33,11 +34,12 @@ import (
"github.com/container-storage-interface/spec/lib/go/csi"
"github.com/pborman/uuid"
"github.com/rubiojr/go-vhd/vhd"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"

v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/errors"
apierrors "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/util/wait"
"k8s.io/klog/v2"
Expand Down Expand Up @@ -248,6 +250,7 @@ type Driver struct {
printVolumeStatsCallLogs bool
fileClient *azureFileClient
mounter *mount.SafeFormatAndMount
server *grpc.Server
// lock per volume attach (only for vhd disk feature)
volLockMap *lockMap
// only for nfs feature
Expand Down Expand Up @@ -353,7 +356,7 @@ func NewDriver(options *DriverOptions) *Driver {
}

// Run driver initialization
func (d *Driver) Run(endpoint, kubeconfig string, testBool bool) {
func (d *Driver) Run(ctx context.Context, endpoint, kubeconfig string) error {
versionMeta, err := GetVersionYAML(d.Name)
if err != nil {
klog.Fatalf("%v", err)
Expand Down Expand Up @@ -407,10 +410,29 @@ func (d *Driver) Run(endpoint, kubeconfig string, testBool bool) {
}
d.AddNodeServiceCapabilities(nodeCap)

s := csicommon.NewNonBlockingGRPCServer()
// Driver d act as IdentityServer, ControllerServer and NodeServer
s.Start(endpoint, d, d, d, testBool)
s.Wait()
//setup grpc server
opts := []grpc.ServerOption{
grpc.UnaryInterceptor(csicommon.LogGRPC),
}
server := grpc.NewServer(opts...)
csi.RegisterIdentityServer(server, d)
csi.RegisterControllerServer(server, d)
csi.RegisterNodeServer(server, d)
d.server = server

listener, err := csicommon.ListenEndpoint(endpoint)
if err != nil {
klog.Fatalf("failed to listen endpoint: %v", err)
}
go func() {
<-ctx.Done()
d.server.GracefulStop()
}()
if err = d.server.Serve(listener); errors.Is(err, grpc.ErrServerStopped) {
klog.Infof("gRPC server stopped serving")
return nil
}
return err
}

// getFileShareQuota return (-1, nil) means file share does not exist
Expand Down Expand Up @@ -1160,7 +1182,7 @@ func (d *Driver) SetAzureCredentials(ctx context.Context, accountName, accountKe
Type: "Opaque",
}
_, err := d.cloud.KubeClient.CoreV1().Secrets(secretNamespace).Create(ctx, secret, metav1.CreateOptions{})
if errors.IsAlreadyExists(err) {
if apierrors.IsAlreadyExists(err) {
err = nil
}
if err != nil {
Expand Down
20 changes: 18 additions & 2 deletions pkg/azurefile/azurefile_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"reflect"
"sort"
"testing"
"time"

"github.com/Azure/azure-sdk-for-go/services/storage/mgmt/2021-09-01/storage"
azure2 "github.com/Azure/go-autorest/autorest/azure"
Expand Down Expand Up @@ -1052,7 +1053,15 @@ func TestRun(t *testing.T) {
os.Setenv(DefaultAzureCredentialFileEnv, fakeCredFile)

d := NewFakeDriver()
d.Run("tcp://127.0.0.1:0", "", true)
ctx, cancelFn := context.WithCancel(context.Background())
go func() {
time.Sleep(1 * time.Second)
cancelFn()
}()
if err := d.Run(ctx, "tcp://127.0.0.1:0", ""); err != nil {
t.Error(err.Error())
}

},
},
{
Expand All @@ -1077,9 +1086,16 @@ func TestRun(t *testing.T) {
os.Setenv(DefaultAzureCredentialFileEnv, fakeCredFile)

d := NewFakeDriver()
ctx, cancelFn := context.WithCancel(context.Background())
go func() {
time.Sleep(1 * time.Second)
cancelFn()
}()
d.cloud = &azure.Cloud{}
d.NodeID = ""
d.Run("tcp://127.0.0.1:0", "", true)
if err := d.Run(ctx, "tcp://127.0.0.1:0", ""); err != nil {
t.Error(err.Error())
}
},
},
}
Expand Down
4 changes: 3 additions & 1 deletion pkg/azurefileplugin/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,9 @@ func handle() {
if driver == nil {
klog.Fatalln("Failed to initialize azurefile CSI Driver")
}
driver.Run(*endpoint, *kubeconfig, false)
if err := driver.Run(context.Background(), *endpoint, *kubeconfig); err != nil {
klog.Fatalln(err)
}
}

func exportMetrics() {
Expand Down
122 changes: 0 additions & 122 deletions pkg/csi-common/server.go

This file was deleted.

67 changes: 0 additions & 67 deletions pkg/csi-common/server_test.go

This file was deleted.

28 changes: 26 additions & 2 deletions pkg/csi-common/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ package csicommon

import (
"fmt"
"net"
"os"
"runtime"
"strings"

"golang.org/x/net/context"
Expand All @@ -28,7 +31,7 @@ import (
"github.com/kubernetes-csi/csi-lib-utils/protosanitizer"
)

func ParseEndpoint(ep string) (string, string, error) {
func parseEndpoint(ep string) (string, string, error) {
if strings.HasPrefix(strings.ToLower(ep), "unix://") || strings.HasPrefix(strings.ToLower(ep), "tcp://") {
s := strings.SplitN(ep, "://", 2)
if s[1] != "" {
Expand All @@ -37,6 +40,27 @@ func ParseEndpoint(ep string) (string, string, error) {
}
return "", "", fmt.Errorf("Invalid endpoint: %v", ep)
}
func ListenEndpoint(endpoint string) (net.Listener, error) {
proto, addr, err := parseEndpoint(endpoint)
if err != nil {
klog.Fatal(err.Error())
}

if proto == "unix" {
if runtime.GOOS != "windows" {
addr = "/" + addr
}
if err := os.Remove(addr); err != nil && !os.IsNotExist(err) {
klog.Fatalf("Failed to remove %s, error: %s", addr, err.Error())
}
}

listener, err := net.Listen(proto, addr)
if err != nil {
klog.Fatalf("Failed to listen: %v", err)
}
return listener, err
}

func NewVolumeCapabilityAccessMode(mode csi.VolumeCapability_AccessMode_Mode) *csi.VolumeCapability_AccessMode {
return &csi.VolumeCapability_AccessMode{Mode: mode}
Expand Down Expand Up @@ -71,7 +95,7 @@ func getLogLevel(method string) int32 {
return 2
}

func logGRPC(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
func LogGRPC(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
level := klog.Level(getLogLevel(info.FullMethod))
klog.V(level).Infof("GRPC call: %s", info.FullMethod)
klog.V(level).Infof("GRPC request: %s", protosanitizer.StripSecrets(req))
Expand Down
Loading

0 comments on commit 0406086

Please sign in to comment.