From bd7db8ef2a5873493303100104025aeda116a4a2 Mon Sep 17 00:00:00 2001 From: andresvia Date: Sun, 18 Oct 2020 16:22:21 -0700 Subject: [PATCH] More unmangling. More working on #16 and #18 --- README.md | 2 + asg.go | 22 ++++---- asg_test.go | 2 +- ecs.go | 76 ++++++++++++++++--------- ecs_test.go | 160 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 223 insertions(+), 39 deletions(-) create mode 100644 ecs_test.go diff --git a/README.md b/README.md index f906aff..208b5b8 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,7 @@ # go-awsecs +[![godoc reference](http://img.shields.io/badge/godoc-reference-blue.svg)](https://pkg.go.dev/github.com/Autodesk/go-awsecs) + [![travis ci](https://api.travis-ci.org/Autodesk/go-awsecs.svg?branch=master)](https://travis-ci.org/Autodesk/go-awsecs) [![coverage status](https://coveralls.io/repos/github/Autodesk/go-awsecs/badge.svg?branch=master)](https://coveralls.io/github/Autodesk/go-awsecs?branch=master) diff --git a/asg.go b/asg.go index 6550fdd..ba68b46 100644 --- a/asg.go +++ b/asg.go @@ -175,18 +175,16 @@ func drainingContainerInstanceIsDrained(ECSAPI ecs.ECS, clusterName, containerIn } func findDrainingContainerInstance(output *ecs.DescribeContainerInstancesOutput, containerInstanceID string) error { - for _, containerInstance := range output.ContainerInstances { - containerInstanceArn := *containerInstance.ContainerInstanceArn - parsedArn, err := arn.Parse(containerInstanceArn) - if err != nil { - return err - } - err = checkDrainingContainerInstance(containerInstance, parsedArn, containerInstanceID) - if err != nil { - return err - } + if len(output.ContainerInstances) == 0 { + return ErrContainerInstanceNotFound } - return backoff.Permanent(errors.New("container instance not found")) + containerInstance := output.ContainerInstances[0] + containerInstanceArn := *containerInstance.ContainerInstanceArn + parsedArn, err := arn.Parse(containerInstanceArn) + if err != nil { + return err + } + return checkDrainingContainerInstance(containerInstance, parsedArn, containerInstanceID) } func checkDrainingContainerInstance(containerInstance *ecs.ContainerInstance, parsedArn arn.ARN, containerInstanceID string) error { @@ -206,7 +204,7 @@ func checkDrainingContainerInstance(containerInstance *ecs.ContainerInstance, pa } return nil } - return nil + return ErrContainerInstanceNotFound } func drainAll(ASAPI autoscaling.AutoScaling, ECSAPI ecs.ECS, EC2API ec2.EC2, instances []ecsEC2Instance, asgName, clusterName string) error { diff --git a/asg_test.go b/asg_test.go index b372f4d..34c91b4 100644 --- a/asg_test.go +++ b/asg_test.go @@ -112,7 +112,7 @@ func TestCheckDrainingContainerInstance(t *testing.T) { }, { name: "Not matching container instance ID", - wantErr: false, + wantErr: true, args: args{ containerInstance: &ecs.ContainerInstance{}, parsedArn: arn.ARN{ diff --git a/ecs.go b/ecs.go index a138608..62429b4 100644 --- a/ecs.go +++ b/ecs.go @@ -24,6 +24,8 @@ var ( ErrServiceNotFound = errors.New("the service does not exist") // ErrServiceDeletedAfterUpdate service was updated and then deleted elsewhere ErrServiceDeletedAfterUpdate = backoff.Permanent(errors.New("the service was deleted after the update")) + // ErrContainerInstanceNotFound the container instance was removed from the cluster elsewhere + ErrContainerInstanceNotFound = backoff.Permanent(errors.New("container instance not found")) ) var ( @@ -165,42 +167,64 @@ func copyTaskDef(api ecs.ECS, taskdef string, imageMap map[string]string, envMap if err != nil { return "", err } - arn := tdNew.TaskDefinition.TaskDefinitionArn - return *arn, nil + taskDefinitionArn := tdNew.TaskDefinition.TaskDefinitionArn + return *taskDefinitionArn, nil } -// TODO: add coverage func alterService(api ecs.ECS, cluster, service string, imageMap map[string]string, envMaps map[string]map[string]string, secretMaps map[string]map[string]string, logopts map[string]map[string]map[string]string, logsecrets map[string]map[string]map[string]string, taskRole string, desiredCount *int64, taskdef string) (ecs.Service, ecs.Service, error) { output, err := api.DescribeServices(&ecs.DescribeServicesInput{Cluster: aws.String(cluster), Services: []*string{aws.String(service)}}) if err != nil { return ecs.Service{}, ecs.Service{}, err } - for _, svc := range output.Services { - clusterArn := *svc.ClusterArn - parsedClusterArn, err := arn.Parse(clusterArn) + copyTaskDefinitionAction := func(sourceTaskDefinition string) (string, error) { + return copyTaskDef(api, sourceTaskDefinition, imageMap, envMaps, secretMaps, logopts, logsecrets, taskRole) + } + updateAction := func(newTaskDefinition *string, desiredCount *int64) (*ecs.UpdateServiceOutput, error) { + updateServiceInput := &ecs.UpdateServiceInput{ + Cluster: aws.String(cluster), + Service: aws.String(service), + TaskDefinition: newTaskDefinition, + DesiredCount: desiredCount, + ForceNewDeployment: aws.Bool(true), + } + return api.UpdateService(updateServiceInput) + } + return findAndUpdateService(output, cluster, service, taskdef, desiredCount, copyTaskDefinitionAction, updateAction) +} + +func findAndUpdateService(output *ecs.DescribeServicesOutput, cluster, service, taskDefinition string, desiredCount *int64, copyTdAction func(string) (string, error), updateSvcAction func(*string, *int64) (*ecs.UpdateServiceOutput, error)) (ecs.Service, ecs.Service, error) { + if len(output.Services) == 0 { + return ecs.Service{}, ecs.Service{}, ErrServiceNotFound + } + svc := output.Services[0] + clusterArn := *svc.ClusterArn + parsedClusterArn, err := arn.Parse(clusterArn) + if err != nil { + return ecs.Service{}, ecs.Service{}, err + } + return updateService(parsedClusterArn, svc, cluster, service, taskDefinition, desiredCount, copyTdAction, updateSvcAction) +} + +func updateService(parsedClusterArn arn.ARN, svc *ecs.Service, cluster, service, td string, desiredCount *int64, copyTdAction func(string) (string, error), updateSvcAction func(*string, *int64) (*ecs.UpdateServiceOutput, error)) (ecs.Service, ecs.Service, error) { + clusterNameFound := strings.TrimPrefix(parsedClusterArn.Resource, "cluster/") + serviceNameFound := *svc.ServiceName + if clusterNameFound == cluster && serviceNameFound == service { + srcTaskDef := svc.TaskDefinition + if td != "" { + srcTaskDef = &td + } + newTd, err := copyTdAction(*srcTaskDef) if err != nil { - return ecs.Service{}, ecs.Service{}, err + return *svc, ecs.Service{}, err } - clusterNameFound := strings.TrimPrefix(parsedClusterArn.Resource, "cluster/") - serviceNameFound := *svc.ServiceName - if clusterNameFound == cluster && serviceNameFound == service { - srcTaskDef := svc.TaskDefinition - if taskdef != "" { - srcTaskDef = &taskdef - } - newTd, err := copyTaskDef(api, *srcTaskDef, imageMap, envMaps, secretMaps, logopts, logsecrets, taskRole) - if err != nil { - return *svc, ecs.Service{}, err - } - if desiredCount == nil { - desiredCount = svc.DesiredCount - } - updated, err := api.UpdateService(&ecs.UpdateServiceInput{Cluster: aws.String(cluster), Service: aws.String(service), TaskDefinition: aws.String(newTd), DesiredCount: desiredCount, ForceNewDeployment: aws.Bool(true)}) - if err != nil { - return *svc, ecs.Service{}, err - } - return *svc, *updated.Service, nil + if desiredCount == nil { + desiredCount = svc.DesiredCount + } + updated, err := updateSvcAction(aws.String(newTd), desiredCount) + if err != nil { + return *svc, ecs.Service{}, err } + return *svc, *updated.Service, nil } return ecs.Service{}, ecs.Service{}, ErrServiceNotFound } diff --git a/ecs_test.go b/ecs_test.go new file mode 100644 index 0000000..ad16bf6 --- /dev/null +++ b/ecs_test.go @@ -0,0 +1,160 @@ +package awsecs + +import ( + "errors" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/arn" + "github.com/aws/aws-sdk-go/service/ecs" + "reflect" + "testing" +) + +func TestUpdateService(t *testing.T) { + type args struct { + parsedClusterArn arn.ARN + svc *ecs.Service + cluster string + service string + td string + desiredCount *int64 + copyTdAction func(string) (string, error) + updateSvcAction func(*string, *int64) (*ecs.UpdateServiceOutput, error) + } + tests := []struct { + name string + wantErr bool + beforeUpdate ecs.Service + afterUpdate ecs.Service + args args + }{ + { + name: "On copy error I want error", + wantErr: true, + beforeUpdate: ecs.Service{ + ServiceName: aws.String("my-service"), + }, + afterUpdate: ecs.Service{}, + args: args{ + parsedClusterArn: arn.ARN{ + Resource: "cluster/my-cluster", + }, + svc: &ecs.Service{ + ServiceName: aws.String("my-service"), + }, + cluster: "my-cluster", + service: "my-service", + td: "task:1", + desiredCount: aws.Int64(1), + copyTdAction: func(string) (string, error) { + return "", errors.New("failed to copy") + }, + updateSvcAction: nil, + }, + }, + { + name: "On update error I want error", + wantErr: true, + beforeUpdate: ecs.Service{ + ServiceName: aws.String("my-service"), + }, + afterUpdate: ecs.Service{}, + args: args{ + parsedClusterArn: arn.ARN{ + Resource: "cluster/my-cluster", + }, + svc: &ecs.Service{ + ServiceName: aws.String("my-service"), + }, + cluster: "my-cluster", + service: "my-service", + td: "task:1", + desiredCount: aws.Int64(1), + copyTdAction: func(string) (string, error) { + return "task:2", nil + }, + updateSvcAction: func(*string, *int64) (*ecs.UpdateServiceOutput, error) { + return nil, errors.New("failed to update") + }, + }, + }, + { + name: "On non matching cluster I want error", + wantErr: true, + beforeUpdate: ecs.Service{}, + afterUpdate: ecs.Service{}, + args: args{ + parsedClusterArn: arn.ARN{ + Resource: "cluster/my-cluster", + }, + svc: &ecs.Service{ + ServiceName: aws.String("my-service"), + }, + cluster: "my-other-cluster", + service: "my-service", + }, + }, + { + name: "On non matching service I want error", + wantErr: true, + beforeUpdate: ecs.Service{}, + afterUpdate: ecs.Service{}, + args: args{ + parsedClusterArn: arn.ARN{ + Resource: "cluster/my-cluster", + }, + svc: &ecs.Service{ + ServiceName: aws.String("my-service"), + }, + cluster: "my-cluster", + service: "my-other-service", + }, + }, + { + name: "Check before and after update", + wantErr: false, + beforeUpdate: ecs.Service{ + ServiceName: aws.String("my-service"), + }, + afterUpdate: ecs.Service{ + TaskDefinition: aws.String("task:2"), + }, + args: args{ + parsedClusterArn: arn.ARN{ + Resource: "cluster/my-cluster", + }, + svc: &ecs.Service{ + ServiceName: aws.String("my-service"), + }, + cluster: "my-cluster", + service: "my-service", + td: "task:1", + desiredCount: nil, + copyTdAction: func(s string) (string, error) { + return "task:2", nil + }, + updateSvcAction: func(s *string, i *int64) (*ecs.UpdateServiceOutput, error) { + return &ecs.UpdateServiceOutput{ + Service: &ecs.Service{ + TaskDefinition: aws.String("task:2"), + }, + }, nil + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, got1, err := updateService(tt.args.parsedClusterArn, tt.args.svc, tt.args.cluster, tt.args.service, tt.args.td, tt.args.desiredCount, tt.args.copyTdAction, tt.args.updateSvcAction) + if (err != nil) != tt.wantErr { + t.Errorf("updateService() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.beforeUpdate) { + t.Errorf("updateService() got = %v, want %v", got, tt.beforeUpdate) + } + if !reflect.DeepEqual(got1, tt.afterUpdate) { + t.Errorf("updateService() got1 = %v, want %v", got1, tt.afterUpdate) + } + }) + } +}