Skip to content

Commit

Permalink
fix potential panic due to race condition and add slow consumer handl…
Browse files Browse the repository at this point in the history
…ing logic
  • Loading branch information
ettec committed Dec 19, 2024
1 parent 99efd78 commit e2ed35c
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 9 deletions.
38 changes: 29 additions & 9 deletions core/capabilities/remote/trigger_subscriber.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ var _ services.Service = &triggerSubscriber{}

// TODO makes this configurable with a default
const (
defaultSendChannelBufferSize = 1000
DefaultSendChannelBufferSize = 1000
maxBatchedWorkflowIDs = 1000
)

Expand Down Expand Up @@ -120,7 +120,7 @@ func (s *triggerSubscriber) RegisterTrigger(ctx context.Context, request commonc
regState, ok := s.registeredWorkflows[request.Metadata.WorkflowID]
if !ok {
regState = &subRegState{
callback: make(chan commoncap.TriggerResponse, defaultSendChannelBufferSize),
callback: make(chan commoncap.TriggerResponse, DefaultSendChannelBufferSize),
rawRequest: rawRequest,
}
s.registeredWorkflows[request.Metadata.WorkflowID] = regState
Expand Down Expand Up @@ -171,16 +171,20 @@ func (s *triggerSubscriber) UnregisterTrigger(ctx context.Context, request commo
s.mu.Lock()
defer s.mu.Unlock()

state := s.registeredWorkflows[request.Metadata.WorkflowID]
if state != nil && state.callback != nil {
close(state.callback)
}
delete(s.registeredWorkflows, request.Metadata.WorkflowID)
s.closeSubscription(request.Metadata.WorkflowID)
// Registrations will quickly expire on all remote nodes.
// Alternatively, we could send UnregisterTrigger messages right away.
return nil
}

func (s *triggerSubscriber) closeSubscription(workflowID string) {
state := s.registeredWorkflows[workflowID]
if state != nil && state.callback != nil {
close(state.callback)
}
delete(s.registeredWorkflows, workflowID)
}

func (s *triggerSubscriber) Receive(_ context.Context, msg *types.MessageBody) {
sender, err := ToPeerID(msg.Sender)
if err != nil {
Expand All @@ -204,7 +208,7 @@ func (s *triggerSubscriber) Receive(_ context.Context, msg *types.MessageBody) {
}
for _, workflowID := range meta.WorkflowIds {
s.mu.RLock()
registration, found := s.registeredWorkflows[workflowID]
_, found := s.registeredWorkflows[workflowID]
s.mu.RUnlock()
if !found {
s.lggr.Errorw("received message for unregistered workflow", "capabilityId", s.capInfo.ID, "workflowID", SanitizeLogString(workflowID), "sender", sender)
Expand All @@ -231,14 +235,30 @@ func (s *triggerSubscriber) Receive(_ context.Context, msg *types.MessageBody) {
continue
}
s.lggr.Infow("remote trigger event aggregated", "triggerEventID", meta.TriggerEventId, "capabilityId", s.capInfo.ID, "workflowId", workflowID)
registration.callback <- aggregatedResponse
s.sendResponse(workflowID, aggregatedResponse)
}
}
} else {
s.lggr.Errorw("received trigger event with unknown method", "method", SanitizeLogString(msg.Method), "sender", sender)
}
}

func (s *triggerSubscriber) sendResponse(workflowID string, response commoncap.TriggerResponse) {
s.mu.Lock()
defer s.mu.Unlock()

registration, found := s.registeredWorkflows[workflowID]
if found {
select {
case registration.callback <- response:
default:
s.lggr.Warn("slow consumer detected, closing subscription", "capabilityId", s.capInfo.ID, "workflowId", workflowID)
s.closeSubscription(workflowID)
return
}
}
}

func (s *triggerSubscriber) eventCleanupLoop() {
defer s.wg.Done()
ticker := time.NewTicker(s.config.MessageExpiry)
Expand Down
113 changes: 113 additions & 0 deletions core/capabilities/remote/trigger_subscriber_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package remote_test

import (
"strconv"
"testing"
"time"

Expand All @@ -22,12 +23,124 @@ const (
peerID1 = "12D3KooWF3dVeJ6YoT5HFnYhmwQWWMoEwVFzJQ5kKCMX3ZityxMC"
peerID2 = "12D3KooWQsmok6aD8PZqt3RnJhQRrNzKHLficq7zYFRp7kZ1hHP8"
workflowID1 = "15c631d295ef5e32deb99a10ee6804bc4af13855687559d7ff6552ac6dbb2ce0"
workflowID2 = "15c631d295ef5e32deb99a10ee6804bc4af13855687559d7ff6552ac6dbb2ce2"
)

var (
triggerEvent1 = map[string]any{"event": "triggerEvent1"}
)

func TestTriggerSubscriber_SlowConsumer(t *testing.T) {
lggr := logger.TestLogger(t)
ctx := testutils.Context(t)
capInfo := commoncap.CapabilityInfo{
ID: "cap_id@1",
CapabilityType: commoncap.CapabilityTypeTrigger,
Description: "Remote Trigger",
}
p1 := p2ptypes.PeerID{}
require.NoError(t, p1.UnmarshalText([]byte(peerID1)))
capDonInfo := commoncap.DON{
ID: 1,
Members: []p2ptypes.PeerID{p1},
F: 0,
}
workflowDonInfo := commoncap.DON{
ID: 2,
Members: []p2ptypes.PeerID{p1},
F: 0,
}
dispatcher := remoteMocks.NewDispatcher(t)

awaitRegistrationMessageCh := make(chan struct{})
dispatcher.On("Send", mock.Anything, mock.Anything).Return(nil).Run(func(args mock.Arguments) {
select {
case awaitRegistrationMessageCh <- struct{}{}:
default:
}
})

// register trigger
config := &commoncap.RemoteTriggerConfig{
RegistrationRefresh: 100 * time.Millisecond,
RegistrationExpiry: 100 * time.Second,
MinResponsesToAggregate: 1,
MessageExpiry: 100 * time.Second,
}
subscriber := remote.NewTriggerSubscriber(config, capInfo, capDonInfo, workflowDonInfo, dispatcher, nil, lggr)
require.NoError(t, subscriber.Start(ctx))

req := commoncap.TriggerRegistrationRequest{
Metadata: commoncap.RequestMetadata{
WorkflowID: workflowID1,
},
}
triggerEventCallbackCh, err := subscriber.RegisterTrigger(ctx, req)
require.NoError(t, err)
<-awaitRegistrationMessageCh

req2 := commoncap.TriggerRegistrationRequest{
Metadata: commoncap.RequestMetadata{
WorkflowID: workflowID2,
},
}
triggerEventCallbackCh2, err := subscriber.RegisterTrigger(ctx, req2)
require.NoError(t, err)
<-awaitRegistrationMessageCh

// receive trigger event
triggerEventValue, err := values.NewMap(triggerEvent1)
require.NoError(t, err)
capResponse := commoncap.TriggerResponse{
Event: commoncap.TriggerEvent{
Outputs: triggerEventValue,
},
Err: nil,
}
marshaled, err := pb.MarshalTriggerResponse(capResponse)
require.NoError(t, err)

// Simulate a slow consumer by not pulling events from triggerEventCallbackCh after a few events
// triggerEventCallbackCh2 should still receive events, note the number of messages

for i := 0; i < remote.DefaultSendChannelBufferSize*2; i++ {
triggerEvent := &remotetypes.MessageBody{

Sender: p1[:],
Method: remotetypes.MethodTriggerEvent,
Metadata: &remotetypes.MessageBody_TriggerEventMetadata{
TriggerEventMetadata: &remotetypes.TriggerEventMetadata{
TriggerEventId: strconv.Itoa(i),
WorkflowIds: []string{workflowID1, workflowID2},
},
},
Payload: marshaled,
}

subscriber.Receive(ctx, triggerEvent)
if i < 5 {
select {
case <-triggerEventCallbackCh:
case <-triggerEventCallbackCh2:
}
} else {
<-triggerEventCallbackCh2
}
}

// Confirm the slow consumer channel was closed
for {
_, ok := <-triggerEventCallbackCh
if !ok {
break
}
}

require.NoError(t, subscriber.UnregisterTrigger(ctx, req))
require.NoError(t, subscriber.UnregisterTrigger(ctx, req2))
require.NoError(t, subscriber.Close())
}

func TestTriggerSubscriber_RegisterAndReceive(t *testing.T) {
lggr := logger.TestLogger(t)
ctx := testutils.Context(t)
Expand Down

0 comments on commit e2ed35c

Please sign in to comment.