diff --git a/context.go b/context.go new file mode 100644 index 0000000..6de26d9 --- /dev/null +++ b/context.go @@ -0,0 +1,54 @@ +package xrun + +import ( + "context" + "sync/atomic" +) + +// SignalStarted signals that the component has started. +func SignalStarted(ctx context.Context) { + if n, ok := ctx.Value(ctxKeyNotify).(*notifyCtx); ok { + n.start() + } +} + +type ctxKey int + +const ( + ctxKeyNotify ctxKey = iota +) + +type notifyCtx struct { + parent context.Context + _started atomic.Bool + ch chan struct{} +} + +func newNotifyCtx(parent context.Context) context.Context { + return context.WithValue(parent, ctxKeyNotify, ¬ifyCtx{ + parent: parent, + ch: make(chan struct{}, 1), + }) +} + +func (n *notifyCtx) started() <-chan struct{} { return n.ch } + +func (n *notifyCtx) start() { + if n._started.CompareAndSwap(false, true) { + close(n.ch) + } +} + +func started(ctx context.Context) <-chan struct{} { + if n, ok := ctx.Value(ctxKeyNotify).(*notifyCtx); ok { + return n.started() + } + + return closedCh() +} + +func closedCh() <-chan struct{} { + ch := make(chan struct{}, 1) + defer close(ch) + return ch +} diff --git a/context_test.go b/context_test.go new file mode 100644 index 0000000..90a2829 --- /dev/null +++ b/context_test.go @@ -0,0 +1,10 @@ +package xrun + +import ( + "context" + "testing" +) + +func Test_started_with_no_notifyCtx(t *testing.T) { + <-started(context.TODO()) +} diff --git a/manager.go b/manager.go index 51936fa..ff8b0fe 100644 --- a/manager.go +++ b/manager.go @@ -12,7 +12,10 @@ import ( // NewManager creates a Manager and applies provided Option func NewManager(opts ...Option) *Manager { - m := &Manager{shutdownTimeout: NoTimeout} + m := &Manager{ + shutdownTimeout: NoTimeout, + maxStartWait: defaultMaxStartWait, + } for _, o := range opts { o.apply(m) @@ -24,13 +27,16 @@ func NewManager(opts ...Option) *Manager { // Manager helps to run multiple components // and waits for them to complete type Manager struct { - mu sync.Mutex + strategy Strategy + maxStartWait time.Duration + mu sync.Mutex internalCtx context.Context internalCancel context.CancelFunc - components []Component - wg sync.WaitGroup + components []Component + componentCancels []context.CancelFunc + wg sync.WaitGroup started bool stopping bool @@ -87,20 +93,38 @@ func (m *Manager) start() { defer m.mu.Unlock() m.started = true - for _, c := range m.components { - if c != nil { - m.startComponent(c) + switch m.strategy { + case OrderedStart: + for _, c := range m.components { + if c != nil { + notify := newNotifyCtx(m.internalCtx) + nCtx, cancel := context.WithCancel(notify) + m.startComponent(c, nCtx) + m.componentCancels = append([]context.CancelFunc{cancel}, m.componentCancels...) + + // Block until the component has started or the timeout has elapsed. + select { + case <-started(notify): + case <-time.After(m.maxStartWait): + } + } + } + case DefaultStartStop: + for _, c := range m.components { + if c != nil { + m.startComponent(c, m.internalCtx) + } } } } -func (m *Manager) startComponent(c Component) { +func (m *Manager) startComponent(c Component, ctx context.Context) { m.wg.Add(1) go func() { defer m.wg.Done() - if err := c.Run(m.internalCtx); err != nil && !errors.Is(err, context.Canceled) { + if err := c.Run(ctx); err != nil && !errors.Is(err, context.Canceled) { m.errChan <- err } }() @@ -110,7 +134,14 @@ func (m *Manager) engageStopProcedure() error { shutdownCancel := m.cancelFunc() defer shutdownCancel() - m.internalCancel() + switch m.strategy { + case OrderedStart: + for _, cancel := range m.componentCancels { + cancel() + } + case DefaultStartStop: + m.internalCancel() + } m.mu.Lock() defer m.mu.Unlock() diff --git a/manager_test.go b/manager_test.go index 04edd45..9bf6023 100644 --- a/manager_test.go +++ b/manager_test.go @@ -19,177 +19,263 @@ func TestManagerSuite(t *testing.T) { } func (s *ManagerSuite) TestNewManager() { - testcases := []struct { + type testcase struct { name string wantErr assert.ErrorAssertionFunc wantAddErr bool components []Component + wantOrder []int options []Option - }{ - { - name: "WithZeroComponents", - wantErr: assert.NoError, - }, - { - name: "WithOneComponent", - wantErr: assert.NoError, - components: []Component{ - ComponentFunc(func(ctx context.Context) error { - time.Sleep(3 * time.Second) - <-ctx.Done() - return nil - }), + } + + s.Run("DefaultManager", func() { + testcases := []testcase{ + { + name: "WithZeroComponents", + wantErr: assert.NoError, }, - }, - { - name: "WithErrorOnComponentStart", - wantErr: assert.Error, - components: []Component{ - ComponentFunc(func(ctx context.Context) error { - return errors.New("start error") - }), + { + name: "WithOneComponent", + wantErr: assert.NoError, + components: []Component{ + ComponentFunc(func(ctx context.Context) error { + time.Sleep(3 * time.Second) + <-ctx.Done() + return nil + }), + }, }, - }, - { - name: "WithGracefulShutdownErrorOnOneComponent", - options: []Option{ShutdownTimeout(5 * time.Second)}, - wantErr: assert.Error, - components: []Component{ - ComponentFunc(func(ctx context.Context) error { - time.Sleep(time.Second) - <-ctx.Done() - time.Sleep(time.Second) - return nil - }), - ComponentFunc(func(ctx context.Context) error { - <-ctx.Done() - time.Sleep(time.Minute) - return nil - }), + { + name: "WithErrorOnComponentStart", + wantErr: assert.Error, + components: []Component{ + ComponentFunc(func(ctx context.Context) error { + return errors.New("start error") + }), + }, }, - }, - { - name: "WithGracefulShutdownForTwoLongRunningComponents", - options: []Option{ShutdownTimeout(time.Minute)}, - wantErr: assert.NoError, - components: []Component{ - ComponentFunc(func(ctx context.Context) error { - time.Sleep(5 * time.Second) - <-ctx.Done() - time.Sleep(5 * time.Second) - return nil - }), - ComponentFunc(func(ctx context.Context) error { - time.Sleep(time.Second) - <-ctx.Done() - time.Sleep(10 * time.Second) - return nil - }), + { + name: "WithGracefulShutdownErrorOnOneComponent", + options: []Option{ShutdownTimeout(5 * time.Second)}, + wantErr: assert.Error, + components: []Component{ + ComponentFunc(func(ctx context.Context) error { + time.Sleep(time.Second) + <-ctx.Done() + time.Sleep(time.Second) + return nil + }), + ComponentFunc(func(ctx context.Context) error { + <-ctx.Done() + time.Sleep(time.Minute) + return nil + }), + }, }, - }, - { - name: "UndefinedGracefulShutdown", - wantErr: assert.NoError, - components: []Component{ - ComponentFunc(func(ctx context.Context) error { - time.Sleep(5 * time.Second) - <-ctx.Done() - time.Sleep(5 * time.Second) - return nil - }), + { + name: "WithGracefulShutdownForTwoLongRunningComponents", + options: []Option{ShutdownTimeout(time.Minute)}, + wantErr: assert.NoError, + components: []Component{ + ComponentFunc(func(ctx context.Context) error { + time.Sleep(5 * time.Second) + <-ctx.Done() + time.Sleep(5 * time.Second) + return nil + }), + ComponentFunc(func(ctx context.Context) error { + time.Sleep(time.Second) + <-ctx.Done() + time.Sleep(10 * time.Second) + return nil + }), + }, }, - }, - { - name: "ShutdownWhenComponentReturnsContextErrorAsItIs", - wantErr: assert.NoError, - components: []Component{ - ComponentFunc(func(ctx context.Context) error { - time.Sleep(time.Second) - <-ctx.Done() - time.Sleep(2 * time.Second) - return nil - }), - ComponentFunc(func(ctx context.Context) error { - time.Sleep(time.Second) - <-ctx.Done() - time.Sleep(time.Second) - return ctx.Err() - }), + { + name: "UndefinedGracefulShutdown", + wantErr: assert.NoError, + components: []Component{ + ComponentFunc(func(ctx context.Context) error { + time.Sleep(5 * time.Second) + <-ctx.Done() + time.Sleep(5 * time.Second) + return nil + }), + }, }, - }, - { - name: "ShutdownWhenOneComponentReturnsErrorOnExit", - wantErr: func(t assert.TestingT, err error, i ...interface{}) bool { - return assert.EqualError(t, err, `1 error occurred: + { + name: "ShutdownWhenComponentReturnsContextErrorAsItIs", + wantErr: assert.NoError, + components: []Component{ + ComponentFunc(func(ctx context.Context) error { + time.Sleep(time.Second) + <-ctx.Done() + time.Sleep(2 * time.Second) + return nil + }), + ComponentFunc(func(ctx context.Context) error { + time.Sleep(time.Second) + <-ctx.Done() + time.Sleep(time.Second) + return ctx.Err() + }), + }, + }, + { + name: "ShutdownWhenOneComponentReturnsErrorOnExit", + wantErr: func(t assert.TestingT, err error, i ...interface{}) bool { + return assert.EqualError(t, err, `1 error occurred: * shutdown error `, i...) + }, + components: []Component{ + ComponentFunc(func(ctx context.Context) error { + time.Sleep(time.Second) + <-ctx.Done() + time.Sleep(2 * time.Second) + return nil + }), + ComponentFunc(func(ctx context.Context) error { + time.Sleep(time.Second) + <-ctx.Done() + time.Sleep(time.Second) + return errors.New("shutdown error") + }), + }, }, - components: []Component{ - ComponentFunc(func(ctx context.Context) error { - time.Sleep(time.Second) - <-ctx.Done() - time.Sleep(2 * time.Second) - return nil - }), - ComponentFunc(func(ctx context.Context) error { - time.Sleep(time.Second) - <-ctx.Done() - time.Sleep(time.Second) - return errors.New("shutdown error") - }), - }, - }, - { - name: "ShutdownWhenMoreThanOneComponentReturnsErrorOnExit", - wantErr: func(t assert.TestingT, err error, i ...interface{}) bool { - return assert.EqualError(t, err, `2 errors occurred: + { + name: "ShutdownWhenMoreThanOneComponentReturnsErrorOnExit", + wantErr: func(t assert.TestingT, err error, i ...interface{}) bool { + return assert.EqualError(t, err, `2 errors occurred: * shutdown error 2 * shutdown error 1 `, i...) + }, + components: []Component{ + ComponentFunc(func(ctx context.Context) error { + <-ctx.Done() + time.Sleep(2 * time.Second) + return nil + }), + ComponentFunc(func(ctx context.Context) error { + <-ctx.Done() + time.Sleep(3 * time.Second) + return errors.New("shutdown error 1") + }), + ComponentFunc(func(ctx context.Context) error { + <-ctx.Done() + time.Sleep(2 * time.Second) + return errors.New("shutdown error 2") + }), + }, + }, + } + + for _, t := range testcases { + s.Run(t.name, func() { + m := NewManager(t.options...) + + for _, r := range t.components { + s.NoError(m.Add(r)) + } + + ctx, cancel := context.WithCancel(context.Background()) + + errCh := make(chan error, 1) + go func() { + errCh <- m.Run(ctx) + }() + + time.Sleep(1 * time.Second) + cancel() + + t.wantErr(s.T(), <-errCh) + }) + } + }) + + s.Run("OrderedStart", func() { + testcases := []testcase{ + { + name: "OrderedStartWithSignalStartedCalled", + wantOrder: []int{1, 2, 3}, + components: []Component{ + ComponentFunc(func(ctx context.Context) error { + SignalStarted(ctx) + <-ctx.Done() + return nil + }), + ComponentFunc(func(ctx context.Context) error { + SignalStarted(ctx) + <-ctx.Done() + return nil + }), + ComponentFunc(func(ctx context.Context) error { + SignalStarted(ctx) + <-ctx.Done() + return nil + }), + }, + wantErr: assert.NoError, }, - components: []Component{ - ComponentFunc(func(ctx context.Context) error { - <-ctx.Done() - time.Sleep(2 * time.Second) - return nil - }), - ComponentFunc(func(ctx context.Context) error { - <-ctx.Done() - time.Sleep(3 * time.Second) - return errors.New("shutdown error 1") - }), - ComponentFunc(func(ctx context.Context) error { - <-ctx.Done() - time.Sleep(2 * time.Second) - return errors.New("shutdown error 2") - }), + { + name: "OrderedStartWithSignalStartedNotCalled", + wantOrder: []int{1, 2, 3}, + options: []Option{MaxStartWait(100 * time.Millisecond)}, + components: []Component{ + ComponentFunc(func(ctx context.Context) error { + SignalStarted(ctx) + <-ctx.Done() + return nil + }), + ComponentFunc(func(ctx context.Context) error { + <-ctx.Done() + return nil + }), + ComponentFunc(func(ctx context.Context) error { + SignalStarted(ctx) + <-ctx.Done() + return nil + }), + }, + wantErr: assert.NoError, }, - }, - } + } - for _, t := range testcases { - s.Run(t.name, func() { - m := NewManager(t.options...) + for _, t := range testcases { + s.Run(t.name, func() { + m := NewManager(append(t.options, OrderedStart)...) - for _, r := range t.components { - s.NoError(m.Add(r)) - } + var order []int + for i, r := range t.components { + ii := i + rr := r + s.NoError(m.Add(ComponentFunc(func(ctx context.Context) error { + order = append(order, ii+1) + defer func() { + order = append(order, ii+1) + }() + return rr.Run(ctx) + }))) + } - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(context.Background()) - errCh := make(chan error, 1) - go func() { - errCh <- m.Run(ctx) - }() + errCh := make(chan error, 1) + go func() { + errCh <- m.Run(ctx) + }() - time.Sleep(1 * time.Second) - cancel() + time.Sleep(1 * time.Second) + cancel() - t.wantErr(s.T(), <-errCh) - }) - } + t.wantErr(s.T(), <-errCh) + s.Equal(t.wantOrder, order[:len(t.wantOrder)]) + }) + } + }) } func (s *ManagerSuite) TestAddNewComponentAfterStop() { diff --git a/strategy.go b/strategy.go new file mode 100644 index 0000000..5519ca7 --- /dev/null +++ b/strategy.go @@ -0,0 +1,21 @@ +package xrun + +import "time" + +// Strategy defines the order of starting and stopping components +type Strategy int + +const ( + // DefaultStartStop starts and stops components in any order + DefaultStartStop Strategy = iota + // OrderedStart starts components in order they were added and stops them in random order + OrderedStart +) + +// MaxStartWait allows to set max wait time for component to start when using OrderedStart strategy +type MaxStartWait time.Duration + +const defaultMaxStartWait = 5 * time.Minute + +func (s Strategy) apply(m *Manager) { m.strategy = s } +func (t MaxStartWait) apply(m *Manager) { m.maxStartWait = time.Duration(t) }