Skip to content

Commit

Permalink
Fix a race in the policy updater map (#1687)
Browse files Browse the repository at this point in the history
  • Loading branch information
shleikes authored Sep 11, 2024
1 parent c840513 commit 66aaf13
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 3 deletions.
19 changes: 19 additions & 0 deletions protocol/rpcconsumer/policies_map.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,22 @@ func (sm *syncMapPolicyUpdaters) Load(key string) (ret *updaters.PolicyUpdater,
}
return ret, true
}

// LoadOrStore returns the existing value for the key if present.
// Otherwise, it stores and returns the given value.
// The loaded result is true if the value was loaded, false if stored.
// The function returns the value that was loaded or stored.
func (sm *syncMapPolicyUpdaters) LoadOrStore(key string, value *updaters.PolicyUpdater) (ret *updaters.PolicyUpdater, loaded bool) {
actual, loaded := sm.localMap.LoadOrStore(key, value)
if loaded {
// loaded from map
ret, loaded = actual.(*updaters.PolicyUpdater)
if !loaded {
utils.LavaFormatFatal("invalid usage of syncmap, could not cast result into a PolicyUpdater", nil)
}
return ret, loaded
}

// stored in map
return value, false
}
5 changes: 2 additions & 3 deletions protocol/rpcconsumer/rpcconsumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,14 +205,13 @@ func (rpcc *RPCConsumer) Start(ctx context.Context, options *rpcConsumerStartOpt
}
chainID := rpcEndpoint.ChainID
// create policyUpdaters per chain
if policyUpdater, ok := policyUpdaters.Load(rpcEndpoint.ChainID); ok {
newPolicyUpdater := updaters.NewPolicyUpdater(chainID, consumerStateTracker, consumerAddr.String(), chainParser, *rpcEndpoint)
if policyUpdater, ok := policyUpdaters.LoadOrStore(chainID, newPolicyUpdater); ok {
err := policyUpdater.AddPolicySetter(chainParser, *rpcEndpoint)
if err != nil {
errCh <- err
return utils.LavaFormatError("failed adding policy setter", err)
}
} else {
policyUpdaters.Store(rpcEndpoint.ChainID, updaters.NewPolicyUpdater(chainID, consumerStateTracker, consumerAddr.String(), chainParser, *rpcEndpoint))
}

err = statetracker.RegisterForSpecUpdatesOrSetStaticSpec(ctx, chainParser, options.cmdFlags.StaticSpecPath, *rpcEndpoint, rpcc.consumerStateTracker)
Expand Down

0 comments on commit 66aaf13

Please sign in to comment.