diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 77252a5c..2590f69b 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -22,4 +22,4 @@ jobs: - name: golangci-lint uses: golangci/golangci-lint-action@v3 with: - args: --timeout=5m --config dev/.golangci.yaml + args: --timeout=5m --config .golangci.yaml diff --git a/dev/.golangci.yaml b/.golangci.yaml similarity index 100% rename from dev/.golangci.yaml rename to .golangci.yaml diff --git a/.mockery.yaml b/.mockery.yaml index 9b00d38a..522551d5 100644 --- a/.mockery.yaml +++ b/.mockery.yaml @@ -4,6 +4,9 @@ mockname: "Mock{{.InterfaceName}}" outpkg: mocks filename: "mock_{{.InterfaceName}}.go" packages: + github.com/xmtp/xmtpd/pkg/registry: + interfaces: + NodesContract: github.com/xmtp/xmtpd/pkg/indexer/blockchain: interfaces: ChainClient: diff --git a/README.md b/README.md index 565bff4f..7a53547b 100644 --- a/README.md +++ b/README.md @@ -8,9 +8,9 @@ `xmtpd` (XMTP daemon) is an experimental version of XMTP node software. It is **not** the node software that currently forms the XMTP network. -After `xmtpd` meets specific functional requirements, the plan is for it to become the node software that powers the XMTP network. +After `xmtpd` meets specific functional requirements, the plan is for it to become the node software that powers the XMTP network. -Some of these requirements include reaching functional parity with the current node software and reliably performing data replication without data loss. +Some of these requirements include reaching functional parity with the current node software and reliably performing data replication without data loss. To keep up with and provide feedback about `xmtpd` development, see the [Issues tab](https://github.com/xmtp/xmtpd/issues) in this repo. @@ -46,7 +46,7 @@ dev/down To start the `xmtpd` node, run: ```sh -dev/start +dev/run ``` ## Test the node @@ -57,7 +57,7 @@ To run tests against the `xmtpd` node, run: dev/test ``` -These tests provide a full suite of unit and integration tests for the `xmtpd` repo to help ensure and maintain correctness of the code over time and to avoid regressions as the code evolves. You can explore the tests by taking a look at any files with the suffix `_test.go`. +These tests provide a full suite of unit and integration tests for the `xmtpd` repo to help ensure and maintain correctness of the code over time and to avoid regressions as the code evolves. You can explore the tests by taking a look at any files with the suffix `_test.go`. ## Monitor the node @@ -65,19 +65,19 @@ The `xmtpd` node build provides two options for monitoring your node. - To access your local Prometheus instance to explore node metrics, run: - ```sh - open http://localhost:9090 - ``` + ```sh + open http://localhost:9090 + ``` - To learn how to query node data in Prometheus, see [Metric Types in Prometheus and PromQL](https://promlabs.com/blog/2020/09/25/metric-types-in-prometheus-and-promql) and [The Anatomy of a PromQL Query](https://promlabs.com/blog/2020/06/18/the-anatomy-of-a-promql-query/). + To learn how to query node data in Prometheus, see [Metric Types in Prometheus and PromQL](https://promlabs.com/blog/2020/09/25/metric-types-in-prometheus-and-promql) and [The Anatomy of a PromQL Query](https://promlabs.com/blog/2020/06/18/the-anatomy-of-a-promql-query/). - To access your local Grafana instance to explore and build node dashboards, run: - ```sh - open http://localhost:3000 - ``` + ```sh + open http://localhost:3000 + ``` - To learn how to visualize node data in Grafana, see [Prometheus Histograms with Grafana Heatmaps](https://towardsdatascience.com/prometheus-histograms-with-grafana-heatmaps-d556c28612c7) and [How to visualize Prometheus histograms in Grafana](https://grafana.com/blog/2020/06/23/how-to-visualize-prometheus-histograms-in-grafana/). + To learn how to visualize node data in Grafana, see [Prometheus Histograms with Grafana Heatmaps](https://towardsdatascience.com/prometheus-histograms-with-grafana-heatmaps-d556c28612c7) and [How to visualize Prometheus histograms in Grafana](https://grafana.com/blog/2020/06/23/how-to-visualize-prometheus-histograms-in-grafana/). # Contributing @@ -87,17 +87,17 @@ Please follow the [style guide](https://google.github.io/styleguide/go/decisions Submit and land a PR to https://github.com/xmtp/proto. Then run: - ```sh - dev/generate - ``` +```sh +dev/generate +``` ## Modifying the database schema Create a new migration by running: - ```sh - dev/gen-migration - ``` +```sh +dev/gen-migration +``` Fill in the migrations in the generated files. If you are unfamiliar with migrations, you may follow [this guide](https://github.com/golang-migrate/migrate/blob/master/MIGRATIONS.md). The database is PostgreSQL and the driver is PGX. @@ -105,6 +105,6 @@ Fill in the migrations in the generated files. If you are unfamiliar with migrat We use [sqlc](https://docs.sqlc.dev/en/latest/index.html) to generate the code for our DB queries. Modify the `queries.sql` file, and then run: - ```sh - sqlc generate - ``` +```sh +sqlc generate +``` diff --git a/cmd/replication/main.go b/cmd/replication/main.go index 1fc47020..ba4d2254 100644 --- a/cmd/replication/main.go +++ b/cmd/replication/main.go @@ -9,6 +9,7 @@ import ( "syscall" "github.com/jessevdk/go-flags" + "github.com/xmtp/xmtpd/pkg/config" "github.com/xmtp/xmtpd/pkg/registry" "github.com/xmtp/xmtpd/pkg/server" "github.com/xmtp/xmtpd/pkg/tracing" @@ -18,7 +19,7 @@ import ( var Commit string -var options server.Options +var options config.ServerOptions func main() { if _, err := flags.Parse(&options); err != nil { @@ -81,7 +82,7 @@ func fatal(msg string, args ...any) { log.Fatalf(msg, args...) } -func buildLogger(options server.Options) (*zap.Logger, *zap.Config, error) { +func buildLogger(options config.ServerOptions) (*zap.Logger, *zap.Config, error) { atom := zap.NewAtomicLevel() level := zapcore.InfoLevel err := level.Set(options.LogLevel) diff --git a/dev/generate b/dev/generate index 6d05c30e..e67ccd8d 100755 --- a/dev/generate +++ b/dev/generate @@ -2,6 +2,7 @@ set -e go generate ./... +rm -f pkg/mocks/* mockery ./dev/abigen diff --git a/dev/local.env b/dev/local.env new file mode 100755 index 00000000..1b13508a --- /dev/null +++ b/dev/local.env @@ -0,0 +1,11 @@ +#!/bin/bash + +source dev/contracts/.env + +export CHAIN_RPC_URL=$DOCKER_RPC_URL # From contracts/.env +export NODE_PRIVATE_KEY=$PRIVATE_KEY # From contracts/.env +export WRITER_CONNECTION_STRING="postgres://postgres:xmtp@localhost:8765/postgres?sslmode=disable" +NODES_CONTRACT_ADDRESS="$(jq -r '.deployedTo' build/Nodes.json)" # Built by contracts/deploy-local +export NODES_CONTRACT_ADDRESS +GROUP_MESSAGES_CONTRACT_ADDRESS="$(jq -r '.deployedTo' build/GroupMessages.json)" # Built by contracts/deploy-local +export GROUP_MESSAGES_CONTRACT_ADDRESS \ No newline at end of file diff --git a/dev/run b/dev/run new file mode 100755 index 00000000..571d09d1 --- /dev/null +++ b/dev/run @@ -0,0 +1,12 @@ +#!/bin/bash + +set -eu + +. dev/local.env + +go run cmd/replication/main.go \ + --db.writer-connection-string=$WRITER_CONNECTION_STRING \ + --private-key=${NODE_PRIVATE_KEY} \ + --contracts.nodes-address=$NODES_CONTRACT_ADDRESS \ + --contracts.messages-address=$GROUP_MESSAGES_CONTRACT_ADDRESS \ + --contracts.rpc-url=$CHAIN_RPC_URL \ No newline at end of file diff --git a/pkg/server/options.go b/pkg/config/options.go similarity index 66% rename from pkg/server/options.go rename to pkg/config/options.go index 06a004ef..835b89f7 100644 --- a/pkg/server/options.go +++ b/pkg/config/options.go @@ -1,15 +1,20 @@ -package server +package config import ( "time" - - "github.com/xmtp/xmtpd/pkg/indexer" ) type ApiOptions struct { Port int `short:"p" long:"port" description:"Port to listen on" default:"5050"` } +type ContractsOptions struct { + RpcUrl string `long:"rpc-url" description:"Blockchain RPC URL"` + NodesContractAddress string `long:"nodes-address" description:"Node contract address"` + MessagesContractAddress string `long:"messages-address" description:"Message contract address"` + RefreshInterval time.Duration `long:"refresh-interval" description:"Refresh interval" default:"60s"` +} + type DbOptions struct { ReaderConnectionString string `long:"reader-connection-string" description:"Reader connection string"` WriterConnectionString string `long:"writer-connection-string" description:"Writer connection string" required:"true"` @@ -19,14 +24,14 @@ type DbOptions struct { WaitForDB time.Duration `long:"wait-for" description:"wait for DB on start, up to specified duration"` } -type Options struct { +type ServerOptions struct { LogLevel string `short:"l" long:"log-level" description:"Define the logging level, supported strings are: DEBUG, INFO, WARN, ERROR, DPANIC, PANIC, FATAL, and their lower-case forms." default:"INFO"` //nolint:staticcheck LogEncoding string `long:"log-encoding" description:"Log encoding format. Either console or json" choice:"console" choice:"json" default:"console"` PrivateKeyString string `long:"private-key" description:"Private key to use for the node"` - API ApiOptions `group:"API Options" namespace:"api"` - DB DbOptions `group:"Database Options" namespace:"db"` - Contracts indexer.ContractsOptions `group:"Contracts Options" namespace:"contracts"` + API ApiOptions `group:"API Options" namespace:"api"` + DB DbOptions `group:"Database Options" namespace:"db"` + Contracts ContractsOptions `group:"Contracts Options" namespace:"contracts"` } diff --git a/pkg/indexer/indexer.go b/pkg/indexer/indexer.go index a64f4eab..3516453a 100644 --- a/pkg/indexer/indexer.go +++ b/pkg/indexer/indexer.go @@ -7,6 +7,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" "github.com/xmtp/xmtpd/pkg/abis" + "github.com/xmtp/xmtpd/pkg/config" "github.com/xmtp/xmtpd/pkg/db/queries" "github.com/xmtp/xmtpd/pkg/indexer/blockchain" "github.com/xmtp/xmtpd/pkg/indexer/storer" @@ -15,16 +16,31 @@ import ( ) // Start the indexer and run until the context is canceled -func StartIndexer(ctx context.Context, logger *zap.Logger, queries *queries.Queries, options ContractsOptions) error { - builder := blockchain.NewRpcLogStreamBuilder(options.RpcUrl, logger) +func StartIndexer( + ctx context.Context, + logger *zap.Logger, + queries *queries.Queries, + cfg config.ContractsOptions, +) error { + builder := blockchain.NewRpcLogStreamBuilder(cfg.RpcUrl, logger) messagesTopic, err := buildMessagesTopic() if err != nil { return err } - messagesChannel := builder.ListenForContractEvent(0, common.HexToAddress(options.MessagesContractAddress), []common.Hash{messagesTopic}) - indexLogs(ctx, messagesChannel, logger.Named("indexLogs").With(zap.String("contractAddress", options.MessagesContractAddress)), storer.NewGroupMessageStorer(queries, logger)) + messagesChannel := builder.ListenForContractEvent( + 0, + common.HexToAddress(cfg.MessagesContractAddress), + []common.Hash{messagesTopic}, + ) + + indexLogs( + ctx, + messagesChannel, + logger.Named("indexLogs").With(zap.String("contractAddress", cfg.MessagesContractAddress)), + storer.NewGroupMessageStorer(queries, logger), + ) streamer, err := builder.Build() if err != nil { @@ -41,7 +57,12 @@ If an event fails to be stored, and the error is retryable, it will sleep for 10 The only non-retriable errors should be things like malformed events or failed validations. */ -func indexLogs(ctx context.Context, eventChannel <-chan types.Log, logger *zap.Logger, logStorer storer.LogStorer) { +func indexLogs( + ctx context.Context, + eventChannel <-chan types.Log, + logger *zap.Logger, + logStorer storer.LogStorer, +) { var err storer.LogStorageError // We don't need to listen for the ctx.Done() here, since the eventChannel will be closed when the parent context is canceled for event := range eventChannel { diff --git a/pkg/indexer/options.go b/pkg/indexer/options.go deleted file mode 100644 index be240f45..00000000 --- a/pkg/indexer/options.go +++ /dev/null @@ -1,7 +0,0 @@ -package indexer - -type ContractsOptions struct { - RpcUrl string `log:"rpc-url" description:"Blockchain RPC URL"` - NodesContractAddress string `long:"nodes-address" description:"Node contract address"` - MessagesContractAddress string `long:"messages-address" description:"Message contract address"` -} diff --git a/pkg/mocks/mock_NodesContract.go b/pkg/mocks/mock_NodesContract.go new file mode 100644 index 00000000..8a1532b4 --- /dev/null +++ b/pkg/mocks/mock_NodesContract.go @@ -0,0 +1,94 @@ +// Code generated by mockery v2.43.2. DO NOT EDIT. + +package mocks + +import ( + bind "github.com/ethereum/go-ethereum/accounts/abi/bind" + mock "github.com/stretchr/testify/mock" + abis "github.com/xmtp/xmtpd/pkg/abis" +) + +// MockNodesContract is an autogenerated mock type for the NodesContract type +type MockNodesContract struct { + mock.Mock +} + +type MockNodesContract_Expecter struct { + mock *mock.Mock +} + +func (_m *MockNodesContract) EXPECT() *MockNodesContract_Expecter { + return &MockNodesContract_Expecter{mock: &_m.Mock} +} + +// AllNodes provides a mock function with given fields: opts +func (_m *MockNodesContract) AllNodes(opts *bind.CallOpts) ([]abis.NodesNodeWithId, error) { + ret := _m.Called(opts) + + if len(ret) == 0 { + panic("no return value specified for AllNodes") + } + + var r0 []abis.NodesNodeWithId + var r1 error + if rf, ok := ret.Get(0).(func(*bind.CallOpts) ([]abis.NodesNodeWithId, error)); ok { + return rf(opts) + } + if rf, ok := ret.Get(0).(func(*bind.CallOpts) []abis.NodesNodeWithId); ok { + r0 = rf(opts) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]abis.NodesNodeWithId) + } + } + + if rf, ok := ret.Get(1).(func(*bind.CallOpts) error); ok { + r1 = rf(opts) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockNodesContract_AllNodes_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AllNodes' +type MockNodesContract_AllNodes_Call struct { + *mock.Call +} + +// AllNodes is a helper method to define mock.On call +// - opts *bind.CallOpts +func (_e *MockNodesContract_Expecter) AllNodes(opts interface{}) *MockNodesContract_AllNodes_Call { + return &MockNodesContract_AllNodes_Call{Call: _e.mock.On("AllNodes", opts)} +} + +func (_c *MockNodesContract_AllNodes_Call) Run(run func(opts *bind.CallOpts)) *MockNodesContract_AllNodes_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(*bind.CallOpts)) + }) + return _c +} + +func (_c *MockNodesContract_AllNodes_Call) Return(_a0 []abis.NodesNodeWithId, _a1 error) *MockNodesContract_AllNodes_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockNodesContract_AllNodes_Call) RunAndReturn(run func(*bind.CallOpts) ([]abis.NodesNodeWithId, error)) *MockNodesContract_AllNodes_Call { + _c.Call.Return(run) + return _c +} + +// NewMockNodesContract creates a new instance of MockNodesContract. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockNodesContract(t interface { + mock.TestingT + Cleanup(func()) +}) *MockNodesContract { + mock := &MockNodesContract{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/pkg/registry/contractRegistry.go b/pkg/registry/contractRegistry.go new file mode 100644 index 00000000..22531eed --- /dev/null +++ b/pkg/registry/contractRegistry.go @@ -0,0 +1,228 @@ +package registry + +import ( + "bytes" + "context" + "strings" + "sync" + "time" + + "github.com/ethereum/go-ethereum/accounts/abi/bind" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/crypto" + "github.com/xmtp/xmtpd/pkg/abis" + "github.com/xmtp/xmtpd/pkg/config" + "go.uber.org/zap" +) + +const ( + CONTRACT_CALL_TIMEOUT = 10 * time.Second +) + +/* +* +The SmartContractRegistry notifies listeners of changes to the nodes by polling the contract +and diffing the returned node list with what is currently in memory. + +This allows it to operate statelessly and not require a database, with a trade-off for latency. + +Given how infrequently this list changes, that trade-off seems acceptable. +*/ +type SmartContractRegistry struct { + ctx context.Context + contract NodesContract + logger *zap.Logger + refreshInterval time.Duration + // Hold the raw nodes for easier comparison with results from the contract + rawNodes map[uint16]abis.NodesNode + rawNodesMutex sync.Mutex + // Hold the transformed nodes for easier processing + nodes map[uint16]Node + nodesMutex sync.RWMutex + newNodesNotifier *notifier[[]Node] + changedNodeNotifiers map[uint16]*notifier[Node] + changedNodeNotifiersMutex sync.RWMutex +} + +func NewSmartContractRegistry( + ethclient bind.ContractCaller, + logger *zap.Logger, + options config.ContractsOptions, +) (*SmartContractRegistry, error) { + contract, err := abis.NewNodesCaller( + common.HexToAddress(options.NodesContractAddress), + ethclient, + ) + + if err != nil { + return nil, err + } + + return &SmartContractRegistry{ + contract: contract, + refreshInterval: options.RefreshInterval, + logger: logger.Named("smartContractRegistry"), + newNodesNotifier: newNotifier[[]Node](), + rawNodes: make(map[uint16]abis.NodesNode), + nodes: make(map[uint16]Node), + changedNodeNotifiers: make(map[uint16]*notifier[Node]), + }, nil +} + +/* +* +Loads the initial state from the contract and starts a background refresh loop. + +To stop refreshing callers should cancel the context +* +*/ +func (s *SmartContractRegistry) Start(ctx context.Context) error { + s.ctx = ctx + // If we can't load the data at least once, fail to start the service + if err := s.refreshData(); err != nil { + return err + } + + go s.refreshLoop() + + return nil +} + +func (s *SmartContractRegistry) OnNewNodes() (<-chan []Node, CancelSubscription) { + return s.newNodesNotifier.register() +} + +func (s *SmartContractRegistry) OnChangedNode( + nodeId uint16, +) (<-chan Node, CancelSubscription) { + s.changedNodeNotifiersMutex.Lock() + defer s.changedNodeNotifiersMutex.Unlock() + + notifier, ok := s.changedNodeNotifiers[nodeId] + if !ok { + notifier = newNotifier[Node]() + s.changedNodeNotifiers[nodeId] = notifier + } + return notifier.register() +} + +func (s *SmartContractRegistry) GetNodes() ([]Node, error) { + s.nodesMutex.RLock() + defer s.nodesMutex.RUnlock() + + nodes := make([]Node, len(s.nodes)) + for idx, node := range s.nodes { + nodes[idx] = node + } + return nodes, nil +} + +func (s *SmartContractRegistry) refreshLoop() { + ticker := time.NewTicker(s.refreshInterval) + for { + select { + case <-s.ctx.Done(): + return + case <-ticker.C: + if err := s.refreshData(); err != nil { + s.logger.Error("Failed to refresh data", zap.Error(err)) + } + } + } +} + +func (s *SmartContractRegistry) refreshData() error { + rawNodes, err := s.loadFromContract() + if err != nil { + return err + } + + // Lock the mutex to protect against concurrent writes + s.rawNodesMutex.Lock() + defer s.rawNodesMutex.Unlock() + + newNodes := []Node{} + for _, rawNodeWithId := range rawNodes { + existingValue, ok := s.rawNodes[rawNodeWithId.NodeId] + if !ok { + // New node found + newNodes = append(newNodes, convertNode(rawNodeWithId)) + } else if !equalRawNodes(existingValue, rawNodeWithId.Node) { + s.processChangedNode(rawNodeWithId) + } else { + // No change, skip + continue + } + s.rawNodes[rawNodeWithId.NodeId] = rawNodeWithId.Node + } + + if len(newNodes) > 0 { + s.processNewNodes(newNodes) + } + + return nil +} + +func (s *SmartContractRegistry) processNewNodes(nodes []Node) { + s.logger.Info("processing new nodes", zap.Int("count", len(nodes)), zap.Any("nodes", nodes)) + s.newNodesNotifier.trigger(nodes) + + s.nodesMutex.Lock() + defer s.nodesMutex.Unlock() + for _, node := range nodes { + s.nodes[node.NodeId] = node + } +} + +func (s *SmartContractRegistry) processChangedNode(rawNode abis.NodesNodeWithId) { + s.nodesMutex.Lock() + defer s.nodesMutex.Unlock() + s.changedNodeNotifiersMutex.RLock() + defer s.changedNodeNotifiersMutex.RUnlock() + + node := convertNode(rawNode) + s.nodes[node.NodeId] = node + s.logger.Info("processing changed node", zap.Any("node", node)) + if registry, ok := s.changedNodeNotifiers[node.NodeId]; ok { + registry.trigger(node) + } +} + +func (s *SmartContractRegistry) loadFromContract() ([]abis.NodesNodeWithId, error) { + ctx, cancel := context.WithTimeout(s.ctx, CONTRACT_CALL_TIMEOUT) + defer cancel() + nodes, err := s.contract.AllNodes(&bind.CallOpts{Context: ctx}) + if err != nil { + return nil, err + } + + return nodes, nil +} + +func convertNode(rawNode abis.NodesNodeWithId) Node { + // Unmarshal the signing key. + // If invalid, mark the config as being invalid as well. Clients should treat the + // node as unhealthy in this case + signingKey, err := crypto.UnmarshalPubkey(rawNode.Node.SigningKeyPub) + isValidConfig := err == nil + + httpAddress := rawNode.Node.HttpAddress + + // Ensure the httpAddress is well formed + if !strings.HasPrefix(httpAddress, "https://") && !strings.HasPrefix(httpAddress, "http://") { + isValidConfig = false + } + + return Node{ + NodeId: rawNode.NodeId, + SigningKey: signingKey, + HttpAddress: httpAddress, + IsHealthy: rawNode.Node.IsHealthy, + IsValidConfig: isValidConfig, + } +} + +func equalRawNodes(a abis.NodesNode, b abis.NodesNode) bool { + return bytes.Equal(a.SigningKeyPub, b.SigningKeyPub) && a.HttpAddress == b.HttpAddress && + a.IsHealthy == b.IsHealthy +} diff --git a/pkg/registry/contractRegistry_test.go b/pkg/registry/contractRegistry_test.go new file mode 100644 index 00000000..0f93a29a --- /dev/null +++ b/pkg/registry/contractRegistry_test.go @@ -0,0 +1,136 @@ +package registry + +import ( + "context" + "math/rand" + "testing" + "time" + + "github.com/ethereum/go-ethereum/accounts/abi/bind" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/xmtp/xmtpd/pkg/abis" + "github.com/xmtp/xmtpd/pkg/config" + "github.com/xmtp/xmtpd/pkg/mocks" + testUtils "github.com/xmtp/xmtpd/pkg/testing" +) + +func TestContractRegistryNewNodes(t *testing.T) { + registry, err := NewSmartContractRegistry( + nil, + testUtils.NewLog(t), + config.ContractsOptions{RefreshInterval: 100 * time.Millisecond}, + ) + require.NoError(t, err) + + mockContract := mocks.NewMockNodesContract(t) + mockContract.EXPECT(). + AllNodes(mock.Anything). + Return([]abis.NodesNodeWithId{ + {NodeId: 1, Node: abis.NodesNode{HttpAddress: "http://foo.com"}}, + {NodeId: 2, Node: abis.NodesNode{HttpAddress: "https://bar.com"}}, + }, nil) + + registry.contract = mockContract + + sub, cancelSub := registry.OnNewNodes() + defer cancelSub() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + require.NoError(t, registry.Start(ctx)) + newNodes := <-sub + require.Equal( + t, + []Node{ + {NodeId: 1, HttpAddress: "http://foo.com"}, + {NodeId: 2, HttpAddress: "https://bar.com"}, + }, + newNodes, + ) +} + +func TestContractRegistryChangedNodes(t *testing.T) { + registry, err := NewSmartContractRegistry( + nil, + testUtils.NewLog(t), + config.ContractsOptions{RefreshInterval: 10 * time.Millisecond}, + ) + require.NoError(t, err) + + mockContract := mocks.NewMockNodesContract(t) + + hasSentInitialValues := false + // The first call, we'll set the address to foo.com. + // Subsequent calls will set the address to bar.com + mockContract.EXPECT(). + AllNodes(mock.Anything).RunAndReturn(func(*bind.CallOpts) ([]abis.NodesNodeWithId, error) { + httpAddress := "http://foo.com" + if !hasSentInitialValues { + hasSentInitialValues = true + } else { + httpAddress = "http://bar.com" + } + return []abis.NodesNodeWithId{ + {NodeId: 1, Node: abis.NodesNode{HttpAddress: httpAddress}}, + }, nil + }) + + // Override the contract in the registry with a mock before calling Start + registry.contract = mockContract + + sub, cancelSub := registry.OnChangedNode(1) + defer cancelSub() + counterSub, cancelCounter := registry.OnChangedNode(1) + getCurrentCount := countChannel(counterSub) + defer cancelCounter() + go func() { + for node := range sub { + require.Equal(t, node.HttpAddress, "http://bar.com") + } + }() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + require.NoError(t, registry.Start(ctx)) + time.Sleep(100 * time.Millisecond) + require.Equal(t, getCurrentCount(), 1) +} + +func TestStopOnContextCancel(t *testing.T) { + registry, err := NewSmartContractRegistry( + nil, + testUtils.NewLog(t), + config.ContractsOptions{RefreshInterval: 10 * time.Millisecond}, + ) + require.NoError(t, err) + + mockContract := mocks.NewMockNodesContract(t) + mockContract.EXPECT(). + AllNodes(mock.Anything). + RunAndReturn(func(*bind.CallOpts) ([]abis.NodesNodeWithId, error) { + return []abis.NodesNodeWithId{ + { + NodeId: uint16(rand.Intn(1000)), + Node: abis.NodesNode{HttpAddress: "http://foo.com"}, + }, + }, nil + }) + + registry.contract = mockContract + + sub, cancelSub := registry.OnNewNodes() + defer cancelSub() + getCurrentCount := countChannel(sub) + + ctx, cancel := context.WithCancel(context.Background()) + require.NoError(t, registry.Start(ctx)) + time.Sleep(100 * time.Millisecond) + require.Greater(t, getCurrentCount(), 0) + // Cancel the context + cancel() + // Wait for a little bit to give the cancellation time to take effect + time.Sleep(10 * time.Millisecond) + currentNodeCount := getCurrentCount() + time.Sleep(100 * time.Millisecond) + require.Equal(t, currentNodeCount, getCurrentCount()) +} diff --git a/pkg/registry/fixedRegistry.go b/pkg/registry/fixedRegistry.go new file mode 100644 index 00000000..38a8e064 --- /dev/null +++ b/pkg/registry/fixedRegistry.go @@ -0,0 +1,41 @@ +package registry + +import "sync" + +// TODO: Delete this or move to a test file +type FixedNodeRegistry struct { + nodes []Node + newNodeNotifier *notifier[[]Node] + changedNodeNotifiers map[uint16]*notifier[Node] + changedNodeNotifiersMutex sync.Mutex +} + +func NewFixedNodeRegistry(nodes []Node) *FixedNodeRegistry { + return &FixedNodeRegistry{nodes: nodes} +} + +func (r *FixedNodeRegistry) GetNodes() ([]Node, error) { + return r.nodes, nil +} + +func (f *FixedNodeRegistry) AddNode(node Node) { + f.nodes = append(f.nodes, node) +} + +func (f *FixedNodeRegistry) OnNewNodes() (<-chan []Node, CancelSubscription) { + return f.newNodeNotifier.register() +} + +func (f *FixedNodeRegistry) OnChangedNode( + nodeId uint16, +) (<-chan Node, CancelSubscription) { + f.changedNodeNotifiersMutex.Lock() + defer f.changedNodeNotifiersMutex.Unlock() + + registry, ok := f.changedNodeNotifiers[nodeId] + if !ok { + registry = newNotifier[Node]() + f.changedNodeNotifiers[nodeId] = registry + } + return registry.register() +} diff --git a/pkg/registry/interface.go b/pkg/registry/interface.go new file mode 100644 index 00000000..79de34a9 --- /dev/null +++ b/pkg/registry/interface.go @@ -0,0 +1,38 @@ +package registry + +import ( + "crypto/ecdsa" + + "github.com/ethereum/go-ethereum/accounts/abi/bind" + "github.com/xmtp/xmtpd/pkg/abis" +) + +type Node struct { + NodeId uint16 + SigningKey *ecdsa.PublicKey + HttpAddress string + IsHealthy bool + IsValidConfig bool +} + +/* +* +A dumbed down interface of abis.NodesCaller for generating mocks +*/ +type NodesContract interface { + AllNodes(opts *bind.CallOpts) ([]abis.NodesNodeWithId, error) +} + +// Unregister the callback +type CancelSubscription func() + +/* +* +The NodeRegistry is responsible for fetching the list of nodes from the registry contract +and notifying listeners when the list of nodes changes. +*/ +type NodeRegistry interface { + GetNodes() ([]Node, error) + OnNewNodes() (<-chan []Node, CancelSubscription) + OnChangedNode(uint16) (<-chan Node, CancelSubscription) +} diff --git a/pkg/registry/notifier.go b/pkg/registry/notifier.go new file mode 100644 index 00000000..69da4e8d --- /dev/null +++ b/pkg/registry/notifier.go @@ -0,0 +1,42 @@ +package registry + +import ( + "sync" +) + +type notifier[ValueType any] struct { + channels map[chan<- ValueType]bool + mutex sync.RWMutex +} + +func newNotifier[ValueType any]() *notifier[ValueType] { + return ¬ifier[ValueType]{ + channels: make(map[chan<- ValueType]bool), + } +} + +func (c *notifier[Node]) register() (<-chan Node, CancelSubscription) { + c.mutex.Lock() + defer c.mutex.Unlock() + newChannel := make(chan Node) + c.channels[newChannel] = true + + return newChannel, func() { + c.mutex.Lock() + defer c.mutex.Unlock() + close(newChannel) + delete(c.channels, newChannel) + } +} + +func (c *notifier[any]) trigger(value any) { + c.mutex.RLock() + defer c.mutex.RUnlock() + for channel := range c.channels { + + // Write to the channel in a goroutine to avoid blocking the caller + go func(channel chan<- any) { + channel <- value + }(channel) + } +} diff --git a/pkg/registry/notifier_test.go b/pkg/registry/notifier_test.go new file mode 100644 index 00000000..3cf7faf3 --- /dev/null +++ b/pkg/registry/notifier_test.go @@ -0,0 +1,85 @@ +package registry + +import ( + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestNotifier(t *testing.T) { + registry := newNotifier[int]() + channel, cancel := registry.register() + getCurrentCount := countChannel(channel) + + // Make sure the value is getting writter to the cannel + registry.trigger(1) + // Sleep for 10ms since we read from the channel in a goroutinee + time.Sleep(10 * time.Millisecond) + require.Equal(t, 1, getCurrentCount()) + + // Trigger again and make sure it still works + registry.trigger(1) + time.Sleep(10 * time.Millisecond) + require.Equal(t, 2, getCurrentCount()) + + // Unregister the subscription + cancel() + registry.trigger(1) + time.Sleep(10 * time.Millisecond) + // Make sure the count hasn't changed + require.Equal(t, 2, getCurrentCount()) +} + +func TestNotifierMultiple(t *testing.T) { + registry := newNotifier[int]() + + channel1, cancel1 := registry.register() + getCurrentCount1 := countChannel(channel1) + channel2, cancel2 := registry.register() + getCurrentCount2 := countChannel(channel2) + + registry.trigger(1) + time.Sleep(10 * time.Millisecond) + require.Equal(t, 1, getCurrentCount1()) + require.Equal(t, 1, getCurrentCount2()) + + cancel1() + registry.trigger(1) + time.Sleep(10 * time.Millisecond) + require.Equal(t, 1, getCurrentCount1()) + require.Equal(t, 2, getCurrentCount2()) + cancel2() +} + +func TestNotifierConcurrent(t *testing.T) { + registry := newNotifier[int]() + channel, cancel := registry.register() + getCurrentCount := countChannel(channel) + defer cancel() + + for i := 0; i < 100; i++ { + go registry.trigger(1) + } + time.Sleep(30 * time.Millisecond) + require.Equal(t, 100, getCurrentCount()) +} + +func countChannel[Kind any](ch <-chan Kind) func() int { + var count int + var mutex sync.RWMutex + go func() { + for range ch { + mutex.Lock() + count++ + mutex.Unlock() + } + }() + + return func() int { + mutex.RLock() + defer mutex.RUnlock() + return count + } +} diff --git a/pkg/registry/registry.go b/pkg/registry/registry.go deleted file mode 100644 index ddd83221..00000000 --- a/pkg/registry/registry.go +++ /dev/null @@ -1,31 +0,0 @@ -package registry - -type Node struct { - NodeId int - SigningKey []byte - HttpAddress string - MtlsCert []byte -} - -type NodeRegistry interface { - GetNodes() ([]Node, error) - // OnChange() -} - -// TODO: Delete this or move to a test file - -type FixedNodeRegistry struct { - nodes []Node -} - -func NewFixedNodeRegistry(nodes []Node) *FixedNodeRegistry { - return &FixedNodeRegistry{nodes: nodes} -} - -func (r *FixedNodeRegistry) GetNodes() ([]Node, error) { - return r.nodes, nil -} - -func (f *FixedNodeRegistry) AddNode(node Node) { - f.nodes = append(f.nodes, node) -} diff --git a/pkg/server/server.go b/pkg/server/server.go index 2c169066..023d6899 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -7,17 +7,19 @@ import ( "net" "os" "os/signal" + "strings" "syscall" "github.com/ethereum/go-ethereum/crypto" "github.com/xmtp/xmtpd/pkg/api" + "github.com/xmtp/xmtpd/pkg/config" "github.com/xmtp/xmtpd/pkg/db" "github.com/xmtp/xmtpd/pkg/registry" "go.uber.org/zap" ) type ReplicationServer struct { - options Options + options config.ServerOptions log *zap.Logger ctx context.Context cancel context.CancelFunc @@ -28,7 +30,7 @@ type ReplicationServer struct { // Can add reader DB later if needed } -func NewReplicationServer(ctx context.Context, log *zap.Logger, options Options, nodeRegistry registry.NodeRegistry) (*ReplicationServer, error) { +func NewReplicationServer(ctx context.Context, log *zap.Logger, options config.ServerOptions, nodeRegistry registry.NodeRegistry) (*ReplicationServer, error) { var err error s := &ReplicationServer{ options: options, @@ -72,5 +74,5 @@ func (s *ReplicationServer) Shutdown() { } func parsePrivateKey(privateKeyString string) (*ecdsa.PrivateKey, error) { - return crypto.HexToECDSA(privateKeyString) + return crypto.HexToECDSA(strings.TrimPrefix(privateKeyString, "0x")) } diff --git a/pkg/server/server_test.go b/pkg/server/server_test.go index 9bf33b04..6e0d2674 100644 --- a/pkg/server/server_test.go +++ b/pkg/server/server_test.go @@ -8,6 +8,7 @@ import ( "github.com/ethereum/go-ethereum/crypto" "github.com/stretchr/testify/require" + "github.com/xmtp/xmtpd/pkg/config" "github.com/xmtp/xmtpd/pkg/registry" test "github.com/xmtp/xmtpd/pkg/testing" ) @@ -19,12 +20,12 @@ func NewTestServer(t *testing.T, registry registry.NodeRegistry) *ReplicationSer privateKey, err := crypto.GenerateKey() require.NoError(t, err) - server, err := NewReplicationServer(context.Background(), log, Options{ + server, err := NewReplicationServer(context.Background(), log, config.ServerOptions{ PrivateKeyString: hex.EncodeToString(crypto.FromECDSA(privateKey)), - API: ApiOptions{ + API: config.ApiOptions{ Port: 0, }, - DB: DbOptions{ + DB: config.DbOptions{ WriterConnectionString: WRITER_DB_CONNECTION_STRING, ReadTimeout: time.Second * 10, WriteTimeout: time.Second * 10,