diff --git a/updater/aws.go b/updater/aws.go new file mode 100644 index 0000000..39c5f87 --- /dev/null +++ b/updater/aws.go @@ -0,0 +1,120 @@ +package main + +import ( + "encoding/json" + "fmt" + "log" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/ecs" + "github.com/aws/aws-sdk-go/service/ssm" +) + +const ( + pageSize = 50 +) + +func listContainerInstances(ecsClient *ecs.ECS, cluster string, pageSize int64) ([]*string, error) { + resp, err := ecsClient.ListContainerInstances(&ecs.ListContainerInstancesInput{ + Cluster: &cluster, + MaxResults: aws.Int64(pageSize), + }) + if err != nil { + return nil, fmt.Errorf("cannot list container instances: %#v", err) + } + log.Printf("%#v", resp) + return resp.ContainerInstanceArns, nil +} + +func filterBottlerocketInstances(ecsClient *ecs.ECS, cluster string, instances []*string) ([]*string, error) { + resp, err := ecsClient.DescribeContainerInstances(&ecs.DescribeContainerInstancesInput{ + Cluster: &cluster, ContainerInstances: instances, + }) + if err != nil { + return nil, fmt.Errorf("cannot describe container instances: %#v", err) + } + + log.Printf("Container descriptions: %#v", resp) + + ec2IDs := make([]*string, 0) + // Check the DescribeInstances response for Bottlerocket container instances, add them to ec2ids if detected + for _, instance := range resp.ContainerInstances { + if containsAttribute(instance.Attributes, "bottlerocket.variant") { + ec2IDs = append(ec2IDs, instance.Ec2InstanceId) + log.Printf("Bottlerocket instance detected. Instance %#v added to check updates", *instance.Ec2InstanceId) + } + } + return ec2IDs, nil +} + +// checks if ECS Attributes struct contains a specified string +func containsAttribute(attrs []*ecs.Attribute, searchString string) bool { + for _, attr := range attrs { + if aws.StringValue(attr.Name) == searchString { + return true + } + } + return false +} + +func sendCommand(ssmClient *ssm.SSM, instanceIDs []*string, ssmCommand string) (string, error) { + log.Printf("Checking InstanceIDs: %#v", instanceIDs) + + resp, err := ssmClient.SendCommand(&ssm.SendCommandInput{ + DocumentName: aws.String("AWS-RunShellScript"), + DocumentVersion: aws.String("$DEFAULT"), + InstanceIds: instanceIDs, + Parameters: map[string][]*string{ + "commands": {aws.String(ssmCommand)}, + }, + }) + if err != nil { + return "", fmt.Errorf("command invocation failed: %#v", err) + } + + commandID := *resp.Command.CommandId + // Wait for the sent commands to complete + // TODO Update this to use WaitGroups + for _, v := range instanceIDs { + ssmClient.WaitUntilCommandExecuted(&ssm.GetCommandInvocationInput{ + CommandId: &commandID, + InstanceId: v, + }) + } + log.Printf("CommandID: %#v", commandID) + return commandID, nil +} + +func checkCommandOutput(ssmClient *ssm.SSM, commandID string, instanceIDs []*string) ([]string, error) { + updateCandidates := make([]string, 0) + for _, v := range instanceIDs { + resp, err := ssmClient.GetCommandInvocation(&ssm.GetCommandInvocationInput{ + CommandId: aws.String(commandID), + InstanceId: v, + }) + if err != nil { + return nil, fmt.Errorf("failed to retreive command invocation output: %#v", err) + } + + type updateCheckResult struct { + UpdateState string `json:"update_state"` + } + + var result updateCheckResult + err = json.Unmarshal([]byte(*resp.StandardOutputContent), &result) + if err != nil { + log.Printf("failed to unmarshal command invocation output: %#v", err) + } + log.Println("update_state: ", result) + + switch result.UpdateState { + case "Available": + updateCandidates = append(updateCandidates, *v) + } + } + + if updateCandidates == nil { + log.Printf("No instances to update") + } + return updateCandidates, nil +} diff --git a/updater/main.go b/updater/main.go index 08a6bfc..c9bc82b 100644 --- a/updater/main.go +++ b/updater/main.go @@ -1,6 +1,7 @@ package main import ( + "errors" "flag" "fmt" "log" @@ -9,6 +10,7 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/ecs" + "github.com/aws/aws-sdk-go/service/ssm" ) var ( @@ -17,81 +19,54 @@ var ( ) func main() { + if err := _main(); err != nil { + fmt.Fprintf(os.Stderr, err.Error()) + os.Exit(1) + } +} + +func _main() error { flag.Parse() switch { case *flagCluster == "": - log.Println("cluster is required") flag.Usage() - os.Exit(1) + return errors.New("cluster is required") case *flagRegion == "": - log.Println("region is required") flag.Usage() - os.Exit(1) + return errors.New("region is required") } sess := session.Must(session.NewSession(&aws.Config{ Region: aws.String(*flagRegion), })) ecsClient := ecs.New(sess) + ssmClient := ssm.New(sess, aws.NewConfig().WithLogLevel(aws.LogDebugWithHTTPBody)) - instances, err := listContainerInstances(ecsClient, *flagCluster) + instances, err := listContainerInstances(ecsClient, *flagCluster, pageSize) if err != nil { - fmt.Fprintf(os.Stderr, err.Error()) - os.Exit(1) + return err } bottlerocketInstances, err := filterBottlerocketInstances(ecsClient, *flagCluster, instances) if err != nil { - fmt.Fprintf(os.Stderr, err.Error()) - os.Exit(1) + return err } if len(bottlerocketInstances) == 0 { log.Printf("No Bottlerocket instances detected") + return nil } -} -func listContainerInstances(ecsClient *ecs.ECS, cluster string) ([]*string, error) { - resp, err := ecsClient.ListContainerInstances(&ecs.ListContainerInstancesInput{Cluster: &cluster}) + commandID, err := sendCommand(ssmClient, bottlerocketInstances, "apiclient update check") if err != nil { - return nil, fmt.Errorf("Cannot list container instances: %#v", err) + return err } - log.Printf("%#v", resp) - var values []*string - - for _, v := range resp.ContainerInstanceArns { - values = append(values, v) - } - return values, nil -} -func filterBottlerocketInstances(ecsClient *ecs.ECS, cluster string, instances []*string) ([]string, error) { - resp, err := ecsClient.DescribeContainerInstances(&ecs.DescribeContainerInstancesInput{ - Cluster: &cluster, ContainerInstances: instances, - }) + instancesToUpdate, err := checkCommandOutput(ssmClient, commandID, bottlerocketInstances) if err != nil { - return nil, fmt.Errorf("Cannot describe container instances: %#v", err) + return err } - log.Printf("Container descriptions: %#v", resp) - var ec2IDs []string - - //Check the DescribeInstances response for Bottlerocket nodes, add them to ec2ids if detected - for _, instance := range resp.ContainerInstances { - if containsAttribute(instance.Attributes, "bottlerocket.variant") { - ec2IDs = append(ec2IDs, *instance.Ec2InstanceId) - log.Printf("Bottlerocket instance detected. Instance %#v added to check updates", *instance.Ec2InstanceId) - } - } - return ec2IDs, nil -} - -// checks if ECS Attributes struct contains a specified string -func containsAttribute(attrs []*ecs.Attribute, searchString string) bool { - for _, attr := range attrs { - if aws.StringValue(attr.Name) == searchString { - return true - } - } - return false + fmt.Println("Instances ready for update: ", instancesToUpdate) + return nil }