From b3035e3ed2ccae4cc17da40e908b087b6480962b Mon Sep 17 00:00:00 2001 From: Aaron Lu <50029043+aalu1418@users.noreply.github.com> Date: Wed, 31 Jul 2024 10:16:42 -0600 Subject: [PATCH] cleanup: deduplicate shared cache logic (#803) * cleanup: deduplicate shared cache logic * use Read + pass name to cache state management --- pkg/solana/cache_test.go | 40 +++++---- pkg/solana/client/cache.go | 143 ++++++++++++++++++++++++++++++ pkg/solana/config_tracker.go | 4 +- pkg/solana/median_contract.go | 6 +- pkg/solana/relay.go | 4 +- pkg/solana/state_cache.go | 116 ++---------------------- pkg/solana/transmissions_cache.go | 116 ++---------------------- pkg/solana/transmitter.go | 2 +- 8 files changed, 183 insertions(+), 248 deletions(-) create mode 100644 pkg/solana/client/cache.go diff --git a/pkg/solana/cache_test.go b/pkg/solana/cache_test.go index 3351ecf68..e39bb52ad 100644 --- a/pkg/solana/cache_test.go +++ b/pkg/solana/cache_test.go @@ -167,29 +167,33 @@ func TestCache(t *testing.T) { })) lggr := logger.Test(t) - stateCache := StateCache{ - StateID: solana.MustPublicKeyFromBase58("11111111111111111111111111111111"), - cfg: config.NewDefault(), - reader: testSetupReader(t, mockServer.URL), - lggr: lggr, - } + stateCache := NewStateCache( + solana.MustPublicKeyFromBase58("11111111111111111111111111111111"), + "test-chain-id", + config.NewDefault(), + testSetupReader(t, mockServer.URL), + lggr, + ) require.NoError(t, stateCache.Start(ctx)) require.NoError(t, stateCache.Close()) - require.NoError(t, stateCache.fetchState(ctx)) - assert.Equal(t, "GADeYvXjPwZP7ds1yDY9VFp12bNjdxT1YyksMvFGK9xn", stateCache.state.Transmissions.String()) - assert.True(t, !stateCache.stateTime.IsZero()) - - transmissionsCache := TransmissionsCache{ - TransmissionsID: solana.MustPublicKeyFromBase58("11111111111111111111111111111112"), - cfg: config.NewDefault(), - reader: testSetupReader(t, mockServer.URL), - lggr: lggr, - } + require.NoError(t, stateCache.Fetch(ctx)) + state, err := stateCache.Read() + require.NoError(t, err) + assert.Equal(t, "GADeYvXjPwZP7ds1yDY9VFp12bNjdxT1YyksMvFGK9xn", state.Transmissions.String()) + assert.True(t, !stateCache.Timestamp().IsZero()) + + transmissionsCache := NewTransmissionsCache( + solana.MustPublicKeyFromBase58("11111111111111111111111111111112"), + "test-chain-id", + config.NewDefault(), + testSetupReader(t, mockServer.URL), + lggr, + ) require.NoError(t, transmissionsCache.Start(ctx)) require.NoError(t, transmissionsCache.Close()) - require.NoError(t, transmissionsCache.fetchLatestTransmission(ctx)) - answer, err := transmissionsCache.ReadAnswer() + require.NoError(t, transmissionsCache.Fetch(ctx)) + answer, err := transmissionsCache.Read() assert.NoError(t, err) assert.Equal(t, expectedTime, answer.Timestamp) assert.Equal(t, expectedAns, answer.Data.String()) diff --git a/pkg/solana/client/cache.go b/pkg/solana/client/cache.go new file mode 100644 index 000000000..a9ff0865d --- /dev/null +++ b/pkg/solana/client/cache.go @@ -0,0 +1,143 @@ +package client + +import ( + "context" + "errors" + "sync" + "time" + + "github.com/gagliardetto/solana-go" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/services" + "github.com/smartcontractkit/chainlink-common/pkg/utils" + + "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" + "github.com/smartcontractkit/chainlink-solana/pkg/solana/monitor" +) + +type CacheGetter[R any] func(ctx context.Context) (res R, slot uint64, err error) + +// Cache is a generic implementation for caching data from the chain +type Cache[R any] struct { + services.StateMachine + + // identifier + metricName string + Account solana.PublicKey + ChainID string + + // stored answer + resLock sync.RWMutex + res R + resTime time.Time + + // dependencies + getter CacheGetter[R] + cfg config.Config + lggr logger.Logger + + // polling + done chan struct{} + stopCh services.StopChan +} + +func NewCache[R any](metricName string, account solana.PublicKey, chainID string, cfg config.Config, getFunc CacheGetter[R], lggr logger.Logger) *Cache[R] { + return &Cache[R]{ + metricName: metricName, + Account: account, + ChainID: chainID, + getter: getFunc, + lggr: lggr, + cfg: cfg, + } +} + +func (c *Cache[R]) Name() string { + return c.lggr.Name() +} + +// Start polling +func (c *Cache[R]) Start(ctx context.Context) error { + return c.StartOnce("cache_"+c.metricName, func() error { + c.done = make(chan struct{}) + c.stopCh = make(chan struct{}) + // We synchronously update the config on start so that + // when OCR starts there is config available (if possible). + // Avoids confusing "contract has not been configured" OCR errors. + err := c.Fetch(ctx) + if err != nil { + c.lggr.Warnf("error in initial fetch %s", err) + } + go c.Poll() + return nil + }) +} + +// Close stops the polling +func (c *Cache[R]) Close() error { + return c.StopOnce("cache_"+c.metricName, func() error { + close(c.stopCh) + <-c.done + return nil + }) +} + +// Poll contains the polling implementation +func (c *Cache[R]) Poll() { + defer close(c.done) + ctx, cancel := c.stopCh.NewCtx() + defer cancel() + c.lggr.Debugf("Starting polling: %s", c.Account) + tick := time.After(0) + for { + select { + case <-ctx.Done(): + c.lggr.Debugf("Stopping polling: %s", c.Account) + return + case <-tick: + start := time.Now() + err := c.Fetch(ctx) + if err != nil { + c.lggr.Errorf("error in Poll.fetch %s", err) + } + // Note negative duration will be immediately ready + tick = time.After(utils.WithJitter(c.cfg.OCR2CachePollPeriod()) - time.Since(start)) + } + } +} + +// Read reads the latest result from memory with mutex and errors if timeout is exceeded +func (c *Cache[R]) Read() (R, error) { + c.resLock.RLock() + defer c.resLock.RUnlock() + + // check if stale timeout + var err error + if time.Since(c.resTime) > c.cfg.OCR2CacheTTL() { + err = errors.New("error in Read: stale data, polling is likely experiencing errors") + } + return c.res, err +} + +func (c *Cache[R]) Timestamp() time.Time { + return c.resTime +} + +func (c *Cache[R]) Fetch(ctx context.Context) error { + c.lggr.Debugf("fetch for account: %s", c.Account) + res, _, err := c.getter(ctx) + if err != nil { + return err + } + c.lggr.Debugf("latest fetched for account: %s, result: %v", c.Account, res) + + timestamp := time.Now() + monitor.SetCacheTimestamp(timestamp, c.metricName, c.ChainID, c.Account.String()) + // acquire lock and write to state + c.resLock.Lock() + defer c.resLock.Unlock() + c.res = res + c.resTime = timestamp + return nil +} diff --git a/pkg/solana/config_tracker.go b/pkg/solana/config_tracker.go index 333e296ba..3287aa115 100644 --- a/pkg/solana/config_tracker.go +++ b/pkg/solana/config_tracker.go @@ -21,7 +21,7 @@ func (c *ConfigTracker) Notify() <-chan struct{} { // LatestConfigDetails returns information about the latest configuration, // but not the configuration itself. func (c *ConfigTracker) LatestConfigDetails(ctx context.Context) (changedInBlock uint64, configDigest types.ConfigDigest, err error) { - state, err := c.stateCache.ReadState() + state, err := c.stateCache.Read() return state.Config.LatestConfigBlockNumber, state.Config.LatestConfigDigest, err } @@ -66,7 +66,7 @@ func ConfigFromState(state State) (types.ContractConfig, error) { // LatestConfig returns the latest configuration. func (c *ConfigTracker) LatestConfig(ctx context.Context, changedInBlock uint64) (types.ContractConfig, error) { - state, err := c.stateCache.ReadState() + state, err := c.stateCache.Read() if err != nil { return types.ContractConfig{}, err } diff --git a/pkg/solana/median_contract.go b/pkg/solana/median_contract.go index 863143d57..d1199e18c 100644 --- a/pkg/solana/median_contract.go +++ b/pkg/solana/median_contract.go @@ -23,11 +23,11 @@ func (c *MedianContract) LatestTransmissionDetails( latestTimestamp time.Time, err error, ) { - state, err := c.stateCache.ReadState() + state, err := c.stateCache.Read() if err != nil { return configDigest, epoch, round, latestAnswer, latestTimestamp, err } - answer, err := c.transmissionsCache.ReadAnswer() + answer, err := c.transmissionsCache.Read() if err != nil { return configDigest, epoch, round, latestAnswer, latestTimestamp, err } @@ -60,6 +60,6 @@ func (c *MedianContract) LatestRoundRequested( round uint8, err error, ) { - state, err := c.stateCache.ReadState() + state, err := c.stateCache.Read() return state.Config.LatestConfigDigest, 0, 0, err } diff --git a/pkg/solana/relay.go b/pkg/solana/relay.go index bbf1a6fa7..a272f15ea 100644 --- a/pkg/solana/relay.go +++ b/pkg/solana/relay.go @@ -222,7 +222,7 @@ func newConfigProvider(ctx context.Context, lggr logger.Logger, chain Chain, arg } func (c *configProvider) Name() string { - return c.stateCache.lggr.Name() + return c.stateCache.Name() } func (c *configProvider) Start(ctx context.Context) error { @@ -260,7 +260,7 @@ type medianProvider struct { } func (p *medianProvider) Name() string { - return p.stateCache.lggr.Name() + return p.stateCache.Name() } // start both cache services diff --git a/pkg/solana/state_cache.go b/pkg/solana/state_cache.go index 54293e7e9..9faa766d0 100644 --- a/pkg/solana/state_cache.go +++ b/pkg/solana/state_cache.go @@ -2,23 +2,17 @@ package solana import ( "context" - "encoding/hex" "errors" "fmt" - "sync" - "time" bin "github.com/gagliardetto/binary" "github.com/gagliardetto/solana-go" "github.com/gagliardetto/solana-go/rpc" "github.com/smartcontractkit/chainlink-common/pkg/logger" - "github.com/smartcontractkit/chainlink-common/pkg/services" - "github.com/smartcontractkit/chainlink-common/pkg/utils" "github.com/smartcontractkit/chainlink-solana/pkg/solana/client" "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" - "github.com/smartcontractkit/chainlink-solana/pkg/solana/monitor" ) var ( @@ -26,115 +20,15 @@ var ( ) type StateCache struct { - services.StateMachine - // on-chain program + 2x state accounts (state + transmissions) - StateID solana.PublicKey - chainID string - - stateLock sync.RWMutex - state State - stateTime time.Time - - // dependencies - reader client.Reader - cfg config.Config - lggr logger.Logger - - // polling - done chan struct{} - stopCh services.StopChan + *client.Cache[State] } func NewStateCache(stateID solana.PublicKey, chainID string, cfg config.Config, reader client.Reader, lggr logger.Logger) *StateCache { - return &StateCache{ - StateID: stateID, - chainID: chainID, - reader: reader, - lggr: lggr, - cfg: cfg, - } -} - -// Start polling -func (c *StateCache) Start(ctx context.Context) error { - return c.StartOnce("pollState", func() error { - c.done = make(chan struct{}) - c.stopCh = make(chan struct{}) - // We synchronously update the config on start so that - // when OCR starts there is config available (if possible). - // Avoids confusing "contract has not been configured" OCR errors. - err := c.fetchState(ctx) - if err != nil { - c.lggr.Warnf("error in initial PollState.fetchState %s", err) - } - go c.PollState() - return nil - }) -} - -// PollState contains the state and transmissions polling implementation -func (c *StateCache) PollState() { - defer close(c.done) - ctx, cancel := c.stopCh.NewCtx() - defer cancel() - c.lggr.Debugf("Starting state polling for state: %s", c.StateID) - tick := time.After(0) - for { - select { - case <-ctx.Done(): - c.lggr.Debugf("Stopping state polling for state: %s", c.StateID) - return - case <-tick: - // async poll both ocr2 states - start := time.Now() - err := c.fetchState(ctx) - if err != nil { - c.lggr.Errorf("error in PollState.fetchState %s", err) - } - // Note negative duration will be immediately ready - tick = time.After(utils.WithJitter(c.cfg.OCR2CachePollPeriod()) - time.Since(start)) - } - } -} - -// Close stops the polling -func (c *StateCache) Close() error { - return c.StopOnce("pollState", func() error { - close(c.stopCh) - <-c.done - return nil - }) -} - -// ReadState reads the latest state from memory with mutex and errors if timeout is exceeded -func (c *StateCache) ReadState() (State, error) { - c.stateLock.RLock() - defer c.stateLock.RUnlock() - - var err error - if time.Since(c.stateTime) > c.cfg.OCR2CacheTTL() { - err = errors.New("error in ReadState: stale state data, polling is likely experiencing errors") + name := "ocr2_median_state" + getter := func(ctx context.Context) (State, uint64, error) { + return GetState(ctx, reader, stateID, cfg.Commitment()) } - return c.state, err -} - -func (c *StateCache) fetchState(ctx context.Context) error { - c.lggr.Debugf("fetch state for account: %s", c.StateID.String()) - state, _, err := GetState(ctx, c.reader, c.StateID, c.cfg.Commitment()) - if err != nil { - return err - } - - c.lggr.Debugf("state fetched for account: %s, result (config digest): %v", c.StateID, hex.EncodeToString(state.Config.LatestConfigDigest[:])) - - timestamp := time.Now() - monitor.SetCacheTimestamp(timestamp, "ocr2_median_state", c.chainID, c.StateID.String()) - // acquire lock and write to state - c.stateLock.Lock() - defer c.stateLock.Unlock() - c.state = state - c.stateTime = timestamp - return nil + return &StateCache{client.NewCache(name, stateID, chainID, cfg, getter, logger.With(lggr, "cache", name))} } func GetState(ctx context.Context, reader client.AccountReader, account solana.PublicKey, commitment rpc.CommitmentType) (State, uint64, error) { diff --git a/pkg/solana/transmissions_cache.go b/pkg/solana/transmissions_cache.go index 8f1ceab5e..75ad30a6b 100644 --- a/pkg/solana/transmissions_cache.go +++ b/pkg/solana/transmissions_cache.go @@ -4,133 +4,27 @@ import ( "context" "errors" "fmt" - "sync" - "time" bin "github.com/gagliardetto/binary" "github.com/gagliardetto/solana-go" "github.com/gagliardetto/solana-go/rpc" "github.com/smartcontractkit/chainlink-common/pkg/logger" - "github.com/smartcontractkit/chainlink-common/pkg/services" - "github.com/smartcontractkit/chainlink-common/pkg/utils" "github.com/smartcontractkit/chainlink-solana/pkg/solana/client" "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" - "github.com/smartcontractkit/chainlink-solana/pkg/solana/monitor" ) type TransmissionsCache struct { - services.StateMachine - - // on-chain program + 2x state accounts (state + transmissions) - TransmissionsID solana.PublicKey - chainID string - - ansLock sync.RWMutex - answer Answer - ansTime time.Time - - // dependencies - reader client.Reader - cfg config.Config - lggr logger.Logger - - // polling - done chan struct{} - stopCh services.StopChan + *client.Cache[Answer] } func NewTransmissionsCache(transmissionsID solana.PublicKey, chainID string, cfg config.Config, reader client.Reader, lggr logger.Logger) *TransmissionsCache { - return &TransmissionsCache{ - TransmissionsID: transmissionsID, - chainID: chainID, - reader: reader, - lggr: lggr, - cfg: cfg, - } -} - -// Start polling -func (c *TransmissionsCache) Start(ctx context.Context) error { - return c.StartOnce("pollTransmissions", func() error { - c.done = make(chan struct{}) - c.stopCh = make(chan struct{}) - // We synchronously update the config on start so that - // when OCR starts there is config available (if possible). - // Avoids confusing "contract has not been configured" OCR errors. - err := c.fetchLatestTransmission(ctx) - if err != nil { - c.lggr.Warnf("error in initial PollTransmissions %s", err) - } - go c.PollTransmissions() - return nil - }) -} - -// Close stops the polling -func (c *TransmissionsCache) Close() error { - return c.StopOnce("transmissionCache", func() error { - close(c.stopCh) - <-c.done - return nil - }) -} - -// PollTransmissions contains the transmissions polling implementation -func (c *TransmissionsCache) PollTransmissions() { - defer close(c.done) - ctx, cancel := c.stopCh.NewCtx() - defer cancel() - c.lggr.Debugf("Starting state polling transmissions: %s", c.TransmissionsID) - tick := time.After(0) - for { - select { - case <-ctx.Done(): - c.lggr.Debugf("Stopping state polling transmissions: %s", c.TransmissionsID) - return - case <-tick: - // async poll both transmission + ocr2 states - start := time.Now() - err := c.fetchLatestTransmission(ctx) - if err != nil { - c.lggr.Errorf("error in PollTransmissions.fetchLatestTransmission %s", err) - } - // Note negative duration will be immediately ready - tick = time.After(utils.WithJitter(c.cfg.OCR2CachePollPeriod()) - time.Since(start)) - } - } -} - -// ReadAnswer reads the latest state from memory with mutex and errors if timeout is exceeded -func (c *TransmissionsCache) ReadAnswer() (Answer, error) { - c.ansLock.RLock() - defer c.ansLock.RUnlock() - - // check if stale timeout - var err error - if time.Since(c.ansTime) > c.cfg.OCR2CacheTTL() { - err = errors.New("error in ReadAnswer: stale answer data, polling is likely experiencing errors") - } - return c.answer, err -} - -func (c *TransmissionsCache) fetchLatestTransmission(ctx context.Context) error { - c.lggr.Debugf("fetch latest transmission for account: %s", c.TransmissionsID) - answer, _, err := GetLatestTransmission(ctx, c.reader, c.TransmissionsID, c.cfg.Commitment()) - if err != nil { - return err + name := "ocr2_median_transmissions" + getter := func(ctx context.Context) (Answer, uint64, error) { + return GetLatestTransmission(ctx, reader, transmissionsID, cfg.Commitment()) } - c.lggr.Debugf("latest transmission fetched for account: %s, result: %v", c.TransmissionsID, answer) - - timestamp := time.Now() - monitor.SetCacheTimestamp(timestamp, "ocr2_median_transmissions", c.chainID, c.TransmissionsID.String()) - // acquire lock and write to state - c.ansLock.Lock() - defer c.ansLock.Unlock() - c.answer = answer - c.ansTime = timestamp - return nil + return &TransmissionsCache{client.NewCache(name, transmissionsID, chainID, cfg, getter, logger.With(lggr, "cache", name))} } func GetLatestTransmission(ctx context.Context, reader client.AccountReader, account solana.PublicKey, commitment rpc.CommitmentType) (Answer, uint64, error) { diff --git a/pkg/solana/transmitter.go b/pkg/solana/transmitter.go index 9a26f82d4..e524470d5 100644 --- a/pkg/solana/transmitter.go +++ b/pkg/solana/transmitter.go @@ -97,7 +97,7 @@ func (c *Transmitter) LatestConfigDigestAndEpoch( epoch uint32, err error, ) { - state, err := c.stateCache.ReadState() + state, err := c.stateCache.Read() return state.Config.LatestConfigDigest, state.Config.Epoch, err }