Skip to content

Commit

Permalink
Updated chain writer tests to use txm mock
Browse files Browse the repository at this point in the history
  • Loading branch information
amit-momin committed Dec 3, 2024
1 parent a84504a commit 9b33790
Showing 1 changed file with 32 additions and 259 deletions.
291 changes: 32 additions & 259 deletions pkg/solana/chainwriter/chain_writer_test.go
Original file line number Diff line number Diff line change
@@ -1,38 +1,27 @@
package chainwriter_test

import (
"context"
"crypto/rand"
"errors"
"fmt"
"math/big"
"os"
"reflect"
"sync"
"testing"
"time"

"github.com/gagliardetto/solana-go"
"github.com/gagliardetto/solana-go/programs/system"
"github.com/gagliardetto/solana-go/rpc"
"github.com/google/uuid"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"

commoncodec "github.com/smartcontractkit/chainlink-common/pkg/codec"
relayconfig "github.com/smartcontractkit/chainlink-common/pkg/config"
"github.com/smartcontractkit/chainlink-common/pkg/logger"
"github.com/smartcontractkit/chainlink-common/pkg/types"
"github.com/smartcontractkit/chainlink-common/pkg/utils"
"github.com/smartcontractkit/chainlink-common/pkg/utils/tests"

"github.com/smartcontractkit/chainlink-solana/pkg/solana/chainwriter"
"github.com/smartcontractkit/chainlink-solana/pkg/solana/client"
clientmocks "github.com/smartcontractkit/chainlink-solana/pkg/solana/client/mocks"
"github.com/smartcontractkit/chainlink-solana/pkg/solana/config"
"github.com/smartcontractkit/chainlink-solana/pkg/solana/fees"
feemocks "github.com/smartcontractkit/chainlink-solana/pkg/solana/fees/mocks"
"github.com/smartcontractkit/chainlink-solana/pkg/solana/txm"
keyMocks "github.com/smartcontractkit/chainlink-solana/pkg/solana/txm/mocks"
txmMocks "github.com/smartcontractkit/chainlink-solana/pkg/solana/txm/mocks"
)

func TestChainWriter_GetAddresses(t *testing.T) {}
Expand All @@ -43,25 +32,13 @@ func TestChainWriter_SubmitTransaction(t *testing.T) {
t.Parallel()

ctx := tests.Context(t)
lggr := logger.Test(t)
cfg := config.NewDefault()
// Retain transactions after finality or error to maintain their status in memory
cfg.Chain.TxRetentionTimeout = relayconfig.MustNewDuration(5 * time.Second)
// Disable bumping to avoid issues with send tx mocking
cfg.Chain.FeeBumpPeriod = relayconfig.MustNewDuration(0 * time.Second)
rw := clientmocks.NewReaderWriter(t)
rw.On("GetLatestBlock", mock.Anything).Return(&rpc.GetBlockResult{}, nil).Maybe()
rw.On("SlotHeight", mock.Anything).Return(uint64(0), nil).Maybe()
loader := utils.NewLazyLoad(func() (client.ReaderWriter, error) { return rw, nil })
ge := feemocks.NewEstimator(t)
// mock solana keystore
keystore := keyMocks.NewSimpleKeystore(t)
keystore.On("Sign", mock.Anything, mock.Anything, mock.Anything).Return([]byte{}, nil).Maybe()

// initialize and start TXM
txm := txm.NewTxm(uuid.NewString(), loader, nil, cfg, keystore, lggr)
require.NoError(t, txm.Start(ctx))
t.Cleanup(func() { require.NoError(t, txm.Close()) })
// mock txm
txm := txmMocks.NewTxManager(t)

idlJSON, err := os.ReadFile("../../../contracts/target/idl/write_test.json")
require.NoError(t, err)
Expand Down Expand Up @@ -193,218 +170,79 @@ func TestChainWriter_GetTransactionStatus(t *testing.T) {
t.Parallel()

ctx := tests.Context(t)
lggr := logger.Test(t)
cfg := config.NewDefault()
// Retain transactions after finality or error to maintain their status in memory
cfg.Chain.TxRetentionTimeout = relayconfig.MustNewDuration(5 * time.Second)
// Disable bumping to avoid issues with send tx mocking
cfg.Chain.FeeBumpPeriod = relayconfig.MustNewDuration(0 * time.Second)
rw := clientmocks.NewReaderWriter(t)
rw.On("GetLatestBlock", mock.Anything).Return(&rpc.GetBlockResult{}, nil).Maybe()
rw.On("SlotHeight", mock.Anything).Return(uint64(0), nil).Maybe()
loader := utils.NewLazyLoad(func() (client.ReaderWriter, error) { return rw, nil })
ge := feemocks.NewEstimator(t)
// mock solana keystore
keystore := keyMocks.NewSimpleKeystore(t)
keystore.On("Sign", mock.Anything, mock.Anything, mock.Anything).Return([]byte{}, nil).Maybe()

// initialize and start TXM
txm := txm.NewTxm(uuid.NewString(), loader, nil, cfg, keystore, lggr)
require.NoError(t, txm.Start(ctx))
t.Cleanup(func() { require.NoError(t, txm.Close()) })
// mock txm
txm := txmMocks.NewTxManager(t)

// initialize chain writer
cw, err := chainwriter.NewSolanaChainWriterService(rw, txm, ge, chainwriter.ChainWriterConfig{})
require.NoError(t, err)

computeUnitLimitDefault := fees.ComputeUnitLimit(cfg.ComputeUnitLimitDefault())

// mock signature statuses calls
statuses := map[solana.Signature]func() *rpc.SignatureStatusesResult{}
rw.On("SignatureStatuses", mock.Anything, mock.AnythingOfType("[]solana.Signature")).Return(
func(_ context.Context, sigs []solana.Signature) (out []*rpc.SignatureStatusesResult) {
for i := range sigs {
get, exists := statuses[sigs[i]]
if !exists {
out = append(out, nil)
continue
}
out = append(out, get())
}
return out
}, nil,
)

t.Run("returns unknown with error if ID not found", func(t *testing.T) {
status, err := cw.GetTransactionStatus(ctx, uuid.NewString())
txID := uuid.NewString()
txm.On("GetTransactionStatus", mock.Anything, txID).Return(types.Unknown, errors.New("tx not found")).Once()
status, err := cw.GetTransactionStatus(ctx, txID)
require.Error(t, err)
require.Equal(t, types.Unknown, status)
})

t.Run("returns pending when transaction is broadcasted", func(t *testing.T) {
tx, signed := getTx(t, 1, keystore)
signedTx := signed(0, true, computeUnitLimitDefault)
for _, ins := range signedTx.Message.Instructions {
if cuprice, err := fees.ParseComputeUnitPrice(ins.Data); err == nil {
t.Log("compute unit price", cuprice)
}
}
sig := randomSignature(t)
rw.On("SendTx", mock.Anything, signed(0, true, computeUnitLimitDefault)).Return(sig, nil)
rw.On("SimulateTx", mock.Anything, signed(0, true, computeUnitLimitDefault), mock.Anything).Return(&rpc.SimulateTransactionResult{}, nil).Maybe()

// mock transaction in broadcasted state
var wg sync.WaitGroup
wg.Add(1)
count := 0
statuses[sig] = func() (out *rpc.SignatureStatusesResult) {
defer func() { count++ }()
if count == 0 {
wg.Done()
}
return nil
}

t.Run("returns pending when transaction is pending", func(t *testing.T) {
txID := uuid.NewString()
err = txm.Enqueue(ctx, uuid.NewString(), tx, &txID)
require.NoError(t, err)
// wait till transaction is broadcasted
wg.Wait()
// wait for next confirm cycle to ensure transaction had enough time to update in storage
time.Sleep(cfg.ConfirmPollPeriod())

txm.On("GetTransactionStatus", mock.Anything, txID).Return(types.Pending, nil).Once()
status, err := cw.GetTransactionStatus(ctx, txID)
require.NoError(t, err)
require.Equal(t, types.Pending, status)
})

t.Run("returns unconfirmed when transaction is processed", func(t *testing.T) {
tx, signed := getTx(t, 2, keystore)
sig := randomSignature(t)
rw.On("SendTx", mock.Anything, signed(0, true, computeUnitLimitDefault)).Return(sig, nil)
rw.On("SimulateTx", mock.Anything, signed(0, true, computeUnitLimitDefault), mock.Anything).Return(&rpc.SimulateTransactionResult{}, nil).Maybe()

// mock transaction in processed state
var wg sync.WaitGroup
wg.Add(1)
count := 0
statuses[sig] = func() (out *rpc.SignatureStatusesResult) {
defer func() { count++ }()
if count == 0 {
wg.Done()
}
return &rpc.SignatureStatusesResult{ConfirmationStatus: rpc.ConfirmationStatusProcessed}
}

t.Run("returns unconfirmed when transaction is unconfirmed", func(t *testing.T) {
txID := uuid.NewString()
err = txm.Enqueue(ctx, uuid.NewString(), tx, &txID)
require.NoError(t, err)
// wait till transaction is processed
wg.Wait()
// wait for next confirm cycle to ensure transaction had enough time to update in storage
time.Sleep(cfg.ConfirmPollPeriod())

txm.On("GetTransactionStatus", mock.Anything, txID).Return(types.Unconfirmed, nil).Once()
status, err := cw.GetTransactionStatus(ctx, txID)
require.NoError(t, err)
require.Equal(t, types.Unconfirmed, status)
})

t.Run("returns unconfirmed when transaction is confirmed", func(t *testing.T) {
tx, signed := getTx(t, 3, keystore)
sig := randomSignature(t)
rw.On("SendTx", mock.Anything, signed(0, true, computeUnitLimitDefault)).Return(sig, nil)
rw.On("SimulateTx", mock.Anything, signed(0, true, computeUnitLimitDefault), mock.Anything).Return(&rpc.SimulateTransactionResult{}, nil).Maybe()

// mock transaction in processed state
var wg sync.WaitGroup
wg.Add(1)
count := 0
statuses[sig] = func() (out *rpc.SignatureStatusesResult) {
defer func() { count++ }()
if count == 0 {
wg.Done()
}
return &rpc.SignatureStatusesResult{ConfirmationStatus: rpc.ConfirmationStatusConfirmed}
}

t.Run("returns finalized when transaction is finalized", func(t *testing.T) {
txID := uuid.NewString()
err = txm.Enqueue(ctx, uuid.NewString(), tx, &txID)
require.NoError(t, err)
// wait till transaction is confirmed
wg.Wait()
// wait for next confirm cycle to ensure transaction had enough time to update in storage
time.Sleep(cfg.ConfirmPollPeriod())

txm.On("GetTransactionStatus", mock.Anything, txID).Return(types.Finalized, nil).Once()
status, err := cw.GetTransactionStatus(ctx, txID)
require.NoError(t, err)
require.Equal(t, types.Unconfirmed, status)
require.Equal(t, types.Finalized, status)
})

t.Run("returns finalized when transaction is finalized", func(t *testing.T) {
tx, signed := getTx(t, 4, keystore)
sig := randomSignature(t)
rw.On("SendTx", mock.Anything, signed(0, true, computeUnitLimitDefault)).Return(sig, nil)
rw.On("SimulateTx", mock.Anything, signed(0, true, computeUnitLimitDefault), mock.Anything).Return(&rpc.SimulateTransactionResult{}, nil).Maybe()

// mock transaction in processed state
var wg sync.WaitGroup
wg.Add(1)
statuses[sig] = func() (out *rpc.SignatureStatusesResult) {
defer wg.Done()
return &rpc.SignatureStatusesResult{ConfirmationStatus: rpc.ConfirmationStatusFinalized}
}

t.Run("returns failed when transaction error classfied as failed", func(t *testing.T) {
txID := uuid.NewString()
err = txm.Enqueue(ctx, uuid.NewString(), tx, &txID)
require.NoError(t, err)
// wait till transaction is finalized
wg.Wait()
// wait for next confirm cycle to ensure transaction had enough time to update in storage
time.Sleep(cfg.ConfirmPollPeriod())

txm.On("GetTransactionStatus", mock.Anything, txID).Return(types.Failed, nil).Once()
status, err := cw.GetTransactionStatus(ctx, txID)
require.NoError(t, err)
require.Equal(t, types.Finalized, status)
require.Equal(t, types.Failed, status)
})

t.Run("returns failed when error encountered", func(t *testing.T) {
tx, signed := getTx(t, 5, keystore)
sig := randomSignature(t)
var wg sync.WaitGroup
wg.Add(1)
rw.On("SendTx", mock.Anything, signed(0, true, computeUnitLimitDefault)).Return(sig, nil)
rw.On("SimulateTx", mock.Anything, signed(0, true, computeUnitLimitDefault), mock.Anything).Run(func(mock.Arguments) {
wg.Done()
}).Return(&rpc.SimulateTransactionResult{
Err: "FAIL",
}, nil).Maybe()

// mock transaction in processed state
statuses[sig] = func() (out *rpc.SignatureStatusesResult) {
return nil
}

t.Run("returns fatal when transaction error classfied as fatal", func(t *testing.T) {
txID := uuid.NewString()
err = txm.Enqueue(ctx, uuid.NewString(), tx, &txID)
require.NoError(t, err)
// wait till transaction is finalized
wg.Wait()

txm.On("GetTransactionStatus", mock.Anything, txID).Return(types.Fatal, nil).Once()
status, err := cw.GetTransactionStatus(ctx, txID)
require.NoError(t, err)
require.Equal(t, types.Failed, status)
require.Equal(t, types.Fatal, status)
})
}

func TestChainWriter_GetFeeComponents(t *testing.T) {
t.Parallel()

ctx := tests.Context(t)
cfg := config.NewDefault()
rw := clientmocks.NewReaderWriter(t)
ge := feemocks.NewEstimator(t)
ge.On("BaseComputeUnitPrice").Return(uint64(100))
cw := setupChainWriter(t, cfg, rw, ge)

// mock txm
txm := txmMocks.NewTxManager(t)

cw, err := chainwriter.NewSolanaChainWriterService(rw, txm, ge, chainwriter.ChainWriterConfig{})
require.NoError(t, err)

t.Run("returns valid compute unit price", func(t *testing.T) {
feeComponents, err := cw.GetFeeComponents(ctx)
require.NoError(t, err)
Expand All @@ -413,74 +251,9 @@ func TestChainWriter_GetFeeComponents(t *testing.T) {
})

t.Run("fails if gas estimator not set", func(t *testing.T) {
cwNoEstimator := setupChainWriter(t, cfg, rw, nil)
_, err := cwNoEstimator.GetFeeComponents(ctx)
cwNoEstimator, err := chainwriter.NewSolanaChainWriterService(rw, txm, nil, chainwriter.ChainWriterConfig{})
require.NoError(t, err)
_, err = cwNoEstimator.GetFeeComponents(ctx)
require.Error(t, err)
})
}

func setupChainWriter(t *testing.T, cfg *config.TOMLConfig, rw client.ReaderWriter, ge fees.Estimator) *chainwriter.SolanaChainWriterService {
ctx := tests.Context(t)
lggr := logger.Test(t)
loader := utils.NewLazyLoad(func() (client.ReaderWriter, error) { return rw, nil })
// mock solana keystore
keystore := keyMocks.NewSimpleKeystore(t)
keystore.On("Sign", mock.Anything, mock.Anything, mock.Anything).Return([]byte{}, nil).Maybe()
// initialize and start TXM
txm := txm.NewTxm(uuid.NewString(), loader, nil, cfg, keystore, lggr)
require.NoError(t, txm.Start(ctx))
t.Cleanup(func() { require.NoError(t, txm.Close()) })

cw, err := chainwriter.NewSolanaChainWriterService(rw, txm, ge, chainwriter.ChainWriterConfig{})
require.NoError(t, err)
return cw
}

func randomSignature(t *testing.T) solana.Signature {
// make random signature
sig := make([]byte, 64)
_, err := rand.Read(sig)
require.NoError(t, err)

return solana.SignatureFromBytes(sig)
}

// create placeholder transaction and returns func for signed tx with fee
func getTx(t *testing.T, val uint64, keystore txm.SimpleKeystore) (*solana.Transaction, func(fees.ComputeUnitPrice, bool, fees.ComputeUnitLimit) *solana.Transaction) {
pubkey := solana.PublicKey{}

// create transfer tx
tx, err := solana.NewTransaction(
[]solana.Instruction{
system.NewTransferInstruction(
val,
pubkey,
pubkey,
).Build(),
},
solana.Hash{},
solana.TransactionPayer(pubkey),
)
require.NoError(t, err)

base := *tx // tx to send to txm, txm will add fee & sign

return &base, func(price fees.ComputeUnitPrice, addLimit bool, limit fees.ComputeUnitLimit) *solana.Transaction {
tx := base
// add fee parameters
require.NoError(t, fees.SetComputeUnitPrice(&tx, price))
if addLimit {
require.NoError(t, fees.SetComputeUnitLimit(&tx, limit)) // default
}

// sign tx
txMsg, err := tx.Message.MarshalBinary()
require.NoError(t, err)
sigBytes, err := keystore.Sign(tests.Context(t), pubkey.String(), txMsg)
require.NoError(t, err)
var finalSig [64]byte
copy(finalSig[:], sigBytes)
tx.Signatures = append(tx.Signatures, finalSig)
return &tx
}
}

0 comments on commit 9b33790

Please sign in to comment.