diff --git a/Makefile b/Makefile index 186f6bcd..477b2c4e 100644 --- a/Makefile +++ b/Makefile @@ -44,7 +44,7 @@ stake: build client: build @./bin/masa-node-cli -# TODO Add -race and fix race conditions +# TODO: Add -race and fix race conditions test: contracts/node_modules @go test -coverprofile=coverage.txt -covermode=atomic -v -count=1 -shuffle=on ./... diff --git a/cmd/masa-node/config.go b/cmd/masa-node/config.go index a149b1ba..87f0f5e6 100644 --- a/cmd/masa-node/config.go +++ b/cmd/masa-node/config.go @@ -3,14 +3,23 @@ package main import ( "github.com/masa-finance/masa-oracle/node" "github.com/masa-finance/masa-oracle/pkg/config" + "github.com/masa-finance/masa-oracle/pkg/masacrypto" pubsub "github.com/masa-finance/masa-oracle/pkg/pubsub" "github.com/masa-finance/masa-oracle/pkg/workers" ) -func initOptions(cfg *config.AppConfig) ([]node.Option, *workers.WorkHandlerManager, *pubsub.PublicKeySubscriptionHandler) { +func initOptions(cfg *config.AppConfig, keyManager *masacrypto.KeyManager) ([]node.Option, *workers.WorkHandlerManager, *pubsub.PublicKeySubscriptionHandler) { // WorkerManager configuration - // XXX: this needs to be moved under config, but now it's here as there are import cycles given singletons - workerManagerOptions := []workers.WorkerOptionFunc{} + // TODO: this needs to be moved under config, but now it's here as there are import cycles given singletons + workerManagerOptions := []workers.WorkerOptionFunc{ + workers.WithLlmChatUrl(cfg.LLMChatUrl), + workers.WithMasaDir(cfg.MasaDir), + } + + cachePath := cfg.CachePath + if cachePath == "" { + cachePath = cfg.MasaDir + "/cache" + } masaNodeOptions := []node.Option{ node.EnableStaked, @@ -19,6 +28,10 @@ func initOptions(cfg *config.AppConfig) ([]node.Option, *workers.WorkHandlerMana node.WithVersion(cfg.Version), node.WithPort(cfg.PortNbr), node.WithBootNodes(cfg.Bootnodes...), + node.WithMasaDir(cfg.MasaDir), + node.WithCachePath(cachePath), + node.WithLLMCloudFlareURL(cfg.LLMCfUrl), + node.WithKeyManager(keyManager), } if cfg.TwitterScraper { @@ -50,8 +63,7 @@ func initOptions(cfg *config.AppConfig) ([]node.Option, *workers.WorkHandlerMana blockChainEventTracker := node.NewBlockChain() pubKeySub := &pubsub.PublicKeySubscriptionHandler{} - // TODO: Where the config is involved, move to the config the generation of - // Node options + // TODO: Where the config is involved, move to the config the generation of Node options masaNodeOptions = append(masaNodeOptions, []node.Option{ // Register the worker manager node.WithMasaProtocolHandler( @@ -68,7 +80,7 @@ func initOptions(cfg *config.AppConfig) ([]node.Option, *workers.WorkHandlerMana // and other peers can do work we only need to check this here // if this peer can or cannot scrape or write that is checked in other places masaNodeOptions = append(masaNodeOptions, - node.WithService(blockChainEventTracker.Start(config.GetInstance().MasaDir)), + node.WithService(blockChainEventTracker.Start(cfg.MasaDir)), ) } diff --git a/cmd/masa-node/main.go b/cmd/masa-node/main.go index 4ce683ca..6443fb9b 100644 --- a/cmd/masa-node/main.go +++ b/cmd/masa-node/main.go @@ -34,13 +34,16 @@ func main() { cfg.LogConfig() cfg.SetupLogging() - keyManager := masacrypto.KeyManagerInstance() + keyManager, err := masacrypto.NewKeyManager(cfg.PrivateKey, cfg.PrivateKeyFile) + if err != nil { + logrus.Fatal("[-] Failed to initialize keys:", err) + } // Create a cancellable context ctx, cancel := context.WithCancel(context.Background()) if cfg.Faucet { - err := handleFaucet(keyManager.EcdsaPrivKey) + err := handleFaucet(cfg.RpcUrl, keyManager.EcdsaPrivKey) if err != nil { logrus.Errorf("[-] %v", err) os.Exit(1) @@ -51,7 +54,7 @@ func main() { } if cfg.StakeAmount != "" { - err := handleStaking(keyManager.EcdsaPrivKey) + err := handleStaking(cfg.RpcUrl, keyManager.EcdsaPrivKey, cfg.StakeAmount) if err != nil { logrus.Warningf("%v", err) } else { @@ -61,7 +64,7 @@ func main() { } // Verify the staking event - isStaked, err := staking.VerifyStakingEvent(keyManager.EthAddress) + isStaked, err := staking.VerifyStakingEvent(cfg.RpcUrl, keyManager.EthAddress) if err != nil { logrus.Error(err) } @@ -70,7 +73,7 @@ func main() { logrus.Warn("No staking event found for this address") } - masaNodeOptions, workHandlerManager, pubKeySub := initOptions(cfg) + masaNodeOptions, workHandlerManager, pubKeySub := initOptions(cfg, keyManager) // Create a new OracleNode masaNode, err := node.NewOracleNode(ctx, masaNodeOptions...) @@ -95,7 +98,7 @@ func main() { } // Init cache resolver - db.InitResolverCache(masaNode, keyManager) + db.InitResolverCache(masaNode, keyManager, cfg.AllowedPeerId, cfg.AllowedPeerPublicKey, cfg.Validator) // Cancel the context when SIGINT is received go handleSignals(cancel, masaNode, cfg) diff --git a/cmd/masa-node/staking.go b/cmd/masa-node/staking.go index 19889b4c..36b07890 100644 --- a/cmd/masa-node/staking.go +++ b/cmd/masa-node/staking.go @@ -9,20 +9,19 @@ import ( "github.com/fatih/color" "github.com/sirupsen/logrus" - "github.com/masa-finance/masa-oracle/pkg/config" "github.com/masa-finance/masa-oracle/pkg/staking" ) -func handleStaking(privateKey *ecdsa.PrivateKey) error { +func handleStaking(rpcUrl string, privateKey *ecdsa.PrivateKey, stakeAmount string) error { // Staking logic // Convert the stake amount to the smallest unit, assuming 18 decimal places - amountBigInt, ok := new(big.Int).SetString(config.GetInstance().StakeAmount, 10) + amountBigInt, ok := new(big.Int).SetString(stakeAmount, 10) if !ok { logrus.Fatal("Invalid stake amount") } amountInSmallestUnit := new(big.Int).Mul(amountBigInt, big.NewInt(1e18)) - stakingClient, err := staking.NewClient(privateKey) + stakingClient, err := staking.NewClient(rpcUrl, privateKey) if err != nil { return err } @@ -86,8 +85,8 @@ func handleStaking(privateKey *ecdsa.PrivateKey) error { return nil } -func handleFaucet(privateKey *ecdsa.PrivateKey) error { - faucetClient, err := staking.NewClient(privateKey) +func handleFaucet(rpcUrl string, privateKey *ecdsa.PrivateKey) error { + faucetClient, err := staking.NewClient(rpcUrl, privateKey) if err != nil { logrus.Error("[-] Failed to create staking client:", err) return err diff --git a/docs/oracle-node/twitter-sentiment.md b/docs/oracle-node/twitter-sentiment.md index a8b3fcb4..cee2dd8c 100644 --- a/docs/oracle-node/twitter-sentiment.md +++ b/docs/oracle-node/twitter-sentiment.md @@ -64,7 +64,7 @@ const ( #### Masa cli or code integration -Tweets are fetched using the Twitter Scraper library, as seen in the [llmbridge](file:///Users/john/Projects/masa/masa-oracle/pkg/llmbridge/sentiment_twitter.go#1%2C9-1%2C9) package. This process does not require Twitter API keys, making it accessible and straightforward. +Tweets are fetched using the Twitter Scraper library, as seen in the [llmbridge](../pkg/llmbridge/sentiment_twitter.go#1%2C9-1%2C9) package. This process does not require Twitter API keys, making it accessible and straightforward. ```go func AnalyzeSentimentTweets(tweets []*twitterscraper.Tweet, model string) (string, string, error) { ... } diff --git a/node/options.go b/node/options.go index c8df794e..9ba05324 100644 --- a/node/options.go +++ b/node/options.go @@ -4,6 +4,7 @@ import ( "context" "github.com/masa-finance/masa-oracle/node/types" + "github.com/masa-finance/masa-oracle/pkg/masacrypto" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/protocol" @@ -30,6 +31,10 @@ type NodeOption struct { MasaProtocolHandlers map[string]network.StreamHandler Environment string Version string + MasaDir string + CachePath string + LLMCloudflareUrl string + KeyManager *masacrypto.KeyManager } type PubSubHandlers struct { @@ -150,3 +155,27 @@ func WithPort(port int) Option { o.PortNbr = port } } + +func WithMasaDir(directory string) Option { + return func(o *NodeOption) { + o.MasaDir = directory + } +} + +func WithCachePath(path string) Option { + return func(o *NodeOption) { + o.CachePath = path + } +} + +func WithLLMCloudFlareURL(url string) Option { + return func(o *NodeOption) { + o.LLMCloudflareUrl = url + } +} + +func WithKeyManager(km *masacrypto.KeyManager) Option { + return func(o *NodeOption) { + o.KeyManager = km + } +} diff --git a/node/oracle_node.go b/node/oracle_node.go index 762ff7c9..8ccb535a 100644 --- a/node/oracle_node.go +++ b/node/oracle_node.go @@ -27,7 +27,6 @@ import ( "github.com/masa-finance/masa-oracle/internal/versioning" "github.com/masa-finance/masa-oracle/pkg/chain" "github.com/masa-finance/masa-oracle/pkg/config" - "github.com/masa-finance/masa-oracle/pkg/masacrypto" myNetwork "github.com/masa-finance/masa-oracle/pkg/network" "github.com/masa-finance/masa-oracle/pkg/pubsub" ) @@ -48,6 +47,7 @@ type OracleNode struct { Blockchain *chain.Chain Options NodeOption Context context.Context + Config *config.AppConfig } // GetMultiAddrs returns the priority multiaddr for this node. @@ -102,7 +102,7 @@ func NewOracleNode(ctx context.Context, opts ...Option) (*OracleNode, error) { if o.RandomIdentity { libp2pOptions = append(libp2pOptions, libp2p.RandomIdentity) } else { - libp2pOptions = append(libp2pOptions, libp2p.Identity(masacrypto.KeyManagerInstance().Libp2pPrivKey)) + libp2pOptions = append(libp2pOptions, libp2p.Identity(o.KeyManager.Libp2pPrivKey)) } securityOptions := []libp2p.Option{ @@ -179,7 +179,7 @@ func (node *OracleNode) getNodeData() *pubsub.NodeData { if node.Options.RandomIdentity { publicEthAddress, _ = node.generateEthHexKeyForRandomIdentity() } else { - publicEthAddress = masacrypto.KeyManagerInstance().EthAddress + publicEthAddress = node.Options.KeyManager.EthAddress } nodeData := pubsub.NewNodeData(node.priorityAddrs, node.Host.ID(), publicEthAddress, pubsub.ActivityJoined) @@ -245,7 +245,7 @@ func (node *OracleNode) Start() (err error) { go p(node.Context, node) } - go myNetwork.Discover(node.Context, node.Host, node.DHT, node.Protocol) + go myNetwork.Discover(node.Context, node.Options.Bootnodes, node.Host, node.DHT, node.Protocol) nodeData := node.NodeTracker.GetNodeData(node.Host.ID().String()) if nodeData == nil { @@ -328,15 +328,14 @@ func (node *OracleNode) handleStream(stream network.Stream) { // IsWorker determines if the OracleNode is configured to act as an actor. // An actor node is one that has at least one of the following scrapers enabled: -// TwitterScraper, DiscordScraper, or WebScraper. +// TwitterScraper, DiscordScraper, TelegramScraper or WebScraper. // It returns true if any of these scrapers are enabled, otherwise false. func (node *OracleNode) IsWorker() bool { // need to get this by node data - cfg := config.GetInstance() - if cfg.TwitterScraper || cfg.DiscordScraper || cfg.TelegramScraper || cfg.WebScraper { - return true - } - return false + return node.Options.IsTwitterScraper || + node.Options.IsDiscordScraper || + node.Options.IsTelegramScraper || + node.Options.IsWebScraper } // IsPublisher returns true if this node is a publisher node. @@ -348,7 +347,7 @@ func (node *OracleNode) IsPublisher() bool { // Version returns the current version string of the oracle node software. func (node *OracleNode) Version() string { - return config.GetInstance().Version + return node.Options.Version } // LogActiveTopics logs the currently active topic names to the diff --git a/pkg/api/handlers_data.go b/pkg/api/handlers_data.go index fb8fb443..08f8d1ec 100644 --- a/pkg/api/handlers_data.go +++ b/pkg/api/handlers_data.go @@ -781,7 +781,7 @@ func (api *API) CfLlmChat() gin.HandlerFunc { } api.sendTrackingEvent(data_types.LLMChat, bodyBytes) - cfUrl := config.GetInstance().LLMCfUrl + cfUrl := api.Node.Options.LLMCloudflareUrl if cfUrl == "" { c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Errorf("missing env LLM_CF_URL")}) return diff --git a/pkg/api/handlers_node.go b/pkg/api/handlers_node.go index d50d926f..7c8d789b 100644 --- a/pkg/api/handlers_node.go +++ b/pkg/api/handlers_node.go @@ -11,7 +11,6 @@ import ( "github.com/masa-finance/masa-oracle/pkg/consensus" "github.com/masa-finance/masa-oracle/pkg/db" - "github.com/masa-finance/masa-oracle/pkg/masacrypto" "github.com/sirupsen/logrus" "github.com/gin-gonic/gin" @@ -181,7 +180,7 @@ func (api *API) PublishPublicKeyHandler() gin.HandlerFunc { return } - keyManager := masacrypto.KeyManagerInstance() + keyManager := api.Node.Options.KeyManager // Set the data to be signed as the signer's Peer ID data := []byte(api.Node.Host.ID().String()) diff --git a/pkg/config/app.go b/pkg/config/app.go index 4bdae773..ba04fb8a 100644 --- a/pkg/config/app.go +++ b/pkg/config/app.go @@ -91,6 +91,7 @@ func GetInstance() *AppConfig { instance = &AppConfig{} instance.setDefaultConfig() + // TODO Shouldn't the env vars override the file config, instead of the other way around? instance.setEnvVariableConfig() instance.setFileConfig(viper.GetString("FILE_PATH")) diff --git a/pkg/db/access_control.go b/pkg/db/access_control.go index 9840998d..47ea5044 100644 --- a/pkg/db/access_control.go +++ b/pkg/db/access_control.go @@ -8,11 +8,11 @@ package db import ( "encoding/hex" + libp2pCrypto "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/host" "github.com/sirupsen/logrus" - "github.com/masa-finance/masa-oracle/pkg/config" "github.com/masa-finance/masa-oracle/pkg/consensus" ) @@ -34,14 +34,7 @@ func isAuthorized(nodeID string) bool { } // Verifier checks if the given host is allowed to access to the database and verifies the signature -func Verifier(h host.Host, data []byte, signature []byte) bool { - // Load configuration instance - cfg := config.GetInstance() - - // Get allowed peer ID and public key from the configuration - allowedPeerID := cfg.AllowedPeerId - allowedPeerPubKeyString := cfg.AllowedPeerPublicKey - +func Verifier(h host.Host, data []byte, signature []byte, allowedPeerID string, allowedPeerPubKeyString string, isValidator bool) bool { if allowedPeerID == "" || allowedPeerPubKeyString == "" { logrus.Warn("[-] Allowed peer ID or public key not found in configuration") return false @@ -81,7 +74,7 @@ func Verifier(h host.Host, data []byte, signature []byte) bool { return false } - if cfg.Validator { + if isValidator { logrus.WithFields(logrus.Fields{ "hostID": h.ID().String(), diff --git a/pkg/db/operations.go b/pkg/db/operations.go index 31ad8da7..91fd5e83 100644 --- a/pkg/db/operations.go +++ b/pkg/db/operations.go @@ -1,4 +1,4 @@ -// TODO Rename this to something else, this is NOT a database (it just stores data in the DHT and in a cache) +// TODO: Rename this to something else, this is NOT a database (it just stores data in the DHT and in a cache) package db import ( diff --git a/pkg/db/resolver_cache.go b/pkg/db/resolver_cache.go index 4473c8da..685fc242 100644 --- a/pkg/db/resolver_cache.go +++ b/pkg/db/resolver_cache.go @@ -43,13 +43,10 @@ type Record struct { // The purpose of this function is to initialize the resolver cache and perform any necessary setup or configuration. It associates the resolver cache with the provided Masa Oracle node and key manager. // // Note: The specific implementation details of the `InitResolverCache` function are not provided in the given code snippet. The function signature suggests that it initializes the resolver cache, but the actual initialization logic would be present in the function body. -func InitResolverCache(node *node.OracleNode, keyManager *masacrypto.KeyManager) { +func InitResolverCache(node *node.OracleNode, keyManager *masacrypto.KeyManager, allowedPeerID string, allowedPeerPubKeyString string, isValidator bool) { var err error - cachePath := config.GetInstance().CachePath - if cachePath == "" { - cachePath = config.GetInstance().MasaDir + "/cache" - } - cache, err = leveldb.NewDatastore(cachePath, nil) + + cache, err = leveldb.NewDatastore(node.Options.CachePath, nil) if err != nil { log.Fatal(err) } @@ -60,7 +57,7 @@ func InitResolverCache(node *node.OracleNode, keyManager *masacrypto.KeyManager) if err != nil { logrus.Errorf("[-] Error signing data: %v", err) } - _ = Verifier(node.Host, data, signature) + _ = Verifier(node.Host, data, signature, allowedPeerID, allowedPeerPubKeyString, isValidator) go monitorNodeData(context.Background(), node) diff --git a/pkg/llmbridge/client.go b/pkg/llmbridge/client.go index dd9b35a4..11000e80 100644 --- a/pkg/llmbridge/client.go +++ b/pkg/llmbridge/client.go @@ -11,7 +11,6 @@ import ( "net/http" "strings" - "github.com/masa-finance/masa-oracle/pkg/config" "github.com/sashabaranov/go-openai" ) @@ -24,15 +23,13 @@ type GPTClient struct { } // NewClaudeClient creates a new ClaudeClient instance with default configuration. -func NewClaudeClient() *ClaudeClient { - cnf := NewClaudeAPIConfig() - return &ClaudeClient{config: cnf} +func NewClaudeClient(config *ClaudeAPIConfig) *ClaudeClient { + return &ClaudeClient{config: config} } // NewGPTClient creates a new GPTClient instance with default configuration. -func NewGPTClient() *GPTClient { - cnf := NewGPTConfig() - return &GPTClient{config: cnf} +func NewGPTClient(config *GPTAPIConfig) *GPTClient { + return &GPTClient{config: config} } // SendRequest sends an HTTP request to the Claude API with the given payload. @@ -64,8 +61,7 @@ func (c *GPTClient) SendRequest(payload string, model string, prompt string) (st break } - cfg := config.GetInstance() - key := cfg.GPTApiKey + key := c.config.APIKey if key == "" { return "", errors.New("OPENAI_API_KEY is not set") } diff --git a/pkg/llmbridge/config.go b/pkg/llmbridge/config.go index 0b6e2b33..ea957e35 100644 --- a/pkg/llmbridge/config.go +++ b/pkg/llmbridge/config.go @@ -14,9 +14,7 @@ type GPTAPIConfig struct { // NewClaudeAPIConfig creates a new ClaudeAPIConfig instance with values loaded // from the application config. -func NewClaudeAPIConfig() *ClaudeAPIConfig { - appConfig := config.GetInstance() - +func NewClaudeAPIConfig(appConfig *config.AppConfig) *ClaudeAPIConfig { // need to add these to the config package return &ClaudeAPIConfig{ URL: appConfig.ClaudeApiURL, @@ -27,9 +25,7 @@ func NewClaudeAPIConfig() *ClaudeAPIConfig { // NewGPTConfig creates a new GPTConfig instance with values loaded // from the application config. -func NewGPTConfig() *GPTAPIConfig { - appConfig := config.GetInstance() - +func NewGPTConfig(appConfig *config.AppConfig) *GPTAPIConfig { // need to add these to the config package return &GPTAPIConfig{ APIKey: appConfig.GPTApiKey, diff --git a/pkg/llmbridge/sentiment.go b/pkg/llmbridge/sentiment.go index 799e50cc..a2a7a24d 100644 --- a/pkg/llmbridge/sentiment.go +++ b/pkg/llmbridge/sentiment.go @@ -21,9 +21,17 @@ import ( // It concatenates the tweets, creates a payload, sends a request to Claude, parses the response, // and returns the concatenated tweets content, a sentiment summary, and any error. func AnalyzeSentimentTweets(tweets []*twitterscraper.TweetResult, model string, prompt string) (string, string, error) { + appConfig := config.GetInstance() + // check if we are using claude or gpt, can add others easily if strings.Contains(model, "claude-") { - client := NewClaudeClient() // Adjusted to call without arguments + client := NewClaudeClient( + &ClaudeAPIConfig{ + URL: appConfig.ClaudeApiURL, + APIKey: appConfig.ClaudeApiKey, + Version: appConfig.ClaudeApiVersion, + }, + ) var validTweets []*twitterscraper.TweetResult for _, tweet := range tweets { @@ -54,7 +62,11 @@ func AnalyzeSentimentTweets(tweets []*twitterscraper.TweetResult, model string, return tweetsContent, sentimentSummary, nil } else if strings.Contains(model, "gpt-") { - client := NewGPTClient() + client := NewGPTClient( + &GPTAPIConfig{ + APIKey: appConfig.GPTApiKey, + }, + ) tweetsContent := ConcatenateTweets(tweets) sentimentSummary, err := client.SendRequest(tweetsContent, model, prompt) if err != nil { @@ -84,7 +96,7 @@ func AnalyzeSentimentTweets(tweets []*twitterscraper.TweetResult, model string, if err != nil { return "", "", err } - uri := config.GetInstance().LLMChatUrl + uri := appConfig.LLMChatUrl if uri == "" { return "", "", errors.New("ollama api url not set") } @@ -124,9 +136,17 @@ func ConcatenateTweets(tweets []*twitterscraper.TweetResult) string { // It concatenates the text, creates a payload, sends a request to Claude, parses the response, // and returns the concatenated content, a sentiment summary, and any error. func AnalyzeSentimentWeb(data string, model string, prompt string) (string, string, error) { + appConfig := config.GetInstance() + // check if we are using claude or gpt, can add others easily if strings.Contains(model, "claude-") { - client := NewClaudeClient() // Adjusted to call without arguments + client := NewClaudeClient( + &ClaudeAPIConfig{ + URL: appConfig.ClaudeApiURL, + APIKey: appConfig.ClaudeApiKey, + Version: appConfig.ClaudeApiVersion, + }, + ) payloadBytes, err := CreatePayload(data, model, prompt) if err != nil { logrus.Errorf("[-] Error creating payload: %v", err) @@ -146,7 +166,11 @@ func AnalyzeSentimentWeb(data string, model string, prompt string) (string, stri return data, sentimentSummary, nil } else if strings.Contains(model, "gpt-") { - client := NewGPTClient() + client := NewGPTClient( + &GPTAPIConfig{ + APIKey: appConfig.GPTApiKey, + }, + ) sentimentSummary, err := client.SendRequest(data, model, prompt) if err != nil { logrus.Errorf("[-] Error sending request to GPT: %v", err) @@ -166,7 +190,7 @@ func AnalyzeSentimentWeb(data string, model string, prompt string) (string, stri if err != nil { return "", "", err } - cfUrl := config.GetInstance().LLMCfUrl + cfUrl := appConfig.LLMCfUrl if cfUrl == "" { return "", "", errors.New("cloudflare workers url not set") } @@ -210,7 +234,7 @@ func AnalyzeSentimentWeb(data string, model string, prompt string) (string, stri if err != nil { return "", "", err } - uri := config.GetInstance().LLMChatUrl + uri := appConfig.LLMChatUrl if uri == "" { return "", "", errors.New("ollama api url not set") } @@ -241,6 +265,8 @@ func AnalyzeSentimentWeb(data string, model string, prompt string) (string, stri // It concatenates the messages, creates a payload, sends a request to the sentiment analysis service, parses the response, // and returns the concatenated messages content, a sentiment summary, and any error. func AnalyzeSentimentDiscord(messages []string, model string, prompt string) (string, string, error) { + appConfig := config.GetInstance() + // Concatenate messages with a newline character messagesContent := strings.Join(messages, "\n") @@ -248,7 +274,14 @@ func AnalyzeSentimentDiscord(messages []string, model string, prompt string) (st // Replace with the actual logic you have for sending requests to your sentiment analysis service // For example, if you're using the Claude API: if strings.Contains(model, "claude-") { - client := NewClaudeClient() // Adjusted to call without arguments + client := NewClaudeClient( + &ClaudeAPIConfig{ + URL: appConfig.ClaudeApiURL, + APIKey: appConfig.ClaudeApiKey, + Version: appConfig.ClaudeApiVersion, + }, + ) + payloadBytes, err := CreatePayload(messagesContent, model, prompt) if err != nil { logrus.Errorf("[-] Error creating payload: %v", err) @@ -289,7 +322,7 @@ func AnalyzeSentimentDiscord(messages []string, model string, prompt string) (st logrus.Errorf("[-] Error marshaling request JSON: %v", err) return "", "", err } - uri := config.GetInstance().LLMChatUrl + uri := appConfig.LLMChatUrl if uri == "" { errMsg := "ollama api url not set" logrus.Errorf("[-] %v", errMsg) @@ -321,6 +354,8 @@ func AnalyzeSentimentDiscord(messages []string, model string, prompt string) (st // AnalyzeSentimentTelegram analyzes the sentiment of the provided Telegram messages by sending them to the sentiment analysis API. func AnalyzeSentimentTelegram(messages []*tg.Message, model string, prompt string) (string, string, error) { + appConfig := config.GetInstance() + // Concatenate messages with a newline character var messageTexts []string for _, msg := range messages { @@ -332,7 +367,13 @@ func AnalyzeSentimentTelegram(messages []*tg.Message, model string, prompt strin // The rest of the code follows the same pattern as AnalyzeSentimentDiscord if strings.Contains(model, "claude-") { - client := NewClaudeClient() // Adjusted to call without arguments + client := NewClaudeClient( + &ClaudeAPIConfig{ + URL: appConfig.ClaudeApiURL, + APIKey: appConfig.ClaudeApiKey, + Version: appConfig.ClaudeApiVersion, + }, + ) payloadBytes, err := CreatePayload(messagesContent, model, prompt) if err != nil { logrus.Errorf("Error creating payload: %v", err) @@ -373,7 +414,7 @@ func AnalyzeSentimentTelegram(messages []*tg.Message, model string, prompt strin logrus.Errorf("[-] Error marshaling request JSON: %v", err) return "", "", err } - uri := config.GetInstance().LLMChatUrl + uri := appConfig.LLMChatUrl if uri == "" { err := errors.New("[-] ollama api url not set") logrus.Errorf("%v", err) diff --git a/pkg/masacrypto/key_manager.go b/pkg/masacrypto/key_manager.go index 426223d1..596922e6 100644 --- a/pkg/masacrypto/key_manager.go +++ b/pkg/masacrypto/key_manager.go @@ -3,13 +3,9 @@ package masacrypto import ( "crypto/ecdsa" "fmt" - "sync" ethCrypto "github.com/ethereum/go-ethereum/crypto" "github.com/libp2p/go-libp2p/core/crypto" - "github.com/sirupsen/logrus" - - "github.com/masa-finance/masa-oracle/pkg/config" ) // KeyManager is meant to simplify the management of cryptographic keys used in the application. @@ -32,19 +28,6 @@ import ( // to Ethereum address format. // - Ensures thread-safe initialization and access to the cryptographic keys through the // use of the sync.Once mechanism. -// -// Usage: -// To access the KeyManager and its functionalities, use the KeyManagerInstance() function -// which returns the singleton instance of KeyManager. This instance can then be used to -// perform various key management tasks, such as retrieving the application's cryptographic -// keys, converting key formats, and more. -// Example: -// keyManager := crypto.KeyManagerInstance() - -var ( - keyManagerInstance *KeyManager - once sync.Once -) // KeyManager holds all the cryptographic entities used in the application. type KeyManager struct { @@ -57,64 +40,58 @@ type KeyManager struct { EthAddress string // Ethereum format address } -// KeyManagerInstance returns the singleton instance of KeyManager, initializing it if necessary. -func KeyManagerInstance() *KeyManager { - once.Do(func() { - keyManagerInstance = &KeyManager{} - if err := keyManagerInstance.loadPrivateKey(); err != nil { - logrus.Fatal("[-] Failed to initialize keys:", err) - } - }) - return keyManagerInstance -} - -// loadPrivateKey loads the node's private key from the environment or a file. -// It first checks for a private key set via the PrivateKey config. If not found, -// it tries to load the key from the PrivateKeyFile. As a last resort, it -// generates a new key and saves it to the private key file. - +// NewKeyManager returns an initialized KeyManager. It first checks for a +// private key set via the PrivateKey config. If not found, it tries to +// load the key from the PrivateKeyFile. As a last resort, it generates +// a new key and saves it to the private key file. // The private key is loaded into both Libp2p and ECDSA formats for use by // different parts of the system. The public key and hex-encoded key representations // are also derived. -func (km *KeyManager) loadPrivateKey() (err error) { - var keyFile string - cfg := config.GetInstance() - if len(cfg.PrivateKey) > 0 { - km.Libp2pPrivKey, err = getPrivateKeyFromEnv(cfg.PrivateKey) +func NewKeyManager(privateKey string, privateKeyFile string) (*KeyManager, error) { + km := &KeyManager{} + + var err error + + if len(privateKey) > 0 { + km.Libp2pPrivKey, err = getPrivateKeyFromEnv(privateKey) if err != nil { - return err + return nil, err } } else { - keyFile = config.GetInstance().PrivateKeyFile // Check if the private key file exists - km.Libp2pPrivKey, err = getPrivateKeyFromFile(keyFile) + km.Libp2pPrivKey, err = getPrivateKeyFromFile(privateKeyFile) if err != nil { - km.Libp2pPrivKey, err = generateNewPrivateKey(keyFile) + km.Libp2pPrivKey, err = generateNewPrivateKey(privateKeyFile) if err != nil { - return err + return nil, err } } } + km.Libp2pPubKey = km.Libp2pPrivKey.GetPublic() + // After obtaining the libp2p privKey, convert it to an ECDSA private key km.EcdsaPrivKey, err = libp2pPrivateKeyToEcdsa(km.Libp2pPrivKey) if err != nil { - return err + return nil, err } - err = saveEcdesaPrivateKeyToFile(km.EcdsaPrivKey, fmt.Sprintf("%s.ecdsa", keyFile)) + err = saveEcdesaPrivateKeyToFile(km.EcdsaPrivKey, fmt.Sprintf("%s.ecdsa", privateKeyFile)) if err != nil { - return err + return nil, err } + km.HexPrivKey, err = getHexEncodedPrivateKey(km.Libp2pPrivKey) if err != nil { - return err + return nil, err } + km.EcdsaPubKey = &km.EcdsaPrivKey.PublicKey km.HexPubKey, err = getHexEncodedPublicKey(km.Libp2pPubKey) if err != nil { - return err + return nil, err } + km.EthAddress = ethCrypto.PubkeyToAddress(km.EcdsaPrivKey.PublicKey).Hex() - return nil + return km, nil } diff --git a/pkg/network/address.go b/pkg/network/address.go index a1674dbb..2de17ba4 100644 --- a/pkg/network/address.go +++ b/pkg/network/address.go @@ -64,7 +64,7 @@ func getPublicMultiAddress(addrs []multiaddr.Multiaddr) multiaddr.Multiaddr { } // GetPriorityAddress returns the best public or private IP address -// TODO rm? +// TODO: rm? func GetPriorityAddress(addrs []multiaddr.Multiaddr) multiaddr.Multiaddr { var bestPrivateAddr multiaddr.Multiaddr bestPublicAddr := getPublicMultiAddress(addrs) diff --git a/pkg/network/discover.go b/pkg/network/discover.go index 5c25c385..757628eb 100644 --- a/pkg/network/discover.go +++ b/pkg/network/discover.go @@ -10,8 +10,6 @@ import ( "github.com/libp2p/go-libp2p/core/peer" "github.com/multiformats/go-multiaddr" - "github.com/masa-finance/masa-oracle/pkg/config" - dht "github.com/libp2p/go-libp2p-kad-dht" "github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/core/network" @@ -24,7 +22,7 @@ import ( // It initializes discovery via the DHT and advertises this node. // It runs discovery in a loop with a ticker, re-advertising and finding new peers. // For each discovered peer, it checks if already connected, and if not, dials them. -func Discover(ctx context.Context, host host.Host, dht *dht.IpfsDHT, protocol protocol.ID) { +func Discover(ctx context.Context, bootNodes []string, host host.Host, dht *dht.IpfsDHT, protocol protocol.ID) { var routingDiscovery *routing.RoutingDiscovery protocolString := string(protocol) logrus.Infof("[+] Discovering peers for protocol: %s", protocolString) @@ -97,16 +95,12 @@ func Discover(ctx context.Context, host host.Host, dht *dht.IpfsDHT, protocol pr ID: availPeer.ID, Addrs: availPeer.Addrs, } - hostAddrInfo := peer.AddrInfo{ - ID: host.ID(), - Addrs: host.Addrs(), - } - if availPeerAddrInfo.ID.String() == hostAddrInfo.ID.String() { + if availPeerAddrInfo.ID == host.ID() { logrus.Debugf("Skipping connect to self: %s", availPeerAddrInfo.ID.String()) continue } if len(availPeerAddrInfo.Addrs) == 0 { - for _, bn := range config.GetInstance().Bootnodes { + for _, bn := range bootNodes { bootNode := strings.Split(bn, "/")[len(strings.Split(bn, "/"))-1] if availPeerAddrInfo.ID.String() != bootNode { logrus.Warningf("Skipping connect to non bootnode peer with no multiaddress: %s", availPeerAddrInfo.ID.String()) @@ -117,7 +111,7 @@ func Discover(ctx context.Context, host host.Host, dht *dht.IpfsDHT, protocol pr logrus.Infof("[+] Available Peer: %s", availPeer.String()) if host.Network().Connectedness(availPeer.ID) != network.Connected { - if isConnectedToBootnode(host, config.GetInstance().Bootnodes) { + if isConnectedToBootnode(host, bootNodes) { _, err := host.Network().DialPeer(ctx, availPeer.ID) if err != nil { logrus.Warningf("[-] Failed to connect to peer %s, will retry...", availPeer.ID.String()) @@ -126,10 +120,10 @@ func Discover(ctx context.Context, host host.Host, dht *dht.IpfsDHT, protocol pr logrus.Infof("[+] Connected to peer %s", availPeer.ID.String()) } } else { - for _, bn := range config.GetInstance().Bootnodes { + for _, bn := range bootNodes { if len(bn) > 0 { logrus.Info("[-] Not connected to any bootnode. Attempting to reconnect...") - reconnectToBootnodes(ctx, host, config.GetInstance().Bootnodes) + reconnectToBootnodes(ctx, host, bootNodes) } } } diff --git a/pkg/scrapers/twitter/common.go b/pkg/scrapers/twitter/common.go index 008c7aed..8a3ee557 100644 --- a/pkg/scrapers/twitter/common.go +++ b/pkg/scrapers/twitter/common.go @@ -7,7 +7,6 @@ import ( "sync" "github.com/joho/godotenv" - "github.com/masa-finance/masa-oracle/pkg/config" "github.com/sirupsen/logrus" ) @@ -49,9 +48,8 @@ func parseAccounts(accountPairs []string) []*TwitterAccount { }) } -func getAuthenticatedScraper() (*Scraper, *TwitterAccount, error) { +func getAuthenticatedScraper(baseDir string) (*Scraper, *TwitterAccount, error) { once.Do(initializeAccountManager) - baseDir := config.GetInstance().MasaDir account := accountManager.GetNextAccount() if account == nil { diff --git a/pkg/scrapers/twitter/followers.go b/pkg/scrapers/twitter/followers.go index 25617904..72020b8c 100644 --- a/pkg/scrapers/twitter/followers.go +++ b/pkg/scrapers/twitter/followers.go @@ -7,8 +7,8 @@ import ( "github.com/sirupsen/logrus" ) -func ScrapeFollowersForProfile(username string, count int) ([]twitterscraper.Legacy, error) { - scraper, account, err := getAuthenticatedScraper() +func ScrapeFollowersForProfile(baseDir string, username string, count int) ([]twitterscraper.Legacy, error) { + scraper, account, err := getAuthenticatedScraper(baseDir) if err != nil { return nil, err } diff --git a/pkg/scrapers/twitter/profile.go b/pkg/scrapers/twitter/profile.go index cfe77096..547ae987 100644 --- a/pkg/scrapers/twitter/profile.go +++ b/pkg/scrapers/twitter/profile.go @@ -4,8 +4,8 @@ import ( twitterscraper "github.com/masa-finance/masa-twitter-scraper" ) -func ScrapeTweetsProfile(username string) (twitterscraper.Profile, error) { - scraper, account, err := getAuthenticatedScraper() +func ScrapeTweetsProfile(baseDir string, username string) (twitterscraper.Profile, error) { + scraper, account, err := getAuthenticatedScraper(baseDir) if err != nil { return twitterscraper.Profile{}, err } diff --git a/pkg/scrapers/twitter/tweets.go b/pkg/scrapers/twitter/tweets.go index b32b2c4d..e58506ac 100644 --- a/pkg/scrapers/twitter/tweets.go +++ b/pkg/scrapers/twitter/tweets.go @@ -11,8 +11,8 @@ type TweetResult struct { Error error } -func ScrapeTweetsByQuery(query string, count int) ([]*TweetResult, error) { - scraper, account, err := getAuthenticatedScraper() +func ScrapeTweetsByQuery(baseDir string, query string, count int) ([]*TweetResult, error) { + scraper, account, err := getAuthenticatedScraper(baseDir) if err != nil { return nil, err } diff --git a/pkg/staking/contracts.go b/pkg/staking/contracts.go index d58eeb9e..5092da73 100644 --- a/pkg/staking/contracts.go +++ b/pkg/staking/contracts.go @@ -6,8 +6,6 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/ethclient" - - "github.com/masa-finance/masa-oracle/pkg/config" ) var MasaTokenAddress common.Address @@ -18,10 +16,10 @@ type Client struct { PrivateKey *ecdsa.PrivateKey } -// NewClient initializes a new Client instance with the provided private key. +// NewClient initializes a new ethClient.Client instance with the provided private key. // It loads the contract addresses, initializes an Ethereum client, and returns // a Client instance. -func NewClient(privateKey *ecdsa.PrivateKey) (*Client, error) { +func NewClient(rpcUrl string, privateKey *ecdsa.PrivateKey) (*Client, error) { addresses, err := LoadContractAddresses() if err != nil { return nil, fmt.Errorf("[-] Failed to load contract addresses: %v", err) @@ -30,7 +28,7 @@ func NewClient(privateKey *ecdsa.PrivateKey) (*Client, error) { MasaTokenAddress = common.HexToAddress(addresses.Sepolia.MasaToken) ProtocolStakingContractAddress = common.HexToAddress(addresses.Sepolia.ProtocolStaking) - client, err := ethclient.Dial(config.GetInstance().RpcUrl) + client, err := ethclient.Dial(rpcUrl) if err != nil { return nil, err } diff --git a/pkg/staking/verify.go b/pkg/staking/verify.go index 73019fad..c4ea6168 100644 --- a/pkg/staking/verify.go +++ b/pkg/staking/verify.go @@ -9,21 +9,14 @@ import ( "github.com/ethereum/go-ethereum" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/ethclient" - - "github.com/masa-finance/masa-oracle/pkg/config" ) // VerifyStakingEvent checks if the given user address has staked tokens by // calling the stakes() view function on the ProtocolStaking contract. // It connects to an Ethereum node, encodes the stakes call, calls the contract, // unpacks the result, and returns true if the stakes amount is > 0. -func VerifyStakingEvent(userAddress string) (bool, error) { - rpcURL := config.GetInstance().RpcUrl - if rpcURL == "" { - return false, fmt.Errorf("%s is not set", config.RpcUrl) - } - - client, err := ethclient.Dial(rpcURL) +func VerifyStakingEvent(rpcUrl string, userAddress string) (bool, error) { + client, err := ethclient.Dial(rpcUrl) if err != nil { return false, fmt.Errorf("[-] Failed to connect to the Ethereum client: %v", err) } diff --git a/pkg/tests/scrapers/twitter_scraper_test.go b/pkg/tests/scrapers/twitter_scraper_test.go index 1031ee6d..42544978 100644 --- a/pkg/tests/scrapers/twitter_scraper_test.go +++ b/pkg/tests/scrapers/twitter_scraper_test.go @@ -14,7 +14,6 @@ import ( "runtime" "github.com/joho/godotenv" - "github.com/masa-finance/masa-oracle/pkg/config" "github.com/masa-finance/masa-oracle/pkg/scrapers/twitter" twitterscraper "github.com/masa-finance/masa-twitter-scraper" . "github.com/onsi/ginkgo/v2" @@ -27,6 +26,7 @@ var _ = Describe("Twitter Auth Function", func() { twitterUsername string twitterPassword string twoFACode string + masaDir string ) loadEnv := func() { @@ -45,8 +45,7 @@ var _ = Describe("Twitter Auth Function", func() { BeforeEach(func() { loadEnv() - tempDir := GinkgoT().TempDir() - config.GetInstance().MasaDir = tempDir + masaDir = GinkgoT().TempDir() twitterUsername = os.Getenv("TWITTER_USERNAME") twitterPassword = os.Getenv("TWITTER_PASSWORD") @@ -54,20 +53,18 @@ var _ = Describe("Twitter Auth Function", func() { Expect(twitterUsername).NotTo(BeEmpty(), "TWITTER_USERNAME environment variable is not set") Expect(twitterPassword).NotTo(BeEmpty(), "TWITTER_PASSWORD environment variable is not set") - - config.GetInstance().TwitterUsername = twitterUsername - config.GetInstance().TwitterPassword = twitterPassword - config.GetInstance().Twitter2FaCode = twoFACode + Expect(twoFACode).NotTo(BeEmpty(), "TWITTER_PASSWORD environment variable is not set") }) authenticate := func() *twitterscraper.Scraper { + // TODO Actually authenticate return nil //return twitter.Auth() } PIt("authenticates and logs in successfully", func() { // Ensure cookie file doesn't exist before authentication - cookieFile := filepath.Join(config.GetInstance().MasaDir, "twitter_cookies.json") + cookieFile := filepath.Join(masaDir, "twitter_cookies.json") Expect(cookieFile).NotTo(BeAnExistingFile()) // Authenticate @@ -81,7 +78,7 @@ var _ = Describe("Twitter Auth Function", func() { Expect(scraper.IsLoggedIn()).To(BeTrue()) // Attempt a simple operation to verify the session is valid - profile, err := twitter.ScrapeTweetsProfile("twitter") + profile, err := twitter.ScrapeTweetsProfile(masaDir, "twitter") Expect(err).To(BeNil()) Expect(profile.Username).To(Equal("twitter")) @@ -94,7 +91,7 @@ var _ = Describe("Twitter Auth Function", func() { Expect(firstScraper).NotTo(BeNil()) // Verify cookie file is created - cookieFile := filepath.Join(config.GetInstance().MasaDir, "twitter_cookies.json") + cookieFile := filepath.Join(masaDir, "twitter_cookies.json") Expect(cookieFile).To(BeAnExistingFile()) // Clear the scraper to force cookie reuse @@ -108,7 +105,7 @@ var _ = Describe("Twitter Auth Function", func() { Expect(secondScraper.IsLoggedIn()).To(BeTrue()) // Attempt a simple operation to verify the session is valid - profile, err := twitter.ScrapeTweetsProfile("twitter") + profile, err := twitter.ScrapeTweetsProfile(masaDir, "twitter") Expect(err).To(BeNil()) Expect(profile.Username).To(Equal("twitter")) @@ -121,7 +118,7 @@ var _ = Describe("Twitter Auth Function", func() { Expect(firstScraper).NotTo(BeNil()) // Verify cookie file is created - cookieFile := filepath.Join(config.GetInstance().MasaDir, "twitter_cookies.json") + cookieFile := filepath.Join(masaDir, "twitter_cookies.json") Expect(cookieFile).To(BeAnExistingFile()) // Clear the scraper to force cookie reuse @@ -135,12 +132,12 @@ var _ = Describe("Twitter Auth Function", func() { Expect(secondScraper.IsLoggedIn()).To(BeTrue()) // Attempt to scrape profile - profile, err := twitter.ScrapeTweetsProfile("god") + profile, err := twitter.ScrapeTweetsProfile(masaDir, "god") Expect(err).To(BeNil()) logrus.Infof("Profile of 'god': %+v", profile) // Scrape recent #Bitcoin tweets - tweets, err := twitter.ScrapeTweetsByQuery("#Bitcoin", 3) + tweets, err := twitter.ScrapeTweetsByQuery(masaDir, "#Bitcoin", 3) Expect(err).To(BeNil()) Expect(tweets).To(HaveLen(3)) @@ -151,6 +148,6 @@ var _ = Describe("Twitter Auth Function", func() { }) AfterEach(func() { - os.RemoveAll(config.GetInstance().MasaDir) + os.RemoveAll(masaDir) }) }) diff --git a/pkg/tests/twitter/twitter_scraper_test.go b/pkg/tests/twitter/twitter_scraper_test.go index 9ede2a0f..b41587cb 100644 --- a/pkg/tests/twitter/twitter_scraper_test.go +++ b/pkg/tests/twitter/twitter_scraper_test.go @@ -6,7 +6,6 @@ import ( "runtime" "github.com/joho/godotenv" - "github.com/masa-finance/masa-oracle/pkg/config" "github.com/masa-finance/masa-oracle/pkg/scrapers/twitter" twitterscraper "github.com/masa-finance/masa-twitter-scraper" . "github.com/onsi/ginkgo/v2" @@ -19,6 +18,7 @@ var _ = Describe("Twitter Auth Function", func() { twitterUsername string twitterPassword string twoFACode string + masaDir string ) loadEnv := func() { @@ -37,8 +37,7 @@ var _ = Describe("Twitter Auth Function", func() { BeforeEach(func() { loadEnv() - tempDir := GinkgoT().TempDir() - config.GetInstance().MasaDir = tempDir + masaDir = GinkgoT().TempDir() twitterUsername = os.Getenv("TWITTER_USERNAME") twitterPassword = os.Getenv("TWITTER_PASSWORD") @@ -46,20 +45,18 @@ var _ = Describe("Twitter Auth Function", func() { Expect(twitterUsername).NotTo(BeEmpty(), "TWITTER_USERNAME environment variable is not set") Expect(twitterPassword).NotTo(BeEmpty(), "TWITTER_PASSWORD environment variable is not set") - - config.GetInstance().TwitterUsername = twitterUsername - config.GetInstance().TwitterPassword = twitterPassword - config.GetInstance().Twitter2FaCode = twoFACode + Expect(twoFACode).NotTo(BeEmpty(), "TWITTER_PASSWORD environment variable is not set") }) authenticate := func() *twitterscraper.Scraper { + // TODO Actually authenticate return nil //return twitter.Auth() } PIt("authenticates and logs in successfully", func() { // Ensure cookie file doesn't exist before authentication - cookieFile := filepath.Join(config.GetInstance().MasaDir, "twitter_cookies.json") + cookieFile := filepath.Join(masaDir, "twitter_cookies.json") Expect(cookieFile).NotTo(BeAnExistingFile()) // Authenticate @@ -73,7 +70,7 @@ var _ = Describe("Twitter Auth Function", func() { Expect(scraper.IsLoggedIn()).To(BeTrue()) // Attempt a simple operation to verify the session is valid - profile, err := twitter.ScrapeTweetsProfile("twitter") + profile, err := twitter.ScrapeTweetsProfile(masaDir, "twitter") Expect(err).To(BeNil()) Expect(profile.Username).To(Equal("twitter")) @@ -86,7 +83,7 @@ var _ = Describe("Twitter Auth Function", func() { Expect(firstScraper).NotTo(BeNil()) // Verify cookie file is created - cookieFile := filepath.Join(config.GetInstance().MasaDir, "twitter_cookies.json") + cookieFile := filepath.Join(masaDir, "twitter_cookies.json") Expect(cookieFile).To(BeAnExistingFile()) // Clear the scraper to force cookie reuse @@ -100,7 +97,7 @@ var _ = Describe("Twitter Auth Function", func() { Expect(secondScraper.IsLoggedIn()).To(BeTrue()) // Attempt a simple operation to verify the session is valid - profile, err := twitter.ScrapeTweetsProfile("twitter") + profile, err := twitter.ScrapeTweetsProfile(masaDir, "twitter") Expect(err).To(BeNil()) Expect(profile.Username).To(Equal("twitter")) @@ -113,7 +110,7 @@ var _ = Describe("Twitter Auth Function", func() { Expect(firstScraper).NotTo(BeNil()) // Verify cookie file is created - cookieFile := filepath.Join(config.GetInstance().MasaDir, "twitter_cookies.json") + cookieFile := filepath.Join(masaDir, "twitter_cookies.json") Expect(cookieFile).To(BeAnExistingFile()) // Clear the scraper to force cookie reuse @@ -127,12 +124,12 @@ var _ = Describe("Twitter Auth Function", func() { Expect(secondScraper.IsLoggedIn()).To(BeTrue()) // Attempt to scrape profile - profile, err := twitter.ScrapeTweetsProfile("god") + profile, err := twitter.ScrapeTweetsProfile(masaDir, "god") Expect(err).To(BeNil()) logrus.Infof("Profile of 'god': %+v", profile) // Scrape recent #Bitcoin tweets - tweets, err := twitter.ScrapeTweetsByQuery("#Bitcoin", 3) + tweets, err := twitter.ScrapeTweetsByQuery(masaDir, "#Bitcoin", 3) Expect(err).To(BeNil()) Expect(tweets).To(HaveLen(3)) @@ -143,6 +140,6 @@ var _ = Describe("Twitter Auth Function", func() { }) AfterEach(func() { - os.RemoveAll(config.GetInstance().MasaDir) + os.RemoveAll(masaDir) }) }) diff --git a/pkg/workers/handlers/llm.go b/pkg/workers/handlers/llm.go index ed9687ae..e527a046 100644 --- a/pkg/workers/handlers/llm.go +++ b/pkg/workers/handlers/llm.go @@ -6,9 +6,8 @@ import ( "github.com/sirupsen/logrus" - "github.com/masa-finance/masa-oracle/pkg/config" "github.com/masa-finance/masa-oracle/pkg/network" - "github.com/masa-finance/masa-oracle/pkg/workers/types" + data_types "github.com/masa-finance/masa-oracle/pkg/workers/types" ) // TODO: LLMChatBody isn't used anywhere in the codebase. Remove after testing @@ -22,13 +21,20 @@ type LLMChatBody struct { Stream bool `json:"stream"` } -type LLMChatHandler struct{} +type LLMChatHandler struct { + llmChatUrl string +} + +func NewLLMChatHandler(llmChatUrl string) *LLMChatHandler { + return &LLMChatHandler{ + llmChatUrl: llmChatUrl, + } +} // HandleWork implements the WorkHandler interface for LLMChatHandler. func (h *LLMChatHandler) HandleWork(data []byte) data_types.WorkResponse { logrus.Infof("[+] LLM Chat %s", data) - uri := config.GetInstance().LLMChatUrl - if uri == "" { + if h.llmChatUrl == "" { return data_types.WorkResponse{Error: "missing env variable LLM_CHAT_URL"} } @@ -41,7 +47,7 @@ func (h *LLMChatHandler) HandleWork(data []byte) data_types.WorkResponse { if err != nil { return data_types.WorkResponse{Error: fmt.Sprintf("unable to marshal LLM chat data: %v", err)} } - resp, err := network.Post(uri, jsnBytes, nil) + resp, err := network.Post(h.llmChatUrl, jsnBytes, nil) if err != nil { return data_types.WorkResponse{Error: fmt.Sprintf("unable to post LLM chat data: %v", err)} } diff --git a/pkg/workers/handlers/twitter.go b/pkg/workers/handlers/twitter.go index bd4cc00e..5709a23e 100644 --- a/pkg/workers/handlers/twitter.go +++ b/pkg/workers/handlers/twitter.go @@ -9,9 +9,9 @@ import ( data_types "github.com/masa-finance/masa-oracle/pkg/workers/types" ) -type TwitterQueryHandler struct{} -type TwitterFollowersHandler struct{} -type TwitterProfileHandler struct{} +type TwitterQueryHandler struct{ MasaDir string } +type TwitterFollowersHandler struct{ MasaDir string } +type TwitterProfileHandler struct{ MasaDir string } func (h *TwitterQueryHandler) HandleWork(data []byte) data_types.WorkResponse { logrus.Infof("[+] TwitterQueryHandler input: %s", data) @@ -25,7 +25,7 @@ func (h *TwitterQueryHandler) HandleWork(data []byte) data_types.WorkResponse { logrus.Infof("[+] Scraping tweets for query: %s, count: %d", query, count) - resp, err := twitter.ScrapeTweetsByQuery(query, count) + resp, err := twitter.ScrapeTweetsByQuery(h.MasaDir, query, count) if err != nil { logrus.Errorf("[+] TwitterQueryHandler error scraping tweets: %v", err) return data_types.WorkResponse{Error: err.Error()} @@ -48,7 +48,7 @@ func (h *TwitterFollowersHandler) HandleWork(data []byte) data_types.WorkRespons } username := dataMap["username"].(string) count := int(dataMap["count"].(float64)) - resp, err := twitter.ScrapeFollowersForProfile(username, count) + resp, err := twitter.ScrapeFollowersForProfile(h.MasaDir, username, count) if err != nil { return data_types.WorkResponse{Error: fmt.Sprintf("unable to get twitter followers: %v", err)} } @@ -64,7 +64,7 @@ func (h *TwitterProfileHandler) HandleWork(data []byte) data_types.WorkResponse return data_types.WorkResponse{Error: fmt.Sprintf("unable to parse twitter profile data: %v", err)} } username := dataMap["username"].(string) - resp, err := twitter.ScrapeTweetsProfile(username) + resp, err := twitter.ScrapeTweetsProfile(h.MasaDir, username) if err != nil { return data_types.WorkResponse{Error: fmt.Sprintf("unable to get twitter profile: %v", err)} } diff --git a/pkg/workers/options.go b/pkg/workers/options.go index 0f261748..4387edf6 100644 --- a/pkg/workers/options.go +++ b/pkg/workers/options.go @@ -5,6 +5,8 @@ type WorkerOption struct { isWebScraperWorker bool isLLMServerWorker bool isDiscordScraperWorker bool + llmChatUrl string + masaDir string } type WorkerOptionFunc func(*WorkerOption) @@ -25,6 +27,18 @@ var EnableDiscordScraperWorker = func(o *WorkerOption) { o.isDiscordScraperWorker = true } +func WithLlmChatUrl(url string) WorkerOptionFunc { + return func(o *WorkerOption) { + o.llmChatUrl = url + } +} + +func WithMasaDir(dir string) WorkerOptionFunc { + return func(o *WorkerOption) { + o.masaDir = dir + } +} + func (a *WorkerOption) Apply(opts ...WorkerOptionFunc) { for _, opt := range opts { opt(a) diff --git a/pkg/workers/worker_manager.go b/pkg/workers/worker_manager.go index abe8ee4f..3f5ea65f 100644 --- a/pkg/workers/worker_manager.go +++ b/pkg/workers/worker_manager.go @@ -33,9 +33,9 @@ func NewWorkHandlerManager(opts ...WorkerOptionFunc) *WorkHandlerManager { } if options.isTwitterWorker { - whm.addWorkHandler(data_types.Twitter, &handlers.TwitterQueryHandler{}) - whm.addWorkHandler(data_types.TwitterFollowers, &handlers.TwitterFollowersHandler{}) - whm.addWorkHandler(data_types.TwitterProfile, &handlers.TwitterProfileHandler{}) + whm.addWorkHandler(data_types.Twitter, &handlers.TwitterQueryHandler{MasaDir: options.masaDir}) + whm.addWorkHandler(data_types.TwitterFollowers, &handlers.TwitterFollowersHandler{MasaDir: options.masaDir}) + whm.addWorkHandler(data_types.TwitterProfile, &handlers.TwitterProfileHandler{MasaDir: options.masaDir}) } if options.isWebScraperWorker { @@ -43,7 +43,7 @@ func NewWorkHandlerManager(opts ...WorkerOptionFunc) *WorkHandlerManager { } if options.isLLMServerWorker { - whm.addWorkHandler(data_types.LLMChat, &handlers.LLMChatHandler{}) + whm.addWorkHandler(data_types.LLMChat, handlers.NewLLMChatHandler(options.llmChatUrl)) } if options.isDiscordScraperWorker {