Skip to content

Commit

Permalink
separate controller and node structs from the driver
Browse files Browse the repository at this point in the history
  • Loading branch information
dhij committed Aug 19, 2022
1 parent 8b884e9 commit b7f9fe8
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 114 deletions.
21 changes: 19 additions & 2 deletions cmd/do-csi-plugin/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,6 @@ func main() {

drv, err := driver.NewDriver(driver.NewDriverParams{
Endpoint: *endpoint,
Token: *token,
URL: *url,
Region: *region,
DOTag: *doTag,
DriverName: *driverName,
Expand All @@ -65,6 +63,19 @@ func main() {
log.Fatalln(err)
}

ctrl, err := driver.NewController(drv, driver.NewControllerParams{
Token: *token,
URL: *url,
})
if err != nil {
log.Fatalln(err)
}

node, err := driver.NewNode(drv)
if err != nil {
log.Fatalln(err)
}

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

Expand All @@ -78,4 +89,10 @@ func main() {
if err := drv.Run(ctx); err != nil {
log.Fatalln(err)
}
if err := ctrl.Run(ctx); err != nil {
log.Fatalln(err)
}
if err := node.Run(ctx); err != nil {
log.Fatalln(err)
}
}
122 changes: 105 additions & 17 deletions driver/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"github.com/digitalocean/godo"
"github.com/golang/protobuf/ptypes"
"github.com/sirupsen/logrus"
"golang.org/x/oauth2"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"

Expand Down Expand Up @@ -78,9 +79,96 @@ var (
}
)

type Controller struct {
*Driver

storage godo.StorageService
storageActions godo.StorageActionsService
droplets godo.DropletsService
snapshots godo.SnapshotsService
account godo.AccountService
tags godo.TagsService

healthChecker *HealthChecker
}

type NewControllerParams struct {
Token string
URL string
}

func NewController(driver *Driver, p NewControllerParams) (*Controller, error) {
var opts []godo.ClientOpt
opts = append(opts, godo.SetBaseURL(p.URL))

if version == "" {
version = "dev"
}
opts = append(opts, godo.SetUserAgent("csi-digitalocean/"+version))

tokenSource := oauth2.StaticTokenSource(&oauth2.Token{
AccessToken: p.Token,
})
oauthClient := oauth2.NewClient(context.Background(), tokenSource)

doClient, err := godo.New(oauthClient, opts...)
if err != nil {
return nil, fmt.Errorf("couldn't initialize DigitalOcean client: %s", err)
}

healthChecker := NewHealthChecker(&doHealthChecker{account: doClient.Account})

return &Controller{
Driver: driver,
storage: doClient.Storage,
storageActions: doClient.StorageActions,
droplets: doClient.Droplets,
snapshots: doClient.Snapshots,
account: doClient.Account,
tags: doClient.Tags,

healthChecker: healthChecker,
}, nil
}

// Run starts the CSI plugin by communication over the given endpoint
func (d *Controller) Run(ctx context.Context) error {
details, err := d.checkLimit(context.Background())
if err != nil {
return fmt.Errorf("failed to check volumes limits on startup: %s", err)
}
if details != nil {
d.log.WithFields(logrus.Fields{
"limit": details.limit,
"num_volumes": details.numVolumes,
}).Warn("CSI plugin will not function correctly, please resolve volume limit")
}

if d.debugAddr != "" {
mux := http.NewServeMux()
mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
err := d.healthChecker.Check(r.Context())
if err != nil {
d.log.WithError(err).Error("executing health check")
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
})
d.httpSrv = &http.Server{
Addr: d.debugAddr,
Handler: mux,
}
}

csi.RegisterControllerServer(d.srv, d)

return nil
}

// CreateVolume creates a new volume from the given request. The function is
// idempotent.
func (d *Driver) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest) (*csi.CreateVolumeResponse, error) {
func (d *Controller) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest) (*csi.CreateVolumeResponse, error) {
if req.Name == "" {
return nil, status.Error(codes.InvalidArgument, "CreateVolume Name must be provided")
}
Expand Down Expand Up @@ -230,7 +318,7 @@ func (d *Driver) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest)
}

// DeleteVolume deletes the given volume. The function is idempotent.
func (d *Driver) DeleteVolume(ctx context.Context, req *csi.DeleteVolumeRequest) (*csi.DeleteVolumeResponse, error) {
func (d *Controller) DeleteVolume(ctx context.Context, req *csi.DeleteVolumeRequest) (*csi.DeleteVolumeResponse, error) {
if req.VolumeId == "" {
return nil, status.Error(codes.InvalidArgument, "DeleteVolume Volume ID must be provided")
}
Expand Down Expand Up @@ -259,7 +347,7 @@ func (d *Driver) DeleteVolume(ctx context.Context, req *csi.DeleteVolumeRequest)
}

// ControllerPublishVolume attaches the given volume to the node
func (d *Driver) ControllerPublishVolume(ctx context.Context, req *csi.ControllerPublishVolumeRequest) (*csi.ControllerPublishVolumeResponse, error) {
func (d *Controller) ControllerPublishVolume(ctx context.Context, req *csi.ControllerPublishVolumeRequest) (*csi.ControllerPublishVolumeResponse, error) {
if req.VolumeId == "" {
return nil, status.Error(codes.InvalidArgument, "ControllerPublishVolume Volume ID must be provided")
}
Expand Down Expand Up @@ -389,7 +477,7 @@ func (d *Driver) ControllerPublishVolume(ctx context.Context, req *csi.Controlle
}

// ControllerUnpublishVolume deattaches the given volume from the node
func (d *Driver) ControllerUnpublishVolume(ctx context.Context, req *csi.ControllerUnpublishVolumeRequest) (*csi.ControllerUnpublishVolumeResponse, error) {
func (d *Controller) ControllerUnpublishVolume(ctx context.Context, req *csi.ControllerUnpublishVolumeRequest) (*csi.ControllerUnpublishVolumeResponse, error) {
if req.VolumeId == "" {
return nil, status.Error(codes.InvalidArgument, "ControllerUnpublishVolume Volume ID must be provided")
}
Expand Down Expand Up @@ -475,7 +563,7 @@ func (d *Driver) ControllerUnpublishVolume(ctx context.Context, req *csi.Control

// ValidateVolumeCapabilities checks whether the volume capabilities requested
// are supported.
func (d *Driver) ValidateVolumeCapabilities(ctx context.Context, req *csi.ValidateVolumeCapabilitiesRequest) (*csi.ValidateVolumeCapabilitiesResponse, error) {
func (d *Controller) ValidateVolumeCapabilities(ctx context.Context, req *csi.ValidateVolumeCapabilitiesRequest) (*csi.ValidateVolumeCapabilitiesResponse, error) {
if req.VolumeId == "" {
return nil, status.Error(codes.InvalidArgument, "ValidateVolumeCapabilities Volume ID must be provided")
}
Expand Down Expand Up @@ -517,7 +605,7 @@ func (d *Driver) ValidateVolumeCapabilities(ctx context.Context, req *csi.Valida
}

// ListVolumes returns a list of all requested volumes
func (d *Driver) ListVolumes(ctx context.Context, req *csi.ListVolumesRequest) (*csi.ListVolumesResponse, error) {
func (d *Controller) ListVolumes(ctx context.Context, req *csi.ListVolumesRequest) (*csi.ListVolumesResponse, error) {
maxEntries := req.MaxEntries
if maxEntries == 0 && d.defaultVolumesPageSize > 0 {
maxEntries = int32(d.defaultVolumesPageSize)
Expand Down Expand Up @@ -596,7 +684,7 @@ func (d *Driver) ListVolumes(ctx context.Context, req *csi.ListVolumesRequest) (
}

// GetCapacity returns the capacity of the storage pool
func (d *Driver) GetCapacity(ctx context.Context, req *csi.GetCapacityRequest) (*csi.GetCapacityResponse, error) {
func (d *Controller) GetCapacity(ctx context.Context, req *csi.GetCapacityRequest) (*csi.GetCapacityResponse, error) {
// TODO(arslan): check if we can provide this information somehow
d.log.WithFields(logrus.Fields{
"params": req.Parameters,
Expand All @@ -606,7 +694,7 @@ func (d *Driver) GetCapacity(ctx context.Context, req *csi.GetCapacityRequest) (
}

// ControllerGetCapabilities returns the capabilities of the controller service.
func (d *Driver) ControllerGetCapabilities(ctx context.Context, req *csi.ControllerGetCapabilitiesRequest) (*csi.ControllerGetCapabilitiesResponse, error) {
func (d *Controller) ControllerGetCapabilities(ctx context.Context, req *csi.ControllerGetCapabilitiesRequest) (*csi.ControllerGetCapabilitiesResponse, error) {
newCap := func(cap csi.ControllerServiceCapability_RPC_Type) *csi.ControllerServiceCapability {
return &csi.ControllerServiceCapability{
Type: &csi.ControllerServiceCapability_Rpc{
Expand Down Expand Up @@ -643,7 +731,7 @@ func (d *Driver) ControllerGetCapabilities(ctx context.Context, req *csi.Control

// CreateSnapshot will be called by the CO to create a new snapshot from a
// source volume on behalf of a user.
func (d *Driver) CreateSnapshot(ctx context.Context, req *csi.CreateSnapshotRequest) (*csi.CreateSnapshotResponse, error) {
func (d *Controller) CreateSnapshot(ctx context.Context, req *csi.CreateSnapshotRequest) (*csi.CreateSnapshotResponse, error) {
if req.GetName() == "" {
return nil, status.Error(codes.InvalidArgument, "CreateSnapshot Name must be provided")
}
Expand Down Expand Up @@ -739,7 +827,7 @@ func (d *Driver) CreateSnapshot(ctx context.Context, req *csi.CreateSnapshotRequ
}

// DeleteSnapshot will be called by the CO to delete a snapshot.
func (d *Driver) DeleteSnapshot(ctx context.Context, req *csi.DeleteSnapshotRequest) (*csi.DeleteSnapshotResponse, error) {
func (d *Controller) DeleteSnapshot(ctx context.Context, req *csi.DeleteSnapshotRequest) (*csi.DeleteSnapshotResponse, error) {
log := d.log.WithFields(logrus.Fields{
"req_snapshot_id": req.GetSnapshotId(),
"method": "delete_snapshot",
Expand Down Expand Up @@ -772,7 +860,7 @@ func (d *Driver) DeleteSnapshot(ctx context.Context, req *csi.DeleteSnapshotRequ
// system within the given parameters regardless of how they were created.
// ListSnapshots shold not list a snapshot that is being created but has not
// been cut successfully yet.
func (d *Driver) ListSnapshots(ctx context.Context, req *csi.ListSnapshotsRequest) (*csi.ListSnapshotsResponse, error) {
func (d *Controller) ListSnapshots(ctx context.Context, req *csi.ListSnapshotsRequest) (*csi.ListSnapshotsResponse, error) {
listResp := &csi.ListSnapshotsResponse{}
log := d.log.WithFields(logrus.Fields{
"snapshot_id": req.SnapshotId,
Expand Down Expand Up @@ -862,7 +950,7 @@ func (d *Driver) ListSnapshots(ctx context.Context, req *csi.ListSnapshotsReques
}

// ControllerExpandVolume is called from the resizer to increase the volume size.
func (d *Driver) ControllerExpandVolume(ctx context.Context, req *csi.ControllerExpandVolumeRequest) (*csi.ControllerExpandVolumeResponse, error) {
func (d *Controller) ControllerExpandVolume(ctx context.Context, req *csi.ControllerExpandVolumeRequest) (*csi.ControllerExpandVolumeResponse, error) {
volID := req.GetVolumeId()

if len(volID) == 0 {
Expand Down Expand Up @@ -928,15 +1016,15 @@ func (d *Driver) ControllerExpandVolume(ctx context.Context, req *csi.Controller
// The call is used for the CSI health check feature
// (https://github.com/kubernetes/enhancements/pull/1077) which we do not
// support yet.
func (d *Driver) ControllerGetVolume(ctx context.Context, req *csi.ControllerGetVolumeRequest) (*csi.ControllerGetVolumeResponse, error) {
func (d *Controller) ControllerGetVolume(ctx context.Context, req *csi.ControllerGetVolumeRequest) (*csi.ControllerGetVolumeResponse, error) {
return nil, status.Error(codes.Unimplemented, "")
}

// extractStorage extracts the storage size in bytes from the given capacity
// range. If the capacity range is not satisfied it returns the default volume
// size. If the capacity range is above supported sizes, it returns an
// error. If the capacity range is below supported size, it returns the minimum supported size
func (d *Driver) extractStorage(capRange *csi.CapacityRange) (int64, error) {
func (d *Controller) extractStorage(capRange *csi.CapacityRange) (int64, error) {
if capRange == nil {
return defaultVolumeSizeInBytes, nil
}
Expand Down Expand Up @@ -1016,7 +1104,7 @@ func formatBytes(inputBytes int64) string {
}

// waitAction waits until the given action for the volume has completed.
func (d *Driver) waitAction(ctx context.Context, log *logrus.Entry, volumeID string, actionID int) error {
func (d *Controller) waitAction(ctx context.Context, log *logrus.Entry, volumeID string, actionID int) error {
err := wait.PollUntil(1*time.Second, func() (done bool, err error) {
action, _, err := d.storageActions.Get(ctx, volumeID, actionID)
if err != nil {
Expand Down Expand Up @@ -1057,7 +1145,7 @@ type limitDetails struct {
}

// checkLimit checks whether the user hit their account volume limit.
func (d *Driver) checkLimit(ctx context.Context) (*limitDetails, error) {
func (d *Controller) checkLimit(ctx context.Context) (*limitDetails, error) {
// only one provisioner runs, we can make sure to prevent burst creation
d.readyMu.Lock()
defer d.readyMu.Unlock()
Expand Down Expand Up @@ -1144,7 +1232,7 @@ func validateCapabilities(caps []*csi.VolumeCapability) []string {
return violations.List()
}

func (d *Driver) tagVolume(parentCtx context.Context, vol *godo.Volume) error {
func (d *Controller) tagVolume(parentCtx context.Context, vol *godo.Volume) error {
for _, tag := range vol.Tags {
if tag == d.doTag {
return nil
Expand Down
Loading

0 comments on commit b7f9fe8

Please sign in to comment.