Skip to content

Commit

Permalink
Add drain instance functionality
Browse files Browse the repository at this point in the history
Adds functionality to list running tasks in a contianer instance. Listed tasks
are filtered to ensure instances running non-service tasks are not drained and
that instances not running tasks are passed to draining operation in preperation
for performing updates.

[Issue: #8]
  • Loading branch information
Will Moore committed Apr 13, 2021
1 parent ef2553a commit 430af79
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 11 deletions.
117 changes: 108 additions & 9 deletions updater/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ package main

import (
"encoding/json"
"errors"
"fmt"
"log"
"strings"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/ecs"
Expand All @@ -16,7 +18,7 @@ const (

func listContainerInstances(ecsClient *ecs.ECS, cluster string, pageSize int64) ([]*string, error) {
resp, err := ecsClient.ListContainerInstances(&ecs.ListContainerInstancesInput{
Cluster: &cluster,
Cluster: &cluster,
MaxResults: aws.Int64(pageSize),
})
if err != nil {
Expand All @@ -26,25 +28,25 @@ func listContainerInstances(ecsClient *ecs.ECS, cluster string, pageSize int64)
return resp.ContainerInstanceArns, nil
}

func filterBottlerocketInstances(ecsClient *ecs.ECS, cluster string, instances []*string) ([]*string, error) {
func filterBottlerocketInstances(ecsClient *ecs.ECS, cluster string, instances []*string) ([]*string, map[*string]*string, error) {
resp, err := ecsClient.DescribeContainerInstances(&ecs.DescribeContainerInstancesInput{
Cluster: &cluster, ContainerInstances: instances,
})
if err != nil {
return nil, fmt.Errorf("cannot describe container instances: %#v", err)
return nil, nil, fmt.Errorf("cannot describe container instances: %#v", err)
}

log.Printf("Container descriptions: %#v", resp)

ec2IDs := make([]*string, 0)
// Check the DescribeInstances response for Bottlerocket container instances, add them to ec2ids if detected
mapEC2IDtoContainerARN := make(map[*string]*string)
// Check the DescribeInstances response for Bottlerocket nodes, add them to ec2ids if detected
for _, instance := range resp.ContainerInstances {
if containsAttribute(instance.Attributes, "bottlerocket.variant") {
ec2IDs = append(ec2IDs, instance.Ec2InstanceId)
mapEC2IDtoContainerARN[instance.Ec2InstanceId] = instance.ContainerInstanceArn
log.Printf("Bottlerocket instance detected. Instance %#v added to check updates", *instance.Ec2InstanceId)
}
}
return ec2IDs, nil
return ec2IDs, mapEC2IDtoContainerARN, nil
}

// checks if ECS Attributes struct contains a specified string
Expand All @@ -57,8 +59,105 @@ func containsAttribute(attrs []*ecs.Attribute, searchString string) bool {
return false
}

// Checks tasks to determine if task was started by a service and drains the instances if so. Instances running no tasks at all are also drained.
// Instances running non-service tasks will not be drained and are not supported by ECS Updater.
func checkTasksandDrain(ecsClient *ecs.ECS, instanceIDs []string, instanceIDmap map[*string]*string, cluster *string) ([]*string, error) {
// Ensures only instances ready for update have tasks listed
updateMap := make(map[string]string)
for key, value := range instanceIDmap {
if sliceContains(instanceIDs, *key) {
updateMap[*key] = *value
}
}

drainedInstances := make([]*string, 0)
if len(updateMap) != 0 {
for ec2ID, containerARN := range updateMap {
resp, err := ecsClient.ListTasks(&ecs.ListTasksInput{
Cluster: cluster,
ContainerInstance: aws.String(containerARN),
})
if err != nil {
log.Printf("failed to list tasks: %#v", err)
return nil, err
}

log.Printf("Tasks running on %v: %s", containerARN, resp)
readyToDrain := make(map[string]string)
if len(resp.TaskArns) == 0 {
err := changeInstanceState(ecsClient, aws.String(containerARN), cluster, "DRAINING")
if err != nil {
log.Printf("failed to drain. Instance %#v reactivating. Error: %#v", ec2ID, err)
changeInstanceState(ecsClient, aws.String(containerARN), flagCluster, "ACTIVE")
}
drainedInstances = append(drainedInstances, &ec2ID)
} else if checkForServiceStarter(ecsClient, resp, cluster) {
readyToDrain[ec2ID] = containerARN
} else {
delete(readyToDrain, ec2ID)
}
for instanceID, cARN := range readyToDrain {
err := changeInstanceState(ecsClient, aws.String(cARN), cluster, "DRAINING")
if err != nil {
log.Printf("failed to drain. Instance %#v reactivating. Error: %#v", instanceID, err)
changeInstanceState(ecsClient, aws.String(cARN), cluster, "ACTIVE")
}
drainedInstances = append(drainedInstances, &instanceID)
}
}
} else {
log.Printf("No instances to check for tasks")
}
return drainedInstances, nil
}

func checkForServiceStarter(ecsClient *ecs.ECS, taskArns *ecs.ListTasksOutput, cluster *string) bool {
for _, taskARN := range taskArns.TaskArns {
resp, err := ecsClient.DescribeTasks(&ecs.DescribeTasksInput{
Cluster: cluster,
Tasks: []*string{taskARN},
})
if err != nil {
log.Printf("Could not describe task: %#v", taskARN)
}
for _, listResult := range resp.Tasks {
startedBy := aws.StringValue(listResult.StartedBy)
if !strings.Contains(startedBy, "ecs-svc") {
return false
}
}
}
return true
}

func changeInstanceState(ecsClient *ecs.ECS, containerARN, cluster *string, desiredState string) error {
resp, err := ecsClient.UpdateContainerInstancesState(&ecs.UpdateContainerInstancesStateInput{
Cluster: cluster,
ContainerInstances: []*string{containerARN},
Status: aws.String(desiredState),
})
if err != nil {
log.Printf("failed to update container instance state: %#v", err)
return err
}
if len(resp.Failures) == 0 {
log.Printf("Container instance state changed to %s", desiredState)
}
return nil
}

// checks if slice contains string
func sliceContains(s []string, str string) bool {
for _, v := range s {
if v == str {
return true
}
}
return false
}

func sendCommand(ssmClient *ssm.SSM, instanceIDs []*string, ssmCommand string) (string, error) {
log.Printf("Checking InstanceIDs: %#v", instanceIDs)
log.Printf("Checking InstanceIDs: %#v", &instanceIDs)

resp, err := ssmClient.SendCommand(&ssm.SendCommandInput{
DocumentName: aws.String("AWS-RunShellScript"),
Expand Down Expand Up @@ -114,7 +213,7 @@ func checkCommandOutput(ssmClient *ssm.SSM, commandID string, instanceIDs []*str
}

if updateCandidates == nil {
log.Printf("No instances to update")
return nil, errors.New("No instances to update")
}
return updateCandidates, nil
}
16 changes: 14 additions & 2 deletions updater/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,15 @@ func _main() error {
sess := session.Must(session.NewSession(&aws.Config{
Region: aws.String(*flagRegion),
}))
ecsClient := ecs.New(sess)
ecsClient := ecs.New(sess, aws.NewConfig().WithLogLevel(aws.LogDebugWithHTTPBody))
ssmClient := ssm.New(sess, aws.NewConfig().WithLogLevel(aws.LogDebugWithHTTPBody))

instances, err := listContainerInstances(ecsClient, *flagCluster, pageSize)
if err != nil {
return err
}

bottlerocketInstances, err := filterBottlerocketInstances(ecsClient, *flagCluster, instances)
bottlerocketInstances, instanceMap, err := filterBottlerocketInstances(ecsClient, *flagCluster, instances)
if err != nil {
return err
}
Expand All @@ -68,5 +68,17 @@ func _main() error {
}

fmt.Println("Instances ready for update: ", instancesToUpdate)

drainedInstances, err := checkTasksandDrain(ecsClient, instancesToUpdate, instanceMap, flagCluster)
if err != nil {
fmt.Fprintf(os.Stderr, err.Error())
return err
}

for _, v := range drainedInstances {
log.Printf("Instance %s drained\n", *v)
}
log.Printf("Total instances drained: %#v", len(drainedInstances))

return nil
}

0 comments on commit 430af79

Please sign in to comment.