From 665f5196199cd298becf4473247207c34261b4d6 Mon Sep 17 00:00:00 2001 From: Dylan Tinianov Date: Thu, 17 Oct 2024 11:09:48 -0400 Subject: [PATCH] Implement RPC Client Methods (#845) * MultiNode integration setup * Update MultiNode files * Add MultiNode flag * Remove internal dependency * Fix build * Fix import cycle * tidy * Update client_test.go * lint * Fix duplicate metrics * Add chain multinode flag * Extend client * Implement rpc client methods * Add defaults * Add latest block methods * Address comments * lint * Fix lint overflow issues * Update transaction_sender.go * Fix lint * Validate node config * Update toml.go * Add SendOnly nodes * Use pointers on config * Add test outlines * Use test context * Use configured selection mode * Set defaults * lint * Add nil check * Add client test * Add subscription test * tidy * Fix imports * Update chain_test.go * Update multinode.go * Add comments * Update multinode.go * Wrap multinode config * Fix imports * Update .golangci.yml * Use MultiNode * Add multinode to txm * Use MultiNode * Update chain.go * Update balance_test.go * Add retries * Fix head * Update client.go * lint * lint * Address comments * Remove total difficulty * Register polling subs * Extract MultiNodeClient * Remove caching changes * Undo cache changes * Fix tests * Fix variables * Fix imports * lint * Update txm_internal_test.go * lint * Update multinode_client.go * Add tests * Add dial comment * Update pkg/solana/client/multinode_client.go Co-authored-by: Dmytro Haidashenko <34754799+dhaidashenko@users.noreply.github.com> --------- Co-authored-by: Dmytro Haidashenko <34754799+dhaidashenko@users.noreply.github.com> --- pkg/solana/chain.go | 26 +- pkg/solana/chain_test.go | 59 ++++ pkg/solana/client/client.go | 78 +----- pkg/solana/client/multinode_client.go | 307 +++++++++++++++++++++ pkg/solana/client/multinode_client_test.go | 148 ++++++++++ 5 files changed, 530 insertions(+), 88 deletions(-) create mode 100644 pkg/solana/client/multinode_client.go create mode 100644 pkg/solana/client/multinode_client_test.go diff --git a/pkg/solana/chain.go b/pkg/solana/chain.go index a4076ed1d..56f37bc07 100644 --- a/pkg/solana/chain.go +++ b/pkg/solana/chain.go @@ -89,8 +89,8 @@ type chain struct { lggr logger.Logger // if multiNode is enabled, the clientCache will not be used - multiNode *mn.MultiNode[mn.StringID, *client.Client] - txSender *mn.TransactionSender[*solanago.Transaction, mn.StringID, *client.Client] + multiNode *mn.MultiNode[mn.StringID, *client.MultiNodeClient] + txSender *mn.TransactionSender[*solanago.Transaction, mn.StringID, *client.MultiNodeClient] // tracking node chain id for verification clientCache map[string]*verifiedCachedClient // map URL -> {client, chainId} [mainnet/testnet/devnet/localnet] @@ -235,28 +235,29 @@ func newChain(id string, cfg *config.TOMLConfig, ks loop.Keystore, lggr logger.L mnCfg := &cfg.MultiNode - var nodes []mn.Node[mn.StringID, *client.Client] - var sendOnlyNodes []mn.SendOnlyNode[mn.StringID, *client.Client] + var nodes []mn.Node[mn.StringID, *client.MultiNodeClient] + var sendOnlyNodes []mn.SendOnlyNode[mn.StringID, *client.MultiNodeClient] for i, nodeInfo := range cfg.ListNodes() { - rpcClient, err := client.NewClient(nodeInfo.URL.String(), cfg, DefaultRequestTimeout, logger.Named(lggr, "Client."+*nodeInfo.Name)) + rpcClient, err := client.NewMultiNodeClient(nodeInfo.URL.String(), cfg, DefaultRequestTimeout, logger.Named(lggr, "Client."+*nodeInfo.Name)) if err != nil { lggr.Warnw("failed to create client", "name", *nodeInfo.Name, "solana-url", nodeInfo.URL.String(), "err", err.Error()) return nil, fmt.Errorf("failed to create client: %w", err) } - newNode := mn.NewNode[mn.StringID, *client.Head, *client.Client]( - mnCfg, mnCfg, lggr, *nodeInfo.URL.URL(), nil, *nodeInfo.Name, - i, mn.StringID(id), 0, rpcClient, chainFamily) - if nodeInfo.SendOnly { - sendOnlyNodes = append(sendOnlyNodes, newNode) + newSendOnly := mn.NewSendOnlyNode[mn.StringID, *client.MultiNodeClient]( + lggr, *nodeInfo.URL.URL(), *nodeInfo.Name, mn.StringID(id), rpcClient) + sendOnlyNodes = append(sendOnlyNodes, newSendOnly) } else { + newNode := mn.NewNode[mn.StringID, *client.Head, *client.MultiNodeClient]( + mnCfg, mnCfg, lggr, *nodeInfo.URL.URL(), nil, *nodeInfo.Name, + i, mn.StringID(id), 0, rpcClient, chainFamily) nodes = append(nodes, newNode) } } - multiNode := mn.NewMultiNode[mn.StringID, *client.Client]( + multiNode := mn.NewMultiNode[mn.StringID, *client.MultiNodeClient]( lggr, mnCfg.SelectionMode(), mnCfg.LeaseDuration(), @@ -273,7 +274,7 @@ func newChain(id string, cfg *config.TOMLConfig, ks loop.Keystore, lggr logger.L return 0 // TODO ClassifySendError(err, clientErrors, logger.Sugared(logger.Nop()), tx, common.Address{}, false) } - txSender := mn.NewTransactionSender[*solanago.Transaction, mn.StringID, *client.Client]( + txSender := mn.NewTransactionSender[*solanago.Transaction, mn.StringID, *client.MultiNodeClient]( lggr, mn.StringID(id), chainFamily, @@ -395,6 +396,7 @@ func (c *chain) ChainID() string { } // getClient returns a client, randomly selecting one from available and valid nodes +// If multinode is enabled, it will return a client using the multinode selection instead. func (c *chain) getClient() (client.ReaderWriter, error) { if c.cfg.MultiNode.Enabled() { return c.multiNode.SelectRPC() diff --git a/pkg/solana/chain_test.go b/pkg/solana/chain_test.go index 0e52741d2..9f32096d6 100644 --- a/pkg/solana/chain_test.go +++ b/pkg/solana/chain_test.go @@ -59,6 +59,7 @@ func TestSolanaChain_GetClient(t *testing.T) { ChainID: ptr("devnet"), Chain: ch, } + cfg.SetDefaults() testChain := chain{ id: "devnet", cfg: cfg, @@ -125,6 +126,61 @@ func TestSolanaChain_GetClient(t *testing.T) { assert.NoError(t, err) } +func TestSolanaChain_MultiNode_GetClient(t *testing.T) { + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + out := fmt.Sprintf(TestSolanaGenesisHashTemplate, client.MainnetGenesisHash) // mainnet genesis hash + if !strings.Contains(r.URL.Path, "/mismatch") { + // devnet gensis hash + out = fmt.Sprintf(TestSolanaGenesisHashTemplate, client.DevnetGenesisHash) + } + _, err := w.Write([]byte(out)) + require.NoError(t, err) + })) + defer mockServer.Close() + + ch := solcfg.Chain{} + ch.SetDefaults() + mn := solcfg.MultiNodeConfig{ + MultiNode: solcfg.MultiNode{ + Enabled: ptr(true), + }, + } + mn.SetDefaults() + + cfg := &solcfg.TOMLConfig{ + ChainID: ptr("devnet"), + Chain: ch, + MultiNode: mn, + } + cfg.Nodes = []*solcfg.Node{ + { + Name: ptr("devnet"), + URL: config.MustParseURL(mockServer.URL + "/1"), + }, + { + Name: ptr("devnet"), + URL: config.MustParseURL(mockServer.URL + "/2"), + }, + } + + testChain, err := newChain("devnet", cfg, nil, logger.Test(t)) + require.NoError(t, err) + + err = testChain.Start(tests.Context(t)) + require.NoError(t, err) + defer func() { + closeErr := testChain.Close() + require.NoError(t, closeErr) + }() + + selectedClient, err := testChain.getClient() + assert.NoError(t, err) + + id, err := selectedClient.ChainID(tests.Context(t)) + assert.NoError(t, err) + assert.Equal(t, "devnet", id.String()) +} + func TestSolanaChain_VerifiedClient(t *testing.T) { ctx := tests.Context(t) called := false @@ -157,6 +213,8 @@ func TestSolanaChain_VerifiedClient(t *testing.T) { ChainID: ptr("devnet"), Chain: ch, } + cfg.SetDefaults() + testChain := chain{ cfg: cfg, lggr: logger.Test(t), @@ -205,6 +263,7 @@ func TestSolanaChain_VerifiedClient_ParallelClients(t *testing.T) { Enabled: ptr(true), Chain: ch, } + cfg.SetDefaults() testChain := chain{ id: "devnet", cfg: cfg, diff --git a/pkg/solana/client/client.go b/pkg/solana/client/client.go index d649cb2c8..9b55cf595 100644 --- a/pkg/solana/client/client.go +++ b/pkg/solana/client/client.go @@ -4,17 +4,14 @@ import ( "context" "errors" "fmt" - "math/big" "time" - mn "github.com/smartcontractkit/chainlink-solana/pkg/solana/client/multinode" - "github.com/gagliardetto/solana-go" "github.com/gagliardetto/solana-go/rpc" - "golang.org/x/sync/singleflight" - "github.com/smartcontractkit/chainlink-common/pkg/logger" + "golang.org/x/sync/singleflight" + mn "github.com/smartcontractkit/chainlink-solana/pkg/solana/client/multinode" "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" "github.com/smartcontractkit/chainlink-solana/pkg/solana/monitor" ) @@ -68,27 +65,6 @@ type Client struct { requestGroup *singleflight.Group } -type Head struct { - rpc.GetBlockResult -} - -func (h *Head) BlockNumber() int64 { - if !h.IsValid() { - return 0 - } - // nolint:gosec - // G115: integer overflow conversion uint64 -> int64 - return int64(*h.BlockHeight) -} - -func (h *Head) BlockDifficulty() *big.Int { - return nil -} - -func (h *Head) IsValid() bool { - return h.BlockHeight != nil -} - func NewClient(endpoint string, cfg config.Config, requestTimeout time.Duration, log logger.Logger) (*Client, error) { return &Client{ url: endpoint, @@ -103,56 +79,6 @@ func NewClient(endpoint string, cfg config.Config, requestTimeout time.Duration, }, nil } -var _ mn.RPCClient[mn.StringID, *Head] = (*Client)(nil) -var _ mn.SendTxRPCClient[*solana.Transaction] = (*Client)(nil) - -// TODO: BCI-4061: Implement Client for MultiNode - -func (c *Client) Dial(ctx context.Context) error { - //TODO implement me - panic("implement me") -} - -func (c *Client) SubscribeToHeads(ctx context.Context) (<-chan *Head, mn.Subscription, error) { - //TODO implement me - panic("implement me") -} - -func (c *Client) SubscribeToFinalizedHeads(ctx context.Context) (<-chan *Head, mn.Subscription, error) { - //TODO implement me - panic("implement me") -} - -func (c *Client) Ping(ctx context.Context) error { - //TODO implement me - panic("implement me") -} - -func (c *Client) IsSyncing(ctx context.Context) (bool, error) { - //TODO implement me - panic("implement me") -} - -func (c *Client) UnsubscribeAllExcept(subs ...mn.Subscription) { - //TODO implement me - panic("implement me") -} - -func (c *Client) Close() { - //TODO implement me - panic("implement me") -} - -func (c *Client) GetInterceptedChainInfo() (latest, highestUserObservations mn.ChainInfo) { - //TODO implement me - panic("implement me") -} - -func (c *Client) SendTransaction(ctx context.Context, tx *solana.Transaction) error { - // TODO: Implement - return nil -} - func (c *Client) latency(name string) func() { start := time.Now() return func() { diff --git a/pkg/solana/client/multinode_client.go b/pkg/solana/client/multinode_client.go new file mode 100644 index 000000000..086699cef --- /dev/null +++ b/pkg/solana/client/multinode_client.go @@ -0,0 +1,307 @@ +package client + +import ( + "context" + "errors" + "fmt" + "math/big" + "sync" + "time" + + "github.com/gagliardetto/solana-go" + "github.com/gagliardetto/solana-go/rpc" + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/services" + + mn "github.com/smartcontractkit/chainlink-solana/pkg/solana/client/multinode" + "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" +) + +type Head struct { + BlockHeight *uint64 + BlockHash *solana.Hash +} + +func (h *Head) BlockNumber() int64 { + if !h.IsValid() { + return 0 + } + // nolint:gosec + // G115: integer overflow conversion uint64 -> int64 + return int64(*h.BlockHeight) +} + +func (h *Head) BlockDifficulty() *big.Int { + // Not relevant for Solana + return nil +} + +func (h *Head) IsValid() bool { + return h != nil && h.BlockHeight != nil && h.BlockHash != nil +} + +var _ mn.RPCClient[mn.StringID, *Head] = (*MultiNodeClient)(nil) +var _ mn.SendTxRPCClient[*solana.Transaction] = (*MultiNodeClient)(nil) + +type MultiNodeClient struct { + Client + cfg *config.TOMLConfig + stateMu sync.RWMutex // protects state* fields + subsSliceMu sync.RWMutex + subs map[mn.Subscription]struct{} + + // chStopInFlight can be closed to immediately cancel all in-flight requests on + // this RpcClient. Closing and replacing should be serialized through + // stateMu since it can happen on state transitions as well as RpcClient Close. + chStopInFlight chan struct{} + + chainInfoLock sync.RWMutex + // intercepted values seen by callers of the rpcClient excluding health check calls. Need to ensure MultiNode provides repeatable read guarantee + highestUserObservations mn.ChainInfo + // most recent chain info observed during current lifecycle (reseted on DisconnectAll) + latestChainInfo mn.ChainInfo +} + +func NewMultiNodeClient(endpoint string, cfg *config.TOMLConfig, requestTimeout time.Duration, log logger.Logger) (*MultiNodeClient, error) { + client, err := NewClient(endpoint, cfg, requestTimeout, log) + if err != nil { + return nil, err + } + + return &MultiNodeClient{ + Client: *client, + cfg: cfg, + subs: make(map[mn.Subscription]struct{}), + chStopInFlight: make(chan struct{}), + }, nil +} + +// registerSub adds the sub to the rpcClient list +func (m *MultiNodeClient) registerSub(sub mn.Subscription, stopInFLightCh chan struct{}) error { + m.subsSliceMu.Lock() + defer m.subsSliceMu.Unlock() + // ensure that the `sub` belongs to current life cycle of the `rpcClient` and it should not be killed due to + // previous `DisconnectAll` call. + select { + case <-stopInFLightCh: + sub.Unsubscribe() + return fmt.Errorf("failed to register subscription - all in-flight requests were canceled") + default: + } + // TODO: BCI-3358 - delete sub when caller unsubscribes. + m.subs[sub] = struct{}{} + return nil +} + +func (m *MultiNodeClient) Dial(ctx context.Context) error { + // Not relevant for Solana as the RPCs don't need to be dialled. + return nil +} + +func (m *MultiNodeClient) SubscribeToHeads(ctx context.Context) (<-chan *Head, mn.Subscription, error) { + ctx, cancel, chStopInFlight, _ := m.acquireQueryCtx(ctx, m.cfg.TxTimeout()) + defer cancel() + + pollInterval := m.cfg.MultiNode.PollInterval() + if pollInterval == 0 { + return nil, nil, errors.New("PollInterval is 0") + } + timeout := pollInterval + poller, channel := mn.NewPoller[*Head](pollInterval, m.LatestBlock, timeout, m.log) + if err := poller.Start(ctx); err != nil { + return nil, nil, err + } + + err := m.registerSub(&poller, chStopInFlight) + if err != nil { + poller.Unsubscribe() + return nil, nil, err + } + + return channel, &poller, nil +} + +func (m *MultiNodeClient) SubscribeToFinalizedHeads(ctx context.Context) (<-chan *Head, mn.Subscription, error) { + ctx, cancel, chStopInFlight, _ := m.acquireQueryCtx(ctx, m.contextDuration) + defer cancel() + + finalizedBlockPollInterval := m.cfg.MultiNode.FinalizedBlockPollInterval() + if finalizedBlockPollInterval == 0 { + return nil, nil, errors.New("FinalizedBlockPollInterval is 0") + } + timeout := finalizedBlockPollInterval + poller, channel := mn.NewPoller[*Head](finalizedBlockPollInterval, m.LatestFinalizedBlock, timeout, m.log) + if err := poller.Start(ctx); err != nil { + return nil, nil, err + } + + err := m.registerSub(&poller, chStopInFlight) + if err != nil { + poller.Unsubscribe() + return nil, nil, err + } + + return channel, &poller, nil +} + +func (m *MultiNodeClient) LatestBlock(ctx context.Context) (*Head, error) { + // capture chStopInFlight to ensure we are not updating chainInfo with observations related to previous life cycle + ctx, cancel, chStopInFlight, rawRPC := m.acquireQueryCtx(ctx, m.contextDuration) + defer cancel() + + result, err := rawRPC.GetLatestBlockhash(ctx, rpc.CommitmentConfirmed) + if err != nil { + return nil, err + } + + head := &Head{ + BlockHeight: &result.Value.LastValidBlockHeight, + BlockHash: &result.Value.Blockhash, + } + m.onNewHead(ctx, chStopInFlight, head) + return head, nil +} + +func (m *MultiNodeClient) LatestFinalizedBlock(ctx context.Context) (*Head, error) { + ctx, cancel, chStopInFlight, rawRPC := m.acquireQueryCtx(ctx, m.contextDuration) + defer cancel() + + result, err := rawRPC.GetLatestBlockhash(ctx, rpc.CommitmentFinalized) + if err != nil { + return nil, err + } + + head := &Head{ + BlockHeight: &result.Value.LastValidBlockHeight, + BlockHash: &result.Value.Blockhash, + } + m.onNewFinalizedHead(ctx, chStopInFlight, head) + return head, nil +} + +func (m *MultiNodeClient) onNewHead(ctx context.Context, requestCh <-chan struct{}, head *Head) { + if head == nil { + return + } + + m.chainInfoLock.Lock() + defer m.chainInfoLock.Unlock() + if !mn.CtxIsHeathCheckRequest(ctx) { + m.highestUserObservations.BlockNumber = max(m.highestUserObservations.BlockNumber, head.BlockNumber()) + } + select { + case <-requestCh: // no need to update latestChainInfo, as rpcClient already started new life cycle + return + default: + m.latestChainInfo.BlockNumber = head.BlockNumber() + } +} + +func (m *MultiNodeClient) onNewFinalizedHead(ctx context.Context, requestCh <-chan struct{}, head *Head) { + if head == nil { + return + } + m.chainInfoLock.Lock() + defer m.chainInfoLock.Unlock() + if !mn.CtxIsHeathCheckRequest(ctx) { + m.highestUserObservations.FinalizedBlockNumber = max(m.highestUserObservations.FinalizedBlockNumber, head.BlockNumber()) + } + select { + case <-requestCh: // no need to update latestChainInfo, as rpcClient already started new life cycle + return + default: + m.latestChainInfo.FinalizedBlockNumber = head.BlockNumber() + } +} + +// makeQueryCtx returns a context that cancels if: +// 1. Passed in ctx cancels +// 2. Passed in channel is closed +// 3. Default timeout is reached (queryTimeout) +func makeQueryCtx(ctx context.Context, ch services.StopChan, timeout time.Duration) (context.Context, context.CancelFunc) { + var chCancel, timeoutCancel context.CancelFunc + ctx, chCancel = ch.Ctx(ctx) + ctx, timeoutCancel = context.WithTimeout(ctx, timeout) + cancel := func() { + chCancel() + timeoutCancel() + } + return ctx, cancel +} + +func (m *MultiNodeClient) acquireQueryCtx(parentCtx context.Context, timeout time.Duration) (ctx context.Context, cancel context.CancelFunc, + chStopInFlight chan struct{}, raw *rpc.Client) { + // Need to wrap in mutex because state transition can cancel and replace context + m.stateMu.RLock() + chStopInFlight = m.chStopInFlight + cp := *m.rpc + raw = &cp + m.stateMu.RUnlock() + ctx, cancel = makeQueryCtx(parentCtx, chStopInFlight, timeout) + return +} + +func (m *MultiNodeClient) Ping(ctx context.Context) error { + version, err := m.rpc.GetVersion(ctx) + if err != nil { + return fmt.Errorf("ping failed: %v", err) + } + m.log.Debugf("ping client version: %s", version.SolanaCore) + return err +} + +func (m *MultiNodeClient) IsSyncing(ctx context.Context) (bool, error) { + // Not in use for Solana + return false, nil +} + +func (m *MultiNodeClient) UnsubscribeAllExcept(subs ...mn.Subscription) { + m.subsSliceMu.Lock() + defer m.subsSliceMu.Unlock() + + keepSubs := map[mn.Subscription]struct{}{} + for _, sub := range subs { + keepSubs[sub] = struct{}{} + } + + for sub := range m.subs { + if _, keep := keepSubs[sub]; !keep { + sub.Unsubscribe() + delete(m.subs, sub) + } + } +} + +// cancelInflightRequests closes and replaces the chStopInFlight +func (m *MultiNodeClient) cancelInflightRequests() { + m.stateMu.Lock() + defer m.stateMu.Unlock() + close(m.chStopInFlight) + m.chStopInFlight = make(chan struct{}) +} + +func (m *MultiNodeClient) Close() { + defer func() { + err := m.rpc.Close() + if err != nil { + m.log.Errorf("error closing rpc: %v", err) + } + }() + m.cancelInflightRequests() + m.UnsubscribeAllExcept() + m.chainInfoLock.Lock() + m.latestChainInfo = mn.ChainInfo{} + m.chainInfoLock.Unlock() +} + +func (m *MultiNodeClient) GetInterceptedChainInfo() (latest, highestUserObservations mn.ChainInfo) { + m.chainInfoLock.Lock() + defer m.chainInfoLock.Unlock() + return m.latestChainInfo, m.highestUserObservations +} + +func (m *MultiNodeClient) SendTransaction(ctx context.Context, tx *solana.Transaction) error { + // TODO: Use Transaction Sender + _, err := m.SendTx(ctx, tx) + return err +} diff --git a/pkg/solana/client/multinode_client_test.go b/pkg/solana/client/multinode_client_test.go new file mode 100644 index 000000000..04339b5be --- /dev/null +++ b/pkg/solana/client/multinode_client_test.go @@ -0,0 +1,148 @@ +package client + +import ( + "context" + "testing" + "time" + + "github.com/gagliardetto/solana-go" + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" + "github.com/stretchr/testify/require" + + "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" +) + +func initializeMultiNodeClient(t *testing.T) *MultiNodeClient { + url := SetupLocalSolNode(t) + privKey, err := solana.NewRandomPrivateKey() + require.NoError(t, err) + pubKey := privKey.PublicKey() + FundTestAccounts(t, []solana.PublicKey{pubKey}, url) + + requestTimeout := 5 * time.Second + lggr := logger.Test(t) + cfg := config.NewDefault() + enabled := true + cfg.MultiNode.MultiNode.Enabled = &enabled + + c, err := NewMultiNodeClient(url, cfg, requestTimeout, lggr) + require.NoError(t, err) + return c +} + +func TestMultiNodeClient_Ping(t *testing.T) { + c := initializeMultiNodeClient(t) + require.NoError(t, c.Ping(tests.Context(t))) +} + +func TestMultiNodeClient_LatestBlock(t *testing.T) { + c := initializeMultiNodeClient(t) + + t.Run("LatestBlock", func(t *testing.T) { + head, err := c.LatestBlock(tests.Context(t)) + require.NoError(t, err) + require.Equal(t, true, head.IsValid()) + require.NotEqual(t, solana.Hash{}, head.BlockHash) + }) + + t.Run("LatestFinalizedBlock", func(t *testing.T) { + finalizedHead, err := c.LatestFinalizedBlock(tests.Context(t)) + require.NoError(t, err) + require.Equal(t, true, finalizedHead.IsValid()) + require.NotEqual(t, solana.Hash{}, finalizedHead.BlockHash) + }) +} + +func TestMultiNodeClient_HeadSubscriptions(t *testing.T) { + c := initializeMultiNodeClient(t) + + t.Run("SubscribeToHeads", func(t *testing.T) { + ch, sub, err := c.SubscribeToHeads(tests.Context(t)) + require.NoError(t, err) + defer sub.Unsubscribe() + + ctx, cancel := context.WithTimeout(tests.Context(t), time.Minute) + defer cancel() + select { + case head := <-ch: + require.NotEqual(t, solana.Hash{}, head.BlockHash) + latest, _ := c.GetInterceptedChainInfo() + require.Equal(t, head.BlockNumber(), latest.BlockNumber) + case <-ctx.Done(): + t.Fatal("failed to receive head: ", ctx.Err()) + } + }) + + t.Run("SubscribeToFinalizedHeads", func(t *testing.T) { + finalizedCh, finalizedSub, err := c.SubscribeToFinalizedHeads(tests.Context(t)) + require.NoError(t, err) + defer finalizedSub.Unsubscribe() + + ctx, cancel := context.WithTimeout(tests.Context(t), time.Minute) + defer cancel() + select { + case finalizedHead := <-finalizedCh: + require.NotEqual(t, solana.Hash{}, finalizedHead.BlockHash) + latest, _ := c.GetInterceptedChainInfo() + require.Equal(t, finalizedHead.BlockNumber(), latest.FinalizedBlockNumber) + case <-ctx.Done(): + t.Fatal("failed to receive finalized head: ", ctx.Err()) + } + }) +} + +type mockSub struct { + unsubscribed bool +} + +func newMockSub() *mockSub { + return &mockSub{unsubscribed: false} +} + +func (s *mockSub) Unsubscribe() { + s.unsubscribed = true +} +func (s *mockSub) Err() <-chan error { + return nil +} + +func TestMultiNodeClient_RegisterSubs(t *testing.T) { + c := initializeMultiNodeClient(t) + + t.Run("registerSub", func(t *testing.T) { + sub := newMockSub() + err := c.registerSub(sub, make(chan struct{})) + require.NoError(t, err) + require.Len(t, c.subs, 1) + c.UnsubscribeAllExcept() + }) + + t.Run("chStopInFlight returns error and unsubscribes", func(t *testing.T) { + chStopInFlight := make(chan struct{}) + close(chStopInFlight) + sub := newMockSub() + err := c.registerSub(sub, chStopInFlight) + require.Error(t, err) + require.Equal(t, true, sub.unsubscribed) + }) + + t.Run("UnsubscribeAllExcept", func(t *testing.T) { + chStopInFlight := make(chan struct{}) + sub1 := newMockSub() + sub2 := newMockSub() + err := c.registerSub(sub1, chStopInFlight) + require.NoError(t, err) + err = c.registerSub(sub2, chStopInFlight) + require.NoError(t, err) + require.Len(t, c.subs, 2) + + c.UnsubscribeAllExcept(sub1) + require.Len(t, c.subs, 1) + require.Equal(t, true, sub2.unsubscribed) + + c.UnsubscribeAllExcept() + require.Len(t, c.subs, 0) + require.Equal(t, true, sub1.unsubscribed) + }) +}