diff --git a/updater/aws.go b/updater/aws.go index 87c4909..f1c3081 100644 --- a/updater/aws.go +++ b/updater/aws.go @@ -59,9 +59,9 @@ func containsAttribute(attrs []*ecs.Attribute, searchString string) bool { return false } -// Checks tasks to determine if task was started by a service and drains the instances if so. Instances running no tasks at all are also drained. -// Instances running non-service tasks will not be drained and are not supported by ECS Updater. -func checkTasksandDrain(ecsClient *ecs.ECS, instanceIDs []string, instanceIDmap map[*string]*string, cluster *string) ([]*string, error) { +// Checks tasks to determine if task was started by a service and drains the instances if so. Instances running no tasks at all are also drained. +// Instances running non-service tasks will not be drained and are not supported by ECS Updater. +func checkTasksandDrain(ecsClient *ecs.ECS, instanceIDs []string, instanceIDmap map[*string]*string, cluster *string) ([]*string, map[string]string, error) { // Ensures only instances ready for update have tasks listed updateMap := make(map[string]string) for key, value := range instanceIDmap { @@ -79,7 +79,7 @@ func checkTasksandDrain(ecsClient *ecs.ECS, instanceIDs []string, instanceIDmap }) if err != nil { log.Printf("failed to list tasks: %#v", err) - return nil, err + return nil, nil, err } log.Printf("Tasks running on %v: %s", containerARN, resp) @@ -108,7 +108,7 @@ func checkTasksandDrain(ecsClient *ecs.ECS, instanceIDs []string, instanceIDmap } else { log.Printf("No instances to check for tasks") } - return drainedInstances, nil + return drainedInstances, updateMap, nil } func checkForServiceStarter(ecsClient *ecs.ECS, taskArns *ecs.ListTasksOutput, cluster *string) bool { @@ -184,7 +184,7 @@ func sendCommand(ssmClient *ssm.SSM, instanceIDs []*string, ssmCommand string) ( return commandID, nil } -func checkCommandOutput(ssmClient *ssm.SSM, commandID string, instanceIDs []*string) ([]string, error) { +func checkCommandOutput(ssmClient *ssm.SSM, commandID string, instanceIDs []*string) ([]string, string, error) { updateCandidates := make([]string, 0) for _, v := range instanceIDs { resp, err := ssmClient.GetCommandInvocation(&ssm.GetCommandInvocationInput{ @@ -192,9 +192,10 @@ func checkCommandOutput(ssmClient *ssm.SSM, commandID string, instanceIDs []*str InstanceId: v, }) if err != nil { - return nil, fmt.Errorf("failed to retreive command invocation output: %#v", err) + return nil, "", fmt.Errorf("failed to retreive command invocation output: %#v", err) } + type updateCheckResult struct { UpdateState string `json:"update_state"` } @@ -209,11 +210,38 @@ func checkCommandOutput(ssmClient *ssm.SSM, commandID string, instanceIDs []*str switch result.UpdateState { case "Available": updateCandidates = append(updateCandidates, *v) + case "Idle": + return nil, "Update applied", nil } } if updateCandidates == nil { - return nil, errors.New("No instances to update") + return nil, "", errors.New("No instances to update") + } + return updateCandidates, "", nil +} + +func checkUpdateResult(ssmClient *ssm.SSM, commandID string, instanceIDs []*string) ([]*string, error) { + updatedInstances := make([]*string, 0) + for _, v := range instanceIDs { + resp, err := ssmClient.GetCommandInvocation(&ssm.GetCommandInvocationInput{ + CommandId: aws.String(commandID), + InstanceId: v, + }) + if err != nil { + return nil, fmt.Errorf("failed to failed to retrieve command invocation output: %#v", err) + } + log.Printf("Response: %s", *resp.Status) + if *resp.Status == "Success" { + updatedInstances = append(updatedInstances, v) + } else { + log.Printf("Instance %#v failed to update", v) + } + } + if len(updatedInstances) != 0 { + return updatedInstances, nil + } else { + log.Printf("No instances were successfully updated") } - return updateCandidates, nil + return updatedInstances, nil } diff --git a/updater/main.go b/updater/main.go index 906f4f3..8609c2d 100644 --- a/updater/main.go +++ b/updater/main.go @@ -62,23 +62,54 @@ func _main() error { return err } - instancesToUpdate, err := checkCommandOutput(ssmClient, commandID, bottlerocketInstances) + instancesToUpdate, _, err := checkCommandOutput(ssmClient, commandID, bottlerocketInstances) if err != nil { return err } fmt.Println("Instances ready for update: ", instancesToUpdate) - drainedInstances, err := checkTasksandDrain(ecsClient, instancesToUpdate, instanceMap, flagCluster) + drainedInstances, updateMap, err := checkTasksandDrain(ecsClient, instancesToUpdate, instanceMap, flagCluster) if err != nil { - fmt.Fprintf(os.Stderr, err.Error()) return err } for _, v := range drainedInstances { - log.Printf("Instance %s drained\n", *v) + log.Printf("Instance %s drained", *v) } log.Printf("Total instances drained: %#v", len(drainedInstances)) + if len(drainedInstances) != 0 { + // Dedupes drainedInstances slice + deDupedInstances := sanitizeIDs(drainedInstances) + for _, v := range deDupedInstances { + log.Printf("ready to update: %#v", *v) + } + + updateCommandID, err := sendCommand(ssmClient, deDupedInstances, "apiclient update apply -r") + if err != nil { + return err + } + + _, err = checkUpdateResult(ssmClient, updateCommandID, deDupedInstances) + if err != nil { + return err + } + + updateStatus, err := sendCommand(ssmClient, deDupedInstances, "apiclient update check") + if err != nil { + return err + } + + _, _, err = checkCommandOutput(ssmClient, updateStatus, deDupedInstances) + if err != nil { + return err + } + + for _, updatedARN := range updateMap { + changeInstanceState(ecsClient, aws.String(updatedARN), flagCluster, "ACTIVE") + } + + } return nil } diff --git a/updater/utils.go b/updater/utils.go new file mode 100644 index 0000000..eb895ae --- /dev/null +++ b/updater/utils.go @@ -0,0 +1,17 @@ +package main + + +func sanitizeIDs(instanceIDs []*string) []*string { + // If the instanceID occurs the flag changes to true. + // All instanceIDs that has already occurred will be true. + flag := make(map[*string]bool) + var uniqueIDs []*string + for _, ID := range instanceIDs { + if flag[ID] == false { + flag[ID] = true + uniqueIDs = append(uniqueIDs, ID) + } + } + // unique names collected + return uniqueIDs +} \ No newline at end of file