From e68ee4cee85e2e20647ed9d3d6a22a7da1f3e3fb Mon Sep 17 00:00:00 2001 From: amit-momin <108959691+amit-momin@users.noreply.github.com> Date: Tue, 26 Nov 2024 09:02:15 -0600 Subject: [PATCH] Store transaction errors caught before broadcast (#936) * Enabled TXM to store error statuses for transactions caught before broadcast * Addressed feedback * Removed id from finished tx metadata to reduce memory footprint * Updated logs --- pkg/solana/txm/pendingtx.go | 188 +++++++++++++++++-------- pkg/solana/txm/pendingtx_test.go | 170 +++++++++++++++++----- pkg/solana/txm/txm.go | 140 ++++++++++++------- pkg/solana/txm/txm_internal_test.go | 210 ++++++++++++++++++++-------- pkg/solana/txm/txm_unit_test.go | 18 +-- pkg/solana/txm/utils.go | 4 + 6 files changed, 518 insertions(+), 212 deletions(-) diff --git a/pkg/solana/txm/pendingtx.go b/pkg/solana/txm/pendingtx.go index b2c3c98ed..ecae7243b 100644 --- a/pkg/solana/txm/pendingtx.go +++ b/pkg/solana/txm/pendingtx.go @@ -36,20 +36,28 @@ type PendingTxContext interface { OnConfirmed(sig solana.Signature) (string, error) // OnFinalized marks transaction as Finalized, moves it from the broadcasted or confirmed map to finalized map, removes signatures from signature map to stop confirmation checks OnFinalized(sig solana.Signature, retentionTimeout time.Duration) (string, error) + // OnPrebroadcastError adds transaction that has not yet been broadcasted to the finalized/errored map as errored, matches err type using enum + OnPrebroadcastError(id string, retentionTimeout time.Duration, txState TxState, errType TxErrType) error // OnError marks transaction as errored, matches err type using enum, moves it from the broadcasted or confirmed map to finalized/errored map, removes signatures from signature map to stop confirmation checks - OnError(sig solana.Signature, retentionTimeout time.Duration, errType int) (string, error) + OnError(sig solana.Signature, retentionTimeout time.Duration, txState TxState, errType TxErrType) (string, error) // GetTxState returns the transaction state for the provided ID if it exists GetTxState(id string) (TxState, error) // TrimFinalizedErroredTxs removes transactions that have reached their retention time - TrimFinalizedErroredTxs() + TrimFinalizedErroredTxs() int } +// finishedTx is used to store info required to track transactions to finality or error type pendingTx struct { - tx solana.Transaction - cfg TxConfig - signatures []solana.Signature - id string - createTs time.Time + tx solana.Transaction + cfg TxConfig + signatures []solana.Signature + id string + createTs time.Time + state TxState +} + +// finishedTx is used to store minimal info specifically for finalized or errored transactions for external status checks +type finishedTx struct { retentionTs time.Time state TxState } @@ -60,9 +68,9 @@ type pendingTxContext struct { cancelBy map[string]context.CancelFunc sigToID map[solana.Signature]string - broadcastedTxs map[string]pendingTx // transactions that require retry and bumping i.e broadcasted, processed - confirmedTxs map[string]pendingTx // transactions that require monitoring for re-org - finalizedErroredTxs map[string]pendingTx // finalized and errored transactions held onto for status + broadcastedTxs map[string]pendingTx // transactions that require retry and bumping i.e broadcasted, processed + confirmedTxs map[string]pendingTx // transactions that require monitoring for re-org + finalizedErroredTxs map[string]finishedTx // finalized and errored transactions held onto for status lock sync.RWMutex } @@ -74,7 +82,7 @@ func newPendingTxContext() *pendingTxContext { broadcastedTxs: map[string]pendingTx{}, confirmedTxs: map[string]pendingTx{}, - finalizedErroredTxs: map[string]pendingTx{}, + finalizedErroredTxs: map[string]finishedTx{}, } } @@ -262,7 +270,6 @@ func (c *pendingTxContext) OnProcessed(sig solana.Signature) (string, error) { if !exists { return id, ErrTransactionNotFound } - tx = c.broadcastedTxs[id] // update tx state to Processed tx.state = Processed // save updated tx back to the broadcasted map @@ -298,7 +305,8 @@ func (c *pendingTxContext) OnConfirmed(sig solana.Signature) (string, error) { if !sigExists { return id, ErrSigDoesNotExist } - if _, exists := c.broadcastedTxs[id]; !exists { + tx, exists := c.broadcastedTxs[id] + if !exists { return id, ErrTransactionNotFound } // call cancel func + remove from map to stop the retry/bumping cycle for this transaction @@ -306,7 +314,6 @@ func (c *pendingTxContext) OnConfirmed(sig solana.Signature) (string, error) { cancel() // cancel context delete(c.cancelBy, id) } - tx := c.broadcastedTxs[id] // update tx state to Confirmed tx.state = Confirmed // move tx to confirmed map @@ -371,17 +378,58 @@ func (c *pendingTxContext) OnFinalized(sig solana.Signature, retentionTimeout ti if retentionTimeout == 0 { return id, nil } - // set the timestamp till which the tx should be retained in storage - tx.retentionTs = time.Now().Add(retentionTimeout) - // update tx state to Finalized - tx.state = Finalized + finalizedTx := finishedTx{ + state: Finalized, + retentionTs: time.Now().Add(retentionTimeout), + } // move transaction from confirmed to finalized map - c.finalizedErroredTxs[id] = tx + c.finalizedErroredTxs[id] = finalizedTx return id, nil }) } -func (c *pendingTxContext) OnError(sig solana.Signature, retentionTimeout time.Duration, _ int) (string, error) { +func (c *pendingTxContext) OnPrebroadcastError(id string, retentionTimeout time.Duration, txState TxState, _ TxErrType) error { + // nothing to do if retention timeout is 0 since transaction is not stored yet. + if retentionTimeout == 0 { + return nil + } + err := c.withReadLock(func() error { + if tx, exists := c.finalizedErroredTxs[id]; exists && tx.state == txState { + return ErrAlreadyInExpectedState + } + _, broadcastedExists := c.broadcastedTxs[id] + _, confirmedExists := c.confirmedTxs[id] + if broadcastedExists || confirmedExists { + return ErrIDAlreadyExists + } + return nil + }) + if err != nil { + return err + } + + // upgrade to write lock if id does not exist in other maps and is not in expected state already + _, err = c.withWriteLock(func() (string, error) { + if tx, exists := c.finalizedErroredTxs[id]; exists && tx.state == txState { + return "", ErrAlreadyInExpectedState + } + _, broadcastedExists := c.broadcastedTxs[id] + _, confirmedExists := c.confirmedTxs[id] + if broadcastedExists || confirmedExists { + return "", ErrIDAlreadyExists + } + erroredTx := finishedTx{ + state: txState, + retentionTs: time.Now().Add(retentionTimeout), + } + // add transaction to error map + c.finalizedErroredTxs[id] = erroredTx + return id, nil + }) + return err +} + +func (c *pendingTxContext) OnError(sig solana.Signature, retentionTimeout time.Duration, txState TxState, _ TxErrType) (string, error) { err := c.withReadLock(func() error { id, sigExists := c.sigToID[sig] if !sigExists { @@ -432,17 +480,16 @@ func (c *pendingTxContext) OnError(sig solana.Signature, retentionTimeout time.D for _, s := range tx.signatures { delete(c.sigToID, s) } - // if retention duration is set to 0, delete transaction from storage - // otherwise, move to finalized map + // if retention duration is set to 0, skip adding transaction to the errored map if retentionTimeout == 0 { return id, nil } - // set the timestamp till which the tx should be retained in storage - tx.retentionTs = time.Now().Add(retentionTimeout) - // update tx state to Errored - tx.state = Errored + erroredTx := finishedTx{ + state: txState, + retentionTs: time.Now().Add(retentionTimeout), + } // move transaction from broadcasted to error map - c.finalizedErroredTxs[id] = tx + c.finalizedErroredTxs[id] = erroredTx return id, nil }) } @@ -463,18 +510,31 @@ func (c *pendingTxContext) GetTxState(id string) (TxState, error) { } // TrimFinalizedErroredTxs deletes transactions from the finalized/errored map and the allTxs map after the retention period has passed -func (c *pendingTxContext) TrimFinalizedErroredTxs() { - c.lock.Lock() - defer c.lock.Unlock() - expiredIDs := make([]string, 0, len(c.finalizedErroredTxs)) - for id, tx := range c.finalizedErroredTxs { - if time.Now().After(tx.retentionTs) { - expiredIDs = append(expiredIDs, id) +func (c *pendingTxContext) TrimFinalizedErroredTxs() int { + var expiredIDs []string + err := c.withReadLock(func() error { + expiredIDs = make([]string, 0, len(c.finalizedErroredTxs)) + for id, tx := range c.finalizedErroredTxs { + if time.Now().After(tx.retentionTs) { + expiredIDs = append(expiredIDs, id) + } } + return nil + }) + if err != nil { + return 0 } - for _, id := range expiredIDs { - delete(c.finalizedErroredTxs, id) + + _, err = c.withWriteLock(func() (string, error) { + for _, id := range expiredIDs { + delete(c.finalizedErroredTxs, id) + } + return "", nil + }) + if err != nil { + return 0 } + return len(expiredIDs) } func (c *pendingTxContext) withReadLock(fn func() error) error { @@ -496,8 +556,11 @@ type pendingTxContextWithProm struct { chainID string } +type TxErrType int + const ( - TxFailRevert = iota + NoFailure TxErrType = iota + TxFailRevert TxFailReject TxFailDrop TxFailSimRevert @@ -554,44 +617,45 @@ func (c *pendingTxContextWithProm) OnFinalized(sig solana.Signature, retentionTi return id, err } -func (c *pendingTxContextWithProm) OnError(sig solana.Signature, retentionTimeout time.Duration, errType int) (string, error) { - // special RPC rejects transaction (signature will not be valid) - if errType == TxFailReject { - promSolTxmRejectTxs.WithLabelValues(c.chainID).Add(1) - promSolTxmErrorTxs.WithLabelValues(c.chainID).Add(1) - return "", nil +func (c *pendingTxContextWithProm) OnError(sig solana.Signature, retentionTimeout time.Duration, txState TxState, errType TxErrType) (string, error) { + id, err := c.pendingTx.OnError(sig, retentionTimeout, txState, errType) // err indicates transaction not found so may already be removed + if err == nil { + incrementErrorMetrics(errType, c.chainID) } + return id, err +} - id, err := c.pendingTx.OnError(sig, retentionTimeout, errType) // err indicates transaction not found so may already be removed +func (c *pendingTxContextWithProm) OnPrebroadcastError(id string, retentionTimeout time.Duration, txState TxState, errType TxErrType) error { + err := c.pendingTx.OnPrebroadcastError(id, retentionTimeout, txState, errType) // err indicates transaction not found so may already be removed if err == nil { - switch errType { - case TxFailRevert: - promSolTxmRevertTxs.WithLabelValues(c.chainID).Add(1) - promSolTxmErrorTxs.WithLabelValues(c.chainID).Add(1) - case TxFailDrop: - promSolTxmDropTxs.WithLabelValues(c.chainID).Add(1) - promSolTxmErrorTxs.WithLabelValues(c.chainID).Add(1) - } + incrementErrorMetrics(errType, c.chainID) } + return err +} - // Increment simulation error metrics even if no tx found for sig - // Simulation could have occurred before initial broadcast so tx was never stored +func incrementErrorMetrics(errType TxErrType, chainID string) { switch errType { + case NoFailure: + // Return early if no failure identified + return + case TxFailReject: + promSolTxmRejectTxs.WithLabelValues(chainID).Inc() + case TxFailRevert: + promSolTxmRevertTxs.WithLabelValues(chainID).Inc() + case TxFailDrop: + promSolTxmDropTxs.WithLabelValues(chainID).Inc() case TxFailSimRevert: - promSolTxmSimRevertTxs.WithLabelValues(c.chainID).Add(1) - promSolTxmErrorTxs.WithLabelValues(c.chainID).Add(1) + promSolTxmSimRevertTxs.WithLabelValues(chainID).Inc() case TxFailSimOther: - promSolTxmSimOtherTxs.WithLabelValues(c.chainID).Add(1) - promSolTxmErrorTxs.WithLabelValues(c.chainID).Add(1) + promSolTxmSimOtherTxs.WithLabelValues(chainID).Inc() } - - return id, err + promSolTxmErrorTxs.WithLabelValues(chainID).Inc() } func (c *pendingTxContextWithProm) GetTxState(id string) (TxState, error) { return c.pendingTx.GetTxState(id) } -func (c *pendingTxContextWithProm) TrimFinalizedErroredTxs() { - c.pendingTx.TrimFinalizedErroredTxs() +func (c *pendingTxContextWithProm) TrimFinalizedErroredTxs() int { + return c.pendingTx.TrimFinalizedErroredTxs() } diff --git a/pkg/solana/txm/pendingtx_test.go b/pkg/solana/txm/pendingtx_test.go index b1212ca21..e7b7fc51e 100644 --- a/pkg/solana/txm/pendingtx_test.go +++ b/pkg/solana/txm/pendingtx_test.go @@ -93,11 +93,11 @@ func TestPendingTxContext_new(t *testing.T) { require.Equal(t, Broadcasted, tx.state) // Check it does not exist in confirmed map - tx, exists = txs.confirmedTxs[msg.id] + _, exists = txs.confirmedTxs[msg.id] require.False(t, exists) // Check it does not exist in finalized map - tx, exists = txs.finalizedErroredTxs[msg.id] + _, exists = txs.finalizedErroredTxs[msg.id] require.False(t, exists) } @@ -134,11 +134,11 @@ func TestPendingTxContext_add_signature(t *testing.T) { require.Equal(t, sig2, tx.signatures[1]) // Check confirmed map - tx, exists = txs.confirmedTxs[msg.id] + _, exists = txs.confirmedTxs[msg.id] require.False(t, exists) // Check finalized map - tx, exists = txs.finalizedErroredTxs[msg.id] + _, exists = txs.finalizedErroredTxs[msg.id] require.False(t, exists) }) @@ -225,11 +225,11 @@ func TestPendingTxContext_on_broadcasted_processed(t *testing.T) { require.Equal(t, Processed, tx.state) // Check it does not exist in confirmed map - tx, exists = txs.confirmedTxs[msg.id] + _, exists = txs.confirmedTxs[msg.id] require.False(t, exists) // Check it does not exist in finalized map - tx, exists = txs.finalizedErroredTxs[msg.id] + _, exists = txs.finalizedErroredTxs[msg.id] require.False(t, exists) }) @@ -293,7 +293,7 @@ func TestPendingTxContext_on_broadcasted_processed(t *testing.T) { require.NoError(t, err) // Transition to errored state - id, err := txs.OnError(sig, retentionTimeout, 0) + id, err := txs.OnError(sig, retentionTimeout, Errored, 0) require.NoError(t, err) require.Equal(t, msg.id, id) @@ -364,7 +364,7 @@ func TestPendingTxContext_on_confirmed(t *testing.T) { require.Equal(t, Confirmed, tx.state) // Check it does not exist in finalized map - tx, exists = txs.finalizedErroredTxs[msg.id] + _, exists = txs.finalizedErroredTxs[msg.id] require.False(t, exists) }) @@ -405,7 +405,7 @@ func TestPendingTxContext_on_confirmed(t *testing.T) { require.NoError(t, err) // Transition to errored state - id, err := txs.OnError(sig, retentionTimeout, 0) + id, err := txs.OnError(sig, retentionTimeout, Errored, 0) require.NoError(t, err) require.Equal(t, msg.id, id) @@ -473,9 +473,6 @@ func TestPendingTxContext_on_finalized(t *testing.T) { // Check it exists in finalized map tx, exists := txs.finalizedErroredTxs[msg.id] require.True(t, exists) - require.Len(t, tx.signatures, 2) - require.Equal(t, sig1, tx.signatures[0]) - require.Equal(t, sig2, tx.signatures[1]) // Check status is Finalized require.Equal(t, Finalized, tx.state) @@ -526,9 +523,6 @@ func TestPendingTxContext_on_finalized(t *testing.T) { // Check it exists in finalized map tx, exists := txs.finalizedErroredTxs[msg.id] require.True(t, exists) - require.Len(t, tx.signatures, 2) - require.Equal(t, sig1, tx.signatures[0]) - require.Equal(t, sig2, tx.signatures[1]) // Check status is Finalized require.Equal(t, Finalized, tx.state) @@ -589,7 +583,7 @@ func TestPendingTxContext_on_finalized(t *testing.T) { require.NoError(t, err) // Transition to errored state - id, err := txs.OnError(sig, retentionTimeout, 0) + id, err := txs.OnError(sig, retentionTimeout, Errored, 0) require.NoError(t, err) require.Equal(t, msg.id, id) @@ -614,7 +608,7 @@ func TestPendingTxContext_on_error(t *testing.T) { require.NoError(t, err) // Transition to errored state - id, err := txs.OnError(sig, retentionTimeout, 0) + id, err := txs.OnError(sig, retentionTimeout, Errored, 0) require.NoError(t, err) require.Equal(t, msg.id, id) @@ -629,8 +623,6 @@ func TestPendingTxContext_on_error(t *testing.T) { // Check it exists in errored map tx, exists := txs.finalizedErroredTxs[msg.id] require.True(t, exists) - require.Len(t, tx.signatures, 1) - require.Equal(t, sig, tx.signatures[0]) // Check status is Finalized require.Equal(t, Errored, tx.state) @@ -654,7 +646,7 @@ func TestPendingTxContext_on_error(t *testing.T) { require.Equal(t, msg.id, id) // Transition to errored state - id, err = txs.OnError(sig, retentionTimeout, 0) + id, err = txs.OnError(sig, retentionTimeout, Errored, 0) require.NoError(t, err) require.Equal(t, msg.id, id) @@ -669,8 +661,6 @@ func TestPendingTxContext_on_error(t *testing.T) { // Check it exists in errored map tx, exists := txs.finalizedErroredTxs[msg.id] require.True(t, exists) - require.Len(t, tx.signatures, 1) - require.Equal(t, sig, tx.signatures[0]) // Check status is Finalized require.Equal(t, Errored, tx.state) @@ -680,6 +670,35 @@ func TestPendingTxContext_on_error(t *testing.T) { require.False(t, exists) }) + t.Run("successfully transition transaction from broadcasted/processed to fatally errored state", func(t *testing.T) { + sig := randomSignature(t) + + // Create new transaction + msg := pendingTx{id: uuid.NewString()} + err := txs.New(msg, sig, cancel) + require.NoError(t, err) + + // Transition to fatally errored state + id, err := txs.OnError(sig, retentionTimeout, FatallyErrored, 0) + require.NoError(t, err) + require.Equal(t, msg.id, id) + + // Check it does not exist in broadcasted map + _, exists := txs.broadcastedTxs[msg.id] + require.False(t, exists) + + // Check it exists in errored map + tx, exists := txs.finalizedErroredTxs[msg.id] + require.True(t, exists) + + // Check status is Errored + require.Equal(t, FatallyErrored, tx.state) + + // Check sigs do no exist in signature map + _, exists = txs.sigToID[sig] + require.False(t, exists) + }) + t.Run("successfully delete transaction when errored with 0 retention timeout", func(t *testing.T) { sig := randomSignature(t) @@ -694,7 +713,7 @@ func TestPendingTxContext_on_error(t *testing.T) { require.Equal(t, msg.id, id) // Transition to errored state - id, err = txs.OnError(sig, 0*time.Second, 0) + id, err = txs.OnError(sig, 0*time.Second, Errored, 0) require.NoError(t, err) require.Equal(t, msg.id, id) @@ -729,12 +748,76 @@ func TestPendingTxContext_on_error(t *testing.T) { require.Equal(t, msg.id, id) // Transition back to confirmed state - id, err = txs.OnError(sig, retentionTimeout, 0) + id, err = txs.OnError(sig, retentionTimeout, Errored, 0) require.Error(t, err) require.Equal(t, "", id) }) } +func TestPendingTxContext_on_prebroadcast_error(t *testing.T) { + t.Parallel() + _, cancel := context.WithCancel(tests.Context(t)) + txs := newPendingTxContext() + retentionTimeout := 5 * time.Second + + t.Run("successfully adds transaction with errored state", func(t *testing.T) { + // Create new transaction + msg := pendingTx{id: uuid.NewString()} + // Transition to errored state + err := txs.OnPrebroadcastError(msg.id, retentionTimeout, Errored, 0) + require.NoError(t, err) + + // Check it exists in errored map + tx, exists := txs.finalizedErroredTxs[msg.id] + require.True(t, exists) + + // Check status is Errored + require.Equal(t, Errored, tx.state) + }) + + t.Run("successfully adds transaction with fatally errored state", func(t *testing.T) { + // Create new transaction + msg := pendingTx{id: uuid.NewString()} + + // Transition to fatally errored state + err := txs.OnPrebroadcastError(msg.id, retentionTimeout, FatallyErrored, 0) + require.NoError(t, err) + + // Check it exists in errored map + tx, exists := txs.finalizedErroredTxs[msg.id] + require.True(t, exists) + + // Check status is Errored + require.Equal(t, FatallyErrored, tx.state) + }) + + t.Run("fails to add transaction to errored map if id exists in another map already", func(t *testing.T) { + sig := randomSignature(t) + + // Create new transaction + msg := pendingTx{id: uuid.NewString()} + // Add transaction to broadcasted map + err := txs.New(msg, sig, cancel) + require.NoError(t, err) + + // Transition to errored state + err = txs.OnPrebroadcastError(msg.id, retentionTimeout, FatallyErrored, 0) + require.ErrorIs(t, err, ErrIDAlreadyExists) + }) + + t.Run("predefined error if transaction already in errored state", func(t *testing.T) { + txID := uuid.NewString() + + // Transition to errored state + err := txs.OnPrebroadcastError(txID, retentionTimeout, Errored, 0) + require.NoError(t, err) + + // Transition back to errored state + err = txs.OnPrebroadcastError(txID, retentionTimeout, Errored, 0) + require.ErrorIs(t, err, ErrAlreadyInExpectedState) + }) +} + func TestPendingTxContext_remove(t *testing.T) { t.Parallel() _, cancel := context.WithCancel(tests.Context(t)) @@ -784,7 +867,7 @@ func TestPendingTxContext_remove(t *testing.T) { erroredMsg := pendingTx{id: uuid.NewString()} err = txs.New(erroredMsg, erroredSig, cancel) require.NoError(t, err) - id, err = txs.OnError(erroredSig, retentionTimeout, 0) + id, err = txs.OnError(erroredSig, retentionTimeout, Errored, 0) require.NoError(t, err) require.Equal(t, erroredMsg.id, id) @@ -841,32 +924,34 @@ func TestPendingTxContext_trim_finalized_errored_txs(t *testing.T) { txs := newPendingTxContext() // Create new finalized transaction with retention ts in the past and add to map - finalizedMsg1 := pendingTx{id: uuid.NewString(), retentionTs: time.Now().Add(-2 * time.Second)} - txs.finalizedErroredTxs[finalizedMsg1.id] = finalizedMsg1 + finalizedMsg1 := finishedTx{retentionTs: time.Now().Add(-2 * time.Second)} + finalizedMsg1ID := uuid.NewString() + txs.finalizedErroredTxs[finalizedMsg1ID] = finalizedMsg1 // Create new finalized transaction with retention ts in the future and add to map - finalizedMsg2 := pendingTx{id: uuid.NewString(), retentionTs: time.Now().Add(1 * time.Second)} - txs.finalizedErroredTxs[finalizedMsg2.id] = finalizedMsg2 + finalizedMsg2 := finishedTx{retentionTs: time.Now().Add(1 * time.Second)} + finalizedMsg2ID := uuid.NewString() + txs.finalizedErroredTxs[finalizedMsg2ID] = finalizedMsg2 // Create new finalized transaction with retention ts in the past and add to map - erroredMsg := pendingTx{id: uuid.NewString(), retentionTs: time.Now().Add(-2 * time.Second)} - txs.finalizedErroredTxs[erroredMsg.id] = erroredMsg + erroredMsg := finishedTx{retentionTs: time.Now().Add(-2 * time.Second)} + erroredMsgID := uuid.NewString() + txs.finalizedErroredTxs[erroredMsgID] = erroredMsg // Delete finalized/errored transactions that have passed the retention period txs.TrimFinalizedErroredTxs() // Check finalized message past retention is deleted - _, exists := txs.finalizedErroredTxs[finalizedMsg1.id] + _, exists := txs.finalizedErroredTxs[finalizedMsg1ID] require.False(t, exists) // Check errored message past retention is deleted - _, exists = txs.finalizedErroredTxs[erroredMsg.id] + _, exists = txs.finalizedErroredTxs[erroredMsgID] require.False(t, exists) // Check finalized message within retention period still exists - msg, exists := txs.finalizedErroredTxs[finalizedMsg2.id] + _, exists = txs.finalizedErroredTxs[finalizedMsg2ID] require.True(t, exists) - require.Equal(t, finalizedMsg2.id, msg.id) } func TestPendingTxContext_expired(t *testing.T) { @@ -970,6 +1055,7 @@ func TestGetTxState(t *testing.T) { confirmedSig := randomSignature(t) finalizedSig := randomSignature(t) erroredSig := randomSignature(t) + fatallyErroredSig := randomSignature(t) // Create new broadcasted transaction with extra sig broadcastedMsg := pendingTx{id: uuid.NewString()} @@ -1017,7 +1103,7 @@ func TestGetTxState(t *testing.T) { erroredMsg := pendingTx{id: uuid.NewString()} err = txs.New(erroredMsg, erroredSig, cancel) require.NoError(t, err) - id, err = txs.OnError(erroredSig, retentionTimeout, 0) + id, err = txs.OnError(erroredSig, retentionTimeout, Errored, 0) require.NoError(t, err) require.Equal(t, erroredMsg.id, id) // Check Errored state is returned @@ -1025,6 +1111,18 @@ func TestGetTxState(t *testing.T) { require.NoError(t, err) require.Equal(t, Errored, state) + // Create new fatally errored transaction + fatallyErroredMsg := pendingTx{id: uuid.NewString()} + err = txs.New(fatallyErroredMsg, fatallyErroredSig, cancel) + require.NoError(t, err) + id, err = txs.OnError(fatallyErroredSig, retentionTimeout, FatallyErrored, 0) + require.NoError(t, err) + require.Equal(t, fatallyErroredMsg.id, id) + // Check Errored state is returned + state, err = txs.GetTxState(fatallyErroredMsg.id) + require.NoError(t, err) + require.Equal(t, FatallyErrored, state) + // Check NotFound state is returned if unknown id provided state, err = txs.GetTxState("unknown id") require.Error(t, err) diff --git a/pkg/solana/txm/txm.go b/pkg/solana/txm/txm.go index 13b7fcfdc..342f54dce 100644 --- a/pkg/solana/txm/txm.go +++ b/pkg/solana/txm/txm.go @@ -239,9 +239,9 @@ func (txm *Txm) sendWithRetry(ctx context.Context, msg pendingTx) (solanaGo.Tran // send initial tx (do not retry and exit early if fails) sig, initSendErr := txm.sendTx(ctx, &initTx) if initSendErr != nil { - cancel() // cancel context when exiting early - txm.txs.OnError(sig, txm.cfg.TxRetentionTimeout(), TxFailReject) //nolint // no need to check error since only incrementing metric here - return solanaGo.Transaction{}, "", solanaGo.Signature{}, fmt.Errorf("tx failed initial transmit: %w", initSendErr) + cancel() // cancel context when exiting early + stateTransitionErr := txm.txs.OnPrebroadcastError(msg.id, txm.cfg.TxRetentionTimeout(), Errored, TxFailReject) + return solanaGo.Transaction{}, "", solanaGo.Signature{}, fmt.Errorf("tx failed initial transmit: %w", errors.Join(initSendErr, stateTransitionErr)) } // store tx signature + cancel function @@ -417,12 +417,12 @@ func (txm *Txm) confirm() { ) // check confirm timeout exceeded - if txm.txs.Expired(s[i], txm.cfg.TxConfirmTimeout()) { - id, err := txm.txs.OnError(s[i], txm.cfg.TxRetentionTimeout(), TxFailDrop) + if txm.cfg.TxConfirmTimeout() != 0*time.Second && txm.txs.Expired(s[i], txm.cfg.TxConfirmTimeout()) { + id, err := txm.txs.OnError(s[i], txm.cfg.TxRetentionTimeout(), Errored, TxFailDrop) if err != nil { txm.lggr.Infow("failed to mark transaction as errored", "id", id, "signature", s[i], "timeoutSeconds", txm.cfg.TxConfirmTimeout(), "error", err) } else { - txm.lggr.Infow("failed to find transaction within confirm timeout", "id", id, "signature", s[i], "timeoutSeconds", txm.cfg.TxConfirmTimeout()) + txm.lggr.Debugw("failed to find transaction within confirm timeout", "id", id, "signature", s[i], "timeoutSeconds", txm.cfg.TxConfirmTimeout()) } } continue @@ -430,11 +430,15 @@ func (txm *Txm) confirm() { // if signature has an error, end polling if res[i].Err != nil { - id, err := txm.txs.OnError(s[i], txm.cfg.TxRetentionTimeout(), TxFailRevert) - if err != nil { - txm.lggr.Infow("failed to mark transaction as errored", "id", id, "signature", s[i], "error", err) - } else { - txm.lggr.Debugw("tx state: failed", "id", id, "signature", s[i], "error", res[i].Err, "status", res[i].ConfirmationStatus) + // Process error to determine the corresponding state and type. + // Skip marking as errored if error considered to not be a failure. + if txState, errType := txm.processError(s[i], res[i].Err, false); errType != NoFailure { + id, err := txm.txs.OnError(s[i], txm.cfg.TxRetentionTimeout(), txState, errType) + if err != nil { + txm.lggr.Infow(fmt.Sprintf("failed to mark transaction as %s", txState.String()), "id", id, "signature", s[i], "error", err) + } else { + txm.lggr.Debugw(fmt.Sprintf("marking transaction as %s", txState.String()), "id", id, "signature", s[i], "error", res[i].Err, "status", res[i].ConfirmationStatus) + } } continue } @@ -450,7 +454,7 @@ func (txm *Txm) confirm() { } // check confirm timeout exceeded if TxConfirmTimeout set if txm.cfg.TxConfirmTimeout() != 0*time.Second && txm.txs.Expired(s[i], txm.cfg.TxConfirmTimeout()) { - id, err := txm.txs.OnError(s[i], txm.cfg.TxRetentionTimeout(), TxFailDrop) + id, err := txm.txs.OnError(s[i], txm.cfg.TxRetentionTimeout(), Errored, TxFailDrop) if err != nil { txm.lggr.Infow("failed to mark transaction as errored", "id", id, "signature", s[i], "timeoutSeconds", txm.cfg.TxConfirmTimeout(), "error", err) } else { @@ -536,8 +540,18 @@ func (txm *Txm) simulate() { } // Transaction has to have a signature if simulation succeeded but added check for belt and braces approach - if len(msg.signatures) > 0 { - txm.processSimulationError(msg.id, msg.signatures[0], res) + if len(msg.signatures) == 0 { + continue + } + // Process error to determine the corresponding state and type. + // Certain errors can be considered not to be failures during simulation to allow the process to continue + if txState, errType := txm.processError(msg.signatures[0], res.Err, true); errType != NoFailure { + id, err := txm.txs.OnError(msg.signatures[0], txm.cfg.TxRetentionTimeout(), txState, errType) + if err != nil { + txm.lggr.Errorw(fmt.Sprintf("failed to mark transaction as %s", txState.String()), "id", id, "err", err) + } else { + txm.lggr.Debugw(fmt.Sprintf("marking transaction as %s", txState.String()), "id", id, "signature", msg.signatures[0], "error", res.Err) + } } } } @@ -556,7 +570,10 @@ func (txm *Txm) reap() { case <-ctx.Done(): return case <-tick: - txm.txs.TrimFinalizedErroredTxs() + reapCount := txm.txs.TrimFinalizedErroredTxs() + if reapCount > 0 { + txm.lggr.Debugf("Reaped %d finalized or errored transactions", reapCount) + } } tick = time.After(utils.WithJitter(TxReapInterval)) } @@ -591,8 +608,16 @@ func (txm *Txm) Enqueue(ctx context.Context, accountID string, tx *solanaGo.Tran v(&cfg) } + // Use transaction ID provided by caller if set + id := uuid.New().String() + if txID != nil && *txID != "" { + id = *txID + } + + // Perform compute unit limit estimation after storing transaction + // If error found during simulation, transaction should be in storage to mark accordingly if cfg.EstimateComputeUnitLimit { - computeUnitLimit, err := txm.EstimateComputeUnitLimit(ctx, tx) + computeUnitLimit, err := txm.EstimateComputeUnitLimit(ctx, tx, id) if err != nil { return fmt.Errorf("transaction failed simulation: %w", err) } @@ -602,11 +627,6 @@ func (txm *Txm) Enqueue(ctx context.Context, accountID string, tx *solanaGo.Tran } } - // Use transaction ID provided by caller if set - id := uuid.New().String() - if txID != nil && *txID != "" { - id = *txID - } msg := pendingTx{ tx: *tx, cfg: cfg, @@ -638,6 +658,8 @@ func (txm *Txm) GetTransactionStatus(ctx context.Context, transactionID string) return commontypes.Finalized, nil case Errored: return commontypes.Failed, nil + case FatallyErrored: + return commontypes.Fatal, nil default: return commontypes.Unknown, fmt.Errorf("found unknown transaction state: %s", state.String()) } @@ -645,7 +667,7 @@ func (txm *Txm) GetTransactionStatus(ctx context.Context, transactionID string) // EstimateComputeUnitLimit estimates the compute unit limit needed for a transaction. // It simulates the provided transaction to determine the used compute and applies a buffer to it. -func (txm *Txm) EstimateComputeUnitLimit(ctx context.Context, tx *solanaGo.Transaction) (uint32, error) { +func (txm *Txm) EstimateComputeUnitLimit(ctx context.Context, tx *solanaGo.Transaction, id string) (uint32, error) { txCopy := *tx // Set max compute unit limit when simulating a transaction to avoid getting an error for exceeding the default 200k compute unit limit @@ -678,7 +700,14 @@ func (txm *Txm) EstimateComputeUnitLimit(ctx context.Context, tx *solanaGo.Trans if len(txCopy.Signatures) > 0 { sig = txCopy.Signatures[0] } - txm.processSimulationError("", sig, res) + // Process error to determine the corresponding state and type. + // Certain errors can be considered not to be failures during simulation to allow the process to continue + if txState, errType := txm.processError(sig, res.Err, true); errType != NoFailure { + err := txm.txs.OnPrebroadcastError(id, txm.cfg.TxRetentionTimeout(), txState, errType) + if err != nil { + return 0, fmt.Errorf("failed to process error %v for tx ID %s: %w", res.Err, id, err) + } + } return 0, fmt.Errorf("simulated tx returned error: %v", res.Err) } @@ -689,14 +718,12 @@ func (txm *Txm) EstimateComputeUnitLimit(ctx context.Context, tx *solanaGo.Trans } unitsConsumed := *res.UnitsConsumed - // Add buffer to the used compute estimate - unitsConsumed = bigmath.AddPercentage(new(big.Int).SetUint64(unitsConsumed), EstimateComputeUnitLimitBuffer).Uint64() + computeUnitLimit := bigmath.AddPercentage(new(big.Int).SetUint64(unitsConsumed), EstimateComputeUnitLimitBuffer).Uint64() + // Ensure computeUnitLimit does not exceed the max compute unit limit for a transaction after adding buffer + computeUnitLimit = mathutil.Min(computeUnitLimit, MaxComputeUnitLimit) - // Ensure unitsConsumed does not exceed the max compute unit limit for a transaction after adding buffer - unitsConsumed = mathutil.Min(unitsConsumed, MaxComputeUnitLimit) - - return uint32(unitsConsumed), nil //nolint // unitsConsumed can only be a maximum of 1.4M + return uint32(computeUnitLimit), nil //nolint // computeUnitLimit can only be a maximum of 1.4M } // simulateTx simulates transactions using the SimulateTx client method @@ -718,41 +745,58 @@ func (txm *Txm) simulateTx(ctx context.Context, tx *solanaGo.Transaction) (res * return } -// processSimulationError parses and handles relevant errors found in simulation results -func (txm *Txm) processSimulationError(id string, sig solanaGo.Signature, res *rpc.SimulateTransactionResult) { - if res.Err != nil { +// processError parses and handles relevant errors found in simulation results +func (txm *Txm) processError(sig solanaGo.Signature, resErr interface{}, simulation bool) (txState TxState, errType TxErrType) { + if resErr != nil { // handle various errors // https://github.com/solana-labs/solana/blob/master/sdk/src/transaction/error.rs - errStr := fmt.Sprintf("%v", res.Err) // convert to string to handle various interfaces + errStr := fmt.Sprintf("%v", resErr) // convert to string to handle various interfaces + txm.lggr.Info(errStr) logValues := []interface{}{ - "id", id, "signature", sig, - "result", res, + "error", resErr, + } + // return TxFailRevert on any error if when processing error during confirmation + errType := TxFailRevert + // return TxFailSimRevert on any known error when processing simulation error + if simulation { + errType = TxFailSimRevert } switch { // blockhash not found when simulating, occurs when network bank has not seen the given blockhash or tx is too old // let confirmation process clean up case strings.Contains(errStr, "BlockhashNotFound"): - txm.lggr.Debugw("simulate: BlockhashNotFound", logValues...) - // transaction will encounter execution error/revert, mark as reverted to remove from confirmation + retry - case strings.Contains(errStr, "InstructionError"): - _, err := txm.txs.OnError(sig, txm.cfg.TxRetentionTimeout(), TxFailSimRevert) // cancel retry - if err != nil { - logValues = append(logValues, "stateTransitionErr", err) + txm.lggr.Debugw("BlockhashNotFound", logValues...) + // return no failure for this error when simulating to allow later send/retry code to assign a proper blockhash + // in case the one provided by the caller is outdated + if simulation { + return txState, NoFailure } - txm.lggr.Debugw("simulate: InstructionError", logValues...) - // transaction is already processed in the chain, letting txm confirmation handle + return Errored, errType + // transaction will encounter execution error/revert + case strings.Contains(errStr, "InstructionError"): + txm.lggr.Debugw("InstructionError", logValues...) + return Errored, errType + // transaction is already processed in the chain case strings.Contains(errStr, "AlreadyProcessed"): - txm.lggr.Debugw("simulate: AlreadyProcessed", logValues...) + txm.lggr.Debugw("AlreadyProcessed", logValues...) + // return no failure for this error when simulating in case there is a race between broadcast and simulation + // when doing both in parallel + if simulation { + return txState, NoFailure + } + return Errored, errType // unrecognized errors (indicates more concerning failures) default: - _, err := txm.txs.OnError(sig, txm.cfg.TxRetentionTimeout(), TxFailSimOther) // cancel retry - if err != nil { - logValues = append(logValues, "stateTransitionErr", err) + // if simulating, return TxFailSimOther if error unknown + if simulation { + errType = TxFailSimOther } - txm.lggr.Errorw("simulate: unrecognized error", logValues...) + txm.lggr.Errorw("unrecognized error", logValues...) + return Errored, errType } } + return } func (txm *Txm) InflightTxs() int { diff --git a/pkg/solana/txm/txm_internal_test.go b/pkg/solana/txm/txm_internal_test.go index f19b26b9a..418bdbec1 100644 --- a/pkg/solana/txm/txm_internal_test.go +++ b/pkg/solana/txm/txm_internal_test.go @@ -28,6 +28,7 @@ import ( relayconfig "github.com/smartcontractkit/chainlink-common/pkg/config" "github.com/smartcontractkit/chainlink-common/pkg/logger" "github.com/smartcontractkit/chainlink-common/pkg/types" + commontypes "github.com/smartcontractkit/chainlink-common/pkg/types" "github.com/smartcontractkit/chainlink-common/pkg/utils" bigmath "github.com/smartcontractkit/chainlink-common/pkg/utils/big_math" "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" @@ -136,7 +137,7 @@ func TestTxm(t *testing.T) { loader := utils.NewLazyLoad(func() (client.ReaderWriter, error) { return mc, nil }) txm := NewTxm(id, loader, nil, cfg, mkey, lggr) require.NoError(t, txm.Start(ctx)) - t.Cleanup(func () { require.NoError(t, txm.Close())}) + t.Cleanup(func() { require.NoError(t, txm.Close()) }) // tracking prom metrics prom := soltxmProm{id: id} @@ -776,7 +777,7 @@ func TestTxm_disabled_confirm_timeout_with_retention(t *testing.T) { loader := utils.NewLazyLoad(func() (client.ReaderWriter, error) { return mc, nil }) txm := NewTxm(id, loader, nil, cfg, mkey, lggr) require.NoError(t, txm.Start(ctx)) - t.Cleanup(func () { require.NoError(t, txm.Close())}) + t.Cleanup(func() { require.NoError(t, txm.Close()) }) // tracking prom metrics prom := soltxmProm{id: id} @@ -797,66 +798,153 @@ func TestTxm_disabled_confirm_timeout_with_retention(t *testing.T) { }, nil, ) - // Test tx is not discarded due to confirm timeout and tracked to finalization - tx, signed := getTx(t, 7, mkey) - sig := randomSignature(t) - retry0 := randomSignature(t) - retry1 := randomSignature(t) - var wg sync.WaitGroup - wg.Add(2) - - mc.On("SendTx", mock.Anything, signed(0, true, computeUnitLimitDefault)).Return(sig, nil) - mc.On("SendTx", mock.Anything, signed(1, true, computeUnitLimitDefault)).Return(retry0, nil).Maybe() - mc.On("SendTx", mock.Anything, signed(2, true, computeUnitLimitDefault)).Return(retry1, nil).Maybe() - mc.On("SimulateTx", mock.Anything, signed(0, true, computeUnitLimitDefault), mock.Anything).Run(func(mock.Arguments) { - wg.Done() - }).Return(&rpc.SimulateTransactionResult{}, nil).Once() - - // handle signature status calls (initial stays processed, others don't exist) - start := time.Now() - statuses[sig] = func() (out *rpc.SignatureStatusesResult) { - out = &rpc.SignatureStatusesResult{} - // return confirmed status after default confirmation timeout - if time.Since(start) > 1*time.Second && time.Since(start) < 2*time.Second { - out.ConfirmationStatus = rpc.ConfirmationStatusConfirmed + t.Run("happyPath", func(t *testing.T) { + // Test tx is not discarded due to confirm timeout and tracked to finalization + // use unique val across tests to avoid collision during mocking + tx, signed := getTx(t, 1, mkey) + sig := randomSignature(t) + retry0 := randomSignature(t) + retry1 := randomSignature(t) + var wg sync.WaitGroup + wg.Add(2) + + mc.On("SendTx", mock.Anything, signed(0, true, computeUnitLimitDefault)).Return(sig, nil) + mc.On("SendTx", mock.Anything, signed(1, true, computeUnitLimitDefault)).Return(retry0, nil).Maybe() + mc.On("SendTx", mock.Anything, signed(2, true, computeUnitLimitDefault)).Return(retry1, nil).Maybe() + mc.On("SimulateTx", mock.Anything, signed(0, true, computeUnitLimitDefault), mock.Anything).Run(func(mock.Arguments) { + wg.Done() + }).Return(&rpc.SimulateTransactionResult{}, nil).Once() + + // handle signature status calls (initial stays processed, others don't exist) + start := time.Now() + statuses[sig] = func() (out *rpc.SignatureStatusesResult) { + out = &rpc.SignatureStatusesResult{} + // return confirmed status after default confirmation timeout + if time.Since(start) > 1*time.Second && time.Since(start) < 2*time.Second { + out.ConfirmationStatus = rpc.ConfirmationStatusConfirmed + return + } + // return finalized status only after the confirmation timeout + if time.Since(start) >= 2*time.Second { + out.ConfirmationStatus = rpc.ConfirmationStatusFinalized + wg.Done() + return + } + out.ConfirmationStatus = rpc.ConfirmationStatusProcessed return } - // return finalized status only after the confirmation timeout - if time.Since(start) >= 2*time.Second { - out.ConfirmationStatus = rpc.ConfirmationStatusFinalized + + // tx should be able to queue + testTxID := uuid.New().String() + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID)) + wg.Wait() // wait to be picked up and processed + waitFor(t, 5*time.Second, txm, prom, empty) // inflight txs cleared after timeout + + // panic if sendTx called after context cancelled + mc.On("SendTx", mock.Anything, tx).Panic("SendTx should not be called anymore").Maybe() + + // check prom metric + prom.confirmed++ + prom.finalized++ + prom.assertEqual(t) + + // check transaction status which should still be stored + status, err := txm.GetTransactionStatus(ctx, testTxID) + require.NoError(t, err) + require.Equal(t, types.Finalized, status) + + // Sleep until retention period has passed for transaction and for another reap cycle to run + time.Sleep(10 * time.Second) + + // check if transaction has been purged from memory + status, err = txm.GetTransactionStatus(ctx, testTxID) + require.Error(t, err) + require.Equal(t, types.Unknown, status) + }) + + t.Run("stores error if initial send fails", func(t *testing.T) { + // Test tx is not discarded due to confirm timeout and tracked to finalization + // use unique val across tests to avoid collision during mocking + tx, signed := getTx(t, 2, mkey) + var wg sync.WaitGroup + wg.Add(1) + + mc.On("SendTx", mock.Anything, signed(0, true, computeUnitLimitDefault)).Run(func(mock.Arguments) { wg.Done() - return + }).Return(nil, errors.New("failed to send")) + + // tx should be able to queue + testTxID := uuid.NewString() + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID)) + wg.Wait() + waitFor(t, 5*time.Second, txm, prom, empty) // inflight txs cleared after timeout + + // panic if sendTx called after context cancelled + mc.On("SendTx", mock.Anything, tx).Panic("SendTx should not be called anymore").Maybe() + + // check prom metric + prom.error++ + prom.reject++ + prom.assertEqual(t) + + // check transaction status which should still be stored + status, err := txm.GetTransactionStatus(ctx, testTxID) + require.NoError(t, err) + require.Equal(t, types.Failed, status) + + // Sleep until retention period has passed for transaction and for another reap cycle to run + time.Sleep(15 * time.Second) + + // check if transaction has been purged from memory + status, err = txm.GetTransactionStatus(ctx, testTxID) + require.Error(t, err) + require.Equal(t, types.Unknown, status) + }) + + t.Run("stores error if confirmation returns error", func(t *testing.T) { + // Test tx is not discarded due to confirm timeout and tracked to finalization + // use unique val across tests to avoid collision during mocking + tx, signed := getTx(t, 3, mkey) + sig := randomSignature(t) + var wg sync.WaitGroup + wg.Add(2) + + mc.On("SendTx", mock.Anything, signed(0, true, computeUnitLimitDefault)).Return(sig, nil) + mc.On("SimulateTx", mock.Anything, signed(0, true, computeUnitLimitDefault), mock.Anything).Run(func(mock.Arguments) { + wg.Done() + }).Return(&rpc.SimulateTransactionResult{}, nil).Once() + statuses[sig] = func() (out *rpc.SignatureStatusesResult) { + defer wg.Done() + return &rpc.SignatureStatusesResult{Err: errors.New("InstructionError")} } - out.ConfirmationStatus = rpc.ConfirmationStatusProcessed - return - } - // tx should be able to queue - testTxID := uuid.New().String() - assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID)) - wg.Wait() // wait to be picked up and processed - waitFor(t, 5*time.Second, txm, prom, empty) // inflight txs cleared after timeout + // tx should be able to queue + testTxID := uuid.NewString() + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID)) + wg.Wait() // wait till send tx + waitFor(t, 5*time.Second, txm, prom, empty) // inflight txs cleared after timeout - // panic if sendTx called after context cancelled - mc.On("SendTx", mock.Anything, tx).Panic("SendTx should not be called anymore").Maybe() + // panic if sendTx called after context cancelled + mc.On("SendTx", mock.Anything, tx).Panic("SendTx should not be called anymore").Maybe() - // check prom metric - prom.confirmed++ - prom.finalized++ - prom.assertEqual(t) + // check prom metric + prom.error++ + prom.revert++ + prom.assertEqual(t) - // check transaction status which should still be stored - status, err := txm.GetTransactionStatus(ctx, testTxID) - require.NoError(t, err) - require.Equal(t, types.Finalized, status) + // check transaction status which should still be stored + status, err := txm.GetTransactionStatus(ctx, testTxID) + require.NoError(t, err) + require.Equal(t, types.Failed, status) - // Sleep until retention period has passed for transaction and for another reap cycle to run - time.Sleep(10 *time.Second) + // Sleep until retention period has passed for transaction and for another reap cycle to run + time.Sleep(15 * time.Second) - // check if transaction has been purged from memory - status, err = txm.GetTransactionStatus(ctx, testTxID) - require.Error(t, err) - require.Equal(t, types.Unknown, status) + // check if transaction has been purged from memory + status, err = txm.GetTransactionStatus(ctx, testTxID) + require.Error(t, err) + require.Equal(t, types.Unknown, status) + }) } func TestTxm_compute_unit_limit_estimation(t *testing.T) { @@ -886,7 +974,7 @@ func TestTxm_compute_unit_limit_estimation(t *testing.T) { loader := utils.NewLazyLoad(func() (client.ReaderWriter, error) { return mc, nil }) txm := NewTxm(id, loader, nil, cfg, mkey, lggr) require.NoError(t, txm.Start(ctx)) - t.Cleanup(func () { require.NoError(t, txm.Close())}) + t.Cleanup(func() { require.NoError(t, txm.Close()) }) // tracking prom metrics prom := soltxmProm{id: id} @@ -909,6 +997,7 @@ func TestTxm_compute_unit_limit_estimation(t *testing.T) { t.Run("simulation_succeeds", func(t *testing.T) { // Test tx is not discarded due to confirm timeout and tracked to finalization + // use unique val across tests to avoid collision during mocking tx, signed := getTx(t, 1, mkey) // add signature and compute unit limit to tx for simulation (excludes compute unit price) simulateTx := addSigAndLimitToTx(t, mkey, solana.PublicKey{}, *tx, MaxComputeUnitLimit) @@ -972,7 +1061,8 @@ func TestTxm_compute_unit_limit_estimation(t *testing.T) { t.Run("simulation_fails", func(t *testing.T) { // Test tx is not discarded due to confirm timeout and tracked to finalization - tx, signed := getTx(t, 1, mkey) + // use unique val across tests to avoid collision during mocking + tx, signed := getTx(t, 2, mkey) sig := randomSignature(t) mc.On("SendTx", mock.Anything, signed(0, true, fees.ComputeUnitLimit(0))).Return(sig, nil).Panic("SendTx should never be called").Maybe() @@ -984,16 +1074,22 @@ func TestTxm_compute_unit_limit_estimation(t *testing.T) { t.Run("simulation_returns_error", func(t *testing.T) { // Test tx is not discarded due to confirm timeout and tracked to finalization - tx, _ := getTx(t, 1, mkey) + // use unique val across tests to avoid collision during mocking + tx, _ := getTx(t, 3, mkey) // add signature and compute unit limit to tx for simulation (excludes compute unit price) simulateTx := addSigAndLimitToTx(t, mkey, solana.PublicKey{}, *tx, MaxComputeUnitLimit) sig := randomSignature(t) mc.On("SendTx", mock.Anything, mock.Anything).Return(sig, nil).Panic("SendTx should never be called").Maybe() // First simulation before broadcast with max compute unit limit - mc.On("SimulateTx", mock.Anything, simulateTx, mock.Anything).Return(&rpc.SimulateTransactionResult{Err: errors.New("tx err")}, nil).Once() + mc.On("SimulateTx", mock.Anything, simulateTx, mock.Anything).Return(&rpc.SimulateTransactionResult{Err: errors.New("InstructionError")}, nil).Once() + txID := uuid.NewString() // tx should NOT be able to queue - assert.Error(t, txm.Enqueue(ctx, t.Name(), tx, nil)) + assert.Error(t, txm.Enqueue(ctx, t.Name(), tx, &txID)) + // tx should be stored in-memory and moved to errored state + status, err := txm.GetTransactionStatus(ctx, txID) + require.NoError(t, err) + require.Equal(t, commontypes.Failed, status) }) } diff --git a/pkg/solana/txm/txm_unit_test.go b/pkg/solana/txm/txm_unit_test.go index 0bac3e478..87803581f 100644 --- a/pkg/solana/txm/txm_unit_test.go +++ b/pkg/solana/txm/txm_unit_test.go @@ -61,8 +61,8 @@ func TestTxm_EstimateComputeUnitLimit(t *testing.T) { client.On("SimulateTx", mock.Anything, mock.IsType(&solana.Transaction{}), mock.IsType(&rpc.SimulateTransactionOpts{})).Run(func(args mock.Arguments) { // Validate max compute unit limit is set in transaction tx := args.Get(1).(*solana.Transaction) - limit, err := fees.ParseComputeUnitLimit(tx.Message.Instructions[len(tx.Message.Instructions)-1].Data) - require.NoError(t, err) + limit, parseErr := fees.ParseComputeUnitLimit(tx.Message.Instructions[len(tx.Message.Instructions)-1].Data) + require.NoError(t, parseErr) require.Equal(t, fees.ComputeUnitLimit(solanatxm.MaxComputeUnitLimit), limit) // Validate signature verification is enabled @@ -73,8 +73,8 @@ func TestTxm_EstimateComputeUnitLimit(t *testing.T) { UnitsConsumed: &usedCompute, }, nil).Once() tx := createTx(t, client, pubKey, pubKey, pubKeyReceiver, solana.LAMPORTS_PER_SOL) - computeUnitLimit, err := txm.EstimateComputeUnitLimit(ctx, tx) - require.NoError(t, err) + computeUnitLimit, estimateErr := txm.EstimateComputeUnitLimit(ctx, tx, "") + require.NoError(t, estimateErr) usedComputeWithBuffer := bigmath.AddPercentage(new(big.Int).SetUint64(usedCompute), solanatxm.EstimateComputeUnitLimitBuffer).Uint64() require.Equal(t, usedComputeWithBuffer, uint64(computeUnitLimit)) }) @@ -88,8 +88,8 @@ func TestTxm_EstimateComputeUnitLimit(t *testing.T) { }, nil).Once() client.On("SimulateTx", mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("failed to simulate")).Once() tx := createTx(t, client, pubKey, pubKey, pubKeyReceiver, solana.LAMPORTS_PER_SOL) - _, err := txm.EstimateComputeUnitLimit(ctx, tx) - require.Error(t, err) + _, estimateErr := txm.EstimateComputeUnitLimit(ctx, tx, "") + require.Error(t, estimateErr) }) t.Run("simulation returns error for tx", func(t *testing.T) { @@ -103,7 +103,7 @@ func TestTxm_EstimateComputeUnitLimit(t *testing.T) { Err: errors.New("InstructionError"), }, nil).Once() tx := createTx(t, client, pubKey, pubKey, pubKeyReceiver, solana.LAMPORTS_PER_SOL) - _, err := txm.EstimateComputeUnitLimit(ctx, tx) + _, err = txm.EstimateComputeUnitLimit(ctx, tx, "") require.Error(t, err) }) @@ -118,7 +118,7 @@ func TestTxm_EstimateComputeUnitLimit(t *testing.T) { Err: nil, }, nil).Once() tx := createTx(t, client, pubKey, pubKey, pubKeyReceiver, solana.LAMPORTS_PER_SOL) - computeUnitLimit, err := txm.EstimateComputeUnitLimit(ctx, tx) + computeUnitLimit, err := txm.EstimateComputeUnitLimit(ctx, tx, "") require.NoError(t, err) require.Equal(t, uint32(0), computeUnitLimit) }) @@ -146,7 +146,7 @@ func TestTxm_EstimateComputeUnitLimit(t *testing.T) { UnitsConsumed: &usedCompute, }, nil).Once() tx := createTx(t, client, pubKey, pubKey, pubKeyReceiver, solana.LAMPORTS_PER_SOL) - computeUnitLimit, err := txm.EstimateComputeUnitLimit(ctx, tx) + computeUnitLimit, err := txm.EstimateComputeUnitLimit(ctx, tx, "") require.NoError(t, err) require.Equal(t, uint32(1_400_000), computeUnitLimit) }) diff --git a/pkg/solana/txm/utils.go b/pkg/solana/txm/utils.go index 6b2253818..fef260e3d 100644 --- a/pkg/solana/txm/utils.go +++ b/pkg/solana/txm/utils.go @@ -19,6 +19,7 @@ type TxState int // < tx processed // < tx confirmed // < tx finalized +// < tx fatallyErrored const ( NotFound TxState = iota Errored @@ -26,6 +27,7 @@ const ( Processed Confirmed Finalized + FatallyErrored ) func (s TxState) String() string { @@ -42,6 +44,8 @@ func (s TxState) String() string { return "Confirmed" case Finalized: return "Finalized" + case FatallyErrored: + return "FatallyErrored" default: return fmt.Sprintf("TxState(%d)", s) }