Skip to content

Commit

Permalink
Merge pull request #34 from bottlerocket-os/SSMCommands
Browse files Browse the repository at this point in the history
Add SSM SendCommand and GetCommandInvocation
  • Loading branch information
WilboMo authored Apr 14, 2021
2 parents cc73002 + ef2553a commit 58bcd2b
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 47 deletions.
120 changes: 120 additions & 0 deletions updater/aws.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
package main

import (
"encoding/json"
"fmt"
"log"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/ecs"
"github.com/aws/aws-sdk-go/service/ssm"
)

const (
pageSize = 50
)

func listContainerInstances(ecsClient *ecs.ECS, cluster string, pageSize int64) ([]*string, error) {
resp, err := ecsClient.ListContainerInstances(&ecs.ListContainerInstancesInput{
Cluster: &cluster,
MaxResults: aws.Int64(pageSize),
})
if err != nil {
return nil, fmt.Errorf("cannot list container instances: %#v", err)
}
log.Printf("%#v", resp)
return resp.ContainerInstanceArns, nil
}

func filterBottlerocketInstances(ecsClient *ecs.ECS, cluster string, instances []*string) ([]*string, error) {
resp, err := ecsClient.DescribeContainerInstances(&ecs.DescribeContainerInstancesInput{
Cluster: &cluster, ContainerInstances: instances,
})
if err != nil {
return nil, fmt.Errorf("cannot describe container instances: %#v", err)
}

log.Printf("Container descriptions: %#v", resp)

ec2IDs := make([]*string, 0)
// Check the DescribeInstances response for Bottlerocket container instances, add them to ec2ids if detected
for _, instance := range resp.ContainerInstances {
if containsAttribute(instance.Attributes, "bottlerocket.variant") {
ec2IDs = append(ec2IDs, instance.Ec2InstanceId)
log.Printf("Bottlerocket instance detected. Instance %#v added to check updates", *instance.Ec2InstanceId)
}
}
return ec2IDs, nil
}

// checks if ECS Attributes struct contains a specified string
func containsAttribute(attrs []*ecs.Attribute, searchString string) bool {
for _, attr := range attrs {
if aws.StringValue(attr.Name) == searchString {
return true
}
}
return false
}

func sendCommand(ssmClient *ssm.SSM, instanceIDs []*string, ssmCommand string) (string, error) {
log.Printf("Checking InstanceIDs: %#v", instanceIDs)

resp, err := ssmClient.SendCommand(&ssm.SendCommandInput{
DocumentName: aws.String("AWS-RunShellScript"),
DocumentVersion: aws.String("$DEFAULT"),
InstanceIds: instanceIDs,
Parameters: map[string][]*string{
"commands": {aws.String(ssmCommand)},
},
})
if err != nil {
return "", fmt.Errorf("command invocation failed: %#v", err)
}

commandID := *resp.Command.CommandId
// Wait for the sent commands to complete
// TODO Update this to use WaitGroups
for _, v := range instanceIDs {
ssmClient.WaitUntilCommandExecuted(&ssm.GetCommandInvocationInput{
CommandId: &commandID,
InstanceId: v,
})
}
log.Printf("CommandID: %#v", commandID)
return commandID, nil
}

func checkCommandOutput(ssmClient *ssm.SSM, commandID string, instanceIDs []*string) ([]string, error) {
updateCandidates := make([]string, 0)
for _, v := range instanceIDs {
resp, err := ssmClient.GetCommandInvocation(&ssm.GetCommandInvocationInput{
CommandId: aws.String(commandID),
InstanceId: v,
})
if err != nil {
return nil, fmt.Errorf("failed to retreive command invocation output: %#v", err)
}

type updateCheckResult struct {
UpdateState string `json:"update_state"`
}

var result updateCheckResult
err = json.Unmarshal([]byte(*resp.StandardOutputContent), &result)
if err != nil {
log.Printf("failed to unmarshal command invocation output: %#v", err)
}
log.Println("update_state: ", result)

switch result.UpdateState {
case "Available":
updateCandidates = append(updateCandidates, *v)
}
}

if updateCandidates == nil {
log.Printf("No instances to update")
}
return updateCandidates, nil
}
69 changes: 22 additions & 47 deletions updater/main.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"errors"
"flag"
"fmt"
"log"
Expand All @@ -9,6 +10,7 @@ import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/ecs"
"github.com/aws/aws-sdk-go/service/ssm"
)

var (
Expand All @@ -17,81 +19,54 @@ var (
)

func main() {
if err := _main(); err != nil {
fmt.Fprintf(os.Stderr, err.Error())
os.Exit(1)
}
}

func _main() error {
flag.Parse()
switch {
case *flagCluster == "":
log.Println("cluster is required")
flag.Usage()
os.Exit(1)
return errors.New("cluster is required")
case *flagRegion == "":
log.Println("region is required")
flag.Usage()
os.Exit(1)
return errors.New("region is required")
}

sess := session.Must(session.NewSession(&aws.Config{
Region: aws.String(*flagRegion),
}))
ecsClient := ecs.New(sess)
ssmClient := ssm.New(sess, aws.NewConfig().WithLogLevel(aws.LogDebugWithHTTPBody))

instances, err := listContainerInstances(ecsClient, *flagCluster)
instances, err := listContainerInstances(ecsClient, *flagCluster, pageSize)
if err != nil {
fmt.Fprintf(os.Stderr, err.Error())
os.Exit(1)
return err
}

bottlerocketInstances, err := filterBottlerocketInstances(ecsClient, *flagCluster, instances)
if err != nil {
fmt.Fprintf(os.Stderr, err.Error())
os.Exit(1)
return err
}

if len(bottlerocketInstances) == 0 {
log.Printf("No Bottlerocket instances detected")
return nil
}
}

func listContainerInstances(ecsClient *ecs.ECS, cluster string) ([]*string, error) {
resp, err := ecsClient.ListContainerInstances(&ecs.ListContainerInstancesInput{Cluster: &cluster})
commandID, err := sendCommand(ssmClient, bottlerocketInstances, "apiclient update check")
if err != nil {
return nil, fmt.Errorf("Cannot list container instances: %#v", err)
return err
}
log.Printf("%#v", resp)
var values []*string

for _, v := range resp.ContainerInstanceArns {
values = append(values, v)
}
return values, nil
}

func filterBottlerocketInstances(ecsClient *ecs.ECS, cluster string, instances []*string) ([]string, error) {
resp, err := ecsClient.DescribeContainerInstances(&ecs.DescribeContainerInstancesInput{
Cluster: &cluster, ContainerInstances: instances,
})
instancesToUpdate, err := checkCommandOutput(ssmClient, commandID, bottlerocketInstances)
if err != nil {
return nil, fmt.Errorf("Cannot describe container instances: %#v", err)
return err
}
log.Printf("Container descriptions: %#v", resp)

var ec2IDs []string

//Check the DescribeInstances response for Bottlerocket nodes, add them to ec2ids if detected
for _, instance := range resp.ContainerInstances {
if containsAttribute(instance.Attributes, "bottlerocket.variant") {
ec2IDs = append(ec2IDs, *instance.Ec2InstanceId)
log.Printf("Bottlerocket instance detected. Instance %#v added to check updates", *instance.Ec2InstanceId)
}
}
return ec2IDs, nil
}

// checks if ECS Attributes struct contains a specified string
func containsAttribute(attrs []*ecs.Attribute, searchString string) bool {
for _, attr := range attrs {
if aws.StringValue(attr.Name) == searchString {
return true
}
}
return false
fmt.Println("Instances ready for update: ", instancesToUpdate)
return nil
}

0 comments on commit 58bcd2b

Please sign in to comment.