diff --git a/ecs-agent/api/container/restart/restart_tracker.go b/ecs-agent/api/container/restart/restart_tracker.go new file mode 100644 index 00000000000..15b891dccce --- /dev/null +++ b/ecs-agent/api/container/restart/restart_tracker.go @@ -0,0 +1,96 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package restart + +import ( + "fmt" + "sync" + "time" + + apicontainerstatus "github.com/aws/amazon-ecs-agent/ecs-agent/api/container/status" +) + +type RestartTracker struct { + RestartCount int `json:"restartCount,omitempty"` + LastStartedAt time.Time `json:"lastStartedAt,omitempty"` + restartPolicy RestartPolicy + lock sync.RWMutex +} + +// RestartPolicy represents a policy that contains key information considered when +// deciding whether or not a container should be restarted after it has exited. +type RestartPolicy struct { + Enabled bool `json:"enabled"` + IgnoredExitCodes []int `json:"ignoredExitCodes"` + RestartAttemptPeriod time.Duration `json:"restartAttemptPeriod"` +} + +func NewRestartTracker(restartPolicy RestartPolicy) *RestartTracker { + return &RestartTracker{ + restartPolicy: restartPolicy, + } +} + +func (rt *RestartTracker) GetRestartEnabled() bool { + rt.lock.RLock() + defer rt.lock.RUnlock() + return rt.restartPolicy.Enabled +} + +func (rt *RestartTracker) GetLastStartedAt() time.Time { + rt.lock.RLock() + defer rt.lock.RUnlock() + return rt.LastStartedAt +} + +func (rt *RestartTracker) GetRestartCount() int { + rt.lock.RLock() + defer rt.lock.RUnlock() + return rt.RestartCount +} + +// RecordRestart updates the restart tracker's metadata after a restart has occurred. +// This metadata is used to calculate when restarts should occur and track how many +// have occurred. It is not the job of this method to determine if a restart should +// occur or restart the container. It is expected to receive a startedAt time from the container runtime. +func (rt *RestartTracker) RecordRestart(startedAt time.Time) { + rt.lock.RLock() + defer rt.lock.RUnlock() + rt.RestartCount++ + rt.LastStartedAt = startedAt +} + +// ShouldRestart returns whether the container should restart and a reason string +// explaining why not. +func (rt *RestartTracker) ShouldRestart(exitCode *int, startedAt time.Time, + desiredStatus apicontainerstatus.ContainerStatus) (bool, string) { + if !rt.restartPolicy.Enabled { + return false, "restart policy is not enabled" + } + if desiredStatus == apicontainerstatus.ContainerStopped { + return false, "container's desired status is stopped" + } + if exitCode == nil { + return false, "exit code is nil" + } + for _, ignoredCode := range rt.restartPolicy.IgnoredExitCodes { + if ignoredCode == *exitCode { + return false, fmt.Sprintf("exit code %d should be ignored", *exitCode) + } + } + if time.Since(startedAt) < rt.restartPolicy.RestartAttemptPeriod { + return false, "attempt reset period has not elapsed" + } + return true, "" +} diff --git a/ecs-agent/api/container/restart/restart_tracker_test.go b/ecs-agent/api/container/restart/restart_tracker_test.go new file mode 100644 index 00000000000..8c0028556b7 --- /dev/null +++ b/ecs-agent/api/container/restart/restart_tracker_test.go @@ -0,0 +1,159 @@ +//go:build unit +// +build unit + +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package restart + +import ( + "testing" + "time" + + apicontainerstatus "github.com/aws/amazon-ecs-agent/ecs-agent/api/container/status" + + "github.com/stretchr/testify/assert" +) + +func TestShouldRestart(t *testing.T) { + ignoredCode := 0 + rt := NewRestartTracker(RestartPolicy{ + Enabled: false, + IgnoredExitCodes: []int{ignoredCode}, + RestartAttemptPeriod: 60 * time.Second, + }) + testCases := []struct { + name string + rp RestartPolicy + exitCode int + startedAt time.Time + desiredStatus apicontainerstatus.ContainerStatus + expected bool + expectedReason string + }{ + { + name: "restart policy disabled", + rp: RestartPolicy{ + Enabled: false, + IgnoredExitCodes: []int{ignoredCode}, + RestartAttemptPeriod: 60 * time.Second, + }, + exitCode: 1, + startedAt: time.Now().Add(2 * time.Minute), + desiredStatus: apicontainerstatus.ContainerRunning, + expected: false, + expectedReason: "restart policy is not enabled", + }, + { + name: "ignored exit code", + rp: RestartPolicy{ + Enabled: true, + IgnoredExitCodes: []int{ignoredCode}, + RestartAttemptPeriod: 60 * time.Second, + }, + exitCode: 0, + startedAt: time.Now().Add(-2 * time.Minute), + desiredStatus: apicontainerstatus.ContainerRunning, + expected: false, + expectedReason: "exit code 0 should be ignored", + }, + { + name: "non ignored exit code", + rp: RestartPolicy{Enabled: true, IgnoredExitCodes: []int{ignoredCode}, RestartAttemptPeriod: 60 * time.Second}, + exitCode: 1, + startedAt: time.Now().Add(-2 * time.Minute), + desiredStatus: apicontainerstatus.ContainerRunning, + expected: true, + expectedReason: "", + }, + { + name: "nil exit code", + rp: RestartPolicy{Enabled: true, IgnoredExitCodes: []int{ignoredCode}, RestartAttemptPeriod: 60 * time.Second}, + exitCode: -1, + startedAt: time.Now().Add(-2 * time.Minute), + desiredStatus: apicontainerstatus.ContainerRunning, + expected: false, + expectedReason: "exit code is nil", + }, + { + name: "desired status stopped", + rp: RestartPolicy{Enabled: true, IgnoredExitCodes: []int{ignoredCode}, RestartAttemptPeriod: 60 * time.Second}, + exitCode: 1, + startedAt: time.Now().Add(2 * time.Minute), + desiredStatus: apicontainerstatus.ContainerStopped, + expected: false, + expectedReason: "container's desired status is stopped", + }, + { + name: "attempt reset period not elapsed", + rp: RestartPolicy{Enabled: true, IgnoredExitCodes: []int{ignoredCode}, RestartAttemptPeriod: 60 * time.Second}, + exitCode: 1, + startedAt: time.Now(), + desiredStatus: apicontainerstatus.ContainerRunning, + expected: false, + expectedReason: "attempt reset period has not elapsed", + }, + { + name: "attempt reset period not elapsed within one second", + rp: RestartPolicy{Enabled: true, IgnoredExitCodes: []int{ignoredCode}, RestartAttemptPeriod: 60 * time.Second}, + exitCode: 1, + startedAt: time.Now().Add(-time.Second * 59), + desiredStatus: apicontainerstatus.ContainerRunning, + expected: false, + expectedReason: "attempt reset period has not elapsed", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + rt.restartPolicy = tc.rp + + // Because we cannot instantiate int pointers directly, + // check for the exit code and leave this int pointer as nil + // if there is no value to override it. + var exitCodeAdjusted *int + if tc.exitCode != -1 { + exitCodeAdjusted = &tc.exitCode + } + + shouldRestart, reason := rt.ShouldRestart(exitCodeAdjusted, tc.startedAt, tc.desiredStatus) + assert.Equal(t, tc.expected, shouldRestart) + assert.Equal(t, tc.expectedReason, reason) + }) + } +} + +func TestRecordRestart(t *testing.T) { + rt := NewRestartTracker(RestartPolicy{ + Enabled: false, + RestartAttemptPeriod: 60 * time.Second, + }) + startedAt := time.Now() + assert.Equal(t, 0, rt.RestartCount) + for i := 1; i < 1000; i++ { + newTime := startedAt.Add(time.Duration(i) * time.Second) + rt.RecordRestart(newTime) + assert.Equal(t, i, rt.RestartCount) + assert.Equal(t, newTime, rt.LastStartedAt) + } +} + +func TestRecordRestartPolicy(t *testing.T) { + rt := NewRestartTracker(RestartPolicy{ + Enabled: false, + RestartAttemptPeriod: 60 * time.Second, + }) + assert.Equal(t, 0, rt.RestartCount) + assert.Equal(t, 0, len(rt.restartPolicy.IgnoredExitCodes)) + assert.NotNil(t, rt.restartPolicy) +}