diff --git a/pkg/solana/chain.go b/pkg/solana/chain.go index 95edd5cb4..a4076ed1d 100644 --- a/pkg/solana/chain.go +++ b/pkg/solana/chain.go @@ -108,7 +108,7 @@ type verifiedCachedClient struct { client.ReaderWriter } -func (v *verifiedCachedClient) verifyChainID() (bool, error) { +func (v *verifiedCachedClient) verifyChainID(ctx context.Context) (bool, error) { v.chainIDVerifiedLock.RLock() if v.chainIDVerified { v.chainIDVerifiedLock.RUnlock() @@ -121,7 +121,7 @@ func (v *verifiedCachedClient) verifyChainID() (bool, error) { v.chainIDVerifiedLock.Lock() defer v.chainIDVerifiedLock.Unlock() - strID, err := v.ReaderWriter.ChainID(context.Background()) + strID, err := v.ReaderWriter.ChainID(ctx) v.chainID = strID.String() if err != nil { v.chainIDVerified = false @@ -141,7 +141,7 @@ func (v *verifiedCachedClient) verifyChainID() (bool, error) { } func (v *verifiedCachedClient) SendTx(ctx context.Context, tx *solanago.Transaction) (solanago.Signature, error) { - verified, err := v.verifyChainID() + verified, err := v.verifyChainID(ctx) if !verified { return [64]byte{}, err } @@ -150,7 +150,7 @@ func (v *verifiedCachedClient) SendTx(ctx context.Context, tx *solanago.Transact } func (v *verifiedCachedClient) SimulateTx(ctx context.Context, tx *solanago.Transaction, opts *rpc.SimulateTransactionOpts) (*rpc.SimulateTransactionResult, error) { - verified, err := v.verifyChainID() + verified, err := v.verifyChainID(ctx) if !verified { return nil, err } @@ -159,7 +159,7 @@ func (v *verifiedCachedClient) SimulateTx(ctx context.Context, tx *solanago.Tran } func (v *verifiedCachedClient) SignatureStatuses(ctx context.Context, sigs []solanago.Signature) ([]*rpc.SignatureStatusesResult, error) { - verified, err := v.verifyChainID() + verified, err := v.verifyChainID(ctx) if !verified { return nil, err } @@ -167,35 +167,35 @@ func (v *verifiedCachedClient) SignatureStatuses(ctx context.Context, sigs []sol return v.ReaderWriter.SignatureStatuses(ctx, sigs) } -func (v *verifiedCachedClient) Balance(addr solanago.PublicKey) (uint64, error) { - verified, err := v.verifyChainID() +func (v *verifiedCachedClient) Balance(ctx context.Context, addr solanago.PublicKey) (uint64, error) { + verified, err := v.verifyChainID(ctx) if !verified { return 0, err } - return v.ReaderWriter.Balance(addr) + return v.ReaderWriter.Balance(ctx, addr) } -func (v *verifiedCachedClient) SlotHeight() (uint64, error) { - verified, err := v.verifyChainID() +func (v *verifiedCachedClient) SlotHeight(ctx context.Context) (uint64, error) { + verified, err := v.verifyChainID(ctx) if !verified { return 0, err } - return v.ReaderWriter.SlotHeight() + return v.ReaderWriter.SlotHeight(ctx) } -func (v *verifiedCachedClient) LatestBlockhash() (*rpc.GetLatestBlockhashResult, error) { - verified, err := v.verifyChainID() +func (v *verifiedCachedClient) LatestBlockhash(ctx context.Context) (*rpc.GetLatestBlockhashResult, error) { + verified, err := v.verifyChainID(ctx) if !verified { return nil, err } - return v.ReaderWriter.LatestBlockhash() + return v.ReaderWriter.LatestBlockhash(ctx) } func (v *verifiedCachedClient) ChainID(ctx context.Context) (mn.StringID, error) { - verified, err := v.verifyChainID() + verified, err := v.verifyChainID(ctx) if !verified { return "", err } @@ -203,17 +203,17 @@ func (v *verifiedCachedClient) ChainID(ctx context.Context) (mn.StringID, error) return mn.StringID(v.chainID), nil } -func (v *verifiedCachedClient) GetFeeForMessage(msg string) (uint64, error) { - verified, err := v.verifyChainID() +func (v *verifiedCachedClient) GetFeeForMessage(ctx context.Context, msg string) (uint64, error) { + verified, err := v.verifyChainID(ctx) if !verified { return 0, err } - return v.ReaderWriter.GetFeeForMessage(msg) + return v.ReaderWriter.GetFeeForMessage(ctx, msg) } func (v *verifiedCachedClient) GetAccountInfoWithOpts(ctx context.Context, addr solanago.PublicKey, opts *rpc.GetAccountInfoOpts) (*rpc.GetAccountInfoResult, error) { - verified, err := v.verifyChainID() + verified, err := v.verifyChainID(ctx) if !verified { return nil, err } @@ -293,20 +293,18 @@ func newChain(id string, cfg *config.TOMLConfig, ks loop.Keystore, lggr logger.L return ch.getClient() } ch.txm = txm.NewTxm(ch.id, tc, cfg, ks, lggr) - bc := func() (monitor.BalanceClient, error) { - return ch.getClient() - } + bc := func() (monitor.BalanceClient, error) { return ch.getClient() } ch.balanceMonitor = monitor.NewBalanceMonitor(ch.id, cfg, lggr, ks, bc) return &ch, nil } -func (c *chain) LatestHead(_ context.Context) (types.Head, error) { +func (c *chain) LatestHead(ctx context.Context) (types.Head, error) { sc, err := c.getClient() if err != nil { return types.Head{}, err } - latestBlock, err := sc.GetLatestBlock() + latestBlock, err := sc.GetLatestBlock(ctx) if err != nil { return types.Head{}, nil } @@ -536,7 +534,7 @@ func (c *chain) sendTx(ctx context.Context, from, to string, amount *big.Int, ba } amountI := amount.Uint64() - blockhash, err := reader.LatestBlockhash() + blockhash, err := reader.LatestBlockhash(ctx) if err != nil { return fmt.Errorf("failed to get latest block hash: %w", err) } @@ -556,13 +554,13 @@ func (c *chain) sendTx(ctx context.Context, from, to string, amount *big.Int, ba } if balanceCheck { - if err = solanaValidateBalance(reader, fromKey, amountI, tx.Message.ToBase64()); err != nil { + if err = solanaValidateBalance(ctx, reader, fromKey, amountI, tx.Message.ToBase64()); err != nil { return fmt.Errorf("failed to validate balance: %w", err) } } chainTxm := c.TxManager() - err = chainTxm.Enqueue("", tx, + err = chainTxm.Enqueue(ctx, "", tx, txm.SetComputeUnitLimit(500), // reduce from default 200K limit - should only take 450 compute units // no fee bumping and no additional fee - makes validating balance accurate txm.SetComputeUnitPriceMax(0), @@ -576,13 +574,13 @@ func (c *chain) sendTx(ctx context.Context, from, to string, amount *big.Int, ba return nil } -func solanaValidateBalance(reader client.Reader, from solanago.PublicKey, amount uint64, msg string) error { - balance, err := reader.Balance(from) +func solanaValidateBalance(ctx context.Context, reader client.Reader, from solanago.PublicKey, amount uint64, msg string) error { + balance, err := reader.Balance(ctx, from) if err != nil { return err } - fee, err := reader.GetFeeForMessage(msg) + fee, err := reader.GetFeeForMessage(ctx, msg) if err != nil { return err } diff --git a/pkg/solana/chain_test.go b/pkg/solana/chain_test.go index 4097e38dd..0e52741d2 100644 --- a/pkg/solana/chain_test.go +++ b/pkg/solana/chain_test.go @@ -126,6 +126,7 @@ func TestSolanaChain_GetClient(t *testing.T) { } func TestSolanaChain_VerifiedClient(t *testing.T) { + ctx := tests.Context(t) called := false mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { out := `{ "jsonrpc": "2.0", "result": 1234, "id": 1 }` // getSlot response @@ -175,7 +176,7 @@ func TestSolanaChain_VerifiedClient(t *testing.T) { // retrieve cached client and retrieve slot height c, err := testChain.verifiedClient(node) require.NoError(t, err) - slot, err := c.SlotHeight() + slot, err := c.SlotHeight(ctx) assert.NoError(t, err) assert.Equal(t, uint64(1234), slot) diff --git a/pkg/solana/client/client.go b/pkg/solana/client/client.go index 785e7e508..d649cb2c8 100644 --- a/pkg/solana/client/client.go +++ b/pkg/solana/client/client.go @@ -33,12 +33,12 @@ type ReaderWriter interface { type Reader interface { AccountReader - Balance(addr solana.PublicKey) (uint64, error) - SlotHeight() (uint64, error) - LatestBlockhash() (*rpc.GetLatestBlockhashResult, error) + Balance(ctx context.Context, addr solana.PublicKey) (uint64, error) + SlotHeight(ctx context.Context) (uint64, error) + LatestBlockhash(ctx context.Context) (*rpc.GetLatestBlockhashResult, error) ChainID(ctx context.Context) (mn.StringID, error) - GetFeeForMessage(msg string) (uint64, error) - GetLatestBlock() (*rpc.GetBlockResult, error) + GetFeeForMessage(ctx context.Context, msg string) (uint64, error) + GetLatestBlock(ctx context.Context) (*rpc.GetBlockResult, error) } // AccountReader is an interface that allows users to pass either the solana rpc client or the relay client @@ -160,11 +160,11 @@ func (c *Client) latency(name string) func() { } } -func (c *Client) Balance(addr solana.PublicKey) (uint64, error) { +func (c *Client) Balance(ctx context.Context, addr solana.PublicKey) (uint64, error) { done := c.latency("balance") defer done() - ctx, cancel := context.WithTimeout(context.Background(), c.contextDuration) + ctx, cancel := context.WithTimeout(ctx, c.contextDuration) defer cancel() v, err, _ := c.requestGroup.Do(fmt.Sprintf("GetBalance(%s)", addr.String()), func() (interface{}, error) { @@ -177,15 +177,15 @@ func (c *Client) Balance(addr solana.PublicKey) (uint64, error) { return res.Value, err } -func (c *Client) SlotHeight() (uint64, error) { - return c.SlotHeightWithCommitment(rpc.CommitmentProcessed) // get the latest slot height +func (c *Client) SlotHeight(ctx context.Context) (uint64, error) { + return c.SlotHeightWithCommitment(ctx, rpc.CommitmentProcessed) // get the latest slot height } -func (c *Client) SlotHeightWithCommitment(commitment rpc.CommitmentType) (uint64, error) { +func (c *Client) SlotHeightWithCommitment(ctx context.Context, commitment rpc.CommitmentType) (uint64, error) { done := c.latency("slot_height") defer done() - ctx, cancel := context.WithTimeout(context.Background(), c.contextDuration) + ctx, cancel := context.WithTimeout(ctx, c.contextDuration) defer cancel() v, err, _ := c.requestGroup.Do("GetSlotHeight", func() (interface{}, error) { return c.rpc.GetSlot(ctx, commitment) @@ -203,11 +203,11 @@ func (c *Client) GetAccountInfoWithOpts(ctx context.Context, addr solana.PublicK return c.rpc.GetAccountInfoWithOpts(ctx, addr, opts) } -func (c *Client) LatestBlockhash() (*rpc.GetLatestBlockhashResult, error) { +func (c *Client) LatestBlockhash(ctx context.Context) (*rpc.GetLatestBlockhashResult, error) { done := c.latency("latest_blockhash") defer done() - ctx, cancel := context.WithTimeout(context.Background(), c.contextDuration) + ctx, cancel := context.WithTimeout(ctx, c.contextDuration) defer cancel() v, err, _ := c.requestGroup.Do("GetLatestBlockhash", func() (interface{}, error) { @@ -245,13 +245,13 @@ func (c *Client) ChainID(ctx context.Context) (mn.StringID, error) { return mn.StringID(network), nil } -func (c *Client) GetFeeForMessage(msg string) (uint64, error) { +func (c *Client) GetFeeForMessage(ctx context.Context, msg string) (uint64, error) { done := c.latency("fee_for_message") defer done() // msg is base58 encoded data - ctx, cancel := context.WithTimeout(context.Background(), c.contextDuration) + ctx, cancel := context.WithTimeout(ctx, c.contextDuration) defer cancel() res, err := c.rpc.GetFeeForMessage(ctx, msg, c.commitment) if err != nil { @@ -328,9 +328,9 @@ func (c *Client) SendTx(ctx context.Context, tx *solana.Transaction) (solana.Sig return c.rpc.SendTransactionWithOpts(ctx, tx, opts) } -func (c *Client) GetLatestBlock() (*rpc.GetBlockResult, error) { +func (c *Client) GetLatestBlock(ctx context.Context) (*rpc.GetBlockResult, error) { // get latest confirmed slot - slot, err := c.SlotHeightWithCommitment(c.commitment) + slot, err := c.SlotHeightWithCommitment(ctx, c.commitment) if err != nil { return nil, fmt.Errorf("GetLatestBlock.SlotHeight: %w", err) } @@ -338,7 +338,7 @@ func (c *Client) GetLatestBlock() (*rpc.GetBlockResult, error) { // get block based on slot done := c.latency("latest_block") defer done() - ctx, cancel := context.WithTimeout(context.Background(), c.txTimeout) + ctx, cancel := context.WithTimeout(ctx, c.txTimeout) defer cancel() v, err, _ := c.requestGroup.Do("GetBlockWithOpts", func() (interface{}, error) { version := uint64(0) // pull all tx types (legacy + v0) diff --git a/pkg/solana/client/client_test.go b/pkg/solana/client/client_test.go index 6a4feb61f..54d130206 100644 --- a/pkg/solana/client/client_test.go +++ b/pkg/solana/client/client_test.go @@ -19,6 +19,7 @@ import ( "github.com/stretchr/testify/require" "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" mn "github.com/smartcontractkit/chainlink-solana/pkg/solana/client/multinode" "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" @@ -26,6 +27,7 @@ import ( ) func TestClient_Reader_Integration(t *testing.T) { + ctx := tests.Context(t) url := SetupLocalSolNode(t) privKey, err := solana.NewRandomPrivateKey() require.NoError(t, err) @@ -40,21 +42,21 @@ func TestClient_Reader_Integration(t *testing.T) { require.NoError(t, err) // check balance - bal, err := c.Balance(pubKey) + bal, err := c.Balance(ctx, pubKey) assert.NoError(t, err) assert.Equal(t, uint64(100_000_000_000), bal) // once funds get sent to the system program it should be unrecoverable (so this number should remain > 0) // check SlotHeight - slot0, err := c.SlotHeight() + slot0, err := c.SlotHeight(ctx) assert.NoError(t, err) assert.Greater(t, slot0, uint64(0)) time.Sleep(time.Second) - slot1, err := c.SlotHeight() + slot1, err := c.SlotHeight(ctx) assert.NoError(t, err) assert.Greater(t, slot1, slot0) // fetch recent blockhash - hash, err := c.LatestBlockhash() + hash, err := c.LatestBlockhash(ctx) assert.NoError(t, err) assert.NotEqual(t, hash.Value.Blockhash, solana.Hash{}) // not an empty hash @@ -72,7 +74,7 @@ func TestClient_Reader_Integration(t *testing.T) { ) assert.NoError(t, err) - fee, err := c.GetFeeForMessage(tx.Message.ToBase64()) + fee, err := c.GetFeeForMessage(ctx, tx.Message.ToBase64()) assert.NoError(t, err) assert.Equal(t, uint64(5000), fee) @@ -82,13 +84,13 @@ func TestClient_Reader_Integration(t *testing.T) { assert.Equal(t, mn.StringID("localnet"), network) // get account info (also tested inside contract_test) - res, err := c.GetAccountInfoWithOpts(context.TODO(), solana.PublicKey{}, &rpc.GetAccountInfoOpts{Commitment: rpc.CommitmentFinalized}) + res, err := c.GetAccountInfoWithOpts(ctx, solana.PublicKey{}, &rpc.GetAccountInfoOpts{Commitment: rpc.CommitmentFinalized}) assert.NoError(t, err) assert.Equal(t, uint64(1), res.Value.Lamports) assert.Equal(t, "NativeLoader1111111111111111111111111111111", res.Value.Owner.String()) // get block + check for nonzero values - block, err := c.GetLatestBlock() + block, err := c.GetLatestBlock(ctx) require.NoError(t, err) assert.NotEqual(t, solana.Hash{}, block.Blockhash) assert.NotEqual(t, uint64(0), block.ParentSlot) @@ -96,6 +98,7 @@ func TestClient_Reader_Integration(t *testing.T) { } func TestClient_Reader_ChainID(t *testing.T) { + ctx := tests.Context(t) genesisHashes := []string{ DevnetGenesisHash, // devnet TestnetGenesisHash, // testnet @@ -121,7 +124,7 @@ func TestClient_Reader_ChainID(t *testing.T) { // get chain ID based on gensis hash for _, n := range networks { - network, err := c.ChainID(context.Background()) + network, err := c.ChainID(ctx) assert.NoError(t, err) assert.Equal(t, mn.StringID(n), network) } @@ -138,13 +141,13 @@ func TestClient_Writer_Integration(t *testing.T) { lggr := logger.Test(t) cfg := config.NewDefault() - ctx := context.Background() + ctx := tests.Context(t) c, err := NewClient(url, cfg, requestTimeout, lggr) require.NoError(t, err) // create + sign transaction createTx := func(to solana.PublicKey) *solana.Transaction { - hash, hashErr := c.LatestBlockhash() + hash, hashErr := c.LatestBlockhash(ctx) assert.NoError(t, hashErr) tx, txErr := solana.NewTransaction( @@ -212,6 +215,7 @@ func TestClient_Writer_Integration(t *testing.T) { } func TestClient_SendTxDuplicates_Integration(t *testing.T) { + ctx := tests.Context(t) // set up environment url := SetupLocalSolNode(t) privKey, err := solana.NewRandomPrivateKey() @@ -227,10 +231,10 @@ func TestClient_SendTxDuplicates_Integration(t *testing.T) { require.NoError(t, err) // fetch recent blockhash - hash, err := c.LatestBlockhash() + hash, err := c.LatestBlockhash(ctx) assert.NoError(t, err) - initBal, err := c.Balance(pubKey) + initBal, err := c.Balance(ctx, pubKey) assert.NoError(t, err) // create + sign tx @@ -261,7 +265,6 @@ func TestClient_SendTxDuplicates_Integration(t *testing.T) { n := 5 sigs := make([]solana.Signature, n) var wg sync.WaitGroup - ctx := context.Background() wg.Add(5) for i := 0; i < n; i++ { go func(i int) { @@ -292,7 +295,7 @@ func TestClient_SendTxDuplicates_Integration(t *testing.T) { // expect one sender has only sent one tx // original balance - current bal = 5000 lamports (tx fee) - endBal, err := c.Balance(pubKey) + endBal, err := c.Balance(ctx, pubKey) assert.NoError(t, err) assert.Equal(t, uint64(5_000), initBal-endBal) } diff --git a/pkg/solana/client/mocks/ReaderWriter.go b/pkg/solana/client/mocks/ReaderWriter.go index b6cd6808a..fd750fdb5 100644 --- a/pkg/solana/client/mocks/ReaderWriter.go +++ b/pkg/solana/client/mocks/ReaderWriter.go @@ -17,9 +17,9 @@ type ReaderWriter struct { mock.Mock } -// Balance provides a mock function with given fields: addr -func (_m *ReaderWriter) Balance(addr solana.PublicKey) (uint64, error) { - ret := _m.Called(addr) +// Balance provides a mock function with given fields: ctx, addr +func (_m *ReaderWriter) Balance(ctx context.Context, addr solana.PublicKey) (uint64, error) { + ret := _m.Called(ctx, addr) if len(ret) == 0 { panic("no return value specified for Balance") @@ -27,17 +27,17 @@ func (_m *ReaderWriter) Balance(addr solana.PublicKey) (uint64, error) { var r0 uint64 var r1 error - if rf, ok := ret.Get(0).(func(solana.PublicKey) (uint64, error)); ok { - return rf(addr) + if rf, ok := ret.Get(0).(func(context.Context, solana.PublicKey) (uint64, error)); ok { + return rf(ctx, addr) } - if rf, ok := ret.Get(0).(func(solana.PublicKey) uint64); ok { - r0 = rf(addr) + if rf, ok := ret.Get(0).(func(context.Context, solana.PublicKey) uint64); ok { + r0 = rf(ctx, addr) } else { r0 = ret.Get(0).(uint64) } - if rf, ok := ret.Get(1).(func(solana.PublicKey) error); ok { - r1 = rf(addr) + if rf, ok := ret.Get(1).(func(context.Context, solana.PublicKey) error); ok { + r1 = rf(ctx, addr) } else { r1 = ret.Error(1) } @@ -103,9 +103,9 @@ func (_m *ReaderWriter) GetAccountInfoWithOpts(ctx context.Context, addr solana. return r0, r1 } -// GetFeeForMessage provides a mock function with given fields: msg -func (_m *ReaderWriter) GetFeeForMessage(msg string) (uint64, error) { - ret := _m.Called(msg) +// GetFeeForMessage provides a mock function with given fields: ctx, msg +func (_m *ReaderWriter) GetFeeForMessage(ctx context.Context, msg string) (uint64, error) { + ret := _m.Called(ctx, msg) if len(ret) == 0 { panic("no return value specified for GetFeeForMessage") @@ -113,17 +113,17 @@ func (_m *ReaderWriter) GetFeeForMessage(msg string) (uint64, error) { var r0 uint64 var r1 error - if rf, ok := ret.Get(0).(func(string) (uint64, error)); ok { - return rf(msg) + if rf, ok := ret.Get(0).(func(context.Context, string) (uint64, error)); ok { + return rf(ctx, msg) } - if rf, ok := ret.Get(0).(func(string) uint64); ok { - r0 = rf(msg) + if rf, ok := ret.Get(0).(func(context.Context, string) uint64); ok { + r0 = rf(ctx, msg) } else { r0 = ret.Get(0).(uint64) } - if rf, ok := ret.Get(1).(func(string) error); ok { - r1 = rf(msg) + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, msg) } else { r1 = ret.Error(1) } @@ -131,9 +131,9 @@ func (_m *ReaderWriter) GetFeeForMessage(msg string) (uint64, error) { return r0, r1 } -// GetLatestBlock provides a mock function with given fields: -func (_m *ReaderWriter) GetLatestBlock() (*rpc.GetBlockResult, error) { - ret := _m.Called() +// GetLatestBlock provides a mock function with given fields: ctx +func (_m *ReaderWriter) GetLatestBlock(ctx context.Context) (*rpc.GetBlockResult, error) { + ret := _m.Called(ctx) if len(ret) == 0 { panic("no return value specified for GetLatestBlock") @@ -141,19 +141,19 @@ func (_m *ReaderWriter) GetLatestBlock() (*rpc.GetBlockResult, error) { var r0 *rpc.GetBlockResult var r1 error - if rf, ok := ret.Get(0).(func() (*rpc.GetBlockResult, error)); ok { - return rf() + if rf, ok := ret.Get(0).(func(context.Context) (*rpc.GetBlockResult, error)); ok { + return rf(ctx) } - if rf, ok := ret.Get(0).(func() *rpc.GetBlockResult); ok { - r0 = rf() + if rf, ok := ret.Get(0).(func(context.Context) *rpc.GetBlockResult); ok { + r0 = rf(ctx) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*rpc.GetBlockResult) } } - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) } else { r1 = ret.Error(1) } @@ -161,9 +161,9 @@ func (_m *ReaderWriter) GetLatestBlock() (*rpc.GetBlockResult, error) { return r0, r1 } -// LatestBlockhash provides a mock function with given fields: -func (_m *ReaderWriter) LatestBlockhash() (*rpc.GetLatestBlockhashResult, error) { - ret := _m.Called() +// LatestBlockhash provides a mock function with given fields: ctx +func (_m *ReaderWriter) LatestBlockhash(ctx context.Context) (*rpc.GetLatestBlockhashResult, error) { + ret := _m.Called(ctx) if len(ret) == 0 { panic("no return value specified for LatestBlockhash") @@ -171,19 +171,19 @@ func (_m *ReaderWriter) LatestBlockhash() (*rpc.GetLatestBlockhashResult, error) var r0 *rpc.GetLatestBlockhashResult var r1 error - if rf, ok := ret.Get(0).(func() (*rpc.GetLatestBlockhashResult, error)); ok { - return rf() + if rf, ok := ret.Get(0).(func(context.Context) (*rpc.GetLatestBlockhashResult, error)); ok { + return rf(ctx) } - if rf, ok := ret.Get(0).(func() *rpc.GetLatestBlockhashResult); ok { - r0 = rf() + if rf, ok := ret.Get(0).(func(context.Context) *rpc.GetLatestBlockhashResult); ok { + r0 = rf(ctx) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*rpc.GetLatestBlockhashResult) } } - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) } else { r1 = ret.Error(1) } @@ -281,9 +281,9 @@ func (_m *ReaderWriter) SimulateTx(ctx context.Context, tx *solana.Transaction, return r0, r1 } -// SlotHeight provides a mock function with given fields: -func (_m *ReaderWriter) SlotHeight() (uint64, error) { - ret := _m.Called() +// SlotHeight provides a mock function with given fields: ctx +func (_m *ReaderWriter) SlotHeight(ctx context.Context) (uint64, error) { + ret := _m.Called(ctx) if len(ret) == 0 { panic("no return value specified for SlotHeight") @@ -291,17 +291,17 @@ func (_m *ReaderWriter) SlotHeight() (uint64, error) { var r0 uint64 var r1 error - if rf, ok := ret.Get(0).(func() (uint64, error)); ok { - return rf() + if rf, ok := ret.Get(0).(func(context.Context) (uint64, error)); ok { + return rf(ctx) } - if rf, ok := ret.Get(0).(func() uint64); ok { - r0 = rf() + if rf, ok := ret.Get(0).(func(context.Context) uint64); ok { + r0 = rf(ctx) } else { r0 = ret.Get(0).(uint64) } - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) } else { r1 = ret.Error(1) } diff --git a/pkg/solana/client/test_helpers_test.go b/pkg/solana/client/test_helpers_test.go index 1f530da2b..297296ec8 100644 --- a/pkg/solana/client/test_helpers_test.go +++ b/pkg/solana/client/test_helpers_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/require" "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" ) @@ -28,19 +29,20 @@ func TestSetupLocalSolNode_SimultaneousNetworks(t *testing.T) { // check & fund address checkFunded := func(t *testing.T, url string) { + ctx := tests.Context(t) // create client c, err := NewClient(url, cfg, requestTimeout, lggr) require.NoError(t, err) // check init balance - bal, err := c.Balance(pubkey) + bal, err := c.Balance(ctx, pubkey) assert.NoError(t, err) assert.Equal(t, uint64(0), bal) FundTestAccounts(t, []solana.PublicKey{pubkey}, url) // check end balance - bal, err = c.Balance(pubkey) + bal, err = c.Balance(ctx, pubkey) assert.NoError(t, err) assert.Equal(t, uint64(100_000_000_000), bal) // once funds get sent to the system program it should be unrecoverable (so this number should remain > 0) } diff --git a/pkg/solana/config_tracker.go b/pkg/solana/config_tracker.go index a04450edd..998790b45 100644 --- a/pkg/solana/config_tracker.go +++ b/pkg/solana/config_tracker.go @@ -75,5 +75,5 @@ func (c *ConfigTracker) LatestConfig(ctx context.Context, changedInBlock uint64) // LatestBlockHeight returns the height of the most recent block in the chain. func (c *ConfigTracker) LatestBlockHeight(ctx context.Context) (blockHeight uint64, err error) { - return c.reader.SlotHeight() // this returns the latest slot height through CommitmentProcessed + return c.reader.SlotHeight(ctx) // this returns the latest slot height through CommitmentProcessed } diff --git a/pkg/solana/fees/block_history.go b/pkg/solana/fees/block_history.go index 214612e90..41c63fb2e 100644 --- a/pkg/solana/fees/block_history.go +++ b/pkg/solana/fees/block_history.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "sync" - "time" "github.com/smartcontractkit/chainlink-common/pkg/logger" "github.com/smartcontractkit/chainlink-common/pkg/services" @@ -19,7 +18,7 @@ var _ Estimator = &blockHistoryEstimator{} type blockHistoryEstimator struct { starter services.StateMachine - chStop chan struct{} + chStop services.StopChan done sync.WaitGroup client *utils.LazyLoad[client.ReaderWriter] @@ -54,19 +53,20 @@ func (bhe *blockHistoryEstimator) Start(ctx context.Context) error { func (bhe *blockHistoryEstimator) run() { defer bhe.done.Done() + ctx, cancel := bhe.chStop.NewCtx() + defer cancel() - tick := time.After(0) + ticker := services.NewTicker(bhe.cfg.BlockHistoryPollPeriod()) + defer ticker.Stop() for { select { - case <-bhe.chStop: + case <-ctx.Done(): return - case <-tick: - if err := bhe.calculatePrice(); err != nil { + case <-ticker.C: + if err := bhe.calculatePrice(ctx); err != nil { bhe.lgr.Error(fmt.Errorf("BlockHistoryEstimator failed to fetch price: %w", err)) } } - - tick = time.After(utils.WithJitter(bhe.cfg.BlockHistoryPollPeriod())) } } @@ -98,7 +98,7 @@ func (bhe *blockHistoryEstimator) readRawPrice() uint64 { return bhe.price } -func (bhe *blockHistoryEstimator) calculatePrice() error { +func (bhe *blockHistoryEstimator) calculatePrice(ctx context.Context) error { // fetch client c, err := bhe.client.Get() if err != nil { @@ -106,7 +106,7 @@ func (bhe *blockHistoryEstimator) calculatePrice() error { } // get latest block based on configured confirmation - block, err := c.GetLatestBlock() + block, err := c.GetLatestBlock(ctx) if err != nil { return fmt.Errorf("failed to get block in blockHistoryEstimator.getFee: %w", err) } diff --git a/pkg/solana/fees/block_history_test.go b/pkg/solana/fees/block_history_test.go index a11eff81c..c22ca3aba 100644 --- a/pkg/solana/fees/block_history_test.go +++ b/pkg/solana/fees/block_history_test.go @@ -9,6 +9,7 @@ import ( "github.com/gagliardetto/solana-go/rpc" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "go.uber.org/zap/zapcore" @@ -47,7 +48,7 @@ func TestBlockHistoryEstimator(t *testing.T) { estimator, err := NewBlockHistoryEstimator(rwLoader, cfg, lgr) require.NoError(t, err) - rw.On("GetLatestBlock").Return(blockRes, nil).Once() + rw.On("GetLatestBlock", mock.Anything).Return(blockRes, nil).Once() require.NoError(t, estimator.Start(ctx)) tests.AssertLogEventually(t, logs, "BlockHistoryEstimator: updated") assert.Equal(t, uint64(55000), estimator.readRawPrice()) @@ -61,22 +62,22 @@ func TestBlockHistoryEstimator(t *testing.T) { assert.Equal(t, estimator.readRawPrice(), estimator.BaseComputeUnitPrice()) // failed to get latest block - rw.On("GetLatestBlock").Return(nil, fmt.Errorf("fail rpc call")).Once() + rw.On("GetLatestBlock", mock.Anything).Return(nil, fmt.Errorf("fail rpc call")).Once() tests.AssertLogEventually(t, logs, "failed to get block") assert.Equal(t, validPrice, estimator.BaseComputeUnitPrice(), "price should not change when getPrice fails") // failed to parse block - rw.On("GetLatestBlock").Return(nil, nil).Once() + rw.On("GetLatestBlock", mock.Anything).Return(nil, nil).Once() tests.AssertLogEventually(t, logs, "failed to parse block") assert.Equal(t, validPrice, estimator.BaseComputeUnitPrice(), "price should not change when getPrice fails") // failed to calculate median - rw.On("GetLatestBlock").Return(&rpc.GetBlockResult{}, nil).Once() + rw.On("GetLatestBlock", mock.Anything).Return(&rpc.GetBlockResult{}, nil).Once() tests.AssertLogEventually(t, logs, "failed to find median") assert.Equal(t, validPrice, estimator.BaseComputeUnitPrice(), "price should not change when getPrice fails") // back to happy path - rw.On("GetLatestBlock").Return(blockRes, nil).Once() + rw.On("GetLatestBlock", mock.Anything).Return(blockRes, nil).Once() tests.AssertEventually(t, func() bool { return logs.FilterMessageSnippet("BlockHistoryEstimator: updated").Len() == 2 }) diff --git a/pkg/solana/monitor/balance.go b/pkg/solana/monitor/balance.go index 5c7c88f5d..a1ab59d69 100644 --- a/pkg/solana/monitor/balance.go +++ b/pkg/solana/monitor/balance.go @@ -22,7 +22,7 @@ type Keystore interface { } type BalanceClient interface { - Balance(addr solana.PublicKey) (uint64, error) + Balance(ctx context.Context, addr solana.PublicKey) (uint64, error) } // NewBalanceMonitor returns a balance monitoring services.Service which reports the SOL balance of all ks keys to prometheus. @@ -112,6 +112,9 @@ func (b *balanceMonitor) getReader() (BalanceClient, error) { } func (b *balanceMonitor) updateBalances(ctx context.Context) { + ctx, cancel := b.stop.Ctx(ctx) + defer cancel() + keys, err := b.ks.Accounts(ctx) if err != nil { b.lggr.Errorw("Failed to get keys", "err", err) @@ -129,7 +132,7 @@ func (b *balanceMonitor) updateBalances(ctx context.Context) { for _, k := range keys { // Check for shutdown signal, since Balance blocks and may be slow. select { - case <-b.stop: + case <-ctx.Done(): return default: } @@ -138,7 +141,7 @@ func (b *balanceMonitor) updateBalances(ctx context.Context) { b.lggr.Errorw("Failed parse public key", "account", k, "err", err) continue } - lamports, err := reader.Balance(pubKey) + lamports, err := reader.Balance(ctx, pubKey) if err != nil { b.lggr.Errorw("Failed to get balance", "account", k, "err", err) continue diff --git a/pkg/solana/monitor/balance_test.go b/pkg/solana/monitor/balance_test.go index ff98d0508..9321d6a52 100644 --- a/pkg/solana/monitor/balance_test.go +++ b/pkg/solana/monitor/balance_test.go @@ -9,9 +9,11 @@ import ( "github.com/gagliardetto/solana-go" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/services/servicetest" "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" "github.com/smartcontractkit/chainlink-solana/pkg/solana/client/mocks" @@ -36,13 +38,12 @@ func TestBalanceMonitor(t *testing.T) { "1.000000000", } - client := new(mocks.ReaderWriter) - client.Test(t) + client := mocks.NewReaderWriter(t) type update struct{ acc, bal string } var exp []update for i := range bals { acc := ks[i] - client.On("Balance", acc).Return(bals[i], nil) + client.On("Balance", mock.Anything, acc).Return(bals[i], nil) exp = append(exp, update{acc.String(), expBals[i]}) } cfg := &config{balancePollPeriod: time.Second} @@ -63,11 +64,7 @@ func TestBalanceMonitor(t *testing.T) { } b.reader = client - require.NoError(t, b.Start(tests.Context(t))) - t.Cleanup(func() { - assert.NoError(t, b.Close()) - client.AssertExpectations(t) - }) + servicetest.Run(t, b) select { case <-time.After(tests.WaitTimeout(t)): t.Fatal("timed out waiting for balance monitor") diff --git a/pkg/solana/relay.go b/pkg/solana/relay.go index d53e7de47..6edd11b4f 100644 --- a/pkg/solana/relay.go +++ b/pkg/solana/relay.go @@ -24,7 +24,7 @@ import ( var _ TxManager = (*txm.Txm)(nil) type TxManager interface { - Enqueue(accountID string, msg *solana.Transaction, txCfgs ...txm.SetTxConfig) error + Enqueue(ctx context.Context, accountID string, msg *solana.Transaction, txCfgs ...txm.SetTxConfig) error } var _ relaytypes.Relayer = &Relayer{} //nolint:staticcheck diff --git a/pkg/solana/transmitter.go b/pkg/solana/transmitter.go index ca5a55dc9..4a3731921 100644 --- a/pkg/solana/transmitter.go +++ b/pkg/solana/transmitter.go @@ -32,7 +32,7 @@ func (c *Transmitter) Transmit( report types.Report, sigs []types.AttributedOnchainSignature, ) error { - blockhash, err := c.reader.LatestBlockhash() + blockhash, err := c.reader.LatestBlockhash(ctx) if err != nil { return fmt.Errorf("error on Transmit.GetRecentBlockhash: %w", err) } @@ -84,7 +84,7 @@ func (c *Transmitter) Transmit( // pass transmit payload to tx manager queue c.lggr.Debugf("Queuing transmit tx: state (%s) + transmissions (%s)", c.stateID.String(), c.transmissionsID.String()) - if err = c.txManager.Enqueue(c.stateID.String(), tx); err != nil { + if err = c.txManager.Enqueue(ctx, c.stateID.String(), tx); err != nil { return fmt.Errorf("error on Transmit.txManager.Enqueue: %w", err) } return nil diff --git a/pkg/solana/transmitter_test.go b/pkg/solana/transmitter_test.go index dd9787b93..66dd8658c 100644 --- a/pkg/solana/transmitter_test.go +++ b/pkg/solana/transmitter_test.go @@ -1,16 +1,19 @@ package solana import ( + "context" "testing" "github.com/gagliardetto/solana-go" "github.com/gagliardetto/solana-go/rpc" - "github.com/smartcontractkit/chainlink-common/pkg/logger" - "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" "github.com/smartcontractkit/libocr/offchainreporting2/types" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" + clientmocks "github.com/smartcontractkit/chainlink-solana/pkg/solana/client/mocks" "github.com/smartcontractkit/chainlink-solana/pkg/solana/fees" "github.com/smartcontractkit/chainlink-solana/pkg/solana/txm" @@ -23,7 +26,7 @@ type verifyTxSize struct { s *solana.PrivateKey } -func (txm verifyTxSize) Enqueue(_ string, tx *solana.Transaction, _ ...txm.SetTxConfig) error { +func (txm verifyTxSize) Enqueue(_ context.Context, _ string, tx *solana.Transaction, _ ...txm.SetTxConfig) error { // additional components that transaction manager adds to the transaction require.NoError(txm.t, fees.SetComputeUnitPrice(tx, 0)) require.NoError(txm.t, fees.SetComputeUnitLimit(tx, 0)) @@ -55,7 +58,7 @@ func TestTransmitter_TxSize(t *testing.T) { } rw := clientmocks.NewReaderWriter(t) - rw.On("LatestBlockhash").Return(&rpc.GetLatestBlockhashResult{ + rw.On("LatestBlockhash", mock.Anything).Return(&rpc.GetLatestBlockhashResult{ Value: &rpc.LatestBlockhashResult{}, }, nil) diff --git a/pkg/solana/txm/txm.go b/pkg/solana/txm/txm.go index 3eceb8df3..b7e7d9c47 100644 --- a/pkg/solana/txm/txm.go +++ b/pkg/solana/txm/txm.go @@ -111,6 +111,9 @@ func (txm *Txm) Start(ctx context.Context) error { txm.done.Add(3) // waitgroup: tx retry, confirmer, simulator go txm.run() + go txm.confirm() + go txm.simulate() + return nil }) } @@ -120,10 +123,6 @@ func (txm *Txm) run() { ctx, cancel := txm.chStop.NewCtx() defer cancel() - // start confirmer + simulator - go txm.confirm(ctx) - go txm.simulate(ctx) - for { select { case msg := <-txm.chSend: @@ -152,7 +151,7 @@ func (txm *Txm) run() { } } -func (txm *Txm) sendWithRetry(chanCtx context.Context, baseTx solanaGo.Transaction, txcfg TxConfig) (solanaGo.Transaction, uuid.UUID, solanaGo.Signature, error) { +func (txm *Txm) sendWithRetry(ctx context.Context, baseTx solanaGo.Transaction, txcfg TxConfig) (solanaGo.Transaction, uuid.UUID, solanaGo.Signature, error) { // fetch client client, clientErr := txm.client.Get() if clientErr != nil { @@ -184,7 +183,7 @@ func (txm *Txm) sendWithRetry(chanCtx context.Context, baseTx solanaGo.Transacti } } - buildTx := func(base solanaGo.Transaction, retryCount int) (solanaGo.Transaction, error) { + buildTx := func(ctx context.Context, base solanaGo.Transaction, retryCount int) (solanaGo.Transaction, error) { newTx := base // make copy // set fee @@ -198,7 +197,7 @@ func (txm *Txm) sendWithRetry(chanCtx context.Context, baseTx solanaGo.Transacti if marshalErr != nil { return solanaGo.Transaction{}, fmt.Errorf("error in soltxm.SendWithRetry.MarshalBinary: %w", marshalErr) } - sigBytes, signErr := txm.ks.Sign(context.TODO(), key, txMsg) + sigBytes, signErr := txm.ks.Sign(ctx, key, txMsg) if signErr != nil { return solanaGo.Transaction{}, fmt.Errorf("error in soltxm.SendWithRetry.Sign: %w", signErr) } @@ -209,13 +208,13 @@ func (txm *Txm) sendWithRetry(chanCtx context.Context, baseTx solanaGo.Transacti return newTx, nil } - initTx, initBuildErr := buildTx(baseTx, 0) + initTx, initBuildErr := buildTx(ctx, baseTx, 0) if initBuildErr != nil { return solanaGo.Transaction{}, uuid.Nil, solanaGo.Signature{}, initBuildErr } // create timeout context - ctx, cancel := context.WithTimeout(chanCtx, txcfg.Timeout) + ctx, cancel := context.WithTimeout(ctx, txcfg.Timeout) // send initial tx (do not retry and exit early if fails) sig, initSendErr := client.SendTx(ctx, &initTx) @@ -241,10 +240,12 @@ func (txm *Txm) sendWithRetry(chanCtx context.Context, baseTx solanaGo.Transacti txm.lggr.Debugw("tx initial broadcast", "id", id, "signature", sig) + txm.done.Add(1) // retry with exponential backoff // until context cancelled by timeout or called externally // pass in copy of baseTx (used to build new tx with bumped fee) and broadcasted tx == initTx (used to retry tx without bumping) - go func(baseTx, currentTx solanaGo.Transaction) { + go func(ctx context.Context, baseTx, currentTx solanaGo.Transaction) { + defer txm.done.Done() deltaT := 1 // ms tick := time.After(0) bumpCount := 0 @@ -256,7 +257,7 @@ func (txm *Txm) sendWithRetry(chanCtx context.Context, baseTx solanaGo.Transacti case <-ctx.Done(): // stop sending tx after retry tx ctx times out (does not stop confirmation polling for tx) wg.Wait() - txm.lggr.Debugw("stopped tx retry", "id", id, "signatures", sigs.List()) + txm.lggr.Debugw("stopped tx retry", "id", id, "signatures", sigs.List(), "err", context.Cause(ctx)) return case <-tick: var shouldBump bool @@ -270,7 +271,7 @@ func (txm *Txm) sendWithRetry(chanCtx context.Context, baseTx solanaGo.Transacti // if fee should be bumped, build new tx and replace currentTx if shouldBump { var retryBuildErr error - currentTx, retryBuildErr = buildTx(baseTx, bumpCount) + currentTx, retryBuildErr = buildTx(ctx, baseTx, bumpCount) if retryBuildErr != nil { txm.lggr.Errorw("failed to build bumped retry tx", "error", retryBuildErr, "id", id) return // exit func if cannot build tx for retrying @@ -337,7 +338,7 @@ func (txm *Txm) sendWithRetry(chanCtx context.Context, baseTx solanaGo.Transacti } tick = time.After(time.Duration(deltaT) * time.Millisecond) } - }(baseTx, initTx) + }(ctx, baseTx, initTx) // return signed tx, id, signature for use in simulation return initTx, id, sig, nil @@ -345,8 +346,10 @@ func (txm *Txm) sendWithRetry(chanCtx context.Context, baseTx solanaGo.Transacti // goroutine that polls to confirm implementation // cancels the exponential retry once confirmed -func (txm *Txm) confirm(ctx context.Context) { +func (txm *Txm) confirm() { defer txm.done.Done() + ctx, cancel := txm.chStop.NewCtx() + defer cancel() tick := time.After(0) for { @@ -441,7 +444,6 @@ func (txm *Txm) confirm(ctx context.Context) { // waitgroup for processing var wg sync.WaitGroup - wg.Add(len(sigsBatch)) // loop through batch for i := 0; i < len(sigsBatch); i++ { @@ -449,10 +451,10 @@ func (txm *Txm) confirm(ctx context.Context) { statuses, err := client.SignatureStatuses(ctx, sigsBatch[i]) if err != nil { txm.lggr.Errorw("failed to get signature statuses in soltxm.confirm", "error", err) - wg.Done() // don't block if exit early - break // exit for loop + break // exit for loop } + wg.Add(1) // nonblocking: process batches as soon as they come in go func(index int) { defer wg.Done() @@ -468,8 +470,10 @@ func (txm *Txm) confirm(ctx context.Context) { // goroutine that simulates tx (use a bounded number of goroutines to pick from queue?) // simulate can cancel the send retry function early in the tx management process // additionally, it can provide reasons for why a tx failed in the logs -func (txm *Txm) simulate(ctx context.Context) { +func (txm *Txm) simulate() { defer txm.done.Done() + ctx, cancel := txm.chStop.NewCtx() + defer cancel() for { select { @@ -526,7 +530,7 @@ func (txm *Txm) simulate(ctx context.Context) { } // Enqueue enqueue a msg destined for the solana chain. -func (txm *Txm) Enqueue(accountID string, tx *solanaGo.Transaction, txCfgs ...SetTxConfig) error { +func (txm *Txm) Enqueue(ctx context.Context, accountID string, tx *solanaGo.Transaction, txCfgs ...SetTxConfig) error { if err := txm.Ready(); err != nil { return fmt.Errorf("error in soltxm.Enqueue: %w", err) } @@ -543,7 +547,7 @@ func (txm *Txm) Enqueue(accountID string, tx *solanaGo.Transaction, txCfgs ...Se // validate expected key exists by trying to sign with it // fee payer account is index 0 account // https://github.com/gagliardetto/solana-go/blob/main/transaction.go#L252 - _, err := txm.ks.Sign(context.TODO(), tx.Message.AccountKeys[0].String(), nil) + _, err := txm.ks.Sign(ctx, tx.Message.AccountKeys[0].String(), nil) if err != nil { return fmt.Errorf("error in soltxm.Enqueue.GetKey: %w", err) } diff --git a/pkg/solana/txm/txm_internal_test.go b/pkg/solana/txm/txm_internal_test.go index 17b73d0e6..c420dd811 100644 --- a/pkg/solana/txm/txm_internal_test.go +++ b/pkg/solana/txm/txm_internal_test.go @@ -104,7 +104,7 @@ func TestTxm(t *testing.T) { cfg := config.NewDefault() cfg.Chain.FeeEstimatorMode = &estimator mc := mocks.NewReaderWriter(t) - mc.On("GetLatestBlock").Return(&rpc.GetBlockResult{}, nil).Maybe() + mc.On("GetLatestBlock", mock.Anything).Return(&rpc.GetBlockResult{}, nil).Maybe() // mock solana keystore mkey := keyMocks.NewSimpleKeystore(t) @@ -196,7 +196,7 @@ func TestTxm(t *testing.T) { } // send tx - assert.NoError(t, txm.Enqueue(t.Name(), tx)) + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx)) wg.Wait() // no transactions stored inflight txs list @@ -227,7 +227,7 @@ func TestTxm(t *testing.T) { }).Return(solana.Signature{}, errors.New("FAIL")).Once() // tx should be able to queue - assert.NoError(t, txm.Enqueue(t.Name(), tx)) + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx)) wg.Wait() // wait to be picked up and processed // no transactions stored inflight txs list @@ -255,7 +255,7 @@ func TestTxm(t *testing.T) { // signature status is nil (handled automatically) // tx should be able to queue - assert.NoError(t, txm.Enqueue(t.Name(), tx)) + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx)) wg.Wait() // wait to be picked up and processed waitFor(empty) // txs cleared quickly @@ -287,7 +287,7 @@ func TestTxm(t *testing.T) { // all signature statuses are nil, handled automatically // tx should be able to queue - assert.NoError(t, txm.Enqueue(t.Name(), tx)) + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx)) wg.Wait() // wait to be picked up and processed waitFor(empty) // txs cleared after timeout @@ -323,7 +323,7 @@ func TestTxm(t *testing.T) { // all signature statuses are nil, handled automatically // tx should be able to queue - assert.NoError(t, txm.Enqueue(t.Name(), tx)) + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx)) wg.Wait() // wait to be picked up and processed waitFor(empty) // txs cleared after timeout @@ -366,7 +366,7 @@ func TestTxm(t *testing.T) { } // tx should be able to queue - assert.NoError(t, txm.Enqueue(t.Name(), tx)) + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx)) wg.Wait() // wait to be picked up and processed waitFor(empty) // txs cleared after timeout @@ -403,7 +403,7 @@ func TestTxm(t *testing.T) { } // tx should be able to queue - assert.NoError(t, txm.Enqueue(t.Name(), tx)) + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx)) wg.Wait() // wait to be picked up and processed waitFor(empty) // txs cleared after timeout @@ -444,7 +444,7 @@ func TestTxm(t *testing.T) { } // tx should be able to queue - assert.NoError(t, txm.Enqueue(t.Name(), tx)) + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx)) wg.Wait() // wait to be picked up and processed waitFor(empty) // inflight txs cleared after timeout @@ -492,7 +492,7 @@ func TestTxm(t *testing.T) { } // tx should be able to queue - assert.NoError(t, txm.Enqueue(t.Name(), tx)) + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx)) wg.Wait() // wait to be picked up and processed waitFor(empty) // inflight txs cleared after timeout @@ -526,7 +526,7 @@ func TestTxm(t *testing.T) { } // tx should be able to queue - assert.NoError(t, txm.Enqueue(t.Name(), tx)) + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx)) wg.Wait() // wait to be picked up and processed waitFor(empty) // inflight txs cleared after timeout @@ -568,7 +568,7 @@ func TestTxm(t *testing.T) { } // send tx - assert.NoError(t, txm.Enqueue(t.Name(), tx)) + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx)) wg.Wait() // no transactions stored inflight txs list @@ -617,7 +617,7 @@ func TestTxm(t *testing.T) { } // send tx - with disabled fee bumping - assert.NoError(t, txm.Enqueue(t.Name(), tx, SetFeeBumpPeriod(0))) + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, SetFeeBumpPeriod(0))) wg.Wait() // no transactions stored inflight txs list @@ -657,7 +657,7 @@ func TestTxm(t *testing.T) { } // send tx - with disabled fee bumping and disabled compute unit limit - assert.NoError(t, txm.Enqueue(t.Name(), tx, SetFeeBumpPeriod(0), SetComputeUnitLimit(0))) + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, SetFeeBumpPeriod(0), SetComputeUnitLimit(0))) wg.Wait() // no transactions stored inflight txs list @@ -721,7 +721,7 @@ func TestTxm_Enqueue(t *testing.T) { return mc, nil }, cfg, mkey, lggr) - require.ErrorContains(t, txm.Enqueue("txmUnstarted", &solana.Transaction{}), "not started") + require.ErrorContains(t, txm.Enqueue(ctx, "txmUnstarted", &solana.Transaction{}), "not started") require.NoError(t, txm.Start(ctx)) t.Cleanup(func() { require.NoError(t, txm.Close()) }) @@ -739,10 +739,10 @@ func TestTxm_Enqueue(t *testing.T) { for _, run := range txs { t.Run(run.name, func(t *testing.T) { if !run.fail { - assert.NoError(t, txm.Enqueue(run.name, run.tx)) + assert.NoError(t, txm.Enqueue(ctx, run.name, run.tx)) return } - assert.Error(t, txm.Enqueue(run.name, run.tx)) + assert.Error(t, txm.Enqueue(ctx, run.name, run.tx)) }) } } diff --git a/pkg/solana/txm/txm_test.go b/pkg/solana/txm/txm_test.go index 851aebf89..ff7831b02 100644 --- a/pkg/solana/txm/txm_test.go +++ b/pkg/solana/txm/txm_test.go @@ -14,6 +14,7 @@ import ( "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + "github.com/smartcontractkit/chainlink-common/pkg/services/servicetest" solanaClient "github.com/smartcontractkit/chainlink-solana/pkg/solana/client" "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" "github.com/smartcontractkit/chainlink-solana/pkg/solana/txm" @@ -75,19 +76,18 @@ func TestTxm_Integration(t *testing.T) { txm := txm.NewTxm("localnet", getClient, cfg, mkey, lggr) // track initial balance - initBal, err := client.Balance(pubKey) + initBal, err := client.Balance(ctx, pubKey) assert.NoError(t, err) assert.NotEqual(t, uint64(0), initBal) // should be funded - // start - require.NoError(t, txm.Start(ctx)) + servicetest.Run(t, txm) // already started assert.Error(t, txm.Start(ctx)) createTx := func(signer solana.PublicKey, sender solana.PublicKey, receiver solana.PublicKey, amt uint64) *solana.Transaction { // create transfer tx - hash, err := client.LatestBlockhash() + hash, err := client.LatestBlockhash(ctx) assert.NoError(t, err) tx, err := solana.NewTransaction( []solana.Instruction{ @@ -105,47 +105,37 @@ func TestTxm_Integration(t *testing.T) { } // enqueue txs (must pass to move on to load test) - require.NoError(t, txm.Enqueue("test_success_0", createTx(pubKey, pubKey, pubKeyReceiver, solana.LAMPORTS_PER_SOL))) - require.Error(t, txm.Enqueue("test_invalidSigner", createTx(pubKeyReceiver, pubKey, pubKeyReceiver, solana.LAMPORTS_PER_SOL))) // cannot sign tx before enqueuing - require.NoError(t, txm.Enqueue("test_invalidReceiver", createTx(pubKey, pubKey, solana.PublicKey{}, solana.LAMPORTS_PER_SOL))) + require.NoError(t, txm.Enqueue(ctx, "test_success_0", createTx(pubKey, pubKey, pubKeyReceiver, solana.LAMPORTS_PER_SOL))) + require.Error(t, txm.Enqueue(ctx, "test_invalidSigner", createTx(pubKeyReceiver, pubKey, pubKeyReceiver, solana.LAMPORTS_PER_SOL))) // cannot sign tx before enqueuing + require.NoError(t, txm.Enqueue(ctx, "test_invalidReceiver", createTx(pubKey, pubKey, solana.PublicKey{}, solana.LAMPORTS_PER_SOL))) time.Sleep(500 * time.Millisecond) // pause 0.5s for new blockhash - require.NoError(t, txm.Enqueue("test_success_1", createTx(pubKey, pubKey, pubKeyReceiver, solana.LAMPORTS_PER_SOL))) - require.NoError(t, txm.Enqueue("test_txFail", createTx(pubKey, pubKey, pubKeyReceiver, 1000*solana.LAMPORTS_PER_SOL))) + require.NoError(t, txm.Enqueue(ctx, "test_success_1", createTx(pubKey, pubKey, pubKeyReceiver, solana.LAMPORTS_PER_SOL))) + require.NoError(t, txm.Enqueue(ctx, "test_txFail", createTx(pubKey, pubKey, pubKeyReceiver, 1000*solana.LAMPORTS_PER_SOL))) // load test: try to overload txs, confirm, or simulation for i := 0; i < 1000; i++ { - assert.NoError(t, txm.Enqueue(fmt.Sprintf("load_%d", i), createTx(loadTestKey.PublicKey(), loadTestKey.PublicKey(), loadTestKey.PublicKey(), uint64(i)))) + assert.NoError(t, txm.Enqueue(ctx, fmt.Sprintf("load_%d", i), createTx(loadTestKey.PublicKey(), loadTestKey.PublicKey(), loadTestKey.PublicKey(), uint64(i)))) time.Sleep(10 * time.Millisecond) // ~100 txs per second (note: have run 5ms delays for ~200tx/s succesfully) } // check to make sure all txs are closed out from inflight list (longest should last MaxConfirmTimeout) - ctx, cancel := context.WithCancel(ctx) - t.Cleanup(cancel) - ticker := time.NewTicker(time.Second) - defer ticker.Stop() - loop: - for { - select { - case <-ctx.Done(): - assert.Equal(t, 0, txm.InflightTxs()) - break loop - case <-ticker.C: - if txm.InflightTxs() == 0 { - cancel() // exit for loop - } - } - } - assert.NoError(t, txm.Close()) + require.Eventually(t, func() bool { + txs := txm.InflightTxs() + t.Logf("Inflight txs: %d", txs) + return txs == 0 + }, tests.WaitTimeout(t), time.Second) // check balance changes - senderBal, err := client.Balance(pubKey) - assert.NoError(t, err) - assert.Greater(t, initBal, senderBal) - assert.Greater(t, initBal-senderBal, 2*solana.LAMPORTS_PER_SOL) // balance change = sent + fees + senderBal, err := client.Balance(ctx, pubKey) + if assert.NoError(t, err) { + assert.Greater(t, initBal, senderBal) + assert.Greater(t, initBal-senderBal, 2*solana.LAMPORTS_PER_SOL) // balance change = sent + fees + } - receiverBal, err := client.Balance(pubKeyReceiver) - assert.NoError(t, err) - assert.Equal(t, 2*solana.LAMPORTS_PER_SOL, receiverBal) + receiverBal, err := client.Balance(ctx, pubKeyReceiver) + if assert.NoError(t, err) { + assert.Equal(t, 2*solana.LAMPORTS_PER_SOL, receiverBal) + } }) } }