Skip to content

Commit

Permalink
Merge pull request #35 from bottlerocket-os/DrainInstances
Browse files Browse the repository at this point in the history
Drain instances
  • Loading branch information
WilboMo authored Apr 23, 2021
2 parents 58bcd2b + c4d0474 commit bcb7d3e
Show file tree
Hide file tree
Showing 2 changed files with 181 additions and 36 deletions.
170 changes: 143 additions & 27 deletions updater/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ package main

import (
"encoding/json"
"errors"
"fmt"
"log"
"strings"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/ecs"
Expand All @@ -14,9 +16,9 @@ const (
pageSize = 50
)

func listContainerInstances(ecsClient *ecs.ECS, cluster string, pageSize int64) ([]*string, error) {
resp, err := ecsClient.ListContainerInstances(&ecs.ListContainerInstancesInput{
Cluster: &cluster,
func (u *updater) listContainerInstances() ([]*string, error) {
resp, err := u.ecs.ListContainerInstances(&ecs.ListContainerInstancesInput{
Cluster: &u.cluster,
MaxResults: aws.Int64(pageSize),
})
if err != nil {
Expand All @@ -26,28 +28,29 @@ func listContainerInstances(ecsClient *ecs.ECS, cluster string, pageSize int64)
return resp.ContainerInstanceArns, nil
}

func filterBottlerocketInstances(ecsClient *ecs.ECS, cluster string, instances []*string) ([]*string, error) {
resp, err := ecsClient.DescribeContainerInstances(&ecs.DescribeContainerInstancesInput{
Cluster: &cluster, ContainerInstances: instances,
// filterBottlerocketInstances returns a map of EC2 instance IDs to container instance ARNs
// provided as input where the container instance is a Bottlerocket host.
func (u *updater) filterBottlerocketInstances(instances []*string) (map[string]string, error) {
resp, err := u.ecs.DescribeContainerInstances(&ecs.DescribeContainerInstancesInput{
Cluster: &u.cluster,
ContainerInstances: instances,
})
if err != nil {
return nil, fmt.Errorf("cannot describe container instances: %#v", err)
}

log.Printf("Container descriptions: %#v", resp)

ec2IDs := make([]*string, 0)
// Check the DescribeInstances response for Bottlerocket container instances, add them to ec2ids if detected
ec2IDtoECSARN := make(map[string]string)
// Check the DescribeInstances response for Bottlerocket nodes, add them to map if detected.
for _, instance := range resp.ContainerInstances {
if containsAttribute(instance.Attributes, "bottlerocket.variant") {
ec2IDs = append(ec2IDs, instance.Ec2InstanceId)
log.Printf("Bottlerocket instance detected. Instance %#v added to check updates", *instance.Ec2InstanceId)
ec2IDtoECSARN[aws.StringValue(instance.Ec2InstanceId)] = aws.StringValue(instance.ContainerInstanceArn)
log.Printf("Bottlerocket instance detected. Instance %s added to check updates", aws.StringValue(instance.Ec2InstanceId))
}
}
return ec2IDs, nil
return ec2IDtoECSARN, nil
}

// checks if ECS Attributes struct contains a specified string
// containsAttribute checks if a slice of ECS Attributes struct contains a specified name.
func containsAttribute(attrs []*ecs.Attribute, searchString string) bool {
for _, attr := range attrs {
if aws.StringValue(attr.Name) == searchString {
Expand All @@ -57,13 +60,125 @@ func containsAttribute(attrs []*ecs.Attribute, searchString string) bool {
return false
}

func sendCommand(ssmClient *ssm.SSM, instanceIDs []*string, ssmCommand string) (string, error) {
log.Printf("Checking InstanceIDs: %#v", instanceIDs)
// drain drains eligible container instances. Container instances are eligible if all running
// tasks were started by a service, or if there are no running tasks.
func (u *updater) drain(containerInstance string) error {
if !u.eligible(&containerInstance) {
return errors.New("ineligible for updates")
}
return u.drainInstance(aws.String(containerInstance))
}

func (u *updater) eligible(containerInstance *string) bool {
list, err := u.ecs.ListTasks(&ecs.ListTasksInput{
Cluster: &u.cluster,
ContainerInstance: containerInstance,
})
if err != nil {
log.Printf("failed to list tasks for container instance %s: %#v",
aws.StringValue(containerInstance), err)
return false
}

taskARNs := list.TaskArns
if len(list.TaskArns) == 0 {
return true
}

desc, err := u.ecs.DescribeTasks(&ecs.DescribeTasksInput{
Cluster: &u.cluster,
Tasks: taskARNs,
})
if err != nil {
log.Printf("Could not describe tasks")
return false
}

for _, listResult := range desc.Tasks {
startedBy := aws.StringValue(listResult.StartedBy)
if !strings.HasPrefix(startedBy, "ecs-svc/") {
return false
}
}
return true
}

func (u *updater) drainInstance(containerInstance *string) error {
resp, err := u.ecs.UpdateContainerInstancesState(&ecs.UpdateContainerInstancesStateInput{
Cluster: &u.cluster,
ContainerInstances: []*string{containerInstance},
Status: aws.String("DRAINING"),
})
if err != nil {
log.Printf("failed to update container instance %s state to DRAINING: %#v", aws.StringValue(containerInstance), err)
return err
}
if len(resp.Failures) != 0 {
err = u.activateInstance(containerInstance)
if err != nil {
log.Printf("instance failed to reactivate after failing to drain: %#v", err)
}
return fmt.Errorf("Container instance %s failed to drain: %#v", aws.StringValue(containerInstance), resp.Failures)
}
log.Printf("Container instance state changed to DRAINING")

err = u.waitUntilDrained(aws.StringValue(containerInstance))
if err != nil {
err2 := u.activateInstance(containerInstance)
if err2 != nil {
log.Printf("failed to reactivate %s: %s", aws.StringValue(containerInstance), err2.Error())
}
return err
}
return nil
}

func (u *updater) activateInstance(containerInstance *string) error {
resp, err := u.ecs.UpdateContainerInstancesState(&ecs.UpdateContainerInstancesStateInput{
Cluster: &u.cluster,
ContainerInstances: []*string{containerInstance},
Status: aws.String("ACTIVE"),
})
if err != nil {
log.Printf("failed to update container %s instance state to ACTIVE: %#v", aws.StringValue(containerInstance), err)
return err
}
if len(resp.Failures) != 0 {
return fmt.Errorf("Container instance %s failed to activate: %#v", aws.StringValue(containerInstance), resp.Failures)
}
log.Printf("Container instance state changed to ACTIVE")
return nil
}

func (u *updater) waitUntilDrained(containerInstance string) error {
list, err := u.ecs.ListTasks(&ecs.ListTasksInput{
Cluster: &u.cluster,
ContainerInstance: aws.String(containerInstance),
})
if err != nil {
log.Printf("failed to identify a task to wait on")
return err
}

taskARNs := list.TaskArns

if len(taskARNs) == 0 {
return nil
}
// TODO Tune MaxAttempts
return u.ecs.WaitUntilTasksStopped(&ecs.DescribeTasksInput{
Cluster: &u.cluster,
Tasks: taskARNs,
})
}

func (u *updater) sendCommand(instanceIDs []string, ssmCommand string) (string, error) {
log.Printf("Checking InstanceIDs: %q", instanceIDs)

resp, err := ssmClient.SendCommand(&ssm.SendCommandInput{
resp, err := u.ssm.SendCommand(&ssm.SendCommandInput{
DocumentName: aws.String("AWS-RunShellScript"),
DocumentVersion: aws.String("$DEFAULT"),
InstanceIds: instanceIDs,
InstanceIds: aws.StringSlice(instanceIDs),
Parameters: map[string][]*string{
"commands": {aws.String(ssmCommand)},
},
Expand All @@ -73,24 +188,24 @@ func sendCommand(ssmClient *ssm.SSM, instanceIDs []*string, ssmCommand string) (
}

commandID := *resp.Command.CommandId
// Wait for the sent commands to complete
// Wait for the sent commands to complete.
// TODO Update this to use WaitGroups
for _, v := range instanceIDs {
ssmClient.WaitUntilCommandExecuted(&ssm.GetCommandInvocationInput{
u.ssm.WaitUntilCommandExecuted(&ssm.GetCommandInvocationInput{
CommandId: &commandID,
InstanceId: v,
InstanceId: &v,
})
}
log.Printf("CommandID: %#v", commandID)
log.Printf("CommandID: %s", commandID)
return commandID, nil
}

func checkCommandOutput(ssmClient *ssm.SSM, commandID string, instanceIDs []*string) ([]string, error) {
func (u *updater) checkSSMCommandOutput(commandID string, instanceIDs []string) ([]string, error) {
updateCandidates := make([]string, 0)
for _, v := range instanceIDs {
resp, err := ssmClient.GetCommandInvocation(&ssm.GetCommandInvocationInput{
resp, err := u.ssm.GetCommandInvocation(&ssm.GetCommandInvocationInput{
CommandId: aws.String(commandID),
InstanceId: v,
InstanceId: aws.String(v),
})
if err != nil {
return nil, fmt.Errorf("failed to retreive command invocation output: %#v", err)
Expand All @@ -109,12 +224,13 @@ func checkCommandOutput(ssmClient *ssm.SSM, commandID string, instanceIDs []*str

switch result.UpdateState {
case "Available":
updateCandidates = append(updateCandidates, *v)
updateCandidates = append(updateCandidates, v)
}
}

if updateCandidates == nil {
if len(updateCandidates) == 0 {
log.Printf("No instances to update")
return nil, nil
}
return updateCandidates, nil
}
47 changes: 38 additions & 9 deletions updater/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,15 @@ var (
flagRegion = flag.String("region", "", "The AWS Region in which cluster is running.")
)

type updater struct {
cluster string
ecs *ecs.ECS
ssm *ssm.SSM
}

func main() {
if err := _main(); err != nil {
fmt.Fprintf(os.Stderr, err.Error())
log.Println(err.Error())
os.Exit(1)
}
}
Expand All @@ -39,34 +45,57 @@ func _main() error {
sess := session.Must(session.NewSession(&aws.Config{
Region: aws.String(*flagRegion),
}))
ecsClient := ecs.New(sess)
ssmClient := ssm.New(sess, aws.NewConfig().WithLogLevel(aws.LogDebugWithHTTPBody))

instances, err := listContainerInstances(ecsClient, *flagCluster, pageSize)
u := &updater{
cluster: *flagCluster,
ecs: ecs.New(sess, aws.NewConfig().WithLogLevel(aws.LogDebugWithHTTPBody)),
ssm: ssm.New(sess, aws.NewConfig().WithLogLevel(aws.LogDebugWithHTTPBody)),
}

listedInstances, err := u.listContainerInstances()
if err != nil {
return err
}

bottlerocketInstances, err := filterBottlerocketInstances(ecsClient, *flagCluster, instances)
ec2IDtoECSARN, err := u.filterBottlerocketInstances(listedInstances)
if err != nil {
return err
}

if len(bottlerocketInstances) == 0 {
if len(ec2IDtoECSARN) == 0 {
log.Printf("No Bottlerocket instances detected")
return nil
}

commandID, err := sendCommand(ssmClient, bottlerocketInstances, "apiclient update check")
// Make slice of Bottlerocket instances to use with SendCommand and checkCommandOutput
instances := make([]string, 0)
for instance, _ := range ec2IDtoECSARN {
instances = append(instances, instance)
}

commandID, err := u.sendCommand(instances, "apiclient update check")
if err != nil {
return err
}

instancesToUpdate, err := checkCommandOutput(ssmClient, commandID, bottlerocketInstances)
candidates, err := u.checkSSMCommandOutput(commandID, instances)
if err != nil {
return err
}

fmt.Println("Instances ready for update: ", instancesToUpdate)
if len(candidates) == 0 {
log.Printf("No instances to update")
return nil
}
fmt.Println("Instances ready for update: ", candidates)

for ec2ID, containerInstance := range ec2IDtoECSARN {
err := u.drain(containerInstance)
if err != nil {
log.Printf("%#v", err)
continue
}
log.Printf("Instance %s drained", ec2ID)
}
return nil
}

0 comments on commit bcb7d3e

Please sign in to comment.