Skip to content

Commit

Permalink
fix: truncate long hostnames to short hostnames when handling node names
Browse files Browse the repository at this point in the history
  • Loading branch information
Clement Liaw committed Nov 29, 2024
1 parent 0a1dd1c commit ef9f7c7
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 11 deletions.
12 changes: 6 additions & 6 deletions cmd/crusoe-csi-driver/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ func setFlags() {
replacer := strings.NewReplacer("-", "_")
viper.SetEnvKeyReplacer(replacer)

rootCmd.Flags().String(internal.CrusoeAPIEndpointFlag, internal.CrusoeAPIEndpointDefault, "help for api endpoint")
rootCmd.Flags().String(internal.CrusoeAccessKeyFlag, "", "help for access key")
rootCmd.Flags().String(internal.CrusoeSecretKeyFlag, "", "help for secret key")
rootCmd.Flags().String(internal.CrusoeProjectIDFlag, "", "help for project id")
rootCmd.Flags().String(internal.CrusoeAPIEndpointFlag, internal.CrusoeAPIEndpointDefault, "Crusoe API endpoint")
rootCmd.Flags().String(internal.CrusoeAccessKeyFlag, "", "Crusoe Access Key")
rootCmd.Flags().String(internal.CrusoeSecretKeyFlag, "", "Crusoe Secret Key")
rootCmd.Flags().String(internal.CrusoeProjectIDFlag, "", "Cluster Project ID")
rootCmd.Flags().Var(
enumflag.New(&internal.SelectedCSIDriverType,
internal.CSIDriverTypeFlag,
Expand All @@ -54,8 +54,8 @@ func setFlags() {
true),
internal.ServicesFlag,
"help for services")
rootCmd.Flags().String(internal.NodeNameFlag, "", "help for kubernetes node name")
rootCmd.Flags().String(internal.SocketAddressFlag, internal.SocketAddressDefault, "help for socket address")
rootCmd.Flags().String(internal.NodeNameFlag, "", "Kubernetes Node Name")
rootCmd.Flags().String(internal.SocketAddressFlag, internal.SocketAddressDefault, "CSI Socket Address")

err = viper.BindPFlags(rootCmd.Flags())
if err != nil {
Expand Down
19 changes: 14 additions & 5 deletions internal/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"net/url"
"os"
"os/signal"
"strings"
"sync"
"syscall"

Expand Down Expand Up @@ -40,13 +41,16 @@ import (
const (
projectIDEnvKey = "CRUSOE_PROJECT_ID"
projectIDLabelKey = "crusoe.ai/project.id"

numExpectedComponents = 2
)

var (
errInstanceNotFound = errors.New("instance not found")
errMultipleInstances = errors.New("multiple instances found")
errProjectIDNotFound = fmt.Errorf("project ID not found in %s env var or %s node label",
projectIDEnvKey, projectIDLabelKey)
errInvalidNodeName = errors.New("invalid node name")
)

func interruptHandler() (*sync.WaitGroup, context.Context) {
Expand Down Expand Up @@ -80,7 +84,12 @@ func getHostInstance(ctx context.Context) (*crusoeapi.InstanceV1Alpha5, error) {
"crusoe-csi-driver/0.0.1",
)

nodeName := viper.GetString(NodeNameFlag)
fullNodeName := viper.GetString(NodeNameFlag)
nodeNameSplit := strings.SplitN(fullNodeName, ".", numExpectedComponents)
if len(nodeNameSplit) < 1 {
return nil, errInvalidNodeName
}
shortNodeName := nodeNameSplit[0]

var projectID string

Expand All @@ -96,7 +105,7 @@ func getHostInstance(ctx context.Context) (*crusoeapi.InstanceV1Alpha5, error) {
if err != nil {
return nil, fmt.Errorf("could not get kube client: %w", err)
}
hostNode, nodeFetchErr := kubeClient.CoreV1().Nodes().Get(ctx, nodeName, metav1.GetOptions{})
hostNode, nodeFetchErr := kubeClient.CoreV1().Nodes().Get(ctx, fullNodeName, metav1.GetOptions{})
if nodeFetchErr != nil {
return nil, fmt.Errorf("could not fetch current node with kube client: %w", err)
}
Expand All @@ -109,16 +118,16 @@ func getHostInstance(ctx context.Context) (*crusoeapi.InstanceV1Alpha5, error) {

instances, _, err := crusoeClient.VMsApi.ListInstances(ctx, projectID,
&crusoeapi.VMsApiListInstancesOpts{
Names: optional.NewString(nodeName),
Names: optional.NewString(shortNodeName),
})
if err != nil {
return nil, fmt.Errorf("failed to list instances: %w", err)
}

if len(instances.Items) == 0 {
return nil, fmt.Errorf("%w: %s", errInstanceNotFound, nodeName)
return nil, fmt.Errorf("%w: %s", errInstanceNotFound, shortNodeName)
} else if len(instances.Items) > 1 {
return nil, fmt.Errorf("%w: %s", errMultipleInstances, nodeName)
return nil, fmt.Errorf("%w: %s", errMultipleInstances, shortNodeName)
}

return &instances.Items[0], nil
Expand Down

0 comments on commit ef9f7c7

Please sign in to comment.