Skip to content

Commit

Permalink
Create MockHost for testing and update method visibility
Browse files Browse the repository at this point in the history
Introduce MockHost and MockNetwork structs for unit tests to simulate network interactions. Adjust method visibility in WorkHandlerManager for better accessibility and fix logging in AppConfig. Add new unit tests for worker selection and handlers.
  • Loading branch information
restevens402 committed Oct 6, 2024
1 parent 609e2ec commit 8692adb
Show file tree
Hide file tree
Showing 5 changed files with 341 additions and 5 deletions.
2 changes: 1 addition & 1 deletion cmd/masa-node/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package main
import (
"github.com/masa-finance/masa-oracle/node"
"github.com/masa-finance/masa-oracle/pkg/config"
pubsub "github.com/masa-finance/masa-oracle/pkg/pubsub"
"github.com/masa-finance/masa-oracle/pkg/pubsub"
"github.com/masa-finance/masa-oracle/pkg/workers"
)

Expand Down
2 changes: 1 addition & 1 deletion pkg/config/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ func (c *AppConfig) setFileConfig(path string) {
func (c *AppConfig) setEnvVariableConfig() {
err := godotenv.Load()
if err != nil {
logrus.Error("[-] Error loading .env file")
logrus.Errorf("[-] Error loading .env file %s", err)
}
viper.AutomaticEnv()
}
Expand Down
203 changes: 203 additions & 0 deletions pkg/tests/mock_host.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
package tests

import (
"context"
"fmt"

"github.com/libp2p/go-libp2p/core/connmgr"
"github.com/libp2p/go-libp2p/core/event"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/peerstore"
"github.com/libp2p/go-libp2p/core/protocol"
"github.com/multiformats/go-multiaddr"
)

type MockHost struct {
id peer.ID
}

func (m *MockHost) Peerstore() peerstore.Peerstore {
fmt.Printf("Peerstore called\n")
return nil
}

func (m *MockHost) Addrs() []multiaddr.Multiaddr {
fmt.Printf("Addrs called\n")
addr1, _ := multiaddr.NewMultiaddr("/ip4/127.0.0.1/tcp/4001")
return []multiaddr.Multiaddr{addr1}
}

func (m *MockHost) Network() network.Network {
fmt.Printf("Network called\n")
return &MockNetwork{}
}

func (m *MockHost) Mux() protocol.Switch {
fmt.Printf("Mux called\n")
return nil
}

func (m *MockHost) Connect(ctx context.Context, pi peer.AddrInfo) error {
fmt.Printf("Connect called with peer info: %v\n", pi)
if ctx == nil {
fmt.Printf("nil context\n")
}
return nil
}

func (m *MockHost) SetStreamHandler(pid protocol.ID, handler network.StreamHandler) {
fmt.Printf("SetStreamHandler called with protocol ID: %s\n", pid)
if handler == nil {
fmt.Printf("nil handler\n")
}
}

func (m *MockHost) SetStreamHandlerMatch(id protocol.ID, f func(protocol.ID) bool, handler network.StreamHandler) {
fmt.Printf("SetStreamHandlerMatch called with protocol ID: %s\n", id)
if handler == nil {
fmt.Printf("nil handler\n")
}
}

func (m *MockHost) RemoveStreamHandler(pid protocol.ID) {
fmt.Printf("RemoveStreamHandler called with protocol ID: %s\n", pid)
}

func (m *MockHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.ID) (network.Stream, error) {
fmt.Printf("NewStream called with peer: %s, protocol IDs: %v\n", p, pids)
if ctx == nil {
fmt.Printf("nil context\n")
}
return nil, nil
}

func (m *MockHost) Close() error {
fmt.Printf("Close called\n")
return nil
}

func (m *MockHost) ConnManager() connmgr.ConnManager {
fmt.Printf("ConnManager called\n")
return nil
}

func (m *MockHost) EventBus() event.Bus {
fmt.Printf("EventBus called\n")
return nil
}

func (m *MockHost) ID() peer.ID {
fmt.Printf("ID called\n")
return m.id
}

type MockNetwork struct{}

func (m *MockNetwork) Close() error {
fmt.Printf("Close called\n")
return nil
}

func (m *MockNetwork) CanDial(p peer.ID, addr multiaddr.Multiaddr) bool {
fmt.Printf("CanDial called with peer: %s, addr: %s\n", p, addr)
return true
}

func (m *MockNetwork) DialPeer(ctx context.Context, id peer.ID) (network.Conn, error) {
fmt.Printf("DialPeer called with peer: %s\n", id)
if ctx == nil {
fmt.Printf("nil context\n")
}
return nil, nil
}

func (m *MockNetwork) SetStreamHandler(handler network.StreamHandler) {
fmt.Printf("SetStreamHandler called\n")
if handler == nil {
fmt.Printf("nil handler\n")
}
}

func (m *MockNetwork) NewStream(ctx context.Context, id peer.ID) (network.Stream, error) {
fmt.Printf("NewStream called with peer: %s\n", id)
if ctx == nil {
fmt.Printf("nil context\n")
}
return nil, nil
}

func (m *MockNetwork) Listen(m2 ...multiaddr.Multiaddr) error {
fmt.Printf("Listen called with addresses: %v\n", m2)
return nil
}

func (m *MockNetwork) ResourceManager() network.ResourceManager {
fmt.Printf("ResourceManager called\n")
return nil
}

func (m *MockNetwork) Peerstore() peerstore.Peerstore {
fmt.Printf("Peerstore called\n")
return nil
}

func (m *MockNetwork) LocalPeer() peer.ID {
fmt.Printf("LocalPeer called\n")
return "mockLocalPeerID"
}

func (m *MockNetwork) ListenAddresses() []multiaddr.Multiaddr {
fmt.Printf("ListenAddresses called\n")
addr1, _ := multiaddr.NewMultiaddr("/ip4/127.0.0.1/tcp/4001")
return []multiaddr.Multiaddr{addr1}
}

func (m *MockNetwork) InterfaceListenAddresses() ([]multiaddr.Multiaddr, error) {
fmt.Printf("InterfaceListenAddresses called\n")
addr1, _ := multiaddr.NewMultiaddr("/ip4/127.0.0.1/tcp/4001")
return []multiaddr.Multiaddr{addr1}, nil
}

func (m *MockNetwork) Connectedness(p peer.ID) network.Connectedness {
fmt.Printf("Connectedness called with peer: %s\n", p)
return network.NotConnected
}

func (m *MockNetwork) Peers() []peer.ID {
fmt.Printf("Peers called\n")
return []peer.ID{}
}

func (m *MockNetwork) Conns() []network.Conn {
fmt.Printf("Conns called\n")
return []network.Conn{}
}

func (m *MockNetwork) ConnsToPeer(p peer.ID) []network.Conn {
fmt.Printf("ConnsToPeer called with peer: %s\n", p)
return []network.Conn{}
}

func (m *MockNetwork) Notify(notifier network.Notifiee) {
fmt.Printf("Notify called\n")
if notifier == nil {
fmt.Printf("nil notifier\n")
}
}

func (m *MockNetwork) StopNotify(notifier network.Notifiee) {
fmt.Printf("StopNotify called\n")
if notifier == nil {
fmt.Printf("nil notifier\n")
}
}

func (m *MockNetwork) ClosePeer(p peer.ID) error {
fmt.Printf("ClosePeer called with peer: %s\n", p)
return nil
}

func (m *MockNetwork) RemovePeer(p peer.ID) {
fmt.Printf("RemovePeer called with peer: %s\n", p)
}
133 changes: 133 additions & 0 deletions pkg/tests/worker_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
package tests

import (
"context"
"testing"

. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"

"github.com/masa-finance/masa-oracle/node"
"github.com/masa-finance/masa-oracle/pkg/pubsub"
"github.com/masa-finance/masa-oracle/pkg/workers"
datatypes "github.com/masa-finance/masa-oracle/pkg/workers/types"
)

func TestWorkers(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "Workers Suite")
}

var _ = Describe("Worker Selection", func() {
var (
oracleNode1 *node.OracleNode
oracleNode2 *node.OracleNode
category pubsub.WorkerCategory
)

BeforeEach(func() {
ctx := context.Background()

// Start the first node with a random identity
n1, err := node.NewOracleNode(ctx, node.EnableStaked, node.EnableRandomIdentity, node.IsTwitterScraper)
Expect(err).ToNot(HaveOccurred())
err = n1.Start()
Expect(err).ToNot(HaveOccurred())

// Get the address of the first node to use as a bootstrap node
addrs, err := n1.GetP2PMultiAddrs()
Expect(err).ToNot(HaveOccurred())

var bootNodes []string
for _, addr := range addrs {
bootNodes = append(bootNodes, addr.String())
}

// Start the second node with a random identity and bootstrap to the first node
n2, err := node.NewOracleNode(ctx, node.EnableStaked, node.EnableRandomIdentity, node.IsTelegramScraper, node.WithBootNodes(bootNodes...))
Expect(err).ToNot(HaveOccurred())
err = n2.Start()
Expect(err).ToNot(HaveOccurred())

n2.Host = &MockHost{id: "mockHostID1"}
oracleNode1 = n1
oracleNode2 = n2
category = pubsub.CategoryTwitter
})

AfterEach(func() {
//oracleNode1.Stop()
//oracleNode2.Stop()
})

Describe("GetEligibleWorkers", func() {
It("should return empty remote workers and a local worker", func() {
// Wait for the nodes to see each other
Eventually(func() bool {
datas := oracleNode1.NodeTracker.GetAllNodeData()
return len(datas) == 2
}, "30s").Should(BeTrue())

Eventually(func() bool {
datas := oracleNode2.NodeTracker.GetAllNodeData()
return len(datas) == 2
}, "30s").Should(BeTrue())

remoteWorkers, localWorker := workers.GetEligibleWorkers(oracleNode1, category)

Expect(remoteWorkers).To(BeEmpty())
Expect(localWorker).ToNot(BeNil())
})
})
})

var _ = Describe("WorkHandlerManager", func() {
var (
oracleNode *node.OracleNode
manager *workers.WorkHandlerManager
)

BeforeEach(func() {
manager = workers.NewWorkHandlerManager(workers.EnableTwitterWorker)
ctx := context.Background()
var err error
// Start the first node with a random identity
oracleNode, err = node.NewOracleNode(ctx, node.EnableStaked, node.EnableRandomIdentity, node.IsTwitterScraper)
Expect(err).ToNot(HaveOccurred())
err = oracleNode.Start()
Expect(err).ToNot(HaveOccurred())
})

Describe("Add and Get WorkHandler", func() {
It("should add and retrieve a work handler", func() {
handler, exists := manager.GetWorkHandler(datatypes.Twitter)
Expect(exists).To(BeTrue())
Expect(handler).ToNot(BeNil())
})

It("should return false for non-existent work handler", func() {
_, exists := manager.GetWorkHandler(datatypes.WorkerType("NonExistent"))
Expect(exists).To(BeFalse())
})
})

Describe("DistributeWork", func() {
It("should distribute work to eligible workers", func() {
workRequest := datatypes.WorkRequest{
WorkType: datatypes.Twitter,
Data: []byte(`{"query": "test", "count": 10}`),
}
response := manager.DistributeWork(oracleNode, workRequest)
Expect(response.Error).To(BeEmpty())
})

It("should handle errors in work distribution", func() {
workRequest := datatypes.WorkRequest{
WorkType: datatypes.WorkerType("InvalidType"),
Data: []byte(`{"query": "test", "count": 10}`),
}
response := manager.DistributeWork(nil, workRequest)
Expect(response.Error).ToNot(BeEmpty())
})
})
})
6 changes: 3 additions & 3 deletions pkg/workers/worker_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ func (whm *WorkHandlerManager) addWorkHandler(wType data_types.WorkerType, handl
whm.handlers[wType] = &WorkHandlerInfo{Handler: handler}
}

// getWorkHandler retrieves a registered work handler by name.
func (whm *WorkHandlerManager) getWorkHandler(wType data_types.WorkerType) (WorkHandler, bool) {
// GetWorkHandler retrieves a registered work handler by name.
func (whm *WorkHandlerManager) GetWorkHandler(wType data_types.WorkerType) (WorkHandler, bool) {
whm.mu.RLock()
defer whm.mu.RUnlock()
info, exists := whm.handlers[wType]
Expand Down Expand Up @@ -248,7 +248,7 @@ func (whm *WorkHandlerManager) sendWorkToWorker(node *node.OracleNode, worker da
// ExecuteWork finds and executes the work handler associated with the given name.
// It tracks the call count and execution duration for the handler.
func (whm *WorkHandlerManager) ExecuteWork(workRequest data_types.WorkRequest) (response data_types.WorkResponse) {
handler, exists := whm.getWorkHandler(workRequest.WorkType)
handler, exists := whm.GetWorkHandler(workRequest.WorkType)
if !exists {
return data_types.WorkResponse{Error: ErrHandlerNotFound.Error()}
}
Expand Down

0 comments on commit 8692adb

Please sign in to comment.