From e2ed35cd2d9a3402d57ee173a7d05aabfaddd8fc Mon Sep 17 00:00:00 2001 From: Matthew Pendrey Date: Thu, 19 Dec 2024 14:27:08 +0000 Subject: [PATCH] fix potential panic due to race condition and add slow consumer handling logic --- .../capabilities/remote/trigger_subscriber.go | 38 ++++-- .../remote/trigger_subscriber_test.go | 113 ++++++++++++++++++ 2 files changed, 142 insertions(+), 9 deletions(-) diff --git a/core/capabilities/remote/trigger_subscriber.go b/core/capabilities/remote/trigger_subscriber.go index 7edcbf5eba7..c2990b8b622 100644 --- a/core/capabilities/remote/trigger_subscriber.go +++ b/core/capabilities/remote/trigger_subscriber.go @@ -60,7 +60,7 @@ var _ services.Service = &triggerSubscriber{} // TODO makes this configurable with a default const ( - defaultSendChannelBufferSize = 1000 + DefaultSendChannelBufferSize = 1000 maxBatchedWorkflowIDs = 1000 ) @@ -120,7 +120,7 @@ func (s *triggerSubscriber) RegisterTrigger(ctx context.Context, request commonc regState, ok := s.registeredWorkflows[request.Metadata.WorkflowID] if !ok { regState = &subRegState{ - callback: make(chan commoncap.TriggerResponse, defaultSendChannelBufferSize), + callback: make(chan commoncap.TriggerResponse, DefaultSendChannelBufferSize), rawRequest: rawRequest, } s.registeredWorkflows[request.Metadata.WorkflowID] = regState @@ -171,16 +171,20 @@ func (s *triggerSubscriber) UnregisterTrigger(ctx context.Context, request commo s.mu.Lock() defer s.mu.Unlock() - state := s.registeredWorkflows[request.Metadata.WorkflowID] - if state != nil && state.callback != nil { - close(state.callback) - } - delete(s.registeredWorkflows, request.Metadata.WorkflowID) + s.closeSubscription(request.Metadata.WorkflowID) // Registrations will quickly expire on all remote nodes. // Alternatively, we could send UnregisterTrigger messages right away. return nil } +func (s *triggerSubscriber) closeSubscription(workflowID string) { + state := s.registeredWorkflows[workflowID] + if state != nil && state.callback != nil { + close(state.callback) + } + delete(s.registeredWorkflows, workflowID) +} + func (s *triggerSubscriber) Receive(_ context.Context, msg *types.MessageBody) { sender, err := ToPeerID(msg.Sender) if err != nil { @@ -204,7 +208,7 @@ func (s *triggerSubscriber) Receive(_ context.Context, msg *types.MessageBody) { } for _, workflowID := range meta.WorkflowIds { s.mu.RLock() - registration, found := s.registeredWorkflows[workflowID] + _, found := s.registeredWorkflows[workflowID] s.mu.RUnlock() if !found { s.lggr.Errorw("received message for unregistered workflow", "capabilityId", s.capInfo.ID, "workflowID", SanitizeLogString(workflowID), "sender", sender) @@ -231,7 +235,7 @@ func (s *triggerSubscriber) Receive(_ context.Context, msg *types.MessageBody) { continue } s.lggr.Infow("remote trigger event aggregated", "triggerEventID", meta.TriggerEventId, "capabilityId", s.capInfo.ID, "workflowId", workflowID) - registration.callback <- aggregatedResponse + s.sendResponse(workflowID, aggregatedResponse) } } } else { @@ -239,6 +243,22 @@ func (s *triggerSubscriber) Receive(_ context.Context, msg *types.MessageBody) { } } +func (s *triggerSubscriber) sendResponse(workflowID string, response commoncap.TriggerResponse) { + s.mu.Lock() + defer s.mu.Unlock() + + registration, found := s.registeredWorkflows[workflowID] + if found { + select { + case registration.callback <- response: + default: + s.lggr.Warn("slow consumer detected, closing subscription", "capabilityId", s.capInfo.ID, "workflowId", workflowID) + s.closeSubscription(workflowID) + return + } + } +} + func (s *triggerSubscriber) eventCleanupLoop() { defer s.wg.Done() ticker := time.NewTicker(s.config.MessageExpiry) diff --git a/core/capabilities/remote/trigger_subscriber_test.go b/core/capabilities/remote/trigger_subscriber_test.go index d5b48bc1dc8..2f680a810bd 100644 --- a/core/capabilities/remote/trigger_subscriber_test.go +++ b/core/capabilities/remote/trigger_subscriber_test.go @@ -1,6 +1,7 @@ package remote_test import ( + "strconv" "testing" "time" @@ -22,12 +23,124 @@ const ( peerID1 = "12D3KooWF3dVeJ6YoT5HFnYhmwQWWMoEwVFzJQ5kKCMX3ZityxMC" peerID2 = "12D3KooWQsmok6aD8PZqt3RnJhQRrNzKHLficq7zYFRp7kZ1hHP8" workflowID1 = "15c631d295ef5e32deb99a10ee6804bc4af13855687559d7ff6552ac6dbb2ce0" + workflowID2 = "15c631d295ef5e32deb99a10ee6804bc4af13855687559d7ff6552ac6dbb2ce2" ) var ( triggerEvent1 = map[string]any{"event": "triggerEvent1"} ) +func TestTriggerSubscriber_SlowConsumer(t *testing.T) { + lggr := logger.TestLogger(t) + ctx := testutils.Context(t) + capInfo := commoncap.CapabilityInfo{ + ID: "cap_id@1", + CapabilityType: commoncap.CapabilityTypeTrigger, + Description: "Remote Trigger", + } + p1 := p2ptypes.PeerID{} + require.NoError(t, p1.UnmarshalText([]byte(peerID1))) + capDonInfo := commoncap.DON{ + ID: 1, + Members: []p2ptypes.PeerID{p1}, + F: 0, + } + workflowDonInfo := commoncap.DON{ + ID: 2, + Members: []p2ptypes.PeerID{p1}, + F: 0, + } + dispatcher := remoteMocks.NewDispatcher(t) + + awaitRegistrationMessageCh := make(chan struct{}) + dispatcher.On("Send", mock.Anything, mock.Anything).Return(nil).Run(func(args mock.Arguments) { + select { + case awaitRegistrationMessageCh <- struct{}{}: + default: + } + }) + + // register trigger + config := &commoncap.RemoteTriggerConfig{ + RegistrationRefresh: 100 * time.Millisecond, + RegistrationExpiry: 100 * time.Second, + MinResponsesToAggregate: 1, + MessageExpiry: 100 * time.Second, + } + subscriber := remote.NewTriggerSubscriber(config, capInfo, capDonInfo, workflowDonInfo, dispatcher, nil, lggr) + require.NoError(t, subscriber.Start(ctx)) + + req := commoncap.TriggerRegistrationRequest{ + Metadata: commoncap.RequestMetadata{ + WorkflowID: workflowID1, + }, + } + triggerEventCallbackCh, err := subscriber.RegisterTrigger(ctx, req) + require.NoError(t, err) + <-awaitRegistrationMessageCh + + req2 := commoncap.TriggerRegistrationRequest{ + Metadata: commoncap.RequestMetadata{ + WorkflowID: workflowID2, + }, + } + triggerEventCallbackCh2, err := subscriber.RegisterTrigger(ctx, req2) + require.NoError(t, err) + <-awaitRegistrationMessageCh + + // receive trigger event + triggerEventValue, err := values.NewMap(triggerEvent1) + require.NoError(t, err) + capResponse := commoncap.TriggerResponse{ + Event: commoncap.TriggerEvent{ + Outputs: triggerEventValue, + }, + Err: nil, + } + marshaled, err := pb.MarshalTriggerResponse(capResponse) + require.NoError(t, err) + + // Simulate a slow consumer by not pulling events from triggerEventCallbackCh after a few events + // triggerEventCallbackCh2 should still receive events, note the number of messages + + for i := 0; i < remote.DefaultSendChannelBufferSize*2; i++ { + triggerEvent := &remotetypes.MessageBody{ + + Sender: p1[:], + Method: remotetypes.MethodTriggerEvent, + Metadata: &remotetypes.MessageBody_TriggerEventMetadata{ + TriggerEventMetadata: &remotetypes.TriggerEventMetadata{ + TriggerEventId: strconv.Itoa(i), + WorkflowIds: []string{workflowID1, workflowID2}, + }, + }, + Payload: marshaled, + } + + subscriber.Receive(ctx, triggerEvent) + if i < 5 { + select { + case <-triggerEventCallbackCh: + case <-triggerEventCallbackCh2: + } + } else { + <-triggerEventCallbackCh2 + } + } + + // Confirm the slow consumer channel was closed + for { + _, ok := <-triggerEventCallbackCh + if !ok { + break + } + } + + require.NoError(t, subscriber.UnregisterTrigger(ctx, req)) + require.NoError(t, subscriber.UnregisterTrigger(ctx, req2)) + require.NoError(t, subscriber.Close()) +} + func TestTriggerSubscriber_RegisterAndReceive(t *testing.T) { lggr := logger.TestLogger(t) ctx := testutils.Context(t)