diff --git a/cmd/crusoe-csi-driver/main.go b/cmd/crusoe-csi-driver/main.go index bcae318..1662c97 100644 --- a/cmd/crusoe-csi-driver/main.go +++ b/cmd/crusoe-csi-driver/main.go @@ -36,10 +36,10 @@ func setFlags() { replacer := strings.NewReplacer("-", "_") viper.SetEnvKeyReplacer(replacer) - rootCmd.Flags().String(internal.CrusoeAPIEndpointFlag, internal.CrusoeAPIEndpointDefault, "help for api endpoint") - rootCmd.Flags().String(internal.CrusoeAccessKeyFlag, "", "help for access key") - rootCmd.Flags().String(internal.CrusoeSecretKeyFlag, "", "help for secret key") - rootCmd.Flags().String(internal.CrusoeProjectIDFlag, "", "help for project id") + rootCmd.Flags().String(internal.CrusoeAPIEndpointFlag, internal.CrusoeAPIEndpointDefault, "Crusoe API endpoint") + rootCmd.Flags().String(internal.CrusoeAccessKeyFlag, "", "Crusoe Access Key") + rootCmd.Flags().String(internal.CrusoeSecretKeyFlag, "", "Crusoe Secret Key") + rootCmd.Flags().String(internal.CrusoeProjectIDFlag, "", "Cluster Project ID") rootCmd.Flags().Var( enumflag.New(&internal.SelectedCSIDriverType, internal.CSIDriverTypeFlag, @@ -54,8 +54,8 @@ func setFlags() { true), internal.ServicesFlag, "help for services") - rootCmd.Flags().String(internal.NodeNameFlag, "", "help for kubernetes node name") - rootCmd.Flags().String(internal.SocketAddressFlag, internal.SocketAddressDefault, "help for socket address") + rootCmd.Flags().String(internal.NodeNameFlag, "", "Kubernetes Node Name") + rootCmd.Flags().String(internal.SocketAddressFlag, internal.SocketAddressDefault, "CSI Socket Address") err = viper.BindPFlags(rootCmd.Flags()) if err != nil { diff --git a/internal/server.go b/internal/server.go index d267fb3..fd7889f 100644 --- a/internal/server.go +++ b/internal/server.go @@ -9,6 +9,7 @@ import ( "net/url" "os" "os/signal" + "strings" "sync" "syscall" @@ -40,6 +41,8 @@ import ( const ( projectIDEnvKey = "CRUSOE_PROJECT_ID" projectIDLabelKey = "crusoe.ai/project.id" + + numExpectedComponents = 2 ) var ( @@ -47,6 +50,7 @@ var ( errMultipleInstances = errors.New("multiple instances found") errProjectIDNotFound = fmt.Errorf("project ID not found in %s env var or %s node label", projectIDEnvKey, projectIDLabelKey) + errInvalidNodeName = errors.New("invalid node name") ) func interruptHandler() (*sync.WaitGroup, context.Context) { @@ -80,7 +84,12 @@ func getHostInstance(ctx context.Context) (*crusoeapi.InstanceV1Alpha5, error) { "crusoe-csi-driver/0.0.1", ) - nodeName := viper.GetString(NodeNameFlag) + fullNodeName := viper.GetString(NodeNameFlag) + nodeNameSplit := strings.SplitN(fullNodeName, ".", numExpectedComponents) + if len(nodeNameSplit) < 1 { + return nil, errInvalidNodeName + } + shortNodeName := nodeNameSplit[0] var projectID string @@ -96,7 +105,7 @@ func getHostInstance(ctx context.Context) (*crusoeapi.InstanceV1Alpha5, error) { if err != nil { return nil, fmt.Errorf("could not get kube client: %w", err) } - hostNode, nodeFetchErr := kubeClient.CoreV1().Nodes().Get(ctx, nodeName, metav1.GetOptions{}) + hostNode, nodeFetchErr := kubeClient.CoreV1().Nodes().Get(ctx, fullNodeName, metav1.GetOptions{}) if nodeFetchErr != nil { return nil, fmt.Errorf("could not fetch current node with kube client: %w", err) } @@ -109,16 +118,16 @@ func getHostInstance(ctx context.Context) (*crusoeapi.InstanceV1Alpha5, error) { instances, _, err := crusoeClient.VMsApi.ListInstances(ctx, projectID, &crusoeapi.VMsApiListInstancesOpts{ - Names: optional.NewString(nodeName), + Names: optional.NewString(shortNodeName), }) if err != nil { return nil, fmt.Errorf("failed to list instances: %w", err) } if len(instances.Items) == 0 { - return nil, fmt.Errorf("%w: %s", errInstanceNotFound, nodeName) + return nil, fmt.Errorf("%w: %s", errInstanceNotFound, shortNodeName) } else if len(instances.Items) > 1 { - return nil, fmt.Errorf("%w: %s", errMultipleInstances, nodeName) + return nil, fmt.Errorf("%w: %s", errMultipleInstances, shortNodeName) } return &instances.Items[0], nil