From fb0d5475e61f3b1f3fdd7474d2c5de604ee35558 Mon Sep 17 00:00:00 2001 From: zale144 Date: Thu, 18 Apr 2024 12:13:31 +0200 Subject: [PATCH] fix(p2p): handle default mempool check tx error case (#698) --- p2p/validator.go | 19 ++++---- p2p/validator_test.go | 103 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 113 insertions(+), 9 deletions(-) create mode 100644 p2p/validator_test.go diff --git a/p2p/validator.go b/p2p/validator.go index ace29d731..7985142be 100644 --- a/p2p/validator.go +++ b/p2p/validator.go @@ -4,12 +4,13 @@ import ( "context" "errors" - "github.com/dymensionxyz/dymint/mempool" - nodemempool "github.com/dymensionxyz/dymint/node/mempool" - "github.com/dymensionxyz/dymint/types" abci "github.com/tendermint/tendermint/abci/types" "github.com/tendermint/tendermint/libs/pubsub" corep2p "github.com/tendermint/tendermint/p2p" + + "github.com/dymensionxyz/dymint/mempool" + nodemempool "github.com/dymensionxyz/dymint/node/mempool" + "github.com/dymensionxyz/dymint/types" ) // GossipValidator is a callback function type. @@ -43,9 +44,9 @@ func NewValidator(logger types.Logger, pusbsubServer *pubsub.Server) *Validator func (v *Validator) TxValidator(mp mempool.Mempool, mpoolIDS *nodemempool.MempoolIDs) GossipValidator { return func(txMessage *GossipMessage) bool { v.logger.Debug("transaction received", "bytes", len(txMessage.Data)) - checkTxResCh := make(chan *abci.Response, 1) + var res *abci.Response err := mp.CheckTx(txMessage.Data, func(resp *abci.Response) { - checkTxResCh <- resp + res = resp }, mempool.TxInfo{ SenderID: mpoolIDS.GetForPeer(txMessage.From), SenderP2PID: corep2p.ID(txMessage.From), @@ -59,12 +60,12 @@ func (v *Validator) TxValidator(mp mempool.Mempool, mpoolIDS *nodemempool.Mempoo return false case errors.Is(err, mempool.ErrPreCheck{}): return false - default: + case err != nil: + v.logger.Error("check tx", "error", err) + return false } - res := <-checkTxResCh - checkTxResp := res.GetCheckTx() - return checkTxResp.Code == abci.CodeTypeOK + return res.GetCheckTx().Code == abci.CodeTypeOK } } diff --git a/p2p/validator_test.go b/p2p/validator_test.go new file mode 100644 index 000000000..11eae09d2 --- /dev/null +++ b/p2p/validator_test.go @@ -0,0 +1,103 @@ +package p2p_test + +import ( + "testing" + + "github.com/libp2p/go-libp2p/core/peer" + "github.com/stretchr/testify/assert" + abci "github.com/tendermint/tendermint/abci/types" + "github.com/tendermint/tendermint/libs/log" + + "github.com/tendermint/tendermint/types" + + "github.com/dymensionxyz/dymint/mempool" + nodemempool "github.com/dymensionxyz/dymint/node/mempool" + "github.com/dymensionxyz/dymint/p2p" +) + +func TestValidator_TxValidator(t *testing.T) { + type args struct { + mp mempool.Mempool + numMsgs int + } + tests := []struct { + name string + args args + want bool + }{ + { + name: "valid: tx already in cache", + args: args{ + mp: &mockMP{err: mempool.ErrTxInCache}, + numMsgs: 3, + }, + want: true, + }, { + name: "valid: mempool is full", + args: args{ + mp: &mockMP{err: mempool.ErrMempoolIsFull{}}, + numMsgs: 3, + }, + want: true, + }, { + name: "invalid: tx too large", + args: args{ + mp: &mockMP{err: mempool.ErrTxTooLarge{}}, + numMsgs: 3, + }, + want: false, + }, { + name: "invalid: pre-check error", + args: args{ + mp: &mockMP{err: mempool.ErrPreCheck{}}, + numMsgs: 3, + }, + want: false, + }, { + name: "valid: no error", + args: args{ + mp: &mockMP{}, + numMsgs: 3, + }, + want: true, + }, { + name: "unknown error", + args: args{ + mp: &mockMP{err: assert.AnError}, + numMsgs: 3, + }, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + logger := log.TestingLogger() + validateTx := p2p.NewValidator(logger, nil).TxValidator(tt.args.mp, nodemempool.NewMempoolIDs()) + valid := validateTx(txMsg) + assert.Equalf(t, tt.want, valid, "validateTx() = %v, want %v", valid, tt.want) + }) + } +} + +type mockMP struct { + mempool.Mempool + err error +} + +func (m *mockMP) CheckTx(_ types.Tx, cb func(*abci.Response), _ mempool.TxInfo) error { + if cb != nil { + code := abci.CodeTypeOK + if m.err != nil { + code = 1 + } + cb(&abci.Response{ + Value: &abci.Response_CheckTx{CheckTx: &abci.ResponseCheckTx{Code: code}}, + }) + } + return m.err +} + +var txMsg = &p2p.GossipMessage{ + Data: []byte("data"), + From: peer.ID("from"), +}