From 8692adb8b1a5e62313c25c75eeed639f05251971 Mon Sep 17 00:00:00 2001 From: Bob Stevens <35038919+restevens402@users.noreply.github.com> Date: Sun, 6 Oct 2024 16:09:22 -0700 Subject: [PATCH] Create MockHost for testing and update method visibility Introduce MockHost and MockNetwork structs for unit tests to simulate network interactions. Adjust method visibility in WorkHandlerManager for better accessibility and fix logging in AppConfig. Add new unit tests for worker selection and handlers. --- cmd/masa-node/config.go | 2 +- pkg/config/app.go | 2 +- pkg/tests/mock_host.go | 203 ++++++++++++++++++++++++++++++++++ pkg/tests/worker_test.go | 133 ++++++++++++++++++++++ pkg/workers/worker_manager.go | 6 +- 5 files changed, 341 insertions(+), 5 deletions(-) create mode 100644 pkg/tests/mock_host.go create mode 100644 pkg/tests/worker_test.go diff --git a/cmd/masa-node/config.go b/cmd/masa-node/config.go index a149b1ba..1c394453 100644 --- a/cmd/masa-node/config.go +++ b/cmd/masa-node/config.go @@ -3,7 +3,7 @@ package main import ( "github.com/masa-finance/masa-oracle/node" "github.com/masa-finance/masa-oracle/pkg/config" - pubsub "github.com/masa-finance/masa-oracle/pkg/pubsub" + "github.com/masa-finance/masa-oracle/pkg/pubsub" "github.com/masa-finance/masa-oracle/pkg/workers" ) diff --git a/pkg/config/app.go b/pkg/config/app.go index 26d48d8a..3d7fdf87 100644 --- a/pkg/config/app.go +++ b/pkg/config/app.go @@ -177,7 +177,7 @@ func (c *AppConfig) setFileConfig(path string) { func (c *AppConfig) setEnvVariableConfig() { err := godotenv.Load() if err != nil { - logrus.Error("[-] Error loading .env file") + logrus.Errorf("[-] Error loading .env file %s", err) } viper.AutomaticEnv() } diff --git a/pkg/tests/mock_host.go b/pkg/tests/mock_host.go new file mode 100644 index 00000000..2c900651 --- /dev/null +++ b/pkg/tests/mock_host.go @@ -0,0 +1,203 @@ +package tests + +import ( + "context" + "fmt" + + "github.com/libp2p/go-libp2p/core/connmgr" + "github.com/libp2p/go-libp2p/core/event" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/peerstore" + "github.com/libp2p/go-libp2p/core/protocol" + "github.com/multiformats/go-multiaddr" +) + +type MockHost struct { + id peer.ID +} + +func (m *MockHost) Peerstore() peerstore.Peerstore { + fmt.Printf("Peerstore called\n") + return nil +} + +func (m *MockHost) Addrs() []multiaddr.Multiaddr { + fmt.Printf("Addrs called\n") + addr1, _ := multiaddr.NewMultiaddr("/ip4/127.0.0.1/tcp/4001") + return []multiaddr.Multiaddr{addr1} +} + +func (m *MockHost) Network() network.Network { + fmt.Printf("Network called\n") + return &MockNetwork{} +} + +func (m *MockHost) Mux() protocol.Switch { + fmt.Printf("Mux called\n") + return nil +} + +func (m *MockHost) Connect(ctx context.Context, pi peer.AddrInfo) error { + fmt.Printf("Connect called with peer info: %v\n", pi) + if ctx == nil { + fmt.Printf("nil context\n") + } + return nil +} + +func (m *MockHost) SetStreamHandler(pid protocol.ID, handler network.StreamHandler) { + fmt.Printf("SetStreamHandler called with protocol ID: %s\n", pid) + if handler == nil { + fmt.Printf("nil handler\n") + } +} + +func (m *MockHost) SetStreamHandlerMatch(id protocol.ID, f func(protocol.ID) bool, handler network.StreamHandler) { + fmt.Printf("SetStreamHandlerMatch called with protocol ID: %s\n", id) + if handler == nil { + fmt.Printf("nil handler\n") + } +} + +func (m *MockHost) RemoveStreamHandler(pid protocol.ID) { + fmt.Printf("RemoveStreamHandler called with protocol ID: %s\n", pid) +} + +func (m *MockHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.ID) (network.Stream, error) { + fmt.Printf("NewStream called with peer: %s, protocol IDs: %v\n", p, pids) + if ctx == nil { + fmt.Printf("nil context\n") + } + return nil, nil +} + +func (m *MockHost) Close() error { + fmt.Printf("Close called\n") + return nil +} + +func (m *MockHost) ConnManager() connmgr.ConnManager { + fmt.Printf("ConnManager called\n") + return nil +} + +func (m *MockHost) EventBus() event.Bus { + fmt.Printf("EventBus called\n") + return nil +} + +func (m *MockHost) ID() peer.ID { + fmt.Printf("ID called\n") + return m.id +} + +type MockNetwork struct{} + +func (m *MockNetwork) Close() error { + fmt.Printf("Close called\n") + return nil +} + +func (m *MockNetwork) CanDial(p peer.ID, addr multiaddr.Multiaddr) bool { + fmt.Printf("CanDial called with peer: %s, addr: %s\n", p, addr) + return true +} + +func (m *MockNetwork) DialPeer(ctx context.Context, id peer.ID) (network.Conn, error) { + fmt.Printf("DialPeer called with peer: %s\n", id) + if ctx == nil { + fmt.Printf("nil context\n") + } + return nil, nil +} + +func (m *MockNetwork) SetStreamHandler(handler network.StreamHandler) { + fmt.Printf("SetStreamHandler called\n") + if handler == nil { + fmt.Printf("nil handler\n") + } +} + +func (m *MockNetwork) NewStream(ctx context.Context, id peer.ID) (network.Stream, error) { + fmt.Printf("NewStream called with peer: %s\n", id) + if ctx == nil { + fmt.Printf("nil context\n") + } + return nil, nil +} + +func (m *MockNetwork) Listen(m2 ...multiaddr.Multiaddr) error { + fmt.Printf("Listen called with addresses: %v\n", m2) + return nil +} + +func (m *MockNetwork) ResourceManager() network.ResourceManager { + fmt.Printf("ResourceManager called\n") + return nil +} + +func (m *MockNetwork) Peerstore() peerstore.Peerstore { + fmt.Printf("Peerstore called\n") + return nil +} + +func (m *MockNetwork) LocalPeer() peer.ID { + fmt.Printf("LocalPeer called\n") + return "mockLocalPeerID" +} + +func (m *MockNetwork) ListenAddresses() []multiaddr.Multiaddr { + fmt.Printf("ListenAddresses called\n") + addr1, _ := multiaddr.NewMultiaddr("/ip4/127.0.0.1/tcp/4001") + return []multiaddr.Multiaddr{addr1} +} + +func (m *MockNetwork) InterfaceListenAddresses() ([]multiaddr.Multiaddr, error) { + fmt.Printf("InterfaceListenAddresses called\n") + addr1, _ := multiaddr.NewMultiaddr("/ip4/127.0.0.1/tcp/4001") + return []multiaddr.Multiaddr{addr1}, nil +} + +func (m *MockNetwork) Connectedness(p peer.ID) network.Connectedness { + fmt.Printf("Connectedness called with peer: %s\n", p) + return network.NotConnected +} + +func (m *MockNetwork) Peers() []peer.ID { + fmt.Printf("Peers called\n") + return []peer.ID{} +} + +func (m *MockNetwork) Conns() []network.Conn { + fmt.Printf("Conns called\n") + return []network.Conn{} +} + +func (m *MockNetwork) ConnsToPeer(p peer.ID) []network.Conn { + fmt.Printf("ConnsToPeer called with peer: %s\n", p) + return []network.Conn{} +} + +func (m *MockNetwork) Notify(notifier network.Notifiee) { + fmt.Printf("Notify called\n") + if notifier == nil { + fmt.Printf("nil notifier\n") + } +} + +func (m *MockNetwork) StopNotify(notifier network.Notifiee) { + fmt.Printf("StopNotify called\n") + if notifier == nil { + fmt.Printf("nil notifier\n") + } +} + +func (m *MockNetwork) ClosePeer(p peer.ID) error { + fmt.Printf("ClosePeer called with peer: %s\n", p) + return nil +} + +func (m *MockNetwork) RemovePeer(p peer.ID) { + fmt.Printf("RemovePeer called with peer: %s\n", p) +} diff --git a/pkg/tests/worker_test.go b/pkg/tests/worker_test.go new file mode 100644 index 00000000..08df8163 --- /dev/null +++ b/pkg/tests/worker_test.go @@ -0,0 +1,133 @@ +package tests + +import ( + "context" + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "github.com/masa-finance/masa-oracle/node" + "github.com/masa-finance/masa-oracle/pkg/pubsub" + "github.com/masa-finance/masa-oracle/pkg/workers" + datatypes "github.com/masa-finance/masa-oracle/pkg/workers/types" +) + +func TestWorkers(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Workers Suite") +} + +var _ = Describe("Worker Selection", func() { + var ( + oracleNode1 *node.OracleNode + oracleNode2 *node.OracleNode + category pubsub.WorkerCategory + ) + + BeforeEach(func() { + ctx := context.Background() + + // Start the first node with a random identity + n1, err := node.NewOracleNode(ctx, node.EnableStaked, node.EnableRandomIdentity, node.IsTwitterScraper) + Expect(err).ToNot(HaveOccurred()) + err = n1.Start() + Expect(err).ToNot(HaveOccurred()) + + // Get the address of the first node to use as a bootstrap node + addrs, err := n1.GetP2PMultiAddrs() + Expect(err).ToNot(HaveOccurred()) + + var bootNodes []string + for _, addr := range addrs { + bootNodes = append(bootNodes, addr.String()) + } + + // Start the second node with a random identity and bootstrap to the first node + n2, err := node.NewOracleNode(ctx, node.EnableStaked, node.EnableRandomIdentity, node.IsTelegramScraper, node.WithBootNodes(bootNodes...)) + Expect(err).ToNot(HaveOccurred()) + err = n2.Start() + Expect(err).ToNot(HaveOccurred()) + + n2.Host = &MockHost{id: "mockHostID1"} + oracleNode1 = n1 + oracleNode2 = n2 + category = pubsub.CategoryTwitter + }) + + AfterEach(func() { + //oracleNode1.Stop() + //oracleNode2.Stop() + }) + + Describe("GetEligibleWorkers", func() { + It("should return empty remote workers and a local worker", func() { + // Wait for the nodes to see each other + Eventually(func() bool { + datas := oracleNode1.NodeTracker.GetAllNodeData() + return len(datas) == 2 + }, "30s").Should(BeTrue()) + + Eventually(func() bool { + datas := oracleNode2.NodeTracker.GetAllNodeData() + return len(datas) == 2 + }, "30s").Should(BeTrue()) + + remoteWorkers, localWorker := workers.GetEligibleWorkers(oracleNode1, category) + + Expect(remoteWorkers).To(BeEmpty()) + Expect(localWorker).ToNot(BeNil()) + }) + }) +}) + +var _ = Describe("WorkHandlerManager", func() { + var ( + oracleNode *node.OracleNode + manager *workers.WorkHandlerManager + ) + + BeforeEach(func() { + manager = workers.NewWorkHandlerManager(workers.EnableTwitterWorker) + ctx := context.Background() + var err error + // Start the first node with a random identity + oracleNode, err = node.NewOracleNode(ctx, node.EnableStaked, node.EnableRandomIdentity, node.IsTwitterScraper) + Expect(err).ToNot(HaveOccurred()) + err = oracleNode.Start() + Expect(err).ToNot(HaveOccurred()) + }) + + Describe("Add and Get WorkHandler", func() { + It("should add and retrieve a work handler", func() { + handler, exists := manager.GetWorkHandler(datatypes.Twitter) + Expect(exists).To(BeTrue()) + Expect(handler).ToNot(BeNil()) + }) + + It("should return false for non-existent work handler", func() { + _, exists := manager.GetWorkHandler(datatypes.WorkerType("NonExistent")) + Expect(exists).To(BeFalse()) + }) + }) + + Describe("DistributeWork", func() { + It("should distribute work to eligible workers", func() { + workRequest := datatypes.WorkRequest{ + WorkType: datatypes.Twitter, + Data: []byte(`{"query": "test", "count": 10}`), + } + response := manager.DistributeWork(oracleNode, workRequest) + Expect(response.Error).To(BeEmpty()) + }) + + It("should handle errors in work distribution", func() { + workRequest := datatypes.WorkRequest{ + WorkType: datatypes.WorkerType("InvalidType"), + Data: []byte(`{"query": "test", "count": 10}`), + } + response := manager.DistributeWork(nil, workRequest) + Expect(response.Error).ToNot(BeEmpty()) + }) + }) +}) diff --git a/pkg/workers/worker_manager.go b/pkg/workers/worker_manager.go index 1665fc61..e7558b3a 100644 --- a/pkg/workers/worker_manager.go +++ b/pkg/workers/worker_manager.go @@ -84,8 +84,8 @@ func (whm *WorkHandlerManager) addWorkHandler(wType data_types.WorkerType, handl whm.handlers[wType] = &WorkHandlerInfo{Handler: handler} } -// getWorkHandler retrieves a registered work handler by name. -func (whm *WorkHandlerManager) getWorkHandler(wType data_types.WorkerType) (WorkHandler, bool) { +// GetWorkHandler retrieves a registered work handler by name. +func (whm *WorkHandlerManager) GetWorkHandler(wType data_types.WorkerType) (WorkHandler, bool) { whm.mu.RLock() defer whm.mu.RUnlock() info, exists := whm.handlers[wType] @@ -248,7 +248,7 @@ func (whm *WorkHandlerManager) sendWorkToWorker(node *node.OracleNode, worker da // ExecuteWork finds and executes the work handler associated with the given name. // It tracks the call count and execution duration for the handler. func (whm *WorkHandlerManager) ExecuteWork(workRequest data_types.WorkRequest) (response data_types.WorkResponse) { - handler, exists := whm.getWorkHandler(workRequest.WorkType) + handler, exists := whm.GetWorkHandler(workRequest.WorkType) if !exists { return data_types.WorkResponse{Error: ErrHandlerNotFound.Error()} }