From 66aaf13be64209f4e53cb949ddff68dcaa3e99d5 Mon Sep 17 00:00:00 2001 From: Elad Gildnur <6321801+shleikes@users.noreply.github.com> Date: Wed, 11 Sep 2024 16:16:27 +0300 Subject: [PATCH] Fix a race in the policy updater map (#1687) --- protocol/rpcconsumer/policies_map.go | 19 +++++++++++++++++++ protocol/rpcconsumer/rpcconsumer.go | 5 ++--- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/protocol/rpcconsumer/policies_map.go b/protocol/rpcconsumer/policies_map.go index 543b71330b..d70d2de3da 100644 --- a/protocol/rpcconsumer/policies_map.go +++ b/protocol/rpcconsumer/policies_map.go @@ -26,3 +26,22 @@ func (sm *syncMapPolicyUpdaters) Load(key string) (ret *updaters.PolicyUpdater, } return ret, true } + +// LoadOrStore returns the existing value for the key if present. +// Otherwise, it stores and returns the given value. +// The loaded result is true if the value was loaded, false if stored. +// The function returns the value that was loaded or stored. +func (sm *syncMapPolicyUpdaters) LoadOrStore(key string, value *updaters.PolicyUpdater) (ret *updaters.PolicyUpdater, loaded bool) { + actual, loaded := sm.localMap.LoadOrStore(key, value) + if loaded { + // loaded from map + ret, loaded = actual.(*updaters.PolicyUpdater) + if !loaded { + utils.LavaFormatFatal("invalid usage of syncmap, could not cast result into a PolicyUpdater", nil) + } + return ret, loaded + } + + // stored in map + return value, false +} diff --git a/protocol/rpcconsumer/rpcconsumer.go b/protocol/rpcconsumer/rpcconsumer.go index cd5e9eb4e7..67f4a24461 100644 --- a/protocol/rpcconsumer/rpcconsumer.go +++ b/protocol/rpcconsumer/rpcconsumer.go @@ -205,14 +205,13 @@ func (rpcc *RPCConsumer) Start(ctx context.Context, options *rpcConsumerStartOpt } chainID := rpcEndpoint.ChainID // create policyUpdaters per chain - if policyUpdater, ok := policyUpdaters.Load(rpcEndpoint.ChainID); ok { + newPolicyUpdater := updaters.NewPolicyUpdater(chainID, consumerStateTracker, consumerAddr.String(), chainParser, *rpcEndpoint) + if policyUpdater, ok := policyUpdaters.LoadOrStore(chainID, newPolicyUpdater); ok { err := policyUpdater.AddPolicySetter(chainParser, *rpcEndpoint) if err != nil { errCh <- err return utils.LavaFormatError("failed adding policy setter", err) } - } else { - policyUpdaters.Store(rpcEndpoint.ChainID, updaters.NewPolicyUpdater(chainID, consumerStateTracker, consumerAddr.String(), chainParser, *rpcEndpoint)) } err = statetracker.RegisterForSpecUpdatesOrSetStaticSpec(ctx, chainParser, options.cmdFlags.StaticSpecPath, *rpcEndpoint, rpcc.consumerStateTracker)