From c13d549436fbe28740e126fafa174a5a1f361c3f Mon Sep 17 00:00:00 2001 From: Will Moore Date: Fri, 9 Apr 2021 08:58:49 -0700 Subject: [PATCH] Add drain instance functionality Adds functionality to list running tasks in a contianer instance. Listed tasks are filtered to ensure instances running non-service tasks are not drained and that instances not running tasks are passed to draining operation in preperation for performing updates. [Issue: https://github.com/bottlerocket-os/bottlerocket-ecs-updater/issues/8] --- updater/aws.go | 168 +++++++++++++++++++++++++++++++++++++++++------- updater/main.go | 44 ++++++++++--- 2 files changed, 179 insertions(+), 33 deletions(-) diff --git a/updater/aws.go b/updater/aws.go index 39c5f87..0357a60 100644 --- a/updater/aws.go +++ b/updater/aws.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "log" + "strings" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/ecs" @@ -14,9 +15,9 @@ const ( pageSize = 50 ) -func listContainerInstances(ecsClient *ecs.ECS, cluster string, pageSize int64) ([]*string, error) { - resp, err := ecsClient.ListContainerInstances(&ecs.ListContainerInstancesInput{ - Cluster: &cluster, +func (up *updater) listContainerInstances() ([]*string, error) { + resp, err := up.ecs.ListContainerInstances(&ecs.ListContainerInstancesInput{ + Cluster: &up.cluster, MaxResults: aws.Int64(pageSize), }) if err != nil { @@ -26,28 +27,29 @@ func listContainerInstances(ecsClient *ecs.ECS, cluster string, pageSize int64) return resp.ContainerInstanceArns, nil } -func filterBottlerocketInstances(ecsClient *ecs.ECS, cluster string, instances []*string) ([]*string, error) { - resp, err := ecsClient.DescribeContainerInstances(&ecs.DescribeContainerInstancesInput{ - Cluster: &cluster, ContainerInstances: instances, +// filterBottlerocketInstances returns a map of EC2 instance IDs to container instance ARNs +// provided as input where the container instance is a Bottlerocket host. +func (u *updater) filterBottlerocketInstances(instances []*string) (map[string]string, error) { + resp, err := u.ecs.DescribeContainerInstances(&ecs.DescribeContainerInstancesInput{ + Cluster: &u.cluster, + ContainerInstances: instances, }) if err != nil { return nil, fmt.Errorf("cannot describe container instances: %#v", err) } - log.Printf("Container descriptions: %#v", resp) - - ec2IDs := make([]*string, 0) - // Check the DescribeInstances response for Bottlerocket container instances, add them to ec2ids if detected + ec2IDtoECSARN := make(map[string]string) + // Check the DescribeInstances response for Bottlerocket nodes, add them to map if detected. for _, instance := range resp.ContainerInstances { if containsAttribute(instance.Attributes, "bottlerocket.variant") { - ec2IDs = append(ec2IDs, instance.Ec2InstanceId) + ec2IDtoECSARN[*instance.Ec2InstanceId] = *instance.ContainerInstanceArn log.Printf("Bottlerocket instance detected. Instance %#v added to check updates", *instance.Ec2InstanceId) } } - return ec2IDs, nil + return ec2IDtoECSARN, nil } -// checks if ECS Attributes struct contains a specified string +// containsAttribute checks if an ECS Attributes struct contains a specified string. func containsAttribute(attrs []*ecs.Attribute, searchString string) bool { for _, attr := range attrs { if aws.StringValue(attr.Name) == searchString { @@ -57,13 +59,128 @@ func containsAttribute(attrs []*ecs.Attribute, searchString string) bool { return false } -func sendCommand(ssmClient *ssm.SSM, instanceIDs []*string, ssmCommand string) (string, error) { - log.Printf("Checking InstanceIDs: %#v", instanceIDs) +// checkAndDrain drains eligible container instances. Container instances are eligible if all running +// tasks were started by a service, or if there are no running tasks. +func (u *updater) checkAndDrain(containerInstance string) error { + if !eligible(u.ecs, &containerInstance, &u.cluster) { + return nil + } + err := drainInstance(u.ecs, aws.String(containerInstance), &u.cluster) + if err != nil { + return err + } + return nil +} + +func eligible(ecsClient *ecs.ECS, containerInstance, cluster *string) bool { + list, err := ecsClient.ListTasks(&ecs.ListTasksInput{ + Cluster: cluster, + ContainerInstance: containerInstance, + }) + if err != nil { + return false + } + + taskARNs := list.TaskArns + if len(list.TaskArns) == 0 { + return true + } + + desc, err := ecsClient.DescribeTasks(&ecs.DescribeTasksInput{ + Cluster: cluster, + Tasks: taskARNs, + }) + if err != nil { + log.Printf("Could not describe tasks") + return false + } + + for _, listResult := range desc.Tasks { + startedBy := aws.StringValue(listResult.StartedBy) + if !strings.HasPrefix(startedBy, "ecs-svc/") { + return false + } + } + return true +} + +func drainInstance(ecsClient *ecs.ECS, containerInstance, cluster *string) error { + resp, err := ecsClient.UpdateContainerInstancesState(&ecs.UpdateContainerInstancesStateInput{ + Cluster: cluster, + ContainerInstances: []*string{containerInstance}, + Status: aws.String("DRAINING"), + }) + if err != nil { + log.Printf("failed to update container instance %s state to DRAINING: %#v", containerInstance, err) + return err + } + if len(resp.Failures) == 0 { + log.Printf("Container instance state changed to DRAINING") + } else { + log.Printf("Container instance %s failed to drain: %#v", containerInstance, resp.Failures) + // TODO Determine if instance should be reactivated here + return nil + } + + err = waitUntilDrained(ecsClient, *containerInstance, cluster) + if err != nil { + activateInstance(ecsClient, containerInstance, cluster) + return err + } + return nil +} + +func activateInstance(ecsClient *ecs.ECS, containerInstance, cluster *string) error { + resp, err := ecsClient.UpdateContainerInstancesState(&ecs.UpdateContainerInstancesStateInput{ + Cluster: cluster, + ContainerInstances: []*string{containerInstance}, + Status: aws.String("ACTIVE"), + }) + if err != nil { + log.Printf("failed to update container %s instance state to ACTIVE: %#v", containerInstance, err) + return err + } + if len(resp.Failures) == 0 { + log.Printf("Container instance state changed to ACTIVE") + } else { + log.Printf("Container instance %s failed to activate: %#v", containerInstance, resp.Failures) + } + return nil +} + +func waitUntilDrained(ecsClient *ecs.ECS, containerInstance string, cluster *string) error { + list, err := ecsClient.ListTasks(&ecs.ListTasksInput{ + Cluster: cluster, + ContainerInstance: aws.String(containerInstance), + }) + if err != nil { + log.Printf("failed to identify a task to wait on") + return err + } + + taskARNs := list.TaskArns + + if len(taskARNs) == 0 { + return nil + } + // TODO Tune MaxAttempts + err = ecsClient.WaitUntilTasksStopped(&ecs.DescribeTasksInput{ + Cluster: cluster, + Tasks: taskARNs, + }) + if err != nil { + return err + } + return nil +} + +func (u *updater) sendCommand(instanceIDs []string, ssmCommand string) (string, error) { + log.Printf("Checking InstanceIDs: %q", instanceIDs) - resp, err := ssmClient.SendCommand(&ssm.SendCommandInput{ + resp, err := u.ssm.SendCommand(&ssm.SendCommandInput{ DocumentName: aws.String("AWS-RunShellScript"), DocumentVersion: aws.String("$DEFAULT"), - InstanceIds: instanceIDs, + InstanceIds: aws.StringSlice(instanceIDs), Parameters: map[string][]*string{ "commands": {aws.String(ssmCommand)}, }, @@ -73,24 +190,24 @@ func sendCommand(ssmClient *ssm.SSM, instanceIDs []*string, ssmCommand string) ( } commandID := *resp.Command.CommandId - // Wait for the sent commands to complete + // Wait for the sent commands to complete. // TODO Update this to use WaitGroups for _, v := range instanceIDs { - ssmClient.WaitUntilCommandExecuted(&ssm.GetCommandInvocationInput{ + u.ssm.WaitUntilCommandExecuted(&ssm.GetCommandInvocationInput{ CommandId: &commandID, - InstanceId: v, + InstanceId: &v, }) } - log.Printf("CommandID: %#v", commandID) + log.Printf("CommandID: %s", commandID) return commandID, nil } -func checkCommandOutput(ssmClient *ssm.SSM, commandID string, instanceIDs []*string) ([]string, error) { +func (u *updater) checkSSMCommandOutput(commandID string, instanceIDs []string) ([]string, error) { updateCandidates := make([]string, 0) for _, v := range instanceIDs { - resp, err := ssmClient.GetCommandInvocation(&ssm.GetCommandInvocationInput{ + resp, err := u.ssm.GetCommandInvocation(&ssm.GetCommandInvocationInput{ CommandId: aws.String(commandID), - InstanceId: v, + InstanceId: aws.String(v), }) if err != nil { return nil, fmt.Errorf("failed to retreive command invocation output: %#v", err) @@ -109,12 +226,13 @@ func checkCommandOutput(ssmClient *ssm.SSM, commandID string, instanceIDs []*str switch result.UpdateState { case "Available": - updateCandidates = append(updateCandidates, *v) + updateCandidates = append(updateCandidates, v) } } if updateCandidates == nil { log.Printf("No instances to update") + return nil, nil } return updateCandidates, nil } diff --git a/updater/main.go b/updater/main.go index c9bc82b..8a5b58c 100644 --- a/updater/main.go +++ b/updater/main.go @@ -18,9 +18,15 @@ var ( flagRegion = flag.String("region", "", "The AWS Region in which cluster is running.") ) +type updater struct { + cluster string + ecs *ecs.ECS + ssm *ssm.SSM +} + func main() { if err := _main(); err != nil { - fmt.Fprintf(os.Stderr, err.Error()) + log.Println(err.Error()) os.Exit(1) } } @@ -39,34 +45,56 @@ func _main() error { sess := session.Must(session.NewSession(&aws.Config{ Region: aws.String(*flagRegion), })) - ecsClient := ecs.New(sess) - ssmClient := ssm.New(sess, aws.NewConfig().WithLogLevel(aws.LogDebugWithHTTPBody)) - instances, err := listContainerInstances(ecsClient, *flagCluster, pageSize) + up := &updater{ + cluster: *flagCluster, + ecs: ecs.New(sess, aws.NewConfig().WithLogLevel(aws.LogDebugWithHTTPBody)), + ssm: ssm.New(sess, aws.NewConfig().WithLogLevel(aws.LogDebugWithHTTPBody)), + } + + listedInstances, err := up.listContainerInstances() if err != nil { return err } - bottlerocketInstances, err := filterBottlerocketInstances(ecsClient, *flagCluster, instances) + bottlerocketInstanceMap, err := up.filterBottlerocketInstances(listedInstances) if err != nil { return err } - if len(bottlerocketInstances) == 0 { + if len(bottlerocketInstanceMap) == 0 { log.Printf("No Bottlerocket instances detected") return nil } - commandID, err := sendCommand(ssmClient, bottlerocketInstances, "apiclient update check") + // Make slice of Bottlerocket instances to use with SendCommand and checkCommandOutput + instances := make([]string, 0) + for instance, _ := range bottlerocketInstanceMap { + instances = append(instances, instance) + } + + commandID, err := up.sendCommand(instances, "apiclient update check") if err != nil { return err } - instancesToUpdate, err := checkCommandOutput(ssmClient, commandID, bottlerocketInstances) + instancesToUpdate, err := up.checkSSMCommandOutput(commandID, instances) if err != nil { return err } + if len(instancesToUpdate) == 0 { + log.Printf("No Instances to update") + return nil + } fmt.Println("Instances ready for update: ", instancesToUpdate) + + for ec2ID, containerARN := range bottlerocketInstanceMap { + err := up.checkAndDrain(containerARN) + if err != nil { + log.Printf("%#v", err) + } + log.Printf("Instance %s drained", ec2ID) + } return nil }