From 430af79b2834f852310d218cad7a7475e1f4f645 Mon Sep 17 00:00:00 2001 From: Will Moore Date: Fri, 9 Apr 2021 08:58:49 -0700 Subject: [PATCH] Add drain instance functionality Adds functionality to list running tasks in a contianer instance. Listed tasks are filtered to ensure instances running non-service tasks are not drained and that instances not running tasks are passed to draining operation in preperation for performing updates. [Issue: https://github.com/bottlerocket-os/bottlerocket-ecs-updater/issues/8] --- updater/aws.go | 117 ++++++++++++++++++++++++++++++++++++++++++++---- updater/main.go | 16 ++++++- 2 files changed, 122 insertions(+), 11 deletions(-) diff --git a/updater/aws.go b/updater/aws.go index 39c5f87..87c4909 100644 --- a/updater/aws.go +++ b/updater/aws.go @@ -2,8 +2,10 @@ package main import ( "encoding/json" + "errors" "fmt" "log" + "strings" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/ecs" @@ -16,7 +18,7 @@ const ( func listContainerInstances(ecsClient *ecs.ECS, cluster string, pageSize int64) ([]*string, error) { resp, err := ecsClient.ListContainerInstances(&ecs.ListContainerInstancesInput{ - Cluster: &cluster, + Cluster: &cluster, MaxResults: aws.Int64(pageSize), }) if err != nil { @@ -26,25 +28,25 @@ func listContainerInstances(ecsClient *ecs.ECS, cluster string, pageSize int64) return resp.ContainerInstanceArns, nil } -func filterBottlerocketInstances(ecsClient *ecs.ECS, cluster string, instances []*string) ([]*string, error) { +func filterBottlerocketInstances(ecsClient *ecs.ECS, cluster string, instances []*string) ([]*string, map[*string]*string, error) { resp, err := ecsClient.DescribeContainerInstances(&ecs.DescribeContainerInstancesInput{ Cluster: &cluster, ContainerInstances: instances, }) if err != nil { - return nil, fmt.Errorf("cannot describe container instances: %#v", err) + return nil, nil, fmt.Errorf("cannot describe container instances: %#v", err) } - log.Printf("Container descriptions: %#v", resp) - ec2IDs := make([]*string, 0) - // Check the DescribeInstances response for Bottlerocket container instances, add them to ec2ids if detected + mapEC2IDtoContainerARN := make(map[*string]*string) + // Check the DescribeInstances response for Bottlerocket nodes, add them to ec2ids if detected for _, instance := range resp.ContainerInstances { if containsAttribute(instance.Attributes, "bottlerocket.variant") { ec2IDs = append(ec2IDs, instance.Ec2InstanceId) + mapEC2IDtoContainerARN[instance.Ec2InstanceId] = instance.ContainerInstanceArn log.Printf("Bottlerocket instance detected. Instance %#v added to check updates", *instance.Ec2InstanceId) } } - return ec2IDs, nil + return ec2IDs, mapEC2IDtoContainerARN, nil } // checks if ECS Attributes struct contains a specified string @@ -57,8 +59,105 @@ func containsAttribute(attrs []*ecs.Attribute, searchString string) bool { return false } +// Checks tasks to determine if task was started by a service and drains the instances if so. Instances running no tasks at all are also drained. +// Instances running non-service tasks will not be drained and are not supported by ECS Updater. +func checkTasksandDrain(ecsClient *ecs.ECS, instanceIDs []string, instanceIDmap map[*string]*string, cluster *string) ([]*string, error) { + // Ensures only instances ready for update have tasks listed + updateMap := make(map[string]string) + for key, value := range instanceIDmap { + if sliceContains(instanceIDs, *key) { + updateMap[*key] = *value + } + } + + drainedInstances := make([]*string, 0) + if len(updateMap) != 0 { + for ec2ID, containerARN := range updateMap { + resp, err := ecsClient.ListTasks(&ecs.ListTasksInput{ + Cluster: cluster, + ContainerInstance: aws.String(containerARN), + }) + if err != nil { + log.Printf("failed to list tasks: %#v", err) + return nil, err + } + + log.Printf("Tasks running on %v: %s", containerARN, resp) + readyToDrain := make(map[string]string) + if len(resp.TaskArns) == 0 { + err := changeInstanceState(ecsClient, aws.String(containerARN), cluster, "DRAINING") + if err != nil { + log.Printf("failed to drain. Instance %#v reactivating. Error: %#v", ec2ID, err) + changeInstanceState(ecsClient, aws.String(containerARN), flagCluster, "ACTIVE") + } + drainedInstances = append(drainedInstances, &ec2ID) + } else if checkForServiceStarter(ecsClient, resp, cluster) { + readyToDrain[ec2ID] = containerARN + } else { + delete(readyToDrain, ec2ID) + } + for instanceID, cARN := range readyToDrain { + err := changeInstanceState(ecsClient, aws.String(cARN), cluster, "DRAINING") + if err != nil { + log.Printf("failed to drain. Instance %#v reactivating. Error: %#v", instanceID, err) + changeInstanceState(ecsClient, aws.String(cARN), cluster, "ACTIVE") + } + drainedInstances = append(drainedInstances, &instanceID) + } + } + } else { + log.Printf("No instances to check for tasks") + } + return drainedInstances, nil +} + +func checkForServiceStarter(ecsClient *ecs.ECS, taskArns *ecs.ListTasksOutput, cluster *string) bool { + for _, taskARN := range taskArns.TaskArns { + resp, err := ecsClient.DescribeTasks(&ecs.DescribeTasksInput{ + Cluster: cluster, + Tasks: []*string{taskARN}, + }) + if err != nil { + log.Printf("Could not describe task: %#v", taskARN) + } + for _, listResult := range resp.Tasks { + startedBy := aws.StringValue(listResult.StartedBy) + if !strings.Contains(startedBy, "ecs-svc") { + return false + } + } + } + return true +} + +func changeInstanceState(ecsClient *ecs.ECS, containerARN, cluster *string, desiredState string) error { + resp, err := ecsClient.UpdateContainerInstancesState(&ecs.UpdateContainerInstancesStateInput{ + Cluster: cluster, + ContainerInstances: []*string{containerARN}, + Status: aws.String(desiredState), + }) + if err != nil { + log.Printf("failed to update container instance state: %#v", err) + return err + } + if len(resp.Failures) == 0 { + log.Printf("Container instance state changed to %s", desiredState) + } + return nil +} + +// checks if slice contains string +func sliceContains(s []string, str string) bool { + for _, v := range s { + if v == str { + return true + } + } + return false +} + func sendCommand(ssmClient *ssm.SSM, instanceIDs []*string, ssmCommand string) (string, error) { - log.Printf("Checking InstanceIDs: %#v", instanceIDs) + log.Printf("Checking InstanceIDs: %#v", &instanceIDs) resp, err := ssmClient.SendCommand(&ssm.SendCommandInput{ DocumentName: aws.String("AWS-RunShellScript"), @@ -114,7 +213,7 @@ func checkCommandOutput(ssmClient *ssm.SSM, commandID string, instanceIDs []*str } if updateCandidates == nil { - log.Printf("No instances to update") + return nil, errors.New("No instances to update") } return updateCandidates, nil } diff --git a/updater/main.go b/updater/main.go index c9bc82b..906f4f3 100644 --- a/updater/main.go +++ b/updater/main.go @@ -39,7 +39,7 @@ func _main() error { sess := session.Must(session.NewSession(&aws.Config{ Region: aws.String(*flagRegion), })) - ecsClient := ecs.New(sess) + ecsClient := ecs.New(sess, aws.NewConfig().WithLogLevel(aws.LogDebugWithHTTPBody)) ssmClient := ssm.New(sess, aws.NewConfig().WithLogLevel(aws.LogDebugWithHTTPBody)) instances, err := listContainerInstances(ecsClient, *flagCluster, pageSize) @@ -47,7 +47,7 @@ func _main() error { return err } - bottlerocketInstances, err := filterBottlerocketInstances(ecsClient, *flagCluster, instances) + bottlerocketInstances, instanceMap, err := filterBottlerocketInstances(ecsClient, *flagCluster, instances) if err != nil { return err } @@ -68,5 +68,17 @@ func _main() error { } fmt.Println("Instances ready for update: ", instancesToUpdate) + + drainedInstances, err := checkTasksandDrain(ecsClient, instancesToUpdate, instanceMap, flagCluster) + if err != nil { + fmt.Fprintf(os.Stderr, err.Error()) + return err + } + + for _, v := range drainedInstances { + log.Printf("Instance %s drained\n", *v) + } + log.Printf("Total instances drained: %#v", len(drainedInstances)) + return nil }