diff --git a/cmd/masa-node/config.go b/cmd/masa-node/config.go index a149b1ba..1c394453 100644 --- a/cmd/masa-node/config.go +++ b/cmd/masa-node/config.go @@ -3,7 +3,7 @@ package main import ( "github.com/masa-finance/masa-oracle/node" "github.com/masa-finance/masa-oracle/pkg/config" - pubsub "github.com/masa-finance/masa-oracle/pkg/pubsub" + "github.com/masa-finance/masa-oracle/pkg/pubsub" "github.com/masa-finance/masa-oracle/pkg/workers" ) diff --git a/pkg/config/app.go b/pkg/config/app.go index 26d48d8a..3d7fdf87 100644 --- a/pkg/config/app.go +++ b/pkg/config/app.go @@ -177,7 +177,7 @@ func (c *AppConfig) setFileConfig(path string) { func (c *AppConfig) setEnvVariableConfig() { err := godotenv.Load() if err != nil { - logrus.Error("[-] Error loading .env file") + logrus.Errorf("[-] Error loading .env file %s", err) } viper.AutomaticEnv() } diff --git a/pkg/tests/mock_host.go b/pkg/tests/mock_host.go new file mode 100644 index 00000000..2c900651 --- /dev/null +++ b/pkg/tests/mock_host.go @@ -0,0 +1,203 @@ +package tests + +import ( + "context" + "fmt" + + "github.com/libp2p/go-libp2p/core/connmgr" + "github.com/libp2p/go-libp2p/core/event" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/peerstore" + "github.com/libp2p/go-libp2p/core/protocol" + "github.com/multiformats/go-multiaddr" +) + +type MockHost struct { + id peer.ID +} + +func (m *MockHost) Peerstore() peerstore.Peerstore { + fmt.Printf("Peerstore called\n") + return nil +} + +func (m *MockHost) Addrs() []multiaddr.Multiaddr { + fmt.Printf("Addrs called\n") + addr1, _ := multiaddr.NewMultiaddr("/ip4/127.0.0.1/tcp/4001") + return []multiaddr.Multiaddr{addr1} +} + +func (m *MockHost) Network() network.Network { + fmt.Printf("Network called\n") + return &MockNetwork{} +} + +func (m *MockHost) Mux() protocol.Switch { + fmt.Printf("Mux called\n") + return nil +} + +func (m *MockHost) Connect(ctx context.Context, pi peer.AddrInfo) error { + fmt.Printf("Connect called with peer info: %v\n", pi) + if ctx == nil { + fmt.Printf("nil context\n") + } + return nil +} + +func (m *MockHost) SetStreamHandler(pid protocol.ID, handler network.StreamHandler) { + fmt.Printf("SetStreamHandler called with protocol ID: %s\n", pid) + if handler == nil { + fmt.Printf("nil handler\n") + } +} + +func (m *MockHost) SetStreamHandlerMatch(id protocol.ID, f func(protocol.ID) bool, handler network.StreamHandler) { + fmt.Printf("SetStreamHandlerMatch called with protocol ID: %s\n", id) + if handler == nil { + fmt.Printf("nil handler\n") + } +} + +func (m *MockHost) RemoveStreamHandler(pid protocol.ID) { + fmt.Printf("RemoveStreamHandler called with protocol ID: %s\n", pid) +} + +func (m *MockHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.ID) (network.Stream, error) { + fmt.Printf("NewStream called with peer: %s, protocol IDs: %v\n", p, pids) + if ctx == nil { + fmt.Printf("nil context\n") + } + return nil, nil +} + +func (m *MockHost) Close() error { + fmt.Printf("Close called\n") + return nil +} + +func (m *MockHost) ConnManager() connmgr.ConnManager { + fmt.Printf("ConnManager called\n") + return nil +} + +func (m *MockHost) EventBus() event.Bus { + fmt.Printf("EventBus called\n") + return nil +} + +func (m *MockHost) ID() peer.ID { + fmt.Printf("ID called\n") + return m.id +} + +type MockNetwork struct{} + +func (m *MockNetwork) Close() error { + fmt.Printf("Close called\n") + return nil +} + +func (m *MockNetwork) CanDial(p peer.ID, addr multiaddr.Multiaddr) bool { + fmt.Printf("CanDial called with peer: %s, addr: %s\n", p, addr) + return true +} + +func (m *MockNetwork) DialPeer(ctx context.Context, id peer.ID) (network.Conn, error) { + fmt.Printf("DialPeer called with peer: %s\n", id) + if ctx == nil { + fmt.Printf("nil context\n") + } + return nil, nil +} + +func (m *MockNetwork) SetStreamHandler(handler network.StreamHandler) { + fmt.Printf("SetStreamHandler called\n") + if handler == nil { + fmt.Printf("nil handler\n") + } +} + +func (m *MockNetwork) NewStream(ctx context.Context, id peer.ID) (network.Stream, error) { + fmt.Printf("NewStream called with peer: %s\n", id) + if ctx == nil { + fmt.Printf("nil context\n") + } + return nil, nil +} + +func (m *MockNetwork) Listen(m2 ...multiaddr.Multiaddr) error { + fmt.Printf("Listen called with addresses: %v\n", m2) + return nil +} + +func (m *MockNetwork) ResourceManager() network.ResourceManager { + fmt.Printf("ResourceManager called\n") + return nil +} + +func (m *MockNetwork) Peerstore() peerstore.Peerstore { + fmt.Printf("Peerstore called\n") + return nil +} + +func (m *MockNetwork) LocalPeer() peer.ID { + fmt.Printf("LocalPeer called\n") + return "mockLocalPeerID" +} + +func (m *MockNetwork) ListenAddresses() []multiaddr.Multiaddr { + fmt.Printf("ListenAddresses called\n") + addr1, _ := multiaddr.NewMultiaddr("/ip4/127.0.0.1/tcp/4001") + return []multiaddr.Multiaddr{addr1} +} + +func (m *MockNetwork) InterfaceListenAddresses() ([]multiaddr.Multiaddr, error) { + fmt.Printf("InterfaceListenAddresses called\n") + addr1, _ := multiaddr.NewMultiaddr("/ip4/127.0.0.1/tcp/4001") + return []multiaddr.Multiaddr{addr1}, nil +} + +func (m *MockNetwork) Connectedness(p peer.ID) network.Connectedness { + fmt.Printf("Connectedness called with peer: %s\n", p) + return network.NotConnected +} + +func (m *MockNetwork) Peers() []peer.ID { + fmt.Printf("Peers called\n") + return []peer.ID{} +} + +func (m *MockNetwork) Conns() []network.Conn { + fmt.Printf("Conns called\n") + return []network.Conn{} +} + +func (m *MockNetwork) ConnsToPeer(p peer.ID) []network.Conn { + fmt.Printf("ConnsToPeer called with peer: %s\n", p) + return []network.Conn{} +} + +func (m *MockNetwork) Notify(notifier network.Notifiee) { + fmt.Printf("Notify called\n") + if notifier == nil { + fmt.Printf("nil notifier\n") + } +} + +func (m *MockNetwork) StopNotify(notifier network.Notifiee) { + fmt.Printf("StopNotify called\n") + if notifier == nil { + fmt.Printf("nil notifier\n") + } +} + +func (m *MockNetwork) ClosePeer(p peer.ID) error { + fmt.Printf("ClosePeer called with peer: %s\n", p) + return nil +} + +func (m *MockNetwork) RemovePeer(p peer.ID) { + fmt.Printf("RemovePeer called with peer: %s\n", p) +} diff --git a/pkg/tests/worker_test.go b/pkg/tests/worker_test.go new file mode 100644 index 00000000..08df8163 --- /dev/null +++ b/pkg/tests/worker_test.go @@ -0,0 +1,133 @@ +package tests + +import ( + "context" + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "github.com/masa-finance/masa-oracle/node" + "github.com/masa-finance/masa-oracle/pkg/pubsub" + "github.com/masa-finance/masa-oracle/pkg/workers" + datatypes "github.com/masa-finance/masa-oracle/pkg/workers/types" +) + +func TestWorkers(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Workers Suite") +} + +var _ = Describe("Worker Selection", func() { + var ( + oracleNode1 *node.OracleNode + oracleNode2 *node.OracleNode + category pubsub.WorkerCategory + ) + + BeforeEach(func() { + ctx := context.Background() + + // Start the first node with a random identity + n1, err := node.NewOracleNode(ctx, node.EnableStaked, node.EnableRandomIdentity, node.IsTwitterScraper) + Expect(err).ToNot(HaveOccurred()) + err = n1.Start() + Expect(err).ToNot(HaveOccurred()) + + // Get the address of the first node to use as a bootstrap node + addrs, err := n1.GetP2PMultiAddrs() + Expect(err).ToNot(HaveOccurred()) + + var bootNodes []string + for _, addr := range addrs { + bootNodes = append(bootNodes, addr.String()) + } + + // Start the second node with a random identity and bootstrap to the first node + n2, err := node.NewOracleNode(ctx, node.EnableStaked, node.EnableRandomIdentity, node.IsTelegramScraper, node.WithBootNodes(bootNodes...)) + Expect(err).ToNot(HaveOccurred()) + err = n2.Start() + Expect(err).ToNot(HaveOccurred()) + + n2.Host = &MockHost{id: "mockHostID1"} + oracleNode1 = n1 + oracleNode2 = n2 + category = pubsub.CategoryTwitter + }) + + AfterEach(func() { + //oracleNode1.Stop() + //oracleNode2.Stop() + }) + + Describe("GetEligibleWorkers", func() { + It("should return empty remote workers and a local worker", func() { + // Wait for the nodes to see each other + Eventually(func() bool { + datas := oracleNode1.NodeTracker.GetAllNodeData() + return len(datas) == 2 + }, "30s").Should(BeTrue()) + + Eventually(func() bool { + datas := oracleNode2.NodeTracker.GetAllNodeData() + return len(datas) == 2 + }, "30s").Should(BeTrue()) + + remoteWorkers, localWorker := workers.GetEligibleWorkers(oracleNode1, category) + + Expect(remoteWorkers).To(BeEmpty()) + Expect(localWorker).ToNot(BeNil()) + }) + }) +}) + +var _ = Describe("WorkHandlerManager", func() { + var ( + oracleNode *node.OracleNode + manager *workers.WorkHandlerManager + ) + + BeforeEach(func() { + manager = workers.NewWorkHandlerManager(workers.EnableTwitterWorker) + ctx := context.Background() + var err error + // Start the first node with a random identity + oracleNode, err = node.NewOracleNode(ctx, node.EnableStaked, node.EnableRandomIdentity, node.IsTwitterScraper) + Expect(err).ToNot(HaveOccurred()) + err = oracleNode.Start() + Expect(err).ToNot(HaveOccurred()) + }) + + Describe("Add and Get WorkHandler", func() { + It("should add and retrieve a work handler", func() { + handler, exists := manager.GetWorkHandler(datatypes.Twitter) + Expect(exists).To(BeTrue()) + Expect(handler).ToNot(BeNil()) + }) + + It("should return false for non-existent work handler", func() { + _, exists := manager.GetWorkHandler(datatypes.WorkerType("NonExistent")) + Expect(exists).To(BeFalse()) + }) + }) + + Describe("DistributeWork", func() { + It("should distribute work to eligible workers", func() { + workRequest := datatypes.WorkRequest{ + WorkType: datatypes.Twitter, + Data: []byte(`{"query": "test", "count": 10}`), + } + response := manager.DistributeWork(oracleNode, workRequest) + Expect(response.Error).To(BeEmpty()) + }) + + It("should handle errors in work distribution", func() { + workRequest := datatypes.WorkRequest{ + WorkType: datatypes.WorkerType("InvalidType"), + Data: []byte(`{"query": "test", "count": 10}`), + } + response := manager.DistributeWork(nil, workRequest) + Expect(response.Error).ToNot(BeEmpty()) + }) + }) +}) diff --git a/pkg/workers/worker_manager.go b/pkg/workers/worker_manager.go index 1665fc61..e7558b3a 100644 --- a/pkg/workers/worker_manager.go +++ b/pkg/workers/worker_manager.go @@ -84,8 +84,8 @@ func (whm *WorkHandlerManager) addWorkHandler(wType data_types.WorkerType, handl whm.handlers[wType] = &WorkHandlerInfo{Handler: handler} } -// getWorkHandler retrieves a registered work handler by name. -func (whm *WorkHandlerManager) getWorkHandler(wType data_types.WorkerType) (WorkHandler, bool) { +// GetWorkHandler retrieves a registered work handler by name. +func (whm *WorkHandlerManager) GetWorkHandler(wType data_types.WorkerType) (WorkHandler, bool) { whm.mu.RLock() defer whm.mu.RUnlock() info, exists := whm.handlers[wType] @@ -248,7 +248,7 @@ func (whm *WorkHandlerManager) sendWorkToWorker(node *node.OracleNode, worker da // ExecuteWork finds and executes the work handler associated with the given name. // It tracks the call count and execution duration for the handler. func (whm *WorkHandlerManager) ExecuteWork(workRequest data_types.WorkRequest) (response data_types.WorkResponse) { - handler, exists := whm.getWorkHandler(workRequest.WorkType) + handler, exists := whm.GetWorkHandler(workRequest.WorkType) if !exists { return data_types.WorkResponse{Error: ErrHandlerNotFound.Error()} }