diff --git a/core/chains/evm/logpoller/orm.go b/core/chains/evm/logpoller/orm.go index 3a99a38b3fd..1b18e50e217 100644 --- a/core/chains/evm/logpoller/orm.go +++ b/core/chains/evm/logpoller/orm.go @@ -694,11 +694,12 @@ func (o *DbORM) SelectIndexedLogsWithSigsExcluding(sigA, sigB common.Hash, topic } func (o *DbORM) InsertLogsWithBlock(logs []Log, block LogPollerBlock, qopts ...pg.QOpt) error { - // Optimization, don't open TX when there is only block to be persisted + // Optimization, don't open TX when there is only a block to be persisted if len(logs) == 0 { return o.InsertBlock(block.BlockHash, block.BlockNumber, block.BlockTimestamp, block.FinalizedBlockNumber, qopts...) } + // Block and logs goes with the same TX to ensure atomicity return o.q.WithOpts(qopts...).Transaction(func(tx pg.Queryer) error { if err := o.InsertBlock(block.BlockHash, block.BlockNumber, block.BlockTimestamp, block.FinalizedBlockNumber, pg.WithQueryer(tx)); err != nil { return err diff --git a/core/chains/evm/logpoller/orm_test.go b/core/chains/evm/logpoller/orm_test.go index 1221d67c1f8..8736ea33c95 100644 --- a/core/chains/evm/logpoller/orm_test.go +++ b/core/chains/evm/logpoller/orm_test.go @@ -4,6 +4,7 @@ import ( "bytes" "database/sql" "fmt" + "math" "math/big" "testing" "time" @@ -15,7 +16,10 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/chains/evm/logpoller" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/types" + "github.com/smartcontractkit/chainlink/v2/core/internal/cltest/heavyweight" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" + "github.com/smartcontractkit/chainlink/v2/core/internal/testutils/pgtest" + "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/utils" ) @@ -1300,3 +1304,80 @@ func TestNestedLogPollerBlocksQuery(t *testing.T) { require.NoError(t, err) require.Len(t, logs, 0) } + +func TestInsertLogsWithBlock(t *testing.T) { + chainID := testutils.NewRandomEVMChainID() + event := utils.RandomBytes32() + address := utils.RandomAddress() + + // We need full db here, because we want to test transaction rollbacks. + // Using pgtest.NewSqlxDB(t) will run all tests in TXs which is not desired for this type of test + // (inner tx rollback will rollback outer tx, blocking rest of execution) + _, db := heavyweight.FullTestDBV2(t, "logpoller_tx", nil) + o := logpoller.NewORM(chainID, db, logger.TestLogger(t), pgtest.NewQConfig(true)) + + correctLog := GenLog(chainID, 1, 1, utils.RandomAddress().String(), event[:], address) + invalidLog := GenLog(chainID, -10, -10, utils.RandomAddress().String(), event[:], address) + correctBlock := logpoller.NewLogPollerBlock(utils.RandomBytes32(), 20, time.Now(), 10) + invalidBlock := logpoller.NewLogPollerBlock(utils.RandomBytes32(), -10, time.Now(), -10) + + tests := []struct { + name string + logs []logpoller.Log + block logpoller.LogPollerBlock + shouldRollback bool + }{ + { + name: "properly persist all data", + logs: []logpoller.Log{correctLog}, + block: correctBlock, + shouldRollback: false, + }, + { + name: "rollbacks transaction when block is invalid", + logs: []logpoller.Log{correctLog}, + block: invalidBlock, + shouldRollback: true, + }, + { + name: "rollbacks transaction when log is invalid", + logs: []logpoller.Log{invalidLog}, + block: correctBlock, + shouldRollback: true, + }, + { + name: "rollback when only some logs are invalid", + logs: []logpoller.Log{correctLog, invalidLog}, + block: correctBlock, + shouldRollback: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // clean all logs and blocks between test cases + defer func() { _ = o.DeleteLogsAndBlocksAfter(0) }() + insertError := o.InsertLogsWithBlock(tt.logs, tt.block) + + logs, logsErr := o.SelectLogs(0, math.MaxInt, address, event) + block, blockErr := o.SelectLatestBlock() + + if tt.shouldRollback { + assert.Error(t, insertError) + + assert.NoError(t, logsErr) + assert.Len(t, logs, 0) + + assert.Error(t, blockErr) + } else { + assert.NoError(t, insertError) + + assert.NoError(t, logsErr) + assert.Len(t, logs, len(tt.logs)) + + assert.NoError(t, blockErr) + assert.Equal(t, block.BlockNumber, tt.block.BlockNumber) + } + }) + } +}