Skip to content

Commit

Permalink
Add drain instance functionality
Browse files Browse the repository at this point in the history
Adds functionality to list running tasks in a contianer instance. Listed tasks
are filtered to ensure instances running non-service tasks are not drained and
that instances not running tasks are passed to draining operation in preperation
for performing updates.

[Issue: #8]
  • Loading branch information
Will Moore committed Apr 15, 2021
1 parent ef2553a commit e779565
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 34 deletions.
152 changes: 126 additions & 26 deletions updater/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ package main

import (
"encoding/json"
"errors"
"fmt"
"log"
"strings"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/ecs"
Expand All @@ -14,9 +16,9 @@ const (
pageSize = 50
)

func listContainerInstances(ecsClient *ecs.ECS, cluster string, pageSize int64) ([]*string, error) {
resp, err := ecsClient.ListContainerInstances(&ecs.ListContainerInstancesInput{
Cluster: &cluster,
func (up *updater) listContainerInstances() ([]*string, error) {
resp, err := up.ecs.ListContainerInstances(&ecs.ListContainerInstancesInput{
Cluster: &up.cluster,
MaxResults: aws.Int64(pageSize),
})
if err != nil {
Expand All @@ -26,28 +28,29 @@ func listContainerInstances(ecsClient *ecs.ECS, cluster string, pageSize int64)
return resp.ContainerInstanceArns, nil
}

func filterBottlerocketInstances(ecsClient *ecs.ECS, cluster string, instances []*string) ([]*string, error) {
resp, err := ecsClient.DescribeContainerInstances(&ecs.DescribeContainerInstancesInput{
Cluster: &cluster, ContainerInstances: instances,
// filterBottlerocketInstances returns a map of EC2 instance IDs to container instance ARNs
// provided as input where the container instance is a Bottlerocket host.
func (u *updater) filterBottlerocketInstances(instances []*string) (map[string]string, error) {
resp, err := u.ecs.DescribeContainerInstances(&ecs.DescribeContainerInstancesInput{
Cluster: &u.cluster,
ContainerInstances: instances,
})
if err != nil {
return nil, fmt.Errorf("cannot describe container instances: %#v", err)
}

log.Printf("Container descriptions: %#v", resp)

ec2IDs := make([]*string, 0)
// Check the DescribeInstances response for Bottlerocket container instances, add them to ec2ids if detected
ec2IDtoECSARN := make(map[string]string)
// Check the DescribeInstances response for Bottlerocket nodes, add them to map if detected.
for _, instance := range resp.ContainerInstances {
if containsAttribute(instance.Attributes, "bottlerocket.variant") {
ec2IDs = append(ec2IDs, instance.Ec2InstanceId)
ec2IDtoECSARN[*instance.Ec2InstanceId] = *instance.ContainerInstanceArn
log.Printf("Bottlerocket instance detected. Instance %#v added to check updates", *instance.Ec2InstanceId)
}
}
return ec2IDs, nil
return ec2IDtoECSARN, nil
}

// checks if ECS Attributes struct contains a specified string
// containsAttribute checks if an ECS Attributes struct contains a specified string.
func containsAttribute(attrs []*ecs.Attribute, searchString string) bool {
for _, attr := range attrs {
if aws.StringValue(attr.Name) == searchString {
Expand All @@ -57,13 +60,110 @@ func containsAttribute(attrs []*ecs.Attribute, searchString string) bool {
return false
}

func sendCommand(ssmClient *ssm.SSM, instanceIDs []*string, ssmCommand string) (string, error) {
log.Printf("Checking InstanceIDs: %#v", instanceIDs)
// checkAndDrain drains eligible container instances. Container instances are eligible if all running
// tasks were started by a service, or if there are no running tasks.
func (u *updater) checkAndDrain(containerInstance string) error {
if !eligible(u.ecs, &containerInstance, &u.cluster) {
return nil
}

err := changeInstanceState(u.ecs, aws.String(containerInstance), &u.cluster, "DRAINING")
if err != nil {
err2 := changeInstanceState(u.ecs, aws.String(containerInstance), &u.cluster, "ACTIVE")
if err2 != nil {
return fmt.Errorf("failed to undrain: %w", err2.Error())
}
return fmt.Errorf("failed to drain: %w", err)
}
return nil
}

func eligible(ecsClient *ecs.ECS, containerInstance, cluster *string) bool {
list, err := ecsClient.ListTasks(&ecs.ListTasksInput{
Cluster: cluster,
ContainerInstance: containerInstance,
})
if err != nil {
return false
}

taskARNs := list.TaskArns
if len(list.TaskArns) == 0 {
return true
}

desc, err := ecsClient.DescribeTasks(&ecs.DescribeTasksInput{
Cluster: cluster,
Tasks: taskARNs,
})
if err != nil {
log.Printf("Could not describe tasks")
return false
}

for _, listResult := range desc.Tasks {
startedBy := aws.StringValue(listResult.StartedBy)
if !strings.HasPrefix(startedBy, "ecs-svc/") {
return false
}
}
return true
}

func changeInstanceState(ecsClient *ecs.ECS, containerInstance, cluster *string, desiredState string) error {
resp, err := ecsClient.UpdateContainerInstancesState(&ecs.UpdateContainerInstancesStateInput{
Cluster: cluster,
ContainerInstances: []*string{containerInstance},
Status: aws.String(desiredState),
})
if err != nil {
log.Printf("failed to update container instance state: %#v", err)
return err
}
if desiredState == "DRAINING" {
err := waitUntilDrained(ecsClient, *containerInstance, cluster)
if err != nil {
return err
}
}
if len(resp.Failures) == 0 {
log.Printf("Container instance state changed to %s", desiredState)
}
return nil
}

func waitUntilDrained(ecsClient *ecs.ECS, containerInstance string, cluster *string) error {
list, err := ecsClient.ListTasks(&ecs.ListTasksInput{
Cluster: cluster,
ContainerInstance: aws.String(containerInstance),
})
if err != nil {
log.Printf("failed to identify a task to wait on")
return err
}

taskARNs := list.TaskArns

if len(taskARNs) == 0 {
return nil
}
err2 := ecsClient.WaitUntilTasksStopped(&ecs.DescribeTasksInput{
Cluster: cluster,
Tasks: taskARNs,
})
if err2 != nil {
return err
}
return nil
}

func (u *updater) sendCommand(instanceIDs []string, ssmCommand string) (string, error) {
log.Printf("Checking InstanceIDs: %q", instanceIDs)

resp, err := ssmClient.SendCommand(&ssm.SendCommandInput{
resp, err := u.ssm.SendCommand(&ssm.SendCommandInput{
DocumentName: aws.String("AWS-RunShellScript"),
DocumentVersion: aws.String("$DEFAULT"),
InstanceIds: instanceIDs,
InstanceIds: aws.StringSlice(instanceIDs),
Parameters: map[string][]*string{
"commands": {aws.String(ssmCommand)},
},
Expand All @@ -73,24 +173,24 @@ func sendCommand(ssmClient *ssm.SSM, instanceIDs []*string, ssmCommand string) (
}

commandID := *resp.Command.CommandId
// Wait for the sent commands to complete
// Wait for the sent commands to complete.
// TODO Update this to use WaitGroups
for _, v := range instanceIDs {
ssmClient.WaitUntilCommandExecuted(&ssm.GetCommandInvocationInput{
u.ssm.WaitUntilCommandExecuted(&ssm.GetCommandInvocationInput{
CommandId: &commandID,
InstanceId: v,
InstanceId: &v,
})
}
log.Printf("CommandID: %#v", commandID)
log.Printf("CommandID: %s", commandID)
return commandID, nil
}

func checkCommandOutput(ssmClient *ssm.SSM, commandID string, instanceIDs []*string) ([]string, error) {
func (u *updater) checkCommandOutput(commandID string, instanceIDs []string) ([]string, error) {
updateCandidates := make([]string, 0)
for _, v := range instanceIDs {
resp, err := ssmClient.GetCommandInvocation(&ssm.GetCommandInvocationInput{
resp, err := u.ssm.GetCommandInvocation(&ssm.GetCommandInvocationInput{
CommandId: aws.String(commandID),
InstanceId: v,
InstanceId: aws.String(v),
})
if err != nil {
return nil, fmt.Errorf("failed to retreive command invocation output: %#v", err)
Expand All @@ -109,12 +209,12 @@ func checkCommandOutput(ssmClient *ssm.SSM, commandID string, instanceIDs []*str

switch result.UpdateState {
case "Available":
updateCandidates = append(updateCandidates, *v)
updateCandidates = append(updateCandidates, v)
}
}

if updateCandidates == nil {
log.Printf("No instances to update")
return nil, errors.New("No instances to update")
}
return updateCandidates, nil
}
44 changes: 36 additions & 8 deletions updater/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,15 @@ var (
flagRegion = flag.String("region", "", "The AWS Region in which cluster is running.")
)

type updater struct {
cluster string
ecs *ecs.ECS
ssm *ssm.SSM
}

func main() {
if err := _main(); err != nil {
fmt.Fprintf(os.Stderr, err.Error())
log.Println(err.Error())
os.Exit(1)
}
}
Expand All @@ -39,34 +45,56 @@ func _main() error {
sess := session.Must(session.NewSession(&aws.Config{
Region: aws.String(*flagRegion),
}))
ecsClient := ecs.New(sess)
ssmClient := ssm.New(sess, aws.NewConfig().WithLogLevel(aws.LogDebugWithHTTPBody))

instances, err := listContainerInstances(ecsClient, *flagCluster, pageSize)
up := &updater{
cluster: *flagCluster,
ecs: ecs.New(sess, aws.NewConfig().WithLogLevel(aws.LogDebugWithHTTPBody)),
ssm: ssm.New(sess, aws.NewConfig().WithLogLevel(aws.LogDebugWithHTTPBody)),
}

listedInstances, err := up.listContainerInstances()
if err != nil {
return err
}

bottlerocketInstances, err := filterBottlerocketInstances(ecsClient, *flagCluster, instances)
bottlerocketInstanceMap, err := up.filterBottlerocketInstances(listedInstances)
if err != nil {
return err
}

if len(bottlerocketInstances) == 0 {
if len(bottlerocketInstanceMap) == 0 {
log.Printf("No Bottlerocket instances detected")
return nil
}

commandID, err := sendCommand(ssmClient, bottlerocketInstances, "apiclient update check")
// Make slice of Bottlerocket instances to use with SendCommand and checkCommandOutput
instances := make([]string, 0)
for instance, _ := range bottlerocketInstanceMap {
instances = append(instances, instance)
}

commandID, err := up.sendCommand(instances, "apiclient update check")
if err != nil {
return err
}

instancesToUpdate, err := checkCommandOutput(ssmClient, commandID, bottlerocketInstances)
instancesToUpdate, err := up.checkCommandOutput(commandID, instances)
if err != nil {
return err
}

if len(instancesToUpdate) == 0 {
log.Printf("No Instances to update")
return nil
}
fmt.Println("Instances ready for update: ", instancesToUpdate)

for ec2ID, containerARN := range bottlerocketInstanceMap {
err := up.checkAndDrain(containerARN)
if err != nil {
return err
}
log.Printf("Instance %s drained", ec2ID)
}
return nil
}

0 comments on commit e779565

Please sign in to comment.