From 76c85a9352a7ac2d295509b487d5975e3f0c1288 Mon Sep 17 00:00:00 2001 From: amit-momin Date: Tue, 26 Nov 2024 18:32:27 -0600 Subject: [PATCH] Added ChainWriter unit tests for GetFeeComponents and GetTransactionStatus --- pkg/solana/chainwriter/chain_writer.go | 6 +- pkg/solana/chainwriter/chain_writer_test.go | 334 ++++++++++++++++++++ pkg/solana/chainwriter/lookups_test.go | 3 +- pkg/solana/txm/txm.go | 10 +- 4 files changed, 347 insertions(+), 6 deletions(-) create mode 100644 pkg/solana/chainwriter/chain_writer_test.go diff --git a/pkg/solana/chainwriter/chain_writer.go b/pkg/solana/chainwriter/chain_writer.go index 608f3c610..936a0f6d0 100644 --- a/pkg/solana/chainwriter/chain_writer.go +++ b/pkg/solana/chainwriter/chain_writer.go @@ -21,7 +21,7 @@ import ( type SolanaChainWriterService struct { reader client.Reader - txm txm.Txm + txm txm.TxManager ge fees.Estimator config ChainWriterConfig codecs map[string]types.Codec @@ -46,7 +46,7 @@ type MethodConfig struct { DebugIDLocation string } -func NewSolanaChainWriterService(reader client.Reader, txm txm.Txm, ge fees.Estimator, config ChainWriterConfig) (*SolanaChainWriterService, error) { +func NewSolanaChainWriterService(reader client.Reader, txm txm.TxManager, ge fees.Estimator, config ChainWriterConfig) (*SolanaChainWriterService, error) { codecs, err := parseIDLCodecs(config) if err != nil { return nil, fmt.Errorf("failed to parse IDL codecs: %w", err) @@ -275,7 +275,7 @@ var ( // GetTransactionStatus returns the current status of a transaction in the underlying chain's TXM. func (s *SolanaChainWriterService) GetTransactionStatus(ctx context.Context, transactionID string) (types.TransactionStatus, error) { - return types.Unknown, nil + return s.txm.GetTransactionStatus(ctx, transactionID) } // GetFeeComponents retrieves the associated gas costs for executing a transaction. diff --git a/pkg/solana/chainwriter/chain_writer_test.go b/pkg/solana/chainwriter/chain_writer_test.go new file mode 100644 index 000000000..cfc82c2cf --- /dev/null +++ b/pkg/solana/chainwriter/chain_writer_test.go @@ -0,0 +1,334 @@ +package chainwriter_test + +import ( + "context" + "crypto/rand" + "math/big" + "sync" + "testing" + "time" + + "github.com/gagliardetto/solana-go" + "github.com/gagliardetto/solana-go/programs/system" + "github.com/gagliardetto/solana-go/rpc" + "github.com/google/uuid" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + relayconfig "github.com/smartcontractkit/chainlink-common/pkg/config" + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/types" + "github.com/smartcontractkit/chainlink-common/pkg/utils" + "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" + + "github.com/smartcontractkit/chainlink-solana/pkg/solana/chainwriter" + "github.com/smartcontractkit/chainlink-solana/pkg/solana/client" + clientmocks "github.com/smartcontractkit/chainlink-solana/pkg/solana/client/mocks" + "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" + "github.com/smartcontractkit/chainlink-solana/pkg/solana/fees" + feemocks "github.com/smartcontractkit/chainlink-solana/pkg/solana/fees/mocks" + "github.com/smartcontractkit/chainlink-solana/pkg/solana/txm" + keyMocks "github.com/smartcontractkit/chainlink-solana/pkg/solana/txm/mocks" +) + +func TestChainWriter_GetAddresses(t *testing.T) {} + +func TestChainWriter_FilterLookupTableAddresses(t *testing.T) {} + +func TestChainWriter_SubmitTransaction(t *testing.T) {} + +func TestChainWriter_GetTransactionStatus(t *testing.T) { + t.Parallel() + + ctx := tests.Context(t) + lggr := logger.Test(t) + cfg := config.NewDefault() + // Retain transactions after finality or error to maintain their status in memory + cfg.Chain.TxRetentionTimeout = relayconfig.MustNewDuration(5 * time.Second) + // Disable bumping to avoid issues with send tx mocking + cfg.Chain.FeeBumpPeriod = relayconfig.MustNewDuration(0 * time.Second) + rw := clientmocks.NewReaderWriter(t) + rw.On("GetLatestBlock", mock.Anything).Return(&rpc.GetBlockResult{}, nil).Maybe() + rw.On("SlotHeight", mock.Anything).Return(uint64(0), nil).Maybe() + loader := utils.NewLazyLoad(func() (client.ReaderWriter, error) { return rw, nil }) + ge := feemocks.NewEstimator(t) + // mock solana keystore + keystore := keyMocks.NewSimpleKeystore(t) + keystore.On("Sign", mock.Anything, mock.Anything, mock.Anything).Return([]byte{}, nil).Maybe() + + // initialize and start TXM + txm := txm.NewTxm(uuid.NewString(), loader, nil, cfg, keystore, lggr) + require.NoError(t, txm.Start(ctx)) + t.Cleanup(func() { require.NoError(t, txm.Close()) }) + + // initialize chain writer + cw, err := chainwriter.NewSolanaChainWriterService(rw, txm, ge, chainwriter.ChainWriterConfig{}) + require.NoError(t, err) + + computeUnitLimitDefault := fees.ComputeUnitLimit(cfg.ComputeUnitLimitDefault()) + + // mock signature statuses calls + statuses := map[solana.Signature]func() *rpc.SignatureStatusesResult{} + rw.On("SignatureStatuses", mock.Anything, mock.AnythingOfType("[]solana.Signature")).Return( + func(_ context.Context, sigs []solana.Signature) (out []*rpc.SignatureStatusesResult) { + for i := range sigs { + get, exists := statuses[sigs[i]] + if !exists { + out = append(out, nil) + continue + } + out = append(out, get()) + } + return out + }, nil, + ) + + t.Run("returns unknown with error if ID not found", func(t *testing.T) { + status, err := cw.GetTransactionStatus(ctx, uuid.NewString()) + require.Error(t, err) + require.Equal(t, types.Unknown, status) + }) + + t.Run("returns pending when transaction is broadcasted", func(t *testing.T) { + tx, signed := getTx(t, 1, keystore) + signedTx := signed(0, true, computeUnitLimitDefault) + for _, ins := range signedTx.Message.Instructions { + if cuprice, err := fees.ParseComputeUnitPrice(ins.Data); err == nil { + t.Log("compute unit price", cuprice) + } + } + sig := randomSignature(t) + rw.On("SendTx", mock.Anything, signed(0, true, computeUnitLimitDefault)).Return(sig, nil) + rw.On("SimulateTx", mock.Anything, signed(0, true, computeUnitLimitDefault), mock.Anything).Return(&rpc.SimulateTransactionResult{}, nil).Maybe() + + // mock transaction in broadcasted state + var wg sync.WaitGroup + wg.Add(1) + count := 0 + statuses[sig] = func() (out *rpc.SignatureStatusesResult) { + defer func() { count++ }() + if count == 0 { + wg.Done() + } + return nil + } + + txID := uuid.NewString() + err = txm.Enqueue(ctx, uuid.NewString(), tx, &txID) + require.NoError(t, err) + // wait till transaction is broadcasted + wg.Wait() + // wait for next confirm cycle to ensure transaction had enough time to update in storage + time.Sleep(cfg.ConfirmPollPeriod()) + + status, err := cw.GetTransactionStatus(ctx, txID) + require.NoError(t, err) + require.Equal(t, types.Pending, status) + }) + + t.Run("returns unconfirmed when transaction is processed", func(t *testing.T) { + tx, signed := getTx(t, 2, keystore) + sig := randomSignature(t) + rw.On("SendTx", mock.Anything, signed(0, true, computeUnitLimitDefault)).Return(sig, nil) + rw.On("SimulateTx", mock.Anything, signed(0, true, computeUnitLimitDefault), mock.Anything).Return(&rpc.SimulateTransactionResult{}, nil).Maybe() + + // mock transaction in processed state + var wg sync.WaitGroup + wg.Add(1) + count := 0 + statuses[sig] = func() (out *rpc.SignatureStatusesResult) { + defer func() { count++ }() + if count == 0 { + wg.Done() + } + return &rpc.SignatureStatusesResult{ConfirmationStatus: rpc.ConfirmationStatusProcessed} + } + + txID := uuid.NewString() + err = txm.Enqueue(ctx, uuid.NewString(), tx, &txID) + require.NoError(t, err) + // wait till transaction is processed + wg.Wait() + // wait for next confirm cycle to ensure transaction had enough time to update in storage + time.Sleep(cfg.ConfirmPollPeriod()) + + status, err := cw.GetTransactionStatus(ctx, txID) + require.NoError(t, err) + require.Equal(t, types.Unconfirmed, status) + }) + + t.Run("returns unconfirmed when transaction is confirmed", func(t *testing.T) { + tx, signed := getTx(t, 3, keystore) + sig := randomSignature(t) + rw.On("SendTx", mock.Anything, signed(0, true, computeUnitLimitDefault)).Return(sig, nil) + rw.On("SimulateTx", mock.Anything, signed(0, true, computeUnitLimitDefault), mock.Anything).Return(&rpc.SimulateTransactionResult{}, nil).Maybe() + + // mock transaction in processed state + var wg sync.WaitGroup + wg.Add(1) + count := 0 + statuses[sig] = func() (out *rpc.SignatureStatusesResult) { + defer func() { count++ }() + if count == 0 { + wg.Done() + } + return &rpc.SignatureStatusesResult{ConfirmationStatus: rpc.ConfirmationStatusConfirmed} + } + + txID := uuid.NewString() + err = txm.Enqueue(ctx, uuid.NewString(), tx, &txID) + require.NoError(t, err) + // wait till transaction is confirmed + wg.Wait() + // wait for next confirm cycle to ensure transaction had enough time to update in storage + time.Sleep(cfg.ConfirmPollPeriod()) + + status, err := cw.GetTransactionStatus(ctx, txID) + require.NoError(t, err) + require.Equal(t, types.Unconfirmed, status) + }) + + t.Run("returns finalized when transaction is finalized", func(t *testing.T) { + tx, signed := getTx(t, 4, keystore) + sig := randomSignature(t) + rw.On("SendTx", mock.Anything, signed(0, true, computeUnitLimitDefault)).Return(sig, nil) + rw.On("SimulateTx", mock.Anything, signed(0, true, computeUnitLimitDefault), mock.Anything).Return(&rpc.SimulateTransactionResult{}, nil).Maybe() + + // mock transaction in processed state + var wg sync.WaitGroup + wg.Add(1) + statuses[sig] = func() (out *rpc.SignatureStatusesResult) { + defer wg.Done() + return &rpc.SignatureStatusesResult{ConfirmationStatus: rpc.ConfirmationStatusFinalized} + } + + txID := uuid.NewString() + err = txm.Enqueue(ctx, uuid.NewString(), tx, &txID) + require.NoError(t, err) + // wait till transaction is finalized + wg.Wait() + // wait for next confirm cycle to ensure transaction had enough time to update in storage + time.Sleep(cfg.ConfirmPollPeriod()) + + status, err := cw.GetTransactionStatus(ctx, txID) + require.NoError(t, err) + require.Equal(t, types.Finalized, status) + }) + + t.Run("returns failed when error encountered", func(t *testing.T) { + tx, signed := getTx(t, 5, keystore) + sig := randomSignature(t) + var wg sync.WaitGroup + wg.Add(1) + rw.On("SendTx", mock.Anything, signed(0, true, computeUnitLimitDefault)).Return(sig, nil) + rw.On("SimulateTx", mock.Anything, signed(0, true, computeUnitLimitDefault), mock.Anything).Run(func(mock.Arguments) { + wg.Done() + }).Return(&rpc.SimulateTransactionResult{ + Err: "FAIL", + }, nil).Maybe() + + // mock transaction in processed state + statuses[sig] = func() (out *rpc.SignatureStatusesResult) { + return nil + } + + txID := uuid.NewString() + err = txm.Enqueue(ctx, uuid.NewString(), tx, &txID) + require.NoError(t, err) + // wait till transaction is finalized + wg.Wait() + + status, err := cw.GetTransactionStatus(ctx, txID) + require.NoError(t, err) + require.Equal(t, types.Failed, status) + }) +} + +func TestChainWriter_GetFeeComponents(t *testing.T) { + t.Parallel() + + ctx := tests.Context(t) + cfg := config.NewDefault() + rw := clientmocks.NewReaderWriter(t) + ge := feemocks.NewEstimator(t) + ge.On("BaseComputeUnitPrice").Return(uint64(100)) + cw := setupChainWriter(t, cfg, rw, ge) + t.Run("returns valid compute unit price", func(t *testing.T) { + feeComponents, err := cw.GetFeeComponents(ctx) + require.NoError(t, err) + require.Equal(t, big.NewInt(100), feeComponents.ExecutionFee) + require.Nil(t, feeComponents.DataAvailabilityFee) // always nil for Solana + }) + + t.Run("fails if gas estimator not set", func(t *testing.T) { + cwNoEstimator := setupChainWriter(t, cfg, rw, nil) + _, err := cwNoEstimator.GetFeeComponents(ctx) + require.Error(t, err) + }) +} + +func setupChainWriter(t *testing.T, cfg *config.TOMLConfig, rw client.ReaderWriter, ge fees.Estimator) *chainwriter.SolanaChainWriterService { + ctx := tests.Context(t) + lggr := logger.Test(t) + loader := utils.NewLazyLoad(func() (client.ReaderWriter, error) { return rw, nil }) + // mock solana keystore + keystore := keyMocks.NewSimpleKeystore(t) + keystore.On("Sign", mock.Anything, mock.Anything, mock.Anything).Return([]byte{}, nil).Maybe() + // initialize and start TXM + txm := txm.NewTxm(uuid.NewString(), loader, nil, cfg, keystore, lggr) + require.NoError(t, txm.Start(ctx)) + t.Cleanup(func() { require.NoError(t, txm.Close()) }) + + cw, err := chainwriter.NewSolanaChainWriterService(rw, txm, ge, chainwriter.ChainWriterConfig{}) + require.NoError(t, err) + return cw +} + +func randomSignature(t *testing.T) solana.Signature { + // make random signature + sig := make([]byte, 64) + _, err := rand.Read(sig) + require.NoError(t, err) + + return solana.SignatureFromBytes(sig) +} + +// create placeholder transaction and returns func for signed tx with fee +func getTx(t *testing.T, val uint64, keystore txm.SimpleKeystore) (*solana.Transaction, func(fees.ComputeUnitPrice, bool, fees.ComputeUnitLimit) *solana.Transaction) { + pubkey := solana.PublicKey{} + + // create transfer tx + tx, err := solana.NewTransaction( + []solana.Instruction{ + system.NewTransferInstruction( + val, + pubkey, + pubkey, + ).Build(), + }, + solana.Hash{}, + solana.TransactionPayer(pubkey), + ) + require.NoError(t, err) + + base := *tx // tx to send to txm, txm will add fee & sign + + return &base, func(price fees.ComputeUnitPrice, addLimit bool, limit fees.ComputeUnitLimit) *solana.Transaction { + tx := base + // add fee parameters + require.NoError(t, fees.SetComputeUnitPrice(&tx, price)) + if addLimit { + require.NoError(t, fees.SetComputeUnitLimit(&tx, limit)) // default + } + + // sign tx + txMsg, err := tx.Message.MarshalBinary() + require.NoError(t, err) + sigBytes, err := keystore.Sign(tests.Context(t), pubkey.String(), txMsg) + require.NoError(t, err) + var finalSig [64]byte + copy(finalSig[:], sigBytes) + tx.Signatures = append(tx.Signatures, finalSig) + return &tx + } +} diff --git a/pkg/solana/chainwriter/lookups_test.go b/pkg/solana/chainwriter/lookups_test.go index 984f3eb92..36c5019f3 100644 --- a/pkg/solana/chainwriter/lookups_test.go +++ b/pkg/solana/chainwriter/lookups_test.go @@ -11,6 +11,7 @@ import ( "github.com/gagliardetto/solana-go/rpc" "github.com/smartcontractkit/chainlink-common/pkg/logger" "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" + "github.com/smartcontractkit/chainlink-solana/pkg/solana/chainwriter" "github.com/smartcontractkit/chainlink-solana/pkg/solana/client" "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" @@ -297,7 +298,7 @@ func TestLookupTables(t *testing.T) { txm := txm.NewTxm("localnet", loader, nil, cfg, mkey, lggr) - cw, err := chainwriter.NewSolanaChainWriterService(solanaClient, *txm, nil, chainwriter.ChainWriterConfig{}) + cw, err := chainwriter.NewSolanaChainWriterService(solanaClient, txm, nil, chainwriter.ChainWriterConfig{}) t.Run("StaticLookup table resolves properly", func(t *testing.T) { pubKeys := createTestPubKeys(t, 8) diff --git a/pkg/solana/txm/txm.go b/pkg/solana/txm/txm.go index 342f54dce..8882a014d 100644 --- a/pkg/solana/txm/txm.go +++ b/pkg/solana/txm/txm.go @@ -36,8 +36,6 @@ const ( MaxComputeUnitLimit = 1_400_000 // max compute unit limit a transaction can have ) -var _ services.Service = (*Txm)(nil) - type SimpleKeystore interface { Sign(ctx context.Context, account string, data []byte) (signature []byte, err error) Accounts(ctx context.Context) (accounts []string, err error) @@ -45,6 +43,14 @@ type SimpleKeystore interface { var _ loop.Keystore = (SimpleKeystore)(nil) +type TxManager interface { + services.Service + Enqueue(ctx context.Context, accountID string, tx *solanaGo.Transaction, txID *string, txCfgs ...SetTxConfig) error + GetTransactionStatus(ctx context.Context, transactionID string) (commontypes.TransactionStatus, error) +} + +var _ TxManager = (*Txm)(nil) + // Txm manages transactions for the solana blockchain. // simple implementation with no persistently stored txs type Txm struct {