diff --git a/.travis.yml b/.travis.yml index eb8b7b9..9fc00f2 100644 --- a/.travis.yml +++ b/.travis.yml @@ -5,5 +5,7 @@ go: before_install: - go get github.com/mattn/goveralls - go get golang.org/x/tools/cmd/cover + - go get honnef.co/go/tools/cmd/staticcheck script: + - $HOME/gopath/bin/staticcheck ./... - $HOME/gopath/bin/goveralls -service=travis-ci diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 0c98ddf..f279a28 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -12,10 +12,14 @@ Make sure your contribution passes the following validations: `golint ./...` -3. And new code must pass Go Vetting practices: +3. New code must pass Go Vetting practices: `go vet ./...` +4. And new code must pass [staticcheck](https://godoc.org/honnef.co/go/tools/cmd/staticcheck) checks: + + `staticcheck ./...` + I would like to keep this library simple, the proposed change must be a common use case. ## maintainer diff --git a/asg.go b/asg.go index 5020e98..6550fdd 100644 --- a/asg.go +++ b/asg.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/arn" "github.com/aws/aws-sdk-go/service/autoscaling" "github.com/aws/aws-sdk-go/service/autoscaling/autoscalingiface" "github.com/aws/aws-sdk-go/service/ec2" @@ -170,16 +171,42 @@ func drainingContainerInstanceIsDrained(ECSAPI ecs.ECS, clusterName, containerIn if err != nil { return err } + return findDrainingContainerInstance(output, containerInstanceID) +} + +func findDrainingContainerInstance(output *ecs.DescribeContainerInstancesOutput, containerInstanceID string) error { for _, containerInstance := range output.ContainerInstances { - if *containerInstance.Status != "DRAINING" { - return backoff.Permanent(errors.New("the instance should be DRAINING but is not")) + containerInstanceArn := *containerInstance.ContainerInstanceArn + parsedArn, err := arn.Parse(containerInstanceArn) + if err != nil { + return err + } + err = checkDrainingContainerInstance(containerInstance, parsedArn, containerInstanceID) + if err != nil { + return err + } + } + return backoff.Permanent(errors.New("container instance not found")) +} + +func checkDrainingContainerInstance(containerInstance *ecs.ContainerInstance, parsedArn arn.ARN, containerInstanceID string) error { + containerInstanceIDFound := strings.TrimPrefix(parsedArn.Resource, "container-instance/") + if containerInstanceIDFound == containerInstanceID { + if *containerInstance.Status != ecs.ContainerInstanceStatusDraining { + errorStringFormat := "the instance should be %s but is not" + errorString := fmt.Sprintf(errorStringFormat, ecs.ContainerInstanceStatusDraining) + permanentError := errors.New(errorString) + return backoff.Permanent(permanentError) } if *containerInstance.RunningTasksCount != 0 { - return errors.New("container instance still DRAINING") + errorStringFormat := "container instance still %s" + errorString := fmt.Sprintf(errorStringFormat, ecs.ContainerInstanceStatusDraining) + retryableError := errors.New(errorString) + return retryableError } return nil } - return backoff.Permanent(errors.New("container instance not found")) + return nil } func drainAll(ASAPI autoscaling.AutoScaling, ECSAPI ecs.ECS, EC2API ec2.EC2, instances []ecsEC2Instance, asgName, clusterName string) error { diff --git a/asg_test.go b/asg_test.go index c74b772..b372f4d 100644 --- a/asg_test.go +++ b/asg_test.go @@ -1,8 +1,11 @@ package awsecs import ( + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/arn" "github.com/aws/aws-sdk-go/service/autoscaling" "github.com/aws/aws-sdk-go/service/autoscaling/autoscalingiface" + "github.com/aws/aws-sdk-go/service/ecs" "testing" ) @@ -54,3 +57,76 @@ func TestFilterInstancesToReplace(t *testing.T) { t.Errorf("unexpected") } } + +func TestCheckDrainingContainerInstance(t *testing.T) { + type args struct { + containerInstance *ecs.ContainerInstance + parsedArn arn.ARN + containerInstanceID string + } + tests := []struct { + name string + wantErr bool + args args + }{ + { + name: "Found container instance ACTIVE", + wantErr: true, + args: args{ + containerInstance: &ecs.ContainerInstance{ + Status: aws.String(ecs.ContainerInstanceStatusActive), + }, + parsedArn: arn.ARN{ + Resource: "container-instance/container_instance_ID", + }, + containerInstanceID: "container_instance_ID", + }, + }, + { + name: "Found container instance DRAINING and running tasks", + wantErr: true, + args: args{ + containerInstance: &ecs.ContainerInstance{ + Status: aws.String(ecs.ContainerInstanceStatusDraining), + RunningTasksCount: aws.Int64(10), + }, + parsedArn: arn.ARN{ + Resource: "container-instance/container_instance_ID", + }, + containerInstanceID: "container_instance_ID", + }, + }, + { + name: "Found container instance DRAINING and no longer running tasks", + wantErr: false, + args: args{ + containerInstance: &ecs.ContainerInstance{ + Status: aws.String(ecs.ContainerInstanceStatusDraining), + RunningTasksCount: aws.Int64(0), + }, + parsedArn: arn.ARN{ + Resource: "container-instance/container_instance_ID", + }, + containerInstanceID: "container_instance_ID", + }, + }, + { + name: "Not matching container instance ID", + wantErr: false, + args: args{ + containerInstance: &ecs.ContainerInstance{}, + parsedArn: arn.ARN{ + Resource: "container-instance/another_container_instance_ID", + }, + containerInstanceID: "container_instance_ID", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := checkDrainingContainerInstance(tt.args.containerInstance, tt.args.parsedArn, tt.args.containerInstanceID); (err != nil) != tt.wantErr { + t.Errorf("checkDrainingContainerInstance() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/cmd/update-aws-ecs-service/flags_test.go b/cmd/update-aws-ecs-service/flags_test.go index 8e4ba09..19a837e 100644 --- a/cmd/update-aws-ecs-service/flags_test.go +++ b/cmd/update-aws-ecs-service/flags_test.go @@ -19,8 +19,7 @@ func TestMapMapMapFlag_Set(t *testing.T) { if err := actualStruct.Set("container2=fluentd=option1=value1"); err != nil { t.Fatal(err) } - var expectedStruct mapMapMapFlag - expectedStruct = map[string]map[string]map[string]string{ + var expectedStruct mapMapMapFlag = map[string]map[string]map[string]string{ "container1": { "awslogs": { "region": "us-west-2", diff --git a/ecs.go b/ecs.go index 54801d7..a138608 100644 --- a/ecs.go +++ b/ecs.go @@ -3,10 +3,12 @@ package awsecs import ( "errors" "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/arn" "github.com/aws/aws-sdk-go/service/ecs" "github.com/cenkalti/backoff" "log" "reflect" + "strings" ) var ( @@ -167,28 +169,38 @@ func copyTaskDef(api ecs.ECS, taskdef string, imageMap map[string]string, envMap return *arn, nil } +// TODO: add coverage func alterService(api ecs.ECS, cluster, service string, imageMap map[string]string, envMaps map[string]map[string]string, secretMaps map[string]map[string]string, logopts map[string]map[string]map[string]string, logsecrets map[string]map[string]map[string]string, taskRole string, desiredCount *int64, taskdef string) (ecs.Service, ecs.Service, error) { output, err := api.DescribeServices(&ecs.DescribeServicesInput{Cluster: aws.String(cluster), Services: []*string{aws.String(service)}}) if err != nil { return ecs.Service{}, ecs.Service{}, err } for _, svc := range output.Services { - srcTaskDef := svc.TaskDefinition - if taskdef != "" { - srcTaskDef = &taskdef - } - newTd, err := copyTaskDef(api, *srcTaskDef, imageMap, envMaps, secretMaps, logopts, logsecrets, taskRole) + clusterArn := *svc.ClusterArn + parsedClusterArn, err := arn.Parse(clusterArn) if err != nil { - return *svc, ecs.Service{}, err - } - if desiredCount == nil { - desiredCount = svc.DesiredCount + return ecs.Service{}, ecs.Service{}, err } - updated, err := api.UpdateService(&ecs.UpdateServiceInput{Cluster: aws.String(cluster), Service: aws.String(service), TaskDefinition: aws.String(newTd), DesiredCount: desiredCount, ForceNewDeployment: aws.Bool(true)}) - if err != nil { - return *svc, ecs.Service{}, err + clusterNameFound := strings.TrimPrefix(parsedClusterArn.Resource, "cluster/") + serviceNameFound := *svc.ServiceName + if clusterNameFound == cluster && serviceNameFound == service { + srcTaskDef := svc.TaskDefinition + if taskdef != "" { + srcTaskDef = &taskdef + } + newTd, err := copyTaskDef(api, *srcTaskDef, imageMap, envMaps, secretMaps, logopts, logsecrets, taskRole) + if err != nil { + return *svc, ecs.Service{}, err + } + if desiredCount == nil { + desiredCount = svc.DesiredCount + } + updated, err := api.UpdateService(&ecs.UpdateServiceInput{Cluster: aws.String(cluster), Service: aws.String(service), TaskDefinition: aws.String(newTd), DesiredCount: desiredCount, ForceNewDeployment: aws.Bool(true)}) + if err != nil { + return *svc, ecs.Service{}, err + } + return *svc, *updated.Service, nil } - return *svc, *updated.Service, nil } return ecs.Service{}, ecs.Service{}, ErrServiceNotFound }