diff --git a/.gitignore b/.gitignore index 2e2a175..2535279 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,5 @@ /golbd /golbd.exe + +/vendor +.idea \ No newline at end of file diff --git a/lbcluster/lbcluster.go b/lbcluster/lbcluster.go index 05998e3..717ba56 100644 --- a/lbcluster/lbcluster.go +++ b/lbcluster/lbcluster.go @@ -4,38 +4,34 @@ import ( "encoding/json" "fmt" "io/ioutil" + "lb-experts/golbd/metric" "math/rand" "net" "net/http" - "strings" - - "gitlab.cern.ch/lb-experts/golbd/lbhost" - "sort" + "strings" + "sync" "time" + + "lb-experts/golbd/lbhost" + "lb-experts/golbd/logger" + "lb-experts/golbd/model" ) //WorstValue worst possible load const WorstValue int = 99999 -//TIMEOUT snmp timeout -const TIMEOUT int = 10 - -//OID snmp object to get -const OID string = ".1.3.6.1.4.1.96.255.1" - //LBCluster struct of an lbcluster alias type LBCluster struct { - Cluster_name string - Loadbalancing_username string - Loadbalancing_password string + ClusterConfig model.ClusterConfig Host_metric_table map[string]Node Parameters Params Time_of_last_evaluation time.Time Current_best_ips []net.IP Previous_best_ips_dns []net.IP Current_index int - Slog *Log + Slog logger.Logger + MetricLogic metric.Logic } //Params of the alias @@ -52,9 +48,9 @@ type Params struct { // Shuffle pseudo-randomizes the order of elements. // n is the number of elements. Shuffle panics if n < 0. // swap swaps the elements with indexes i and j. -func Shuffle(n int, swap func(i, j int)) { +func Shuffle(n int, swap func(i, j int)) error { if n < 0 { - panic("invalid argument to Shuffle") + return fmt.Errorf("invalid argument to Shuffle") } // Fisher-Yates shuffle: https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle @@ -72,12 +68,14 @@ func Shuffle(n int, swap func(i, j int)) { j := int(rand.Int31n(int32(i + 1))) swap(i, j) } + return nil } //Node Struct to keep the ips and load of a node for an alias type Node struct { - Load int - IPs []net.IP + Load int + IPs []net.IP + HostName string } //NodeList struct for the list @@ -92,24 +90,19 @@ func (lbc *LBCluster) Time_to_refresh() bool { return lbc.Time_of_last_evaluation.Add(time.Duration(lbc.Parameters.Polling_interval) * time.Second).Before(time.Now()) } -//Get_list_hosts Get the hosts for an alias -func (lbc *LBCluster) Get_list_hosts(current_list map[string]lbhost.LBHost) { - lbc.Write_to_log("DEBUG", "Getting the list of hosts for the alias") +//GetHostList Get the hosts for an alias +func (lbc *LBCluster) GetHostList(hostMap map[string]lbhost.Host) { + + lbc.Slog.Debug("Getting the list of hosts for the alias") for host := range lbc.Host_metric_table { - myHost, ok := current_list[host] + myHost, ok := hostMap[host] if ok { - myHost.Cluster_name = myHost.Cluster_name + "," + lbc.Cluster_name + clusterConfig := myHost.GetClusterConfig() + clusterConfig.Cluster_name = clusterConfig.Cluster_name + "," + lbc.ClusterConfig.Cluster_name } else { - myHost = lbhost.LBHost{ - Cluster_name: lbc.Cluster_name, - Host_name: host, - Loadbalancing_username: lbc.Loadbalancing_username, - Loadbalancing_password: lbc.Loadbalancing_password, - LogFile: lbc.Slog.TofilePath, - Debugflag: lbc.Slog.Debugflag, - } + myHost = lbhost.NewLBHost(lbc.ClusterConfig, lbc.Slog) } - current_list[host] = myHost + hostMap[host] = myHost } } @@ -133,7 +126,7 @@ func (lbc *LBCluster) concatenateIps(myIps []net.IP) string { } //Find_best_hosts Looks for the best hosts for a cluster -func (lbc *LBCluster) FindBestHosts(hosts_to_check map[string]lbhost.LBHost) bool { +func (lbc *LBCluster) FindBestHosts(hosts_to_check map[string]lbhost.Host) (bool, error) { lbc.EvaluateHosts(hosts_to_check) allMetrics := make(map[string]bool) @@ -143,24 +136,28 @@ func (lbc *LBCluster) FindBestHosts(hosts_to_check map[string]lbhost.LBHost) boo _, ok := allMetrics[lbc.Parameters.Metric] if !ok { - lbc.Write_to_log("ERROR", "wrong parameter(metric) in definition of cluster "+lbc.Parameters.Metric) - return false + lbc.Slog.Error("wrong parameter(metric) in definition of cluster " + lbc.Parameters.Metric) + return false, nil } lbc.Time_of_last_evaluation = time.Now() - if !lbc.ApplyMetric(hosts_to_check) { - return false + shouldApplyMetric, err := lbc.ApplyMetric(hosts_to_check) + if err != nil { + return false, err + } + if !shouldApplyMetric { + return false, nil } nodes := lbc.concatenateIps(lbc.Current_best_ips) if len(lbc.Current_best_ips) == 0 { nodes = "NONE" } - lbc.Write_to_log("INFO", "best hosts are: "+nodes) - return true + lbc.Slog.Info("best hosts are: " + nodes) + return true, nil } // ApplyMetric This is the core of the lbcluster: based on the metrics, select the best hosts -func (lbc *LBCluster) ApplyMetric(hosts_to_check map[string]lbhost.LBHost) bool { - lbc.Write_to_log("INFO", "Got metric = "+lbc.Parameters.Metric) +func (lbc *LBCluster) ApplyMetric(hosts_to_check map[string]lbhost.Host) (bool, error) { + lbc.Slog.Info("Got metric = " + lbc.Parameters.Metric) pl := make(NodeList, len(lbc.Host_metric_table)) i := 0 for _, v := range lbc.Host_metric_table { @@ -168,9 +165,12 @@ func (lbc *LBCluster) ApplyMetric(hosts_to_check map[string]lbhost.LBHost) bool i++ } //Let's shuffle the hosts before sorting them, in case some hosts have the same value - Shuffle(len(pl), func(i, j int) { pl[i], pl[j] = pl[j], pl[i] }) + err := Shuffle(len(pl), func(i, j int) { pl[i], pl[j] = pl[j], pl[i] }) + if err != nil { + return false, err + } sort.Sort(pl) - lbc.Write_to_log("DEBUG", fmt.Sprintf("%v", pl)) + lbc.Slog.Debug(fmt.Sprintf("%v", pl)) var sorted_host_list []Node var useful_host_list []Node for _, v := range pl { @@ -179,7 +179,7 @@ func (lbc *LBCluster) ApplyMetric(hosts_to_check map[string]lbhost.LBHost) bool } sorted_host_list = append(sorted_host_list, v) } - lbc.Write_to_log("DEBUG", fmt.Sprintf("%v", useful_host_list)) + lbc.Slog.Debug(fmt.Sprintf("%v", useful_host_list)) useful_hosts := len(useful_host_list) listLength := len(pl) max := lbc.Parameters.Best_hosts @@ -187,17 +187,17 @@ func (lbc *LBCluster) ApplyMetric(hosts_to_check map[string]lbhost.LBHost) bool max = listLength } if max > listLength { - lbc.Write_to_log("WARNING", fmt.Sprintf("impossible to return %v hosts from the list of %v hosts (%v). Check the configuration of cluster. Returning %v hosts.", + lbc.Slog.Warning(fmt.Sprintf("impossible to return %v hosts from the list of %v hosts (%v). Check the configuration of cluster. Returning %v hosts.", max, listLength, lbc.concatenateNodes(sorted_host_list), listLength)) max = listLength } lbc.Current_best_ips = []net.IP{} if listLength == 0 { - lbc.Write_to_log("ERROR", "cluster has no hosts defined ! Check the configuration.") + lbc.Slog.Error("cluster has no hosts defined ! Check the configuration.") } else if useful_hosts == 0 { if lbc.Parameters.Metric == "minimum" { - lbc.Write_to_log("WARNING", fmt.Sprintf("no usable hosts found for cluster! Returning random %v hosts.", max)) + lbc.Slog.Warning(fmt.Sprintf("no usable hosts found for cluster! Returning random %v hosts.", max)) //Get hosts with all IPs even when not OK for SNMP lbc.ReEvaluateHostsForMinimum(hosts_to_check) i := 0 @@ -206,21 +206,24 @@ func (lbc *LBCluster) ApplyMetric(hosts_to_check map[string]lbhost.LBHost) bool i++ } //Let's shuffle the hosts - Shuffle(len(pl), func(i, j int) { pl[i], pl[j] = pl[j], pl[i] }) + err := Shuffle(len(pl), func(i, j int) { pl[i], pl[j] = pl[j], pl[i] }) + if err != nil { + return false, err + } for i := 0; i < max; i++ { lbc.Current_best_ips = append(lbc.Current_best_ips, pl[i].IPs...) } - lbc.Write_to_log("WARNING", fmt.Sprintf("We have put random hosts behind the alias: %v", lbc.Current_best_ips)) + lbc.Slog.Warning(fmt.Sprintf("We have put random hosts behind the alias: %v", lbc.Current_best_ips)) } else if (lbc.Parameters.Metric == "minino") || (lbc.Parameters.Metric == "cmsweb") { - lbc.Write_to_log("WARNING", "no usable hosts found for cluster! Returning no hosts.") + lbc.Slog.Warning("no usable hosts found for cluster! Returning no hosts.") } else if lbc.Parameters.Metric == "cmsfrontier" { - lbc.Write_to_log("WARNING", "no usable hosts found for cluster! Skipping the DNS update") - return false + lbc.Slog.Warning("no usable hosts found for cluster! Skipping the DNS update") + return false, nil } } else { if useful_hosts < max { - lbc.Write_to_log("WARNING", fmt.Sprintf("only %v useable hosts found in cluster", useful_hosts)) + lbc.Slog.Warning(fmt.Sprintf("only %v useable hosts found in cluster", useful_hosts)) max = useful_hosts } for i := 0; i < max; i++ { @@ -228,7 +231,7 @@ func (lbc *LBCluster) ApplyMetric(hosts_to_check map[string]lbhost.LBHost) bool } } - return true + return true, nil } //NewTimeoutClient checks the timeout @@ -288,29 +291,48 @@ func (lbc *LBCluster) checkRogerState(host string) string { } //EvaluateHosts gets the load from the all the nodes -func (lbc *LBCluster) EvaluateHosts(hostsToCheck map[string]lbhost.LBHost) { - - for currenthost := range lbc.Host_metric_table { - host := hostsToCheck[currenthost] - ips, err := host.Get_working_IPs() - if err != nil { - ips, err = host.Get_Ips() - } - lbc.Host_metric_table[currenthost] = Node{host.Get_load_for_alias(lbc.Cluster_name), ips} - lbc.Write_to_log("DEBUG", fmt.Sprintf("node: %s It has a load of %d", currenthost, lbc.Host_metric_table[currenthost].Load)) +func (lbc *LBCluster) EvaluateHosts(hostsToCheck map[string]lbhost.Host) { + var nodeChan = make(chan Node) + defer close(nodeChan) + var wg sync.WaitGroup + newHostMetricMap := make(map[string]Node) + for k, v := range lbc.Host_metric_table { + newHostMetricMap[k] = v } + for currentHost := range newHostMetricMap { + wg.Add(1) + go func(selectedHost string) { + host := hostsToCheck[selectedHost] + ips, err := host.GetWorkingIPs() + if err != nil { + ips, err = host.GetIps() + if err != nil { + lbc.Slog.Error(fmt.Sprintf("error while fetching IPs. error: %v", err)) + } + } + nodeChan <- Node{host.GetLoadForAlias(lbc.ClusterConfig.Cluster_name), ips, selectedHost} + }(currentHost) + } + go func() { + for nodeData := range nodeChan { + lbc.Host_metric_table[nodeData.HostName] = Node{Load: nodeData.Load, IPs: nodeData.IPs} + lbc.Slog.Debug(fmt.Sprintf("node: %s It has a load of %d", nodeData.HostName, lbc.Host_metric_table[nodeData.HostName].Load)) + wg.Done() + } + }() + wg.Wait() } //ReEvaluateHostsForMinimum gets the load from the all the nodes for Minimum metric policy -func (lbc *LBCluster) ReEvaluateHostsForMinimum(hostsToCheck map[string]lbhost.LBHost) { +func (lbc *LBCluster) ReEvaluateHostsForMinimum(hostsToCheck map[string]lbhost.Host) { for currenthost := range lbc.Host_metric_table { host := hostsToCheck[currenthost] - ips, err := host.Get_all_IPs() + ips, err := host.GetAllIPs() if err != nil { - ips, err = host.Get_Ips() + ips, err = host.GetIps() } - lbc.Host_metric_table[currenthost] = Node{host.Get_load_for_alias(lbc.Cluster_name), ips} - lbc.Write_to_log("DEBUG", fmt.Sprintf("node: %s It has a load of %d", currenthost, lbc.Host_metric_table[currenthost].Load)) + lbc.Host_metric_table[currenthost] = Node{host.GetLoadForAlias(lbc.ClusterConfig.Cluster_name), ips, host.GetName()} + lbc.Slog.Debug(fmt.Sprintf("node: %s It has a load of %d", currenthost, lbc.Host_metric_table[currenthost].Load)) } } diff --git a/lbcluster/lbcluster_dns.go b/lbcluster/lbcluster_dns.go index 390e89b..4f170cb 100644 --- a/lbcluster/lbcluster_dns.go +++ b/lbcluster/lbcluster_dns.go @@ -2,6 +2,7 @@ package lbcluster import ( "fmt" + "lb-experts/golbd/metric" "net" "time" @@ -14,26 +15,26 @@ func (lbc *LBCluster) RefreshDNS(dnsManager, keyPrefix, internalKey, externalKey e := lbc.GetStateDNS(dnsManager) if e != nil { - lbc.Write_to_log("WARNING", fmt.Sprintf("Get_state_dns Error: %v", e.Error())) + lbc.Slog.Warning(fmt.Sprintf("Get_state_dns Error: %v", e.Error())) } pbiDNS := lbc.concatenateIps(lbc.Previous_best_ips_dns) cbi := lbc.concatenateIps(lbc.Current_best_ips) if pbiDNS == cbi { - lbc.Write_to_log("INFO", fmt.Sprintf("DNS not update keyName %v cbh == pbhDns == %v", keyPrefix, cbi)) + lbc.Slog.Info(fmt.Sprintf("DNS not update keyName %v cbh == pbhDns == %v", keyPrefix, cbi)) return } - lbc.Write_to_log("INFO", fmt.Sprintf("Updating the DNS with %v (previous state was %v)", cbi, pbiDNS)) + lbc.Slog.Info(fmt.Sprintf("Updating the DNS with %v (previous state was %v)", cbi, pbiDNS)) e = lbc.updateDNS(keyPrefix+"internal.", internalKey, dnsManager) if e != nil { - lbc.Write_to_log("WARNING", fmt.Sprintf("Internal Update_dns Error: %v", e.Error())) + lbc.Slog.Warning(fmt.Sprintf("Internal Update_dns Error: %v", e.Error())) } if lbc.externallyVisible() { e = lbc.updateDNS(keyPrefix+"external.", externalKey, dnsManager) if e != nil { - lbc.Write_to_log("WARNING", fmt.Sprintf("External Update_dns Error: %v", e.Error())) + lbc.Slog.Warning(fmt.Sprintf("External Update_dns Error: %v", e.Error())) } } } @@ -51,41 +52,59 @@ func (lbc *LBCluster) updateDNS(keyName, tsigKey, dnsManager string) error { } //best_hosts_len := len(lbc.Current_best_hosts) m := new(dns.Msg) - m.SetUpdate(lbc.Cluster_name + ".") + m.SetUpdate(lbc.ClusterConfig.Cluster_name + ".") m.Id = 1234 - rrRemoveA, _ := dns.NewRR(lbc.Cluster_name + ". " + ttl + " IN A 127.0.0.1") - rrRemoveAAAA, _ := dns.NewRR(lbc.Cluster_name + ". " + ttl + " IN AAAA ::1") + rrRemoveA, _ := dns.NewRR(lbc.ClusterConfig.Cluster_name + ". " + ttl + " IN A 127.0.0.1") + rrRemoveAAAA, _ := dns.NewRR(lbc.ClusterConfig.Cluster_name + ". " + ttl + " IN AAAA ::1") m.RemoveRRset([]dns.RR{rrRemoveA}) m.RemoveRRset([]dns.RR{rrRemoveAAAA}) - + retryModule := NewRetryModule(5*time.Second, lbc.Slog) + err := retryModule.SetMaxCount(10) + if err != nil { + return err + } for _, ip := range lbc.Current_best_ips { var rrInsert dns.RR if ip.To4() != nil { - rrInsert, _ = dns.NewRR(lbc.Cluster_name + ". " + ttl + " IN A " + ip.String()) + rrInsert, _ = dns.NewRR(lbc.ClusterConfig.Cluster_name + ". " + ttl + " IN A " + ip.String()) } else if ip.To16() != nil { - rrInsert, _ = dns.NewRR(lbc.Cluster_name + ". " + ttl + " IN AAAA " + ip.String()) + rrInsert, _ = dns.NewRR(lbc.ClusterConfig.Cluster_name + ". " + ttl + " IN AAAA " + ip.String()) } m.Insert([]dns.RR{rrInsert}) } - lbc.Write_to_log("INFO", fmt.Sprintf("WE WOULD UPDATE THE DNS WITH THE IPS %v", m)) + lbc.Slog.Info(fmt.Sprintf("WE WOULD UPDATE THE DNS WITH THE IPS %v", m)) c := new(dns.Client) m.SetTsig(keyName, dns.HmacMD5, 300, time.Now().Unix()) c.TsigSecret = map[string]string{keyName: tsigKey} - _, _, err := c.Exchange(m, dnsManager) + updateStartTime := time.Now() + err = retryModule.Execute(func() error { + _, _, err := c.Exchange(m, dnsManager+":53") + return err + }) + updateEndTime := time.Now() + if lbc.MetricLogic != nil { + err := lbc.MetricLogic.WriteRecord(metric.Property{ + RoundTripStartTime: updateStartTime, + RoundTripEndTime: updateEndTime, + }) + if err != nil { + return err + } + } if err != nil { - lbc.Write_to_log("ERROR", fmt.Sprintf("DNS update failed with (%v)", err)) + lbc.Slog.Error(fmt.Sprintf("DNS update failed with (%v)", err)) return err } - lbc.Write_to_log("INFO", fmt.Sprintf("DNS update with keyName %v", keyName)) + lbc.Slog.Info(fmt.Sprintf("DNS update with keyName %v", keyName)) return nil } func (lbc *LBCluster) getIpsFromDNS(m *dns.Msg, dnsManager string, dnsType uint16, ips *[]net.IP) error { - m.SetQuestion(lbc.Cluster_name+".", dnsType) + m.SetQuestion(lbc.ClusterConfig.Cluster_name+".", dnsType) in, err := dns.Exchange(m, dnsManager) if err != nil { - lbc.Write_to_log("ERROR", fmt.Sprintf("Error getting the ipv4 state of dns: %v", err)) + lbc.Slog.Error(fmt.Sprintf("Error getting the ipv4 state of dns: %v", err)) return err } for _, a := range in.Answer { @@ -104,7 +123,7 @@ func (lbc *LBCluster) GetStateDNS(dnsManager string) error { m := new(dns.Msg) var ips []net.IP m.SetEdns0(4096, false) - lbc.Write_to_log("DEBUG", "Getting the ips from the DNS") + lbc.Slog.Debug("Getting the ips from the DNS") err := lbc.getIpsFromDNS(m, dnsManager, dns.TypeA, &ips) if err != nil { @@ -115,7 +134,7 @@ func (lbc *LBCluster) GetStateDNS(dnsManager string) error { return err } - lbc.Write_to_log("INFO", fmt.Sprintf("Let's keep the list of ips : %v", ips)) + lbc.Slog.Info(fmt.Sprintf("Let's keep the list of ips : %v", ips)) lbc.Previous_best_ips_dns = ips return nil diff --git a/lbcluster/lbcluster_log.go b/lbcluster/lbcluster_log.go deleted file mode 100644 index 6ed72c8..0000000 --- a/lbcluster/lbcluster_log.go +++ /dev/null @@ -1,128 +0,0 @@ -package lbcluster - -import ( - "fmt" - "log/syslog" - "os" - "strings" - "sync" - "time" -) - -//Log struct for the log -type Log struct { - SyslogWriter *syslog.Writer - Stdout bool - Debugflag bool - TofilePath string - logMu sync.Mutex -} - -//Logger struct for the Logger interface -type Logger interface { - Info(s string) error - Warning(s string) error - Debug(s string) error - Error(s string) error -} - -//Write_to_log put something in the log file -func (lbc *LBCluster) Write_to_log(level string, msg string) error { - - myMessage := "cluster: " + lbc.Cluster_name + " " + msg - - if level == "INFO" { - lbc.Slog.Info(myMessage) - } else if level == "DEBUG" { - lbc.Slog.Debug(myMessage) - } else if level == "WARNING" { - lbc.Slog.Warning(myMessage) - } else if level == "ERROR" { - lbc.Slog.Error(myMessage) - } else { - lbc.Slog.Error("LEVEL " + level + " NOT UNDERSTOOD, ASSUMING ERROR " + myMessage) - } - - return nil -} - -//Info write as Info -func (l *Log) Info(s string) error { - var err error - if l.SyslogWriter != nil { - err = l.SyslogWriter.Info(s) - } - if l.Stdout || (l.TofilePath != "") { - err = l.writefilestd("INFO: " + s) - } - return err - -} - -//Warning write as Warning -func (l *Log) Warning(s string) error { - var err error - if l.SyslogWriter != nil { - err = l.SyslogWriter.Warning(s) - } - if l.Stdout || (l.TofilePath != "") { - err = l.writefilestd("WARNING: " + s) - } - return err - -} - -//Debug write as Debug -func (l *Log) Debug(s string) error { - var err error - if l.Debugflag { - if l.SyslogWriter != nil { - err = l.SyslogWriter.Debug(s) - } - if l.Stdout || (l.TofilePath != "") { - err = l.writefilestd("DEBUG: " + s) - } - } - return err - -} - -//Error write as Error -func (l *Log) Error(s string) error { - var err error - if l.SyslogWriter != nil { - err = l.SyslogWriter.Err(s) - } - if l.Stdout || (l.TofilePath != "") { - err = l.writefilestd("ERROR: " + s) - } - return err - -} - -func (l *Log) writefilestd(s string) error { - var err error - tag := "lbd" - nl := "" - if !strings.HasSuffix(s, "\n") { - nl = "\n" - } - timestamp := time.Now().Format(time.StampMilli) - msg := fmt.Sprintf("%s %s[%d]: %s%s", - timestamp, - tag, os.Getpid(), s, nl) - l.logMu.Lock() - defer l.logMu.Unlock() - if l.Stdout { - _, err = fmt.Printf(msg) - } - if l.TofilePath != "" { - f, err := os.OpenFile(l.TofilePath, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0640) - if err != nil { - return err - } - defer f.Close() - _, err = fmt.Fprintf(f, msg) - } - return err -} diff --git a/lbcluster/retry_module.go b/lbcluster/retry_module.go new file mode 100644 index 0000000..1e22bb4 --- /dev/null +++ b/lbcluster/retry_module.go @@ -0,0 +1,131 @@ +package lbcluster + +import ( + "fmt" + "lb-experts/golbd/logger" + "time" +) + +type Retry struct { + signal chan int + done chan bool + tick chan bool + currentCount int + maxCount int + retryStarted bool + maxDuration time.Duration + retryDuration time.Duration + prevRetryDuration time.Duration + logger logger.Logger +} + +const defaultMaxDuration = 5 * time.Minute + +func NewRetryModule(retryStartDuration time.Duration, logger logger.Logger) *Retry { + retry := &Retry{ + maxCount: -1, + currentCount: 0, + maxDuration: defaultMaxDuration, + retryDuration: retryStartDuration, + prevRetryDuration: retryStartDuration, + logger: logger, + } + return retry +} + +func (r *Retry) SetMaxDuration(maxDuration time.Duration) error { + if r.retryStarted { + return fmt.Errorf("retry routine has already started") + } + if maxDuration <= 0 { + return fmt.Errorf("duration has to be greater than 0") + } + r.maxDuration = maxDuration + return nil +} + +func (r *Retry) SetMaxCount(maxCount int) error { + if r.retryStarted { + return fmt.Errorf("retry routine has already started") + } + if maxCount <= 0 { + return fmt.Errorf("max count has to be greater than 0") + } + r.maxCount = maxCount + return nil +} + +func (r *Retry) start() { + r.retryStarted = true + signal := make(chan int) + done := make(chan bool) + end := time.Tick(r.maxDuration) + r.signal = signal + r.done = done + go func() { + for { + select { + case <-end: + r.done <- true + return + default: + if r.currentCount == r.maxCount { + r.done <- true + return + } + } + } + }() + r.run() +} + +func (r *Retry) run() { + start := make(chan bool) + go func() { + ticker := time.NewTicker(1 * time.Minute) + defer close(r.done) + defer close(r.signal) + defer close(start) + defer ticker.Stop() + + for { + select { + case <-r.done: + return + case <-ticker.C: + r.nextTick(ticker) + case <-start: + r.nextTick(ticker) + } + } + }() + start <- true +} + +func (r *Retry) nextTick(ticker *time.Ticker) { + r.signal <- r.currentCount + 1 + r.currentCount += 1 + ticker.Reset(r.retryDuration) + r.computeNextRetryTime() +} + +func (r *Retry) Execute(executor func() error) error { + var err error + r.start() + for retryCount := range r.signal { + err = executor() + if err != nil { + r.logger.Debug(fmt.Sprintf("retry count: %v", retryCount)) + } else { + r.done <- true + } + } + return err +} + +// using fibonacci algorithm to compute the next run time +func (r *Retry) computeNextRetryTime() { + nextDuration := r.retryDuration + r.prevRetryDuration + r.prevRetryDuration = r.retryDuration + r.retryDuration = nextDuration +} diff --git a/lbconfig/config.go b/lbconfig/config.go index 942ffb2..1ce91ee 100644 --- a/lbconfig/config.go +++ b/lbconfig/config.go @@ -5,129 +5,246 @@ import ( "encoding/json" "fmt" "io" + "lb-experts/golbd/metric" "net" "os" "strconv" "strings" "sync" + "time" - "gitlab.cern.ch/lb-experts/golbd/lbcluster" "gopkg.in/yaml.v3" + + "lb-experts/golbd/lbcluster" + "lb-experts/golbd/logger" + "lb-experts/golbd/model" +) + +const ( + DefaultLoadBalancerConfig = "loadbalancing" + DefaultMetricsDirectoryPath = "" ) +type Config interface { + GetMasterHost() string + GetHeartBeatFileName() string + GetHeartBeatDirPath() string + GetMetricDirectoryPath() string + GetDNSManager() string + GetTSIGKeyPrefix() string + GetTSIGInternalKey() string + GetTSIGExternalKey() string + LockHeartBeatMutex() + UnlockHeartBeatMutex() + WatchFileChange(controlChan <-chan bool, waitGroup sync.WaitGroup) <-chan ConfigFileChangeSignal + Load() ([]lbcluster.LBCluster, error) + LoadClusters() ([]lbcluster.LBCluster, error) + + // testing only + SetMasterHost(masterHostName string) + SetHeartBeatFileName(heartBeatFileName string) + SetHeartBeatDirPath(heartBeatDirPath string) + SetDNSManager(dnsManager string) + SetTSIGKeyPrefix(tsigKeyPrefix string) + SetTSIGInternalKey(tsigInternalKey string) + SetTSIGExternalKey(tsigExternalKey string) + SetClusters(clusters map[string][]string) + SetSNMPPassword(password string) + SetParameters(params map[string]lbcluster.Params) +} + // Config this is the configuration of the lbd -type Config struct { - Master string - HeartbeatFile string - HeartbeatPath string - HeartbeatMu sync.Mutex - TsigKeyPrefix string - TsigInternalKey string - TsigExternalKey string - SnmpPassword string - DNSManager string - ConfigFile string - Clusters map[string][]string - Parameters map[string]lbcluster.Params -} - -func LoadConfig(configFile string, lg *lbcluster.Log) (*Config, []lbcluster.LBCluster, error) { - var configFunc func(configFile string, lg *lbcluster.Log) (*Config, []lbcluster.LBCluster, error) - - if strings.HasSuffix(configFile, ".yaml") { - configFunc = loadConfigYaml - } else { - configFunc = loadConfigOriginal - } +type LBConfig struct { + Master string + HeartbeatFile string + HeartbeatPath string + HeartbeatMu sync.Mutex + TsigKeyPrefix string + TsigInternalKey string + TsigExternalKey string + SnmpPassword string + DNSManager string + configFilePath string + lbLog logger.Logger + Clusters map[string][]string + Parameters map[string]lbcluster.Params + metricDirectoryPath string +} - return configFunc(configFile, lg) +type ConfigFileChangeSignal struct { + readSignal bool + readError error } -// readLines reads a whole file into memory and returns a slice of lines. -func readLines(path string) (lines []string, err error) { - f, err := os.Open(path) - if err != nil { - return nil, err - } - defer f.Close() +func (fs ConfigFileChangeSignal) IsErrorPresent() bool { + return fs.readError != nil +} - sc := bufio.NewScanner(f) - for sc.Scan() { - lines = append(lines, sc.Text()) +// NewLoadBalancerConfig - instantiates a new load balancer config +func NewLoadBalancerConfig(configFilePath string, lbClusterLog logger.Logger) Config { + return &LBConfig{ + configFilePath: configFilePath, + lbLog: lbClusterLog, } - return lines, sc.Err() } -//LoadClusters checks the syntax of the clusters defined in the configuration file -func LoadClusters(config *Config, lg *lbcluster.Log) ([]lbcluster.LBCluster, error) { - var lbc lbcluster.LBCluster - var lbcs []lbcluster.LBCluster +func (c *LBConfig) GetMasterHost() string { + return c.Master +} - for k, v := range config.Clusters { - if len(v) == 0 { - lg.Warning("cluster: " + k + " ignored as it has no members defined in the configuration file " + config.ConfigFile) - continue +func (c *LBConfig) SetMasterHost(masterHostName string) { + c.Master = masterHostName +} + +func (c *LBConfig) GetMetricDirectoryPath() string { + return c.metricDirectoryPath +} + +func (c *LBConfig) GetHeartBeatFileName() string { + return c.HeartbeatFile +} + +func (c *LBConfig) SetHeartBeatFileName(heartBeatFileName string) { + c.HeartbeatFile = heartBeatFileName +} + +func (c *LBConfig) GetHeartBeatDirPath() string { + return c.HeartbeatPath +} + +func (c *LBConfig) SetHeartBeatDirPath(heartBeatDirPath string) { + c.HeartbeatPath = heartBeatDirPath +} + +func (c *LBConfig) GetDNSManager() string { + return c.DNSManager +} + +func (c *LBConfig) SetDNSManager(dnsManager string) { + c.DNSManager = dnsManager +} + +func (c *LBConfig) GetTSIGKeyPrefix() string { + return c.TsigKeyPrefix +} + +func (c *LBConfig) SetTSIGKeyPrefix(tsigKeyPrefix string) { + c.TsigKeyPrefix = tsigKeyPrefix +} + +func (c *LBConfig) GetTSIGInternalKey() string { + return c.TsigInternalKey +} + +func (c *LBConfig) SetTSIGInternalKey(tsigInternalKey string) { + c.TsigInternalKey = tsigInternalKey +} + +func (c *LBConfig) GetTSIGExternalKey() string { + return c.TsigExternalKey +} + +func (c *LBConfig) SetTSIGExternalKey(tsigExternalKey string) { + c.TsigExternalKey = tsigExternalKey +} + +func (c *LBConfig) SetClusters(clusters map[string][]string) { + c.Clusters = clusters +} + +func (c *LBConfig) SetParameters(params map[string]lbcluster.Params) { + c.Parameters = params +} + +func (c *LBConfig) SetSNMPPassword(password string) { + c.SnmpPassword = password +} + +func (c *LBConfig) LockHeartBeatMutex() { + c.HeartbeatMu.Lock() +} + +func (c *LBConfig) UnlockHeartBeatMutex() { + c.HeartbeatMu.Unlock() +} + +func (c *LBConfig) WatchFileChange(controlChan <-chan bool, waitGroup sync.WaitGroup) <-chan ConfigFileChangeSignal { + fileWatcherChan := make(chan ConfigFileChangeSignal) + waitGroup.Add(1) + go func() { + defer close(fileWatcherChan) + defer waitGroup.Done() + initialStat, err := os.Stat(c.configFilePath) + if err != nil { + fileWatcherChan <- ConfigFileChangeSignal{readError: err} } - if par, ok := config.Parameters[k]; ok { - lbc = lbcluster.LBCluster{Cluster_name: k, Loadbalancing_username: "loadbalancing", - Loadbalancing_password: config.SnmpPassword, Parameters: par, - Current_best_ips: []net.IP{}, - Previous_best_ips_dns: []net.IP{}, - Slog: lg} - hm := make(map[string]lbcluster.Node) - for _, h := range v { - hm[h] = lbcluster.Node{Load: 100000, IPs: []net.IP{}} + secondTicker := time.NewTicker(1 * time.Second) + for { + select { + case <-secondTicker.C: + stat, err := os.Stat(c.configFilePath) + if err != nil { + fileWatcherChan <- ConfigFileChangeSignal{readError: err} + continue + } + if stat.Size() != initialStat.Size() || stat.ModTime() != initialStat.ModTime() { + fileWatcherChan <- ConfigFileChangeSignal{readSignal: true} + initialStat = stat + } + case <-controlChan: + return } - lbc.Host_metric_table = hm - lbcs = append(lbcs, lbc) - lbc.Write_to_log("INFO", "(re-)loaded cluster ") - - } else { - lg.Warning("cluster: " + k + " missing parameters for cluster; ignoring the cluster, please check the configuration file " + config.ConfigFile) } - } + }() + return fileWatcherChan - return lbcs, nil +} +func (c *LBConfig) Load() ([]lbcluster.LBCluster, error) { + var configFunc func() ([]lbcluster.LBCluster, error) + + if strings.HasSuffix(c.configFilePath, ".yaml") { + configFunc = c.loadConfigYaml + } else { + configFunc = c.loadConfigOriginal + } + return configFunc() } //LoadConfigYaml reads a YAML configuration file and returns a struct with the config -func loadConfigYaml(configFile string, lg *lbcluster.Log) (*Config, []lbcluster.LBCluster, error) { - var config Config - - configBytes, err := os.ReadFile(configFile) +func (c *LBConfig) loadConfigYaml() ([]lbcluster.LBCluster, error) { + var logger = c.lbLog + configBytes, err := os.ReadFile(c.configFilePath) if err != nil { - return nil, nil, err + return nil, err } - if err := yaml.Unmarshal(configBytes, &config); err != nil { - return nil, nil, err + if err := yaml.Unmarshal(configBytes, &c); err != nil { + return nil, err } - config.ConfigFile = configFile - - lbclusters, err := LoadClusters(&config, lg) + c.lbLog = logger + lbclusters, err := c.LoadClusters() if err != nil { fmt.Println("Error getting the clusters") - return nil, nil, err + return nil, err } - lg.Info("Clusters loaded") + c.lbLog.Info("Clusters loaded") - return &config, lbclusters, nil + return lbclusters, nil } //LoadConfig reads a configuration file and returns a struct with the config -func loadConfigOriginal(configFile string, lg *lbcluster.Log) (*Config, []lbcluster.LBCluster, error) { +func (c *LBConfig) loadConfigOriginal() ([]lbcluster.LBCluster, error) { var ( - config Config - p lbcluster.Params - mc = make(map[string][]string) - mp = make(map[string]lbcluster.Params) + p lbcluster.Params + mc = make(map[string][]string) + mp = make(map[string]lbcluster.Params) ) - lines, err := readLines(configFile) + lines, err := readLines(c.configFilePath) if err != nil { - return nil, nil, err + return nil, err } for _, line := range lines { if strings.HasPrefix(line, "#") || (line == "") { @@ -137,23 +254,23 @@ func loadConfigOriginal(configFile string, lg *lbcluster.Log) (*Config, []lbclus if words[1] == "=" { switch words[0] { case "master": - config.Master = words[2] + c.Master = words[2] case "heartbeat_path": - config.HeartbeatPath = words[2] + c.HeartbeatPath = words[2] case "heartbeat_file": - config.HeartbeatFile = words[2] + c.HeartbeatFile = words[2] case "tsig_key_prefix": - config.TsigKeyPrefix = words[2] + c.TsigKeyPrefix = words[2] case "tsig_internal_key": - config.TsigInternalKey = words[2] + c.TsigInternalKey = words[2] case "tsig_external_key": - config.TsigExternalKey = words[2] + c.TsigExternalKey = words[2] case "snmpd_password": - config.SnmpPassword = words[2] + c.SnmpPassword = words[2] case "dns_manager": - config.DNSManager = words[2] - if !strings.Contains(config.DNSManager, ":") { - config.DNSManager += ":53" + c.DNSManager = words[2] + if !strings.Contains(c.DNSManager, ":") { + c.DNSManager += ":53" } } } else if words[2] == "=" { @@ -180,28 +297,88 @@ func loadConfigOriginal(configFile string, lg *lbcluster.Log) (*Config, []lbclus break } else if err != nil { //log.Fatal(err) - lg.Warning(fmt.Sprintf("%v", err)) + c.lbLog.Warning(fmt.Sprintf("%v", err)) os.Exit(1) } mp[words[1]] = p } else if words[0] == "clusters" { mc[words[1]] = words[3:] - lg.Debug(words[1]) - lg.Debug(fmt.Sprintf("%v", words[3:])) + c.lbLog.Debug(words[1]) + c.lbLog.Debug(fmt.Sprintf("%v", words[3:])) } } } - config.Parameters = mp - config.Clusters = mc - config.ConfigFile = configFile + c.Parameters = mp + c.Clusters = mc - lbclusters, err := LoadClusters(&config, lg) + lbclusters, err := c.LoadClusters() if err != nil { fmt.Println("Error getting the clusters") - return nil, nil, err + return nil, err + } + c.lbLog.Info("Clusters loaded") + + return lbclusters, nil + +} + +// readLines reads a whole file into memory and returns a slice of lines. +func readLines(path string) (lines []string, err error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer f.Close() + + sc := bufio.NewScanner(f) + for sc.Scan() { + lines = append(lines, sc.Text()) } - lg.Info("Clusters loaded") + return lines, sc.Err() +} + +//LoadClusters checks the syntax of the clusters defined in the configuration file +func (c *LBConfig) LoadClusters() ([]lbcluster.LBCluster, error) { + var lbc lbcluster.LBCluster + var lbcs []lbcluster.LBCluster + hostName, err := os.Hostname() + if err != nil { + return nil, err + } + metricLogic := metric.NewLogic(DefaultMetricsDirectoryPath, hostName) + for k, v := range c.Clusters { + if len(v) == 0 { + c.lbLog.Warning("cluster: " + k + " ignored as it has no members defined in the configuration file " + c.configFilePath) + continue + } + if par, ok := c.Parameters[k]; ok { + lbcConfig := model.ClusterConfig{ + Cluster_name: k, + Loadbalancing_username: DefaultLoadBalancerConfig, + Loadbalancing_password: c.SnmpPassword, + } + lbc = lbcluster.LBCluster{ + ClusterConfig: lbcConfig, + Parameters: par, + Current_best_ips: []net.IP{}, + Previous_best_ips_dns: []net.IP{}, + Slog: c.lbLog, + MetricLogic: metricLogic, + } + hm := make(map[string]lbcluster.Node) + for _, h := range v { + hm[h] = lbcluster.Node{Load: 100000, IPs: []net.IP{}} + } + lbc.Host_metric_table = hm + lbcs = append(lbcs, lbc) + lbc.Slog.Info("(re-)loaded cluster ") + + } else { + c.lbLog.Warning("cluster: " + k + " missing parameters for cluster; ignoring the cluster, please check the configuration file " + c.configFilePath) + } + } + + return lbcs, nil - return &config, lbclusters, nil } diff --git a/lbd.go b/lbd.go index 17264e8..ca38898 100644 --- a/lbd.go +++ b/lbd.go @@ -4,19 +4,20 @@ import ( "flag" "fmt" "io/ioutil" - "log/syslog" + "lb-experts/golbd/metric" + "log" "math/rand" "os" - "os/signal" "regexp" "strconv" "strings" - "syscall" + "sync" "time" - "gitlab.cern.ch/lb-experts/golbd/lbcluster" - "gitlab.cern.ch/lb-experts/golbd/lbconfig" - "gitlab.cern.ch/lb-experts/golbd/lbhost" + "lb-experts/golbd/lbcluster" + "lb-experts/golbd/lbconfig" + "lb-experts/golbd/lbhost" + "lb-experts/golbd/logger" ) var ( @@ -37,19 +38,24 @@ var ( stdoutFlag = flag.Bool("stdout", false, "send log to stdtout") ) -const itCSgroupDNSserver string = "cfmgr.cern.ch" +const ( + shouldStartMetricServer = false // server disabled by default + itCSgroupDNSserver = "cfmgr.cern.ch" + DefaultSleepDuration = 10 + DefaultLbdTag = "lbd" + DefaultConnectionTimeout = 10 * time.Second + DefaultReadTimeout = 20 * time.Second +) -func shouldUpdateDNS(config *lbconfig.Config, hostname string, lg *lbcluster.Log) bool { - if hostname == config.Master { +func shouldUpdateDNS(config lbconfig.Config, hostname string, lg logger.Logger) bool { + if strings.EqualFold(hostname, config.GetMasterHost()) { return true } masterHeartbeat := "I am sick" - connectTimeout := (10 * time.Second) - readWriteTimeout := (20 * time.Second) - httpClient := lbcluster.NewTimeoutClient(connectTimeout, readWriteTimeout) - response, err := httpClient.Get("http://" + config.Master + "/load-balancing/" + config.HeartbeatFile) + httpClient := lbcluster.NewTimeoutClient(DefaultConnectionTimeout, DefaultReadTimeout) + response, err := httpClient.Get("http://" + config.GetMasterHost() + "/load-balancing/" + config.GetHeartBeatFileName()) if err != nil { - lg.Warning(fmt.Sprintf("problem fetching heartbeat file from the primary master %v: %v", config.Master, err)) + lg.Warning(fmt.Sprintf("problem fetching heartbeat file from the primary master %v: %v", config.GetMasterHost(), err)) return true } defer response.Body.Close() @@ -60,7 +66,7 @@ func shouldUpdateDNS(config *lbconfig.Config, hostname string, lg *lbcluster.Log lg.Debug(fmt.Sprintf("%s", contents)) masterHeartbeat = strings.TrimSpace(string(contents)) lg.Info("primary master heartbeat: " + masterHeartbeat) - r, _ := regexp.Compile(config.Master + ` : (\d+) : I am alive`) + r, _ := regexp.Compile(config.GetMasterHost() + ` : (\d+) : I am alive`) if r.MatchString(masterHeartbeat) { matches := r.FindStringSubmatch(masterHeartbeat) lg.Debug(fmt.Sprintf(matches[1])) @@ -82,131 +88,130 @@ func shouldUpdateDNS(config *lbconfig.Config, hostname string, lg *lbcluster.Log } -func updateHeartbeat(config *lbconfig.Config, hostname string, lg *lbcluster.Log) error { - if hostname != config.Master { +func updateHeartbeat(config lbconfig.Config, hostname string, lg logger.Logger) error { + if hostname != config.GetMasterHost() { return nil } - heartbeatFile := config.HeartbeatPath + "/" + config.HeartbeatFile + "temp" - heartbeatFileReal := config.HeartbeatPath + "/" + config.HeartbeatFile + heartbeatTempFilePath := config.GetHeartBeatDirPath() + "/" + config.GetHeartBeatFileName() + "temp" + heartbeatFileRealFilePath := config.GetHeartBeatDirPath() + "/" + config.GetHeartBeatFileName() - config.HeartbeatMu.Lock() - defer config.HeartbeatMu.Unlock() + //todo: read from channel + config.LockHeartBeatMutex() + defer config.UnlockHeartBeatMutex() - f, err := os.OpenFile(heartbeatFile, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0644) + err := updateHeartBeatToFile(heartbeatTempFilePath, hostname, lg) if err != nil { - lg.Error(fmt.Sprintf("can not open %v for writing: %v", heartbeatFile, err)) return err } - now := time.Now() - secs := now.Unix() - _, err = fmt.Fprintf(f, "%v : %v : I am alive\n", hostname, secs) - lg.Info("updating: heartbeat file " + heartbeatFile) - if err != nil { - lg.Info(fmt.Sprintf("can not write to %v: %v", heartbeatFile, err)) - } - f.Close() - if err = os.Rename(heartbeatFile, heartbeatFileReal); err != nil { - lg.Error(fmt.Sprintf("can not rename %v to %v: %v", heartbeatFile, heartbeatFileReal, err)) + // todo: could the file be reused for any other use cases? + if err = os.Rename(heartbeatTempFilePath, heartbeatFileRealFilePath); err != nil { + lg.Error(fmt.Sprintf("can not rename %v to %v: %v", heartbeatTempFilePath, heartbeatFileRealFilePath, err)) return err } return nil } -func installSignalHandler(sighup, sigterm *bool, lg *lbcluster.Log) { - c := make(chan os.Signal, 1) - signal.Notify(c, syscall.SIGTERM, syscall.SIGHUP) +func updateHeartBeatToFile(heartBeatFilePath string, hostname string, lg logger.Logger) error { + secs := time.Now().Unix() + f, err := os.OpenFile(heartBeatFilePath, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0644) + defer f.Close() + if err != nil { + lg.Error(fmt.Sprintf("can not open %v for writing: %v", heartBeatFilePath, err)) + return err + } + _, err = fmt.Fprintf(f, "%v : %v : I am alive\n", hostname, secs) + lg.Info("updating: heartbeat file " + heartBeatFilePath) + if err != nil { + lg.Info(fmt.Sprintf("can not write to %v: %v", heartBeatFilePath, err)) + } + return nil +} +func sleep(seconds time.Duration, controlChan <-chan bool, waitGroup sync.WaitGroup) <-chan bool { + sleepSignalChan := make(chan bool) + waitGroup.Add(1) + secondsTicker := time.NewTicker(seconds * time.Second) go func() { + defer waitGroup.Done() for { - // Block until a signal is received. - sig := <-c - lg.Info(fmt.Sprintf("\nGiven signal: %v\n", sig)) - switch sig { - case syscall.SIGHUP: - *sighup = true - case syscall.SIGTERM: - *sigterm = true + select { + case <-secondsTicker.C: + sleepSignalChan <- true + break + case <-controlChan: + return } } }() + return sleepSignalChan } -/* Using this one (instead of fsnotify) -to check also if the file has been moved*/ -func watchFile(filePath string, chanModified chan int) error { - initialStat, err := os.Stat(filePath) +func main() { + wg := sync.WaitGroup{} + logger, err := logger.NewLoggerFactory(*logFileFlag) if err != nil { - return err + fmt.Printf("error during log initialization. error: %v", err) + os.Exit(1) } - for { - stat, err := os.Stat(filePath) - if err == nil { - if stat.Size() != initialStat.Size() || stat.ModTime() != initialStat.ModTime() { - chanModified <- 1 - initialStat = stat - } - } - time.Sleep(1 * time.Second) + if *stdoutFlag { + logger.EnableWriteToSTd() } -} - -func sleep(seconds time.Duration, chanModified chan int) error { - for { - chanModified <- 2 - time.Sleep(seconds * time.Second) - } - return nil -} - -func main() { + controlChan := make(chan bool) + defer close(controlChan) + defer wg.Done() + defer logger.Error("The lbd is not supposed to stop") flag.Parse() if *versionFlag { fmt.Printf("This is a proof of concept golbd version: %s-%s \n", Version, Release) os.Exit(0) } rand.Seed(time.Now().UTC().UnixNano()) - log, e := syslog.New(syslog.LOG_NOTICE, "lbd") - - if e != nil { - fmt.Printf("Error getting a syslog instance %v\nThe service will only write to the logfile %v\n\n", e, *logFileFlag) - } - lg := lbcluster.Log{SyslogWriter: log, Stdout: *stdoutFlag, Debugflag: *debugFlag, TofilePath: *logFileFlag} - - lg.Info("Starting lbd") - - // var sig_hup, sig_term bool - // installSignalHandler(&sig_hup, &sig_term, &lg) - config, lbclusters, err := lbconfig.LoadConfig(*configFileFlag, &lg) + logger.Info("Starting lbd") + lbConfig := lbconfig.NewLoadBalancerConfig(*configFileFlag, logger) + lbclusters, err := lbConfig.Load() if err != nil { - lg.Warning("loadConfig Error: ") - lg.Warning(err.Error()) + logger.Warning("loadConfig Error: ") + logger.Warning(err.Error()) os.Exit(1) } - lg.Info("Clusters loaded") + logger.Info("Clusters loaded") - doneChan := make(chan int) - go watchFile(*configFileFlag, doneChan) - go sleep(10, doneChan) + fileChangeSignal := lbConfig.WatchFileChange(controlChan, wg) + intervalTickerSignal := sleep(DefaultSleepDuration, controlChan, wg) + if shouldStartMetricServer { + go func() { + err := metric.NewMetricServer(lbconfig.DefaultMetricsDirectoryPath) + if err != nil { + logger.Error(fmt.Sprintf("error while starting metric server . error: %v", err)) + } + }() + } for { - myValue := <-doneChan - if myValue == 1 { - lg.Info("Config Changed") - config, lbclusters, err = lbconfig.LoadConfig(*configFileFlag, &lg) + select { + case fileWatcherData := <-fileChangeSignal: + if fileWatcherData.IsErrorPresent() { + // stop all operations + controlChan <- true + return + } + logger.Info("ClusterConfig Changed") + lbclusters, err = lbConfig.Load() if err != nil { - lg.Error(fmt.Sprintf("Error getting the clusters (something wrong in %v", configFileFlag)) + logger.Error(fmt.Sprintf("Error getting the clusters (something wrong in %v", configFileFlag)) } - } else if myValue == 2 { - checkAliases(config, lg, lbclusters) - } else { - lg.Error("Got an unexpected value") + case <-intervalTickerSignal: + checkAliases(lbConfig, logger, lbclusters) + break } } - lg.Error("The lbd is not supposed to stop") - } -func checkAliases(config *lbconfig.Config, lg lbcluster.Log, lbclusters []lbcluster.LBCluster) { + +func checkAliases(config lbconfig.Config, lg logger.Logger, lbclusters []lbcluster.LBCluster) { + hostCheckChannel := make(chan lbhost.Host) + defer close(hostCheckChannel) + hostname, e := os.Hostname() if e == nil { lg.Info("Hostname: " + hostname) @@ -215,55 +220,58 @@ func checkAliases(config *lbconfig.Config, lg lbcluster.Log, lbclusters []lbclus //var wg sync.WaitGroup updateDNS := true lg.Info("Checking if any of the " + strconv.Itoa(len(lbclusters)) + " clusters needs updating") - hostsToCheck := make(map[string]lbhost.LBHost) var clustersToUpdate []*lbcluster.LBCluster + hostsToCheck := make(map[string]lbhost.Host) /* First, let's identify the hosts that have to be checked */ for i := range lbclusters { - pc := &lbclusters[i] - pc.Write_to_log("DEBUG", "DO WE HAVE TO UPDATE?") - if pc.Time_to_refresh() { - pc.Write_to_log("INFO", "Time to refresh the cluster") - pc.Get_list_hosts(hostsToCheck) - clustersToUpdate = append(clustersToUpdate, pc) + currentCluster := &lbclusters[i] + lg.Debug("DO WE HAVE TO UPDATE?") + if currentCluster.Time_to_refresh() { + lg.Info("Time to refresh the cluster") + currentCluster.GetHostList(hostsToCheck) + clustersToUpdate = append(clustersToUpdate, currentCluster) } } - if len(hostsToCheck) != 0 { - myChannel := make(chan lbhost.LBHost) + if len(hostsToCheck) > 0 { /* Now, let's go through the hosts, issuing the snmp call */ for _, hostValue := range hostsToCheck { - go func(myHost lbhost.LBHost) { - myHost.Snmp_req() - myChannel <- myHost + go func(myHost lbhost.Host) { + myHost.SNMPDiscovery() + hostCheckChannel <- myHost }(hostValue) } - lg.Debug("Let's start gathering the results") - for i := 0; i < len(hostsToCheck); i++ { - myNewHost := <-myChannel - hostsToCheck[myNewHost.Host_name] = myNewHost + lg.Debug("start gathering the results") + for hostChanData := range hostCheckChannel { + hostsToCheck[hostChanData.GetName()] = hostChanData } lg.Debug("All the hosts have been tested") - updateDNS = shouldUpdateDNS(config, hostname, &lg) + updateDNS = shouldUpdateDNS(config, hostname, lg) /* Finally, let's go through the aliases, selecting the best hosts*/ for _, pc := range clustersToUpdate { - pc.Write_to_log("DEBUG", "READY TO UPDATE THE CLUSTER") - if pc.FindBestHosts(hostsToCheck) { + lg.Debug("READY TO UPDATE THE CLUSTER") + isDNSUpdateValid, err := pc.FindBestHosts(hostsToCheck) + if err != nil { + log.Fatalf("Error while finding best hosts. error:%v", err) + } + if isDNSUpdateValid { if updateDNS { - pc.Write_to_log("DEBUG", "Should update dns is true") - pc.RefreshDNS(config.DNSManager, config.TsigKeyPrefix, config.TsigInternalKey, config.TsigExternalKey) + lg.Debug("Should update dns is true") + // todo: try to implement retry mechanismlbcluster/lbcluster_dns.go + pc.RefreshDNS(config.GetDNSManager(), config.GetTSIGKeyPrefix(), config.GetTSIGInternalKey(), config.GetTSIGExternalKey()) } else { - pc.Write_to_log("DEBUG", "should_update_dns false") + lg.Debug("should_update_dns false") } } else { - pc.Write_to_log("DEBUG", "FindBestHosts false") + lg.Debug("FindBestHosts false") } } } if updateDNS { - updateHeartbeat(config, hostname, &lg) + updateHeartbeat(config, hostname, lg) } lg.Debug("iteration done!") diff --git a/lbhost/lbhost.go b/lbhost/lbhost.go index 30c1312..37e6c61 100644 --- a/lbhost/lbhost.go +++ b/lbhost/lbhost.go @@ -1,27 +1,23 @@ package lbhost import ( - // "encoding/json" "fmt" - //"io/ioutil" - "github.com/reguero/go-snmplib" - //"math/rand" + "lb-experts/golbd/logger" + "lb-experts/golbd/model" "net" - "os" "regexp" "strconv" - "strings" "sync" - - // "net/http" - - // "sort" - // "strings" "time" + + "github.com/reguero/go-snmplib" ) -const TIMEOUT int = 10 -const OID string = ".1.3.6.1.4.1.96.255.1" +const ( + TIMEOUT int = 10 + OID string = ".1.3.6.1.4.1.96.255.1" + DefaultResponseInt = 100000 +) type LBHostTransportResult struct { Transport string @@ -31,111 +27,153 @@ type LBHostTransportResult struct { Response_error string } type LBHost struct { - Cluster_name string - Host_name string - Host_transports []LBHostTransportResult - Loadbalancing_username string - Loadbalancing_password string - LogFile string - logMu sync.Mutex - Debugflag bool -} - -func (self *LBHost) Snmp_req() { - - self.find_transports() - - for i, my_transport := range self.Host_transports { - my_transport.Response_int = 100000 - transport := my_transport.Transport - node_ip := my_transport.IP.String() - /* There is no need to put square brackets around the ipv6 addresses*/ - self.Write_to_log("DEBUG", "Checking the host "+node_ip+" with "+transport) - snmp, err := snmplib.NewSNMPv3(node_ip, self.Loadbalancing_username, "MD5", self.Loadbalancing_password, "NOPRIV", self.Loadbalancing_password, - time.Duration(TIMEOUT)*time.Second, 2) - if err != nil { - // Failed to create snmpgo.SNMP object - my_transport.Response_error = fmt.Sprintf("contacted node: error creating the snmp object: %v", err) - } else { - defer snmp.Close() - err = snmp.Discover() + ClusterConfig model.ClusterConfig + Host_name string + HostTransports []LBHostTransportResult + Logger logger.Logger + SnmpAgent DiscoveryAgent +} - if err != nil { - my_transport.Response_error = fmt.Sprintf("contacted node: error in the snmp discovery: %v", err) +type snmpDiscoveryResult struct { + hostIdx int + hostTransportResult LBHostTransportResult +} - } else { +type DiscoveryAgent interface { + Close() error + Discover() error + GetV3(oid snmplib.Oid) (interface{}, error) +} - oid, err := snmplib.ParseOid(OID) +func NewHostDiscoveryAgent(nodeIp string, clusterConfig model.ClusterConfig) (DiscoveryAgent, error) { + return snmplib.NewSNMPv3(nodeIp, clusterConfig.Loadbalancing_username, "MD5", + clusterConfig.Loadbalancing_password, "NOPRIV", clusterConfig.Loadbalancing_password, + time.Duration(TIMEOUT)*time.Second, 2) +} - if err != nil { - // Failed to parse Oids - my_transport.Response_error = fmt.Sprintf("contacted node: Error parsing the OID %v", err) +type Host interface { + GetName() string + SetName(name string) + SNMPDiscovery() + GetClusterConfig() *model.ClusterConfig + GetLoadForAlias(clusterName string) int + GetWorkingIPs() ([]net.IP, error) + GetAllIPs() ([]net.IP, error) + GetIps() ([]net.IP, error) + SetTransportPayload(transportPayloadList []LBHostTransportResult) + GetHostTransportPayloads() []LBHostTransportResult +} - } else { - pdu, err := snmp.GetV3(oid) +func NewLBHost(clusterConfig model.ClusterConfig, logger logger.Logger) Host { + return &LBHost{ + ClusterConfig: clusterConfig, + Logger: logger, + } +} - if err != nil { - my_transport.Response_error = fmt.Sprintf("contacted node: The getv3 gave the following error: %v ", err) +func (lh *LBHost) SetName(name string) { + lh.Host_name = name +} - } else { +func (lh *LBHost) GetName() string { + return lh.Host_name +} +func (lh *LBHost) GetClusterConfig() *model.ClusterConfig { + return &lh.ClusterConfig +} - self.Write_to_log("INFO", fmt.Sprintf("contacted node: transport: %v ip: %v - reply was %v", transport, node_ip, pdu)) +func (lh *LBHost) GetHostTransportPayloads() []LBHostTransportResult { + return lh.HostTransports +} - //var pduInteger int - switch t := pdu.(type) { - case int: - my_transport.Response_int = pdu.(int) - case string: - my_transport.Response_string = pdu.(string) - default: - my_transport.Response_error = fmt.Sprintf("The node returned an unexpected type %s in %v", t, pdu) - } - } - } - } - } - self.Host_transports[i] = my_transport +func (lh *LBHost) SetTransportPayload(transportPayloadList []LBHostTransportResult) { + lh.HostTransports = transportPayloadList +} +func (lh *LBHost) SNMPDiscovery() { + var wg sync.WaitGroup + lh.find_transports() + discoveryResultChan := make(chan snmpDiscoveryResult) + defer close(discoveryResultChan) + hostTransportResultList := make([]LBHostTransportResult, 0, len(lh.HostTransports)) + hostTransportResultList = append(hostTransportResultList, lh.HostTransports...) + for i, hostTransport := range lh.HostTransports { + wg.Add(1) + go func(idx int, hostTransport LBHostTransportResult) { + defer wg.Done() + lh.discoverNode(idx, hostTransport, discoveryResultChan) + }(i, hostTransport) } - - self.Write_to_log("DEBUG", "All the ips have been tested") - /*for _, my_transport := range self.Host_transports { - self.Write_to_log("INFO", fmt.Sprintf("%v", my_transport)) - }*/ + go func(discoveryResultChan <-chan snmpDiscoveryResult) { + for discoveryResultData := range discoveryResultChan { + hostTransportResultList[discoveryResultData.hostIdx] = discoveryResultData.hostTransportResult + } + }(discoveryResultChan) + wg.Wait() + lh.HostTransports = hostTransportResultList + lh.Logger.Debug("All the ips have been tested") } -func (self *LBHost) Write_to_log(level string, msg string) error { +func (lh *LBHost) discoverNode(hostTransportIdx int, hostTransport LBHostTransportResult, resultChan chan<- snmpDiscoveryResult) { + var snmpAgent DiscoveryAgent var err error - if level == "DEBUG" && !self.Debugflag { - //The debug messages should not appear - return nil + hostTransport.Response_int = DefaultResponseInt + nodeIp := hostTransport.IP.String() + lh.Logger.Debug("Checking the host " + nodeIp + " with " + hostTransport.Transport) + if lh.SnmpAgent == nil { + snmpAgent, err = NewHostDiscoveryAgent(nodeIp, lh.ClusterConfig) + if err != nil { + hostTransport.Response_error = fmt.Sprintf("contacted node: error creating the snmp object: %v", err) + } + } else { + snmpAgent = lh.SnmpAgent } - if !strings.HasSuffix(msg, "\n") { - msg += "\n" + if err == nil { + defer snmpAgent.Close() + err = snmpAgent.Discover() + if err != nil { + hostTransport.Response_error = fmt.Sprintf("contacted node: error in the snmp discovery: %v", err) + } else { + lh.setTransportResponse(snmpAgent, &hostTransport) + } } - timestamp := time.Now().Format(time.StampMilli) - msg = fmt.Sprintf("%s lbd[%d]: %s: cluster: %s node: %s %s", timestamp, os.Getpid(), level, self.Cluster_name, self.Host_name, msg) - self.logMu.Lock() - defer self.logMu.Unlock() + resultChan <- snmpDiscoveryResult{ + hostIdx: hostTransportIdx, + hostTransportResult: hostTransport, + } +} - f, err := os.OpenFile(self.LogFile, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0640) +func (lh *LBHost) setTransportResponse(snmpClient DiscoveryAgent, lbHostTransportResultPayload *LBHostTransportResult) { + oid, err := snmplib.ParseOid(OID) if err != nil { - return err + lbHostTransportResultPayload.Response_error = fmt.Sprintf("contacted node: Error parsing the OID %v", err) + return + } + pdu, err := snmpClient.GetV3(oid) + if err != nil { + lbHostTransportResultPayload.Response_error = fmt.Sprintf("contacted node: The getv3 gave the following error: %v ", err) + return + } + lh.Logger.Info(fmt.Sprintf("contacted node: transport: %v ip: %v - reply was %v", lbHostTransportResultPayload.Transport, lbHostTransportResultPayload.IP.String(), pdu)) + switch t := pdu.(type) { + case int: + lbHostTransportResultPayload.Response_int = pdu.(int) + case string: + lbHostTransportResultPayload.Response_string = pdu.(string) + default: + lbHostTransportResultPayload.Response_error = fmt.Sprintf("The node returned an unexpected type %s in %v", t, pdu) } - defer f.Close() - _, err = fmt.Fprintf(f, msg) - - return err } -func (self *LBHost) Get_load_for_alias(cluster_name string) int { +// todo: instead of polling try adhoc webhook updates +func (lh *LBHost) GetLoadForAlias(clusterName string) int { my_load := -200 - for _, my_transport := range self.Host_transports { + for _, my_transport := range lh.HostTransports { pduInteger := my_transport.Response_int - re := regexp.MustCompile(cluster_name + "=([0-9]+)") + re := regexp.MustCompile(clusterName + "=([0-9]+)") submatch := re.FindStringSubmatch(my_transport.Response_string) if submatch != nil { @@ -145,80 +183,75 @@ func (self *LBHost) Get_load_for_alias(cluster_name string) int { if (pduInteger > 0 && pduInteger < my_load) || (my_load < 0) { my_load = pduInteger } - self.Write_to_log("DEBUG", fmt.Sprintf("Possible load is %v", pduInteger)) + lh.Logger.Debug(fmt.Sprintf("Possible load is %v", pduInteger)) } - self.Write_to_log("DEBUG", fmt.Sprintf("THE LOAD IS %v, ", my_load)) + lh.Logger.Debug(fmt.Sprintf("THE LOAD IS %v, ", my_load)) return my_load } -func (self *LBHost) Get_working_IPs() ([]net.IP, error) { +func (lh *LBHost) GetWorkingIPs() ([]net.IP, error) { var my_ips []net.IP - for _, my_transport := range self.Host_transports { + for _, my_transport := range lh.HostTransports { if (my_transport.Response_int > 0) && (my_transport.Response_error == "") { my_ips = append(my_ips, my_transport.IP) } } - self.Write_to_log("INFO", fmt.Sprintf("The ips for this host are %v", my_ips)) + lh.Logger.Info(fmt.Sprintf("The ips for this host are %v", my_ips)) return my_ips, nil } -func (self *LBHost) Get_all_IPs() ([]net.IP, error) { +func (lh *LBHost) GetAllIPs() ([]net.IP, error) { var my_ips []net.IP - for _, my_transport := range self.Host_transports { + for _, my_transport := range lh.HostTransports { my_ips = append(my_ips, my_transport.IP) } - self.Write_to_log("INFO", fmt.Sprintf("All ips for this host are %v", my_ips)) + lh.Logger.Info(fmt.Sprintf("All ips for this host are %v", my_ips)) return my_ips, nil } -func (self *LBHost) Get_Ips() ([]net.IP, error) { - +func (lh *LBHost) GetIps() ([]net.IP, error) { var ips []net.IP - var err error - re := regexp.MustCompile(".*no such host") - net.DefaultResolver.StrictErrors = true - for i := 0; i < 3; i++ { - self.Write_to_log("INFO", "Getting the ip addresses") - ips, err = net.LookupIP(self.Host_name) + lh.Logger.Info("Getting the ip addresses") + ips, err = net.LookupIP(lh.Host_name) if err == nil { return ips, nil } - self.Write_to_log("WARNING", fmt.Sprintf("LookupIP: %v has incorrect or missing IP address (%v) ", self.Host_name, err)) + lh.Logger.Info(fmt.Sprintf("LookupIP: %v has incorrect or missing IP address (%v) ", lh.Host_name, err)) submatch := re.FindStringSubmatch(err.Error()) if submatch != nil { - self.Write_to_log("INFO", "There is no need to retry this error") + lh.Logger.Info("There is no need to retry this error") return nil, err } } - self.Write_to_log("ERROR", "After several retries, we couldn't get the ips!. Let's try with partial results") + lh.Logger.Error("After several retries, we couldn't get the ips!. Let's try with partial results") net.DefaultResolver.StrictErrors = false - ips, err = net.LookupIP(self.Host_name) + ips, err = net.LookupIP(lh.Host_name) if err != nil { - self.Write_to_log("ERROR", fmt.Sprintf("It didn't work :(. This node will be ignored during this evaluation: %v", err)) + lh.Logger.Error(fmt.Sprintf("It didn't work :(. This node will be ignored during this evaluation: %v", err)) } return ips, err } -func (self *LBHost) find_transports() { - self.Write_to_log("DEBUG", "Let's find the ips behind this host") +func (lh *LBHost) find_transports() { + lh.Logger.Debug("Let's find the ips behind this host") - ips, _ := self.Get_Ips() + ips, _ := lh.GetIps() for _, ip := range ips { transport := "udp" // If there is an IPv6 address use udp6 transport if ip.To4() == nil { transport = "udp6" } - self.Host_transports = append(self.Host_transports, LBHostTransportResult{Transport: transport, - Response_int: 100000, Response_string: "", IP: ip, + lh.HostTransports = append(lh.HostTransports, LBHostTransportResult{Transport: transport, + Response_int: DefaultResponseInt, Response_string: "", IP: ip, Response_error: ""}) } diff --git a/logger/lbcluster_log.go b/logger/lbcluster_log.go new file mode 100644 index 0000000..61bf951 --- /dev/null +++ b/logger/lbcluster_log.go @@ -0,0 +1,194 @@ +package logger + +import ( + "fmt" + "log/syslog" + "os" + "regexp" + "strings" + "sync" + "time" +) + +const ( + DefaultLbdTag = "lbd" + logLevelInfo = "INFO" + logLevelDebug = "DEBUG" + logLevelWarning = "WARNING" + logLevelError = "ERROR" +) + +//Log struct for the log +type Log struct { + logWriter *syslog.Writer + shouldWriteToSTD bool + isDebugAllowed bool + filePath string + logMu sync.Mutex + logStartTime time.Time + isSnapShotEnabled bool + snapShotCycleTime time.Duration + logFileBasePath string + logFileExtension string + snapshotCounter int +} + +//Logger struct for the Logger interface +type Logger interface { + EnableDebugMode() + EnableWriteToSTd() + StartSnapshot(d time.Duration) + GetLogFilePath() string + Info(s string) + Warning(s string) + Debug(s string) + Error(s string) +} + +func NewLoggerFactory(logFilePath string) (Logger, error) { + log, err := syslog.New(syslog.LOG_NOTICE, DefaultLbdTag) + if err != nil { + return nil, err + } + if strings.EqualFold(logFilePath, "") { + return nil, fmt.Errorf("empty log file path") + } + _, err = os.OpenFile(logFilePath, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0640) + if err != nil { + return nil, err + } + basePath, extension, err := getLogFilePathAndExtension(logFilePath) + if err != nil { + return nil, fmt.Errorf("error while validating log file path. error: %v", err) + } + + if !isLogFilePathValid(basePath, extension) { + return nil, fmt.Errorf("invalid log file path. log path: %s", logFilePath) + } + return &Log{ + logWriter: log, + logStartTime: time.Now(), + filePath: logFilePath, + logFileBasePath: basePath, + logFileExtension: extension, + }, + nil +} + +func isLogFilePathValid(basePath, extension string) bool { + return !strings.EqualFold(basePath, "") && !strings.EqualFold(extension, "") +} + +func getLogFilePathAndExtension(logFilePath string) (string, string, error) { + matcher := regexp.MustCompile(`(.*)\.(.*)`) + matchGroups := matcher.FindAllStringSubmatch(logFilePath, -1) + if len(matchGroups) == 0 || len(matchGroups[0]) < 3 { + return "", "", fmt.Errorf("log file path is not in the right format. path: %s", logFilePath) + } + return matchGroups[0][1], matchGroups[0][2], nil +} + +func (l *Log) EnableDebugMode() { + l.isDebugAllowed = true +} + +func (l *Log) EnableWriteToSTd() { + l.shouldWriteToSTD = true +} + +func (l *Log) StartSnapshot(d time.Duration) { + if !l.isSnapShotEnabled { + l.isSnapShotEnabled = true + l.snapShotCycleTime = d + l.startSnapShot() + } +} + +func (l *Log) GetLogFilePath() string { + return l.filePath +} +func (l *Log) startSnapShot() { + l.logStartTime = time.Now() + l.filePath = fmt.Sprintf("%s.%d.%s", l.logFileBasePath, l.snapshotCounter, l.logFileExtension) + l.snapshotCounter += 1 +} + +func (l *Log) shouldStartNewSnapshot() bool { + if l.isSnapShotEnabled { + return time.Now().Sub(l.logStartTime) >= l.snapShotCycleTime + } + return false +} + +//Info write as Info +func (l *Log) Info(s string) { + if l.logWriter != nil { + _ = l.logWriter.Info(s) + } + if l.shouldWriteToSTD || (l.filePath != "") { + l.write(fmt.Sprintf("%s: %s", logLevelInfo, s)) + } +} + +//Warning write as Warning +func (l *Log) Warning(s string) { + if l.logWriter != nil { + _ = l.logWriter.Warning(s) + } + if l.shouldWriteToSTD || (l.filePath != "") { + l.write(fmt.Sprintf("%s: %s", logLevelWarning, s)) + } +} + +//Debug write as Debug +func (l *Log) Debug(s string) { + if l.isDebugAllowed { + if l.logWriter != nil { + _ = l.logWriter.Debug(s) + } + if l.shouldWriteToSTD || (l.filePath != "") { + l.write(fmt.Sprintf("%s: %s", logLevelDebug, s)) + } + } +} + +//Error write as Error +func (l *Log) Error(s string) { + if l.logWriter != nil { + _ = l.logWriter.Err(s) + } + if l.shouldWriteToSTD || (l.filePath != "") { + l.write(fmt.Sprintf("%s: %s", logLevelError, s)) + } +} + +func (l *Log) write(s string) { + tag := "lbd" + nl := "" + if !strings.HasSuffix(s, "\n") { + nl = "\n" + } + timestamp := time.Now().Format(time.StampMilli) + msg := fmt.Sprintf("%s %s[%d]: %s%s", + timestamp, + tag, os.Getpid(), s, nl) + if l.shouldWriteToSTD { + fmt.Printf(msg) + } + l.logMu.Lock() + defer l.logMu.Unlock() + + if l.shouldStartNewSnapshot() { + l.startSnapShot() + } + f, err := os.OpenFile(l.filePath, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0640) + if err != nil { + _, _ = fmt.Fprintf(os.Stderr, "error while opening the log file. error: %v", err) + return + } + defer f.Close() + _, err = fmt.Fprintf(f, msg) + if err != nil { + _, _ = fmt.Fprintf(os.Stderr, "error while writing to the log file. error: %v", err) + } +} diff --git a/metric/metric_controller.go b/metric/metric_controller.go new file mode 100644 index 0000000..39254d2 --- /dev/null +++ b/metric/metric_controller.go @@ -0,0 +1,55 @@ +package metric + +import ( + "encoding/json" + "fmt" + "net/http" + "os" + "strings" +) + +const defaultPort = ":8081" + +type serverCtx struct { + metricLogic Logic +} + +func NewMetricServer(metricDirectoryPath string) error { + currentHostname, err := os.Hostname() + if err != nil { + return fmt.Errorf("unable to fetch current host name. error: %v", err) + } + logic := NewLogic(metricDirectoryPath, currentHostname) + server := serverCtx{ + metricLogic: logic, + } + http.HandleFunc("/metric", server.fetchHostMetric) + err = http.ListenAndServe(defaultPort, nil) + if err != nil { + return err + } + return nil +} + +func (s *serverCtx) fetchHostMetric(resp http.ResponseWriter, req *http.Request) { + if !strings.EqualFold(req.Method, http.MethodGet) { + resp.WriteHeader(http.StatusBadRequest) + resp.Write([]byte("request method not supported")) + return + } + hostMetric, err := s.metricLogic.ReadHostMetric() + if err != nil { + resp.WriteHeader(http.StatusInternalServerError) + resp.Write([]byte(fmt.Sprintf("error while reading host metric. error: %v", err))) + return + } + responseData, err := json.Marshal(hostMetric) + if err != nil { + resp.WriteHeader(http.StatusInternalServerError) + resp.Write([]byte(fmt.Sprintf("error while marshalling data. error:%v", err))) + return + } + resp.Header().Set("content-type", "application/json") + resp.WriteHeader(http.StatusOK) + resp.Write(responseData) +} diff --git a/metric/metric_logic.go b/metric/metric_logic.go new file mode 100644 index 0000000..8a1fca3 --- /dev/null +++ b/metric/metric_logic.go @@ -0,0 +1,118 @@ +package metric + +import ( + "encoding/csv" + "fmt" + "os" + "strings" + "sync" + "time" +) + +const ( + defaultMetricFileNamePrefix = "metric" + csvExtension = "csv" + defaultCycleDuration = 24 * time.Hour +) + +type Logic interface { + GetFilePath() string + ReadHostMetric() (HostMetric, error) + WriteRecord(property Property) error +} +type BizLogic struct { + isCurrentHostMaster bool + hostName string + filePath string + dirPath string + rwLocker sync.RWMutex + recordStartTime time.Time +} + +func NewLogic(dirPath string, hostName string) Logic { + logic := &BizLogic{ + dirPath: dirPath, + hostName: hostName, + } + logic.initNewFilePath() + return logic +} + +func (bl *BizLogic) ReadHostMetric() (HostMetric, error) { + propertyList, err := bl.readAllRecords() + if err != nil { + return HostMetric{}, err + } + return HostMetric{ + Name: bl.hostName, + PropertyList: propertyList, + }, nil +} + +func (bl *BizLogic) GetFilePath() string { + return bl.filePath +} + +func (bl *BizLogic) readAllRecords() ([]Property, error) { + var propertyResultList []Property + bl.rwLocker.RLock() + defer bl.rwLocker.RUnlock() + fp, err := os.Open(bl.filePath) + defer fp.Close() + if err != nil { + return propertyResultList, fmt.Errorf("error while reading metric. error: %v", err) + } + csvReader := csv.NewReader(fp) + + recordList, err := csvReader.ReadAll() + if err != nil { + return propertyResultList, fmt.Errorf("error while reading metric. error: %v", err) + } + for _, record := range recordList { + property, err := parseProperty(record) + if err != nil { + return propertyResultList, fmt.Errorf("error while parsing a property record. error: %s", err) + } + propertyResultList = append(propertyResultList, property) + } + return propertyResultList, nil +} + +func (bl *BizLogic) WriteRecord(property Property) error { + if err := property.validate(); err != nil { + return err + } + bl.rwLocker.Lock() + defer bl.rwLocker.Unlock() + if bl.shouldUpdateFilePath() { + bl.initNewFilePath() + } + property.RoundTripDuration = property.RoundTripEndTime.Sub(property.RoundTripStartTime) + if strings.EqualFold(bl.filePath, "") { + return fmt.Errorf("filePath cannot be empty") + } + + fp, err := os.OpenFile(bl.filePath, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0640) + defer fp.Close() + if err != nil { + return fmt.Errorf("error while recording metric. error: %v", err) + } + csvWriter := csv.NewWriter(fp) + err = csvWriter.Write(property.marshalToCSV()) + if err != nil { + return err + } + csvWriter.Flush() + + return nil +} + +func (bl *BizLogic) shouldUpdateFilePath() bool { + return time.Now().Sub(bl.recordStartTime) > defaultCycleDuration +} + +func (bl *BizLogic) initNewFilePath() { + curTime := time.Now() + bl.recordStartTime = curTime + bl.filePath = fmt.Sprintf("%s%s_%d_%d_%d.%s", bl.dirPath, defaultMetricFileNamePrefix, curTime.Year(), curTime.Month(), curTime.Day(), csvExtension) +} diff --git a/metric/metric_model.go b/metric/metric_model.go new file mode 100644 index 0000000..eea33bd --- /dev/null +++ b/metric/metric_model.go @@ -0,0 +1,67 @@ +package metric + +import ( + "fmt" + "time" +) + +type HostMetric struct { + Name string `json:"name" csv:"name" rw:"r"` + PropertyList []Property `json:"property_list" csv:"property_list" rw:"r"` +} + +type Property struct { + RoundTripStartTime time.Time `json:"start_time" csv:"start_time" rw:"r"` + RoundTripEndTime time.Time `json:"end_time" csv:"end_time" rw:"r"` + RoundTripDuration time.Duration `json:"round_trip_duration"` +} + +type ClusterMetric struct { + MasterName string + HostMetricList []HostMetric `json:"host_metric_list" csv:"host_metric_list" rw:"r"` +} + +func (p Property) validate() error { + var invalidFields []string + if p.RoundTripStartTime.IsZero() { + invalidFields = append(invalidFields, "start_time") + } + if p.RoundTripEndTime.IsZero() { + invalidFields = append(invalidFields, "end_time") + } + if len(invalidFields) != 0 { + return fmt.Errorf("following fields are not valid. %v", invalidFields) + } + return nil +} + +func (p Property) marshalToCSV() []string { + return []string{ + p.RoundTripStartTime.Format(time.RFC3339), + p.RoundTripEndTime.Format(time.RFC3339), + p.RoundTripDuration.String(), + } +} + +func parseProperty(csvRecord []string) (Property, error) { + if len(csvRecord) < 3 { + return Property{}, fmt.Errorf("insufficient columns. column count %d", len(csvRecord)) + } + parsedStartTime, err := time.Parse(time.RFC3339, csvRecord[0]) + if err != nil { + return Property{}, err + } + parsedEndTime, err := time.Parse(time.RFC3339, csvRecord[1]) + if err != nil { + return Property{}, err + } + parsedDuration, err := time.ParseDuration(csvRecord[2]) + if err != nil { + return Property{}, err + } + return Property{ + RoundTripStartTime: parsedStartTime, + RoundTripEndTime: parsedEndTime, + RoundTripDuration: parsedDuration, + }, nil +} diff --git a/model/cluster_config.go b/model/cluster_config.go new file mode 100644 index 0000000..3b3e2f0 --- /dev/null +++ b/model/cluster_config.go @@ -0,0 +1,7 @@ +package model + +type ClusterConfig struct { + Cluster_name string + Loadbalancing_username string + Loadbalancing_password string +} diff --git a/tests/aliasload_test.go b/tests/aliasload_test.go index bb45d3d..847cf75 100644 --- a/tests/aliasload_test.go +++ b/tests/aliasload_test.go @@ -1,49 +1,54 @@ package main_test import ( + "lb-experts/golbd/lbhost" + "os" "reflect" "testing" - - "gitlab.cern.ch/lb-experts/golbd/lbhost" ) -//Function TestGetLoadHosts tests the function Get_load_for_alias +//Function TestGetLoadHosts tests the function GetLoadForAlias func TestGetLoadHosts(t *testing.T) { - hosts := []lbhost.LBHost{ + hosts := []lbhost.Host{ getHost("lxplus132.cern.ch", 7, ""), getHost("lxplus132.cern.ch", 0, "blabla.cern.ch=179,blablabla2.cern.ch=4"), getHost("toto132.lxplus.cern.ch", 42, ""), getHost("toto132.lxplus.cern.ch", 0, "blabla.subdo.cern.ch=179,blablabla2.subdo.cern.ch=4"), } - expectedhost0 := hosts[0].Host_transports[0].Response_int - //expectedhost1 := hosts[1].Host_transports[0].Response_int - expectedhost2 := hosts[2].Host_transports[0].Response_int - //expectedhost3 := hosts[3].Host_transports[0].Response_int + expectedhost0 := hosts[0].GetHostTransportPayloads()[0].Response_int + //expectedhost1 := hosts[1].HostTransports[0].Response_int + expectedhost2 := hosts[2].GetHostTransportPayloads()[0].Response_int + //expectedhost3 := hosts[3].HostTransports[0].Response_int - if !reflect.DeepEqual(hosts[0].Get_load_for_alias(hosts[0].Cluster_name), expectedhost0) { - t.Errorf(" got\n%v\nexpected\n%v", hosts[0].Get_load_for_alias(hosts[0].Cluster_name), expectedhost0) + if !reflect.DeepEqual(hosts[0].GetLoadForAlias(hosts[0].GetClusterConfig().Cluster_name), expectedhost0) { + t.Errorf(" got\n%v\nexpected\n%v", hosts[0].GetLoadForAlias(hosts[0].GetClusterConfig().Cluster_name), expectedhost0) + } + if !reflect.DeepEqual(hosts[1].GetLoadForAlias(hosts[1].GetClusterConfig().Cluster_name), 0) { + t.Errorf(" got\n%v\nexpected\n%v", hosts[1].GetLoadForAlias(hosts[1].GetClusterConfig().Cluster_name), 0) } - if !reflect.DeepEqual(hosts[1].Get_load_for_alias(hosts[1].Cluster_name), 0) { - t.Errorf(" got\n%v\nexpected\n%v", hosts[1].Get_load_for_alias(hosts[1].Cluster_name), 0) + if !reflect.DeepEqual(hosts[1].GetLoadForAlias("blabla.cern.ch"), 179) { + t.Errorf(" got\n%v\nexpected\n%v", hosts[1].GetLoadForAlias("blabla.cern.ch"), 179) } - if !reflect.DeepEqual(hosts[1].Get_load_for_alias("blabla.cern.ch"), 179) { - t.Errorf(" got\n%v\nexpected\n%v", hosts[1].Get_load_for_alias("blabla.cern.ch"), 179) + if !reflect.DeepEqual(hosts[1].GetLoadForAlias("blablabla2.cern.ch"), 4) { + t.Errorf(" got\n%v\nexpected\n%v", hosts[1].GetLoadForAlias("blablabla2.cern.ch"), 4) } - if !reflect.DeepEqual(hosts[1].Get_load_for_alias("blablabla2.cern.ch"), 4) { - t.Errorf(" got\n%v\nexpected\n%v", hosts[1].Get_load_for_alias("blablabla2.cern.ch"), 4) + if !reflect.DeepEqual(hosts[2].GetLoadForAlias(hosts[2].GetClusterConfig().Cluster_name), expectedhost2) { + t.Errorf(" got\n%v\nexpected\n%v", hosts[2].GetLoadForAlias(hosts[2].GetClusterConfig().Cluster_name), expectedhost2) } - if !reflect.DeepEqual(hosts[2].Get_load_for_alias(hosts[2].Cluster_name), expectedhost2) { - t.Errorf(" got\n%v\nexpected\n%v", hosts[2].Get_load_for_alias(hosts[2].Cluster_name), expectedhost2) + if !reflect.DeepEqual(hosts[2].GetLoadForAlias("toto.subdo.cern.ch"), expectedhost2) { + t.Errorf(" got\n%v\nexpected\n%v", hosts[2].GetLoadForAlias("toto.subdo.cern.ch"), expectedhost2) } - if !reflect.DeepEqual(hosts[2].Get_load_for_alias("toto.subdo.cern.ch"), expectedhost2) { - t.Errorf(" got\n%v\nexpected\n%v", hosts[2].Get_load_for_alias("toto.subdo.cern.ch"), expectedhost2) + if !reflect.DeepEqual(hosts[3].GetLoadForAlias("blabla.subdo.cern.ch"), 179) { + t.Errorf(" got\n%v\nexpected\n%v", hosts[3].GetLoadForAlias("blabla.subdo.cern.ch"), 179) } - if !reflect.DeepEqual(hosts[3].Get_load_for_alias("blabla.subdo.cern.ch"), 179) { - t.Errorf(" got\n%v\nexpected\n%v", hosts[3].Get_load_for_alias("blabla.subdo.cern.ch"), 179) + if !reflect.DeepEqual(hosts[3].GetLoadForAlias("blablabla2.subdo.cern.ch"), 4) { + t.Errorf(" got\n%v\nexpected\n%v", hosts[3].GetLoadForAlias("blablabla2.subdo.cern.ch"), 4) } - if !reflect.DeepEqual(hosts[3].Get_load_for_alias("blablabla2.subdo.cern.ch"), 4) { - t.Errorf(" got\n%v\nexpected\n%v", hosts[3].Get_load_for_alias("blablabla2.subdo.cern.ch"), 4) + err := os.Remove("sample.log") + if err != nil { + t.Fail() + t.Errorf("error deleting file.error %v", err) } } diff --git a/tests/apply_metric_internal_test.go b/tests/apply_metric_internal_test.go index 284e590..10064a6 100644 --- a/tests/apply_metric_internal_test.go +++ b/tests/apply_metric_internal_test.go @@ -3,6 +3,7 @@ package main_test import ( "fmt" "net" + "os" "reflect" "testing" ) @@ -49,4 +50,9 @@ func TestEvaluateMetric(t *testing.T) { if !reflect.DeepEqual(c.Time_of_last_evaluation, expected_time_of_last_evaluation) { t.Errorf("e.apply_metric: c.Time_of_last_evaluation: got\n%v\nexpected\n%v", c.Time_of_last_evaluation, expected_time_of_last_evaluation) } + err := os.Remove("sample.log") + if err != nil { + t.Fail() + t.Errorf("error deleting file.error %v", err) + } } diff --git a/tests/evaluate_hosts_internal_test.go b/tests/evaluate_hosts_internal_test.go index 4387b8e..7185e9c 100644 --- a/tests/evaluate_hosts_internal_test.go +++ b/tests/evaluate_hosts_internal_test.go @@ -1,13 +1,66 @@ package main_test import ( + "fmt" + "lb-experts/golbd/lbcluster" + "lb-experts/golbd/lbhost" + "lb-experts/golbd/logger" + "lb-experts/golbd/model" "net" + "os" "reflect" "testing" - - "gitlab.cern.ch/lb-experts/golbd/lbcluster" + "time" ) +type mockHost struct { +} + +func (m mockHost) GetHostTransportPayloads() []lbhost.LBHostTransportResult { + panic("implement me") +} + +func (m mockHost) SetName(name string) { + panic("implement me") +} + +func (m mockHost) SetTransportPayload(transportPayloadList []lbhost.LBHostTransportResult) { + panic("implement me") +} + +func (m mockHost) GetName() string { + panic("implement me") +} + +func (m mockHost) SNMPDiscovery() { + panic("implement me") +} + +func (m mockHost) GetClusterConfig() *model.ClusterConfig { + panic("implement me") +} + +func (m mockHost) GetLoadForAlias(clusterName string) int { + return 0 +} + +func (m mockHost) GetWorkingIPs() ([]net.IP, error) { + return []net.IP{}, fmt.Errorf("sample error") +} + +func (m mockHost) GetAllIPs() ([]net.IP, error) { + panic("implement me") +} + +func (m mockHost) GetIps() ([]net.IP, error) { + time.Sleep(5 * time.Second) // simulating a network request + return []net.IP{}, nil +} + +func NewMockHost() lbhost.Host { + return &mockHost{} +} + func compareIPs(t *testing.T, source, target []net.IP) { found := map[string]bool{} @@ -67,4 +120,25 @@ func TestEvaluateHosts(t *testing.T) { if !reflect.DeepEqual(c.Time_of_last_evaluation, expectedTimeOfLastEvaluation) { t.Errorf("e.evaluate_hosts: c.Time_of_last_evaluation: got\n%v\nexpected\n%v", c.Time_of_last_evaluation, expectedTimeOfLastEvaluation) } + err := os.Remove("sample.log") + if err != nil { + t.Fail() + t.Errorf("error deleting file.error %v", err) + } +} + +func TestEvaluateHostsConcurrency(t *testing.T) { + mockHostMap := make(map[string]lbhost.Host) + mockHostMap["sampleHost"] = NewMockHost() + logger, _ := logger.NewLoggerFactory("sample.log") + cluster := lbcluster.LBCluster{Slog: logger} + cluster.Host_metric_table = map[string]lbcluster.Node{"sampleHost": {HostName: "sampleHost"}} + startTime := time.Now() + cluster.EvaluateHosts(mockHostMap) + endTime := time.Now() + if endTime.Sub(startTime) > 6*time.Second { + t.Fail() + t.Errorf("concurrent job not running properly. expDuration:%v, actualDuration:%v", 6, endTime.Sub(startTime)) + } + os.Remove("sample.log") } diff --git a/tests/find_best_hosts_test.go b/tests/find_best_hosts_test.go index 2ccc939..745e348 100644 --- a/tests/find_best_hosts_test.go +++ b/tests/find_best_hosts_test.go @@ -2,11 +2,12 @@ package main_test import ( "net" + "os" "reflect" "testing" "time" - "gitlab.cern.ch/lb-experts/golbd/lbcluster" + "lb-experts/golbd/lbcluster" ) func getExpectedHostMetric() map[string]lbcluster.Node { @@ -27,8 +28,11 @@ func TestFindBestHosts(t *testing.T) { expected_host_metric_table := getExpectedHostMetric() expected_current_best_ips := []net.IP{net.ParseIP("2001:1458:d00:2c::100:a6"), net.ParseIP("188.184.108.98"), net.ParseIP("2001:1458:d00:32::100:51"), net.ParseIP("188.184.116.81")} - - if !c.FindBestHosts(hosts_to_check) { + isDNSUpdateValid, err := c.FindBestHosts(hosts_to_check) + if err != nil { + t.Errorf("Error while finding best hosts. error:%v", err) + } + if !isDNSUpdateValid { t.Errorf("e.Find_best_hosts: returned false, expected true") } if !reflect.DeepEqual(c.Host_metric_table, expected_host_metric_table) { @@ -40,6 +44,11 @@ func TestFindBestHosts(t *testing.T) { if c.Time_of_last_evaluation.Add(time.Duration(2) * time.Second).Before(time.Now()) { t.Errorf("e.Find_best_hosts: c.Time_of_last_evaluation: got\n%v\ncurrent time\n%v", c.Time_of_last_evaluation, time.Now()) } + err = os.Remove("sample.log") + if err != nil { + t.Fail() + t.Errorf("error deleting file.error %v", err) + } } func TestFindBestHostsNoValidHostCmsfrontier(t *testing.T) { @@ -53,8 +62,11 @@ func TestFindBestHostsNoValidHostCmsfrontier(t *testing.T) { expected_current_best_ips := []net.IP{} expected_time_of_last_evaluation := c.Time_of_last_evaluation - - if c.FindBestHosts(bad_hosts_to_check) { + isDNSUpdateValid, err := c.FindBestHosts(bad_hosts_to_check) + if err != nil { + t.Errorf("Error while finding best hosts. error:%v", err) + } + if isDNSUpdateValid { t.Errorf("e.Find_best_hosts: returned true, expected false") } if !reflect.DeepEqual(c.Current_best_ips, expected_current_best_ips) { @@ -74,8 +86,11 @@ func TestFindBestHostsNoValidHostMinino(t *testing.T) { bad_hosts_to_check := getBadHostsToCheck(c) expected_current_best_ips := []net.IP{} - - if !c.FindBestHosts(bad_hosts_to_check) { + isDNSUpdateValid, err := c.FindBestHosts(bad_hosts_to_check) + if err != nil { + t.Errorf("Error while finding best hosts. error:%v", err) + } + if !isDNSUpdateValid { t.Errorf("e.Find_best_hosts: returned false, expected true") } if !reflect.DeepEqual(c.Current_best_ips, expected_current_best_ips) { @@ -95,8 +110,11 @@ func TestFindBestHostsNoValidHostMinimum(t *testing.T) { bad_hosts_to_check := getBadHostsToCheck(c) not_expected_current_best_ips := []net.IP{} - - if !c.FindBestHosts(bad_hosts_to_check) { + isDNSUpdateValid, err := c.FindBestHosts(bad_hosts_to_check) + if err != nil { + t.Errorf("Error while finding best hosts. error:%v", err) + } + if !isDNSUpdateValid { t.Errorf("e.Find_best_hosts: returned false, expected true") } if reflect.DeepEqual(c.Current_best_ips, not_expected_current_best_ips) { diff --git a/tests/get_list_hosts_test.go b/tests/get_list_hosts_test.go index eab3360..f364b95 100644 --- a/tests/get_list_hosts_test.go +++ b/tests/get_list_hosts_test.go @@ -1,124 +1,120 @@ package main_test import ( + "lb-experts/golbd/lbcluster" + "lb-experts/golbd/lbhost" + "lb-experts/golbd/logger" + "lb-experts/golbd/model" "net" + "os" "reflect" "testing" - - "gitlab.cern.ch/lb-experts/golbd/lbcluster" - "gitlab.cern.ch/lb-experts/golbd/lbhost" ) func TestGetListHostsOne(t *testing.T) { c := getTestCluster("test01.cern.ch") - - expected := map[string]lbhost.LBHost{ - "lxplus041.cern.ch": lbhost.LBHost{Cluster_name: c.Cluster_name, - Host_name: "lxplus041.cern.ch", - Loadbalancing_username: c.Loadbalancing_username, - Loadbalancing_password: c.Loadbalancing_password, - LogFile: c.Slog.TofilePath, - Debugflag: c.Slog.Debugflag, - }, - "monit-kafkax-17be060b0d.cern.ch": lbhost.LBHost{Cluster_name: c.Cluster_name, - Host_name: "monit-kafkax-17be060b0d.cern.ch", - Loadbalancing_username: c.Loadbalancing_username, - Loadbalancing_password: c.Loadbalancing_password, - LogFile: c.Slog.TofilePath, - Debugflag: c.Slog.Debugflag, - }, - "lxplus132.cern.ch": lbhost.LBHost{Cluster_name: c.Cluster_name, - Host_name: "lxplus132.cern.ch", - Loadbalancing_username: c.Loadbalancing_username, - Loadbalancing_password: c.Loadbalancing_password, - LogFile: c.Slog.TofilePath, - Debugflag: c.Slog.Debugflag, - }, - "lxplus133.subdo.cern.ch": lbhost.LBHost{Cluster_name: c.Cluster_name, - Host_name: "lxplus133.subdo.cern.ch", - Loadbalancing_username: c.Loadbalancing_username, - Loadbalancing_password: c.Loadbalancing_password, - LogFile: c.Slog.TofilePath, - Debugflag: c.Slog.Debugflag, - }, - "lxplus130.cern.ch": lbhost.LBHost{Cluster_name: c.Cluster_name, - Host_name: "lxplus130.cern.ch", - Loadbalancing_username: c.Loadbalancing_username, - Loadbalancing_password: c.Loadbalancing_password, - LogFile: c.Slog.TofilePath, - Debugflag: c.Slog.Debugflag, - }, + host1 := lbhost.NewLBHost(c.ClusterConfig, c.Slog) + host1.SetName("lxplus041.cern.ch") + host2 := lbhost.NewLBHost(c.ClusterConfig, c.Slog) + host2.SetName("monit-kafkax-17be060b0d.cern.ch") + host3 := lbhost.NewLBHost(c.ClusterConfig, c.Slog) + host3.SetName("lxplus132.cern.ch") + host4 := lbhost.NewLBHost(c.ClusterConfig, c.Slog) + host4.SetName("lxplus041.cern.ch") + host5 := lbhost.NewLBHost(c.ClusterConfig, c.Slog) + host5.SetName("lxplus041.cern.ch") + expected := map[string]lbhost.Host{ + "lxplus041.cern.ch": host1, + "monit-kafkax-17be060b0d.cern.ch": host2, + "lxplus132.cern.ch": host3, + "lxplus133.subdo.cern.ch": host4, + "lxplus130.cern.ch": host5, } - - hosts_to_check := make(map[string]lbhost.LBHost) - c.Get_list_hosts(hosts_to_check) - if !reflect.DeepEqual(hosts_to_check, expected) { - t.Errorf("e.Get_list_hosts: got\n%v\nexpected\n%v", hosts_to_check, expected) + hostsToCheck := make(map[string]lbhost.Host) + c.GetHostList(hostsToCheck) + if len(hostsToCheck) != len(expected) { + t.Errorf("length mismatch. expected :%v, actual:%v", len(expected), len(hostsToCheck)) + } + for hostName, actualHost := range hostsToCheck { + expHost := expected[hostName] + if !reflect.DeepEqual(expHost.GetClusterConfig(), actualHost.GetClusterConfig()) { + t.Errorf("mismatch in cluster config. expected:%v,actual:%v", expHost.GetClusterConfig(), actualHost.GetClusterConfig()) + } } + os.Remove("sample.log") } func TestGetListHostsTwo(t *testing.T) { - lg := lbcluster.Log{Stdout: true, Debugflag: false} + logger, _ := logger.NewLoggerFactory("sample.log") + logger.EnableWriteToSTd() clusters := []lbcluster.LBCluster{ - {Cluster_name: "test01.cern.ch", + {ClusterConfig: model.ClusterConfig{ + Cluster_name: "test01.cern.ch", Loadbalancing_username: "loadbalancing", Loadbalancing_password: "zzz123", - Host_metric_table: map[string]lbcluster.Node{"lxplus142.cern.ch": lbcluster.Node{}, "lxplus177.cern.ch": lbcluster.Node{}}, - Parameters: lbcluster.Params{Behaviour: "mindless", Best_hosts: 2, External: true, Metric: "cmsfrontier", Polling_interval: 6, Statistics: "long"}, + }, Host_metric_table: map[string]lbcluster.Node{"lxplus142.cern.ch": lbcluster.Node{}, "lxplus177.cern.ch": lbcluster.Node{}}, + Parameters: lbcluster.Params{Behaviour: "mindless", Best_hosts: 2, External: true, Metric: "cmsfrontier", Polling_interval: 6, Statistics: "long"}, //Time_of_last_evaluation time.Time Current_best_ips: []net.IP{}, Previous_best_ips_dns: []net.IP{}, - Slog: &lg, + Slog: logger, Current_index: 0}, - lbcluster.LBCluster{Cluster_name: "test02.cern.ch", + lbcluster.LBCluster{ClusterConfig: model.ClusterConfig{ + Cluster_name: "test02.cern.ch", Loadbalancing_username: "loadbalancing", Loadbalancing_password: "zzz123", - Host_metric_table: map[string]lbcluster.Node{"lxplus013.cern.ch": lbcluster.Node{}, "lxplus177.cern.ch": lbcluster.Node{}, "lxplus025.cern.ch": lbcluster.Node{}}, - Parameters: lbcluster.Params{Behaviour: "mindless", Best_hosts: 10, External: false, Metric: "cmsfrontier", Polling_interval: 6, Statistics: "long"}, + }, + Host_metric_table: map[string]lbcluster.Node{"lxplus013.cern.ch": lbcluster.Node{}, "lxplus177.cern.ch": lbcluster.Node{}, "lxplus025.cern.ch": lbcluster.Node{}}, + Parameters: lbcluster.Params{Behaviour: "mindless", Best_hosts: 10, External: false, Metric: "cmsfrontier", Polling_interval: 6, Statistics: "long"}, //Time_of_last_evaluation time.Time Current_best_ips: []net.IP{}, Previous_best_ips_dns: []net.IP{}, - Slog: &lg, + Slog: logger, Current_index: 0}} - expected := map[string]lbhost.LBHost{ - "lxplus142.cern.ch": lbhost.LBHost{Cluster_name: "test01.cern.ch", - Host_name: "lxplus142.cern.ch", - Loadbalancing_username: "loadbalancing", - Loadbalancing_password: "zzz123", - LogFile: "", - Debugflag: false, - }, - "lxplus177.cern.ch": lbhost.LBHost{Cluster_name: "test01.cern.ch,test02.cern.ch", - Host_name: "lxplus177.cern.ch", - Loadbalancing_username: "loadbalancing", - Loadbalancing_password: "zzz123", - LogFile: "", - Debugflag: false, - }, - "lxplus013.cern.ch": lbhost.LBHost{Cluster_name: "test02.cern.ch", - Host_name: "lxplus013.cern.ch", - Loadbalancing_username: "loadbalancing", - Loadbalancing_password: "zzz123", - LogFile: "", - Debugflag: false, - }, - "lxplus025.cern.ch": lbhost.LBHost{Cluster_name: "test02.cern.ch", - Host_name: "lxplus025.cern.ch", - Loadbalancing_username: "loadbalancing", - Loadbalancing_password: "zzz123", - LogFile: "", - Debugflag: false, - }, + host1 := lbhost.NewLBHost(model.ClusterConfig{ + Cluster_name: "test01.cern.ch", + Loadbalancing_username: "loadbalancing", + Loadbalancing_password: "zzz123", + }, logger) + host1.SetName("lxplus142.cern.ch") + host2 := lbhost.NewLBHost(model.ClusterConfig{ + Cluster_name: "test01.cern.ch,test02.cern.ch", + Loadbalancing_username: "loadbalancing", + Loadbalancing_password: "zzz123", + }, logger) + host2.SetName("lxplus177.cern.ch") + host3 := lbhost.NewLBHost(model.ClusterConfig{ + Cluster_name: "test02.cern.ch", + Loadbalancing_username: "loadbalancing", + Loadbalancing_password: "zzz123", + }, logger) + host3.SetName("lxplus013.cern.ch") + host4 := lbhost.NewLBHost(model.ClusterConfig{ + Cluster_name: "test02.cern.ch", + Loadbalancing_username: "loadbalancing", + Loadbalancing_password: "zzz123", + }, logger) + host4.SetName("lxplus025.cern.ch") + expected := map[string]lbhost.Host{ + "lxplus142.cern.ch": host1, + "lxplus177.cern.ch": host2, + "lxplus013.cern.ch": host3, + "lxplus025.cern.ch": host4, } - hosts_to_check := make(map[string]lbhost.LBHost) + hostsToCheck := make(map[string]lbhost.Host) for _, c := range clusters { - c.Get_list_hosts(hosts_to_check) + c.GetHostList(hostsToCheck) } - if !reflect.DeepEqual(hosts_to_check, expected) { - t.Errorf("e.Get_list_hosts: got\n%v\nexpected\n%v", hosts_to_check, expected) + + for hostName, actualHost := range hostsToCheck { + expHost := expected[hostName] + if !reflect.DeepEqual(expHost.GetClusterConfig(), actualHost.GetClusterConfig()) { + t.Errorf("mismatch in cluster config. expected:%v,actual:%v", expHost.GetClusterConfig(), actualHost.GetClusterConfig()) + } } + os.Remove("sample.log") } diff --git a/tests/get_state_dns_test.go b/tests/get_state_dns_test.go index f691afc..c8493d3 100644 --- a/tests/get_state_dns_test.go +++ b/tests/get_state_dns_test.go @@ -1,11 +1,13 @@ package main_test import ( + "lb-experts/golbd/lbcluster" + "lb-experts/golbd/logger" + "lb-experts/golbd/model" "net" + "os" "reflect" "testing" - - "gitlab.cern.ch/lb-experts/golbd/lbcluster" ) //TestGetStateDNS tests the function GetStateDNS @@ -45,13 +47,13 @@ func TestGetStateDNS(t *testing.T) { iprecString = append(iprecString, ip.String()) } //Casting to string. The DeepEqual of IP is a bit tricky, since it can - received[c.Cluster_name] = []interface{}{iprecString, err} + received[c.ClusterConfig.Cluster_name] = []interface{}{iprecString, err} } //DeepEqual comparison between the map with expected values and the one with the outputs for _, c := range Clusters { - if !reflect.DeepEqual(received[c.Cluster_name], expected[c.Cluster_name]) { - t.Errorf("\ngot ips\n%T type and value %v\nexpected\n%T type and value %v", received[c.Cluster_name][0], received[c.Cluster_name][0], expected[c.Cluster_name][0], expected[c.Cluster_name][0]) - t.Errorf("\ngot error\n%T type and value %v\nexpected\n%T type and value %v", received[c.Cluster_name][1], received[c.Cluster_name][1], expected[c.Cluster_name][1], expected[c.Cluster_name][1]) + if !reflect.DeepEqual(received[c.ClusterConfig.Cluster_name], expected[c.ClusterConfig.Cluster_name]) { + t.Errorf("\ngot ips\n%T type and value %v\nexpected\n%T type and value %v", received[c.ClusterConfig.Cluster_name][0], received[c.ClusterConfig.Cluster_name][0], expected[c.ClusterConfig.Cluster_name][0], expected[c.ClusterConfig.Cluster_name][0]) + t.Errorf("\ngot error\n%T type and value %v\nexpected\n%T type and value %v", received[c.ClusterConfig.Cluster_name][1], received[c.ClusterConfig.Cluster_name][1], expected[c.ClusterConfig.Cluster_name][1], expected[c.ClusterConfig.Cluster_name][1]) } } } @@ -80,12 +82,14 @@ func TestRefreshDNS(t *testing.T) { for _, tc := range tests { t.Run(tc.cluster_name, func(t *testing.T) { - lg := lbcluster.Log{SyslogWriter: nil, Stdout: false, Debugflag: false} + lg, _ := logger.NewLoggerFactory("sample.log") cluster := lbcluster.LBCluster{ - Cluster_name: tc.cluster_name, + ClusterConfig: model.ClusterConfig{ + Cluster_name: tc.cluster_name, + }, Current_best_ips: tc.current_best_ips, Previous_best_ips_dns: []net.IP{}, - Slog: &lg, + Slog: lg, } cluster.RefreshDNS(dnsManager, "test-", "aW50ZXJuYWxzZWNyZXQ=", "ZXh0ZXJuYWxzZWNyZXQ=") @@ -106,4 +110,5 @@ func TestRefreshDNS(t *testing.T) { } }) } + os.Remove("sample.log") } diff --git a/tests/lbcluster_log_test.go b/tests/lbcluster_log_test.go new file mode 100644 index 0000000..f2f75ac --- /dev/null +++ b/tests/lbcluster_log_test.go @@ -0,0 +1,177 @@ +package main_test + +import ( + "fmt" + "io/ioutil" + "lb-experts/golbd/logger" + "log" + "os" + "strings" + "testing" + "time" +) + +const ( + testLogDirPath = "../tests" + testFileName = "sample" +) + +func getLogFilePath() string { + return fmt.Sprintf("%s/%s.%s", testLogDirPath, testFileName, "log") +} +func TestLBClusterLoggerForInitFailure(t *testing.T) { + _, err := logger.NewLoggerFactory("") + if err == nil { + t.Errorf("expected error not thrown") + } +} + +func TestLBClusterLoggerForInitSuccess(t *testing.T) { + defer deleteFile(t) + logger, err := logger.NewLoggerFactory(getLogFilePath()) + if err != nil { + t.Fail() + t.Errorf("unexpected error thrown. error: %v", err) + } + if logger == nil { + t.Fail() + t.Errorf("logger instance is nil") + } +} + +func TestLBClusterLoggerForSnapshot(t *testing.T) { + defer deleteFile(t) + logger, err := logger.NewLoggerFactory(getLogFilePath()) + if err != nil { + t.Fail() + t.Errorf("unexpected error thrown. error: %v", err) + } + if logger == nil { + t.Fail() + t.Errorf("logger instance is nil") + } + logger.StartSnapshot(1 * time.Minute) + if !strings.HasSuffix(logger.GetLogFilePath(), "sample.0.log") { + t.Fail() + t.Errorf("error while setting snapshot") + } +} + +func TestLBClusterLoggerForNewSnapshot(t *testing.T) { + defer deleteFile(t) + logger, err := logger.NewLoggerFactory(getLogFilePath()) + if err != nil { + t.Fail() + t.Errorf("unexpected error thrown. error: %v", err) + } + if logger == nil { + t.Fail() + t.Errorf("logger instance is nil") + } + curTime := time.Now() + logger.StartSnapshot(5 * time.Second) + time.Sleep(5 * time.Second) + curTime = curTime.Add(5 * time.Second) + logger.Info("sample info") + if !strings.Contains(logger.GetLogFilePath(), + fmt.Sprintf("%v_%v_%v-%v_%v_%v", curTime.Year(), curTime.Month(), curTime.Day(), curTime.Hour(), curTime.Minute(), curTime.Second())) { + t.Fail() + t.Errorf("error while setting snapshot") + } +} + +func TestLBClusterLoggerForDebugDisabled(t *testing.T) { + defer deleteFile(t) + logger, err := logger.NewLoggerFactory(getLogFilePath()) + if err != nil { + t.Fail() + t.Errorf("unexpected error thrown. error: %v", err) + } + logger.Debug("sample info") + if isLogPresentInFile(t, getLogFilePath(), "sample info") { + t.Fail() + t.Errorf("log file does not contain the expected debug info. Expected Info: %s", "sample info") + } +} + +func TestLBClusterLoggerForDebugEnabled(t *testing.T) { + defer deleteFile(t) + logger, err := logger.NewLoggerFactory(getLogFilePath()) + if err != nil { + t.Fail() + t.Errorf("unexpected error thrown. error: %v", err) + } + logger.EnableDebugMode() + logger.Debug("sample info") + if !isLogPresentInFile(t, getLogFilePath(), "sample info") { + t.Fail() + t.Errorf("log file does not contain the expected debug info. Expected Info: %s", "sample info") + } +} + +func TestLBClusterLoggerForInfo(t *testing.T) { + defer deleteFile(t) + logger, err := logger.NewLoggerFactory(getLogFilePath()) + if err != nil { + t.Fail() + t.Errorf("unexpected error thrown. error: %v", err) + } + logger.Info("sample info") + if !isLogPresentInFile(t, getLogFilePath(), "INFO: sample info") { + t.Fail() + t.Errorf("log file does not contain the expected debug info. Expected Info: %s", "INFO: sample info") + } +} + +func TestLBClusterLoggerForWarning(t *testing.T) { + defer deleteFile(t) + logger, err := logger.NewLoggerFactory(getLogFilePath()) + if err != nil { + t.Fail() + t.Errorf("unexpected error thrown. error: %v", err) + } + logger.Warning("sample info") + if !isLogPresentInFile(t, getLogFilePath(), "WARNING: sample info") { + t.Fail() + t.Errorf("log file does not contain the expected debug info. Expected Info: %s", "WARNING: sample info") + } +} + +func TestLBClusterLoggerForError(t *testing.T) { + defer deleteFile(t) + logger, err := logger.NewLoggerFactory(getLogFilePath()) + if err != nil { + t.Fail() + t.Errorf("unexpected error thrown. error: %v", err) + } + logger.Error("sample info") + if !isLogPresentInFile(t, getLogFilePath(), "ERROR: sample info") { + t.Fail() + t.Errorf("log file does not contain the expected debug info. Expected Info: %s", "ERROR: sample info") + } +} + +func deleteFile(t *testing.T) { + files, err := ioutil.ReadDir(testLogDirPath) + if err != nil { + log.Fatal(err) + } + + for _, f := range files { + if strings.HasPrefix(f.Name(), testFileName) { + err = os.Remove(f.Name()) + if err != nil { + t.Errorf("error whil deleting log file error: %v", err) + } + } + } +} + +func isLogPresentInFile(t *testing.T, filePath string, stringToCheck string) bool { + data, err := ioutil.ReadFile(filePath) + if err != nil { + t.Errorf("error while reading log file. error: %v", err) + return false + } + return strings.Contains(string(data), stringToCheck) +} diff --git a/tests/lbhost_test.go b/tests/lbhost_test.go new file mode 100644 index 0000000..572bb1a --- /dev/null +++ b/tests/lbhost_test.go @@ -0,0 +1,50 @@ +package main_test + +import ( + "github.com/reguero/go-snmplib" + "lb-experts/golbd/lbhost" + "lb-experts/golbd/logger" + "net" + "os" + "testing" + "time" +) + +type mockSNMPAgent struct { +} + +func (m mockSNMPAgent) Close() error { + return nil +} + +func (m mockSNMPAgent) Discover() error { + time.Sleep(1 * time.Second) + return nil +} + +func (m mockSNMPAgent) GetV3(oid snmplib.Oid) (interface{}, error) { + return 200, nil +} + +func NewMockSNMPAgent() lbhost.DiscoveryAgent { + return &mockSNMPAgent{} +} + +func TestSNMPDiscoveryForConcurrency(t *testing.T) { + lg, _ := logger.NewLoggerFactory("sample.log") + lg.EnableWriteToSTd() + host := lbhost.LBHost{Logger: lg, SnmpAgent: NewMockSNMPAgent()} + host.HostTransports = []lbhost.LBHostTransportResult{ + {IP: net.ParseIP("1.1.1.1"), Transport: "udp"}, + {IP: net.ParseIP("1.1.1.2"), Transport: "udp"}, + {IP: net.ParseIP("1.1.1.3"), Transport: "udp"}, + } + startTime := time.Now() + host.SNMPDiscovery() + endTime := time.Now() + if endTime.Sub(startTime) > 2*time.Second { + t.Fail() + t.Errorf("execution took more time than expected. expectedTime: %v, actualTime:%v", 1, endTime.Sub(startTime)) + } + os.Remove("sample.log") +} diff --git a/tests/loadClusters_test.go b/tests/loadClusters_test.go index 380af1c..57fab5c 100644 --- a/tests/loadClusters_test.go +++ b/tests/loadClusters_test.go @@ -1,20 +1,26 @@ package main_test import ( + "lb-experts/golbd/lbcluster" + "lb-experts/golbd/lbconfig" + "lb-experts/golbd/lbhost" + "lb-experts/golbd/logger" + "lb-experts/golbd/model" "net" + "os" "reflect" "testing" - - "gitlab.cern.ch/lb-experts/golbd/lbcluster" - "gitlab.cern.ch/lb-experts/golbd/lbconfig" - "gitlab.cern.ch/lb-experts/golbd/lbhost" ) func getTestCluster(name string) lbcluster.LBCluster { - lg := lbcluster.Log{SyslogWriter: nil, Stdout: true, Debugflag: false} - return lbcluster.LBCluster{Cluster_name: name, - Loadbalancing_username: "loadbalancing", - Loadbalancing_password: "zzz123", + lg, _ := logger.NewLoggerFactory("sample.log") + + return lbcluster.LBCluster{ + ClusterConfig: model.ClusterConfig{ + Cluster_name: name, + Loadbalancing_username: "loadbalancing", + Loadbalancing_password: "zzz123", + }, Host_metric_table: map[string]lbcluster.Node{ "lxplus132.cern.ch": lbcluster.Node{Load: 100000, IPs: []net.IP{}}, "lxplus041.cern.ch": lbcluster.Node{Load: 100000, IPs: []net.IP{}}, @@ -25,15 +31,19 @@ func getTestCluster(name string) lbcluster.LBCluster { //Time_of_last_evaluation time.Time Current_best_ips: []net.IP{}, Previous_best_ips_dns: []net.IP{}, - Slog: &lg, + Slog: lg, Current_index: 0} } func getSecondTestCluster() lbcluster.LBCluster { - lg := lbcluster.Log{SyslogWriter: nil, Stdout: true, Debugflag: false} - return lbcluster.LBCluster{Cluster_name: "test02.test.cern.ch", - Loadbalancing_username: "loadbalancing", - Loadbalancing_password: "zzz123", + lg, _ := logger.NewLoggerFactory("sample.log") + + return lbcluster.LBCluster{ + ClusterConfig: model.ClusterConfig{ + Cluster_name: "test02.test.cern.ch", + Loadbalancing_username: "loadbalancing", + Loadbalancing_password: "zzz123", + }, Host_metric_table: map[string]lbcluster.Node{ "lxplus013.cern.ch": lbcluster.Node{Load: 100000, IPs: []net.IP{}}, "lxplus038.cern.ch": lbcluster.Node{Load: 100000, IPs: []net.IP{}}, @@ -43,152 +53,140 @@ func getSecondTestCluster() lbcluster.LBCluster { //Time_of_last_evaluation time.Time Current_best_ips: []net.IP{}, Previous_best_ips_dns: []net.IP{}, - Slog: &lg, + Slog: lg, Current_index: 0} } -func getHostsToCheck(c lbcluster.LBCluster) map[string]lbhost.LBHost { - hostsToCheck := map[string]lbhost.LBHost{ - "lxplus132.cern.ch": lbhost.LBHost{Cluster_name: c.Cluster_name, - Host_name: "lxplus132.cern.ch", - Host_transports: []lbhost.LBHostTransportResult{ - lbhost.LBHostTransportResult{Transport: "udp6", Response_int: 2, Response_string: "", IP: net.ParseIP("2001:1458:d00:2c::100:a6"), Response_error: ""}, - lbhost.LBHostTransportResult{Transport: "udp", Response_int: 2, Response_string: "", IP: net.ParseIP("188.184.108.98"), Response_error: ""}, - }, - Loadbalancing_username: c.Loadbalancing_username, - Loadbalancing_password: c.Loadbalancing_password, - LogFile: c.Slog.TofilePath, - Debugflag: c.Slog.Debugflag, - }, - "lxplus041.cern.ch": lbhost.LBHost{Cluster_name: c.Cluster_name, - Host_name: "lxplus041.cern.ch", - Host_transports: []lbhost.LBHostTransportResult{ - lbhost.LBHostTransportResult{Transport: "udp6", Response_int: 3, Response_string: "", IP: net.ParseIP("2001:1458:d00:32::100:51"), Response_error: ""}, - lbhost.LBHostTransportResult{Transport: "udp", Response_int: 3, Response_string: "", IP: net.ParseIP("188.184.116.81"), Response_error: ""}, - }, - Loadbalancing_username: c.Loadbalancing_username, - Loadbalancing_password: c.Loadbalancing_password, - LogFile: c.Slog.TofilePath, - Debugflag: c.Slog.Debugflag, - }, - "lxplus130.cern.ch": lbhost.LBHost{Cluster_name: c.Cluster_name, - Host_name: "lxplus130.cern.ch", - Host_transports: []lbhost.LBHostTransportResult{lbhost.LBHostTransportResult{Transport: "udp", Response_int: 27, Response_string: "", IP: net.ParseIP("188.184.108.100"), Response_error: ""}}, - Loadbalancing_username: c.Loadbalancing_username, - Loadbalancing_password: c.Loadbalancing_password, - LogFile: c.Slog.TofilePath, - Debugflag: c.Slog.Debugflag, - }, - "lxplus133.subdo.cern.ch": lbhost.LBHost{Cluster_name: c.Cluster_name, - Host_name: "lxplus130.subdo.cern.ch", - Host_transports: []lbhost.LBHostTransportResult{lbhost.LBHostTransportResult{Transport: "udp", Response_int: 27, Response_string: "", IP: net.ParseIP("188.184.108.101"), Response_error: ""}}, - Loadbalancing_username: c.Loadbalancing_username, - Loadbalancing_password: c.Loadbalancing_password, - LogFile: c.Slog.TofilePath, - Debugflag: c.Slog.Debugflag, - }, - "monit-kafkax-17be060b0d.cern.ch": lbhost.LBHost{Cluster_name: c.Cluster_name, - Host_name: "monit-kafkax-17be060b0d.cern.ch", - Host_transports: []lbhost.LBHostTransportResult{lbhost.LBHostTransportResult{Transport: "udp", Response_int: 100000, Response_string: "monit-kafkax.cern.ch=816,monit-kafka.cern.ch=816,test01.cern.ch=816", IP: net.ParseIP("188.184.108.100"), Response_error: ""}}, - Loadbalancing_username: c.Loadbalancing_username, - Loadbalancing_password: c.Loadbalancing_password, - LogFile: c.Slog.TofilePath, - Debugflag: c.Slog.Debugflag, - }, +func getHostsToCheck(c lbcluster.LBCluster) map[string]lbhost.Host { + lg, _ := logger.NewLoggerFactory("sample.log") + host1 := lbhost.NewLBHost(c.ClusterConfig, lg) + host1.SetName("lxplus132.cern.ch") + host1.SetTransportPayload([]lbhost.LBHostTransportResult{ + lbhost.LBHostTransportResult{Transport: "udp6", Response_int: 2, Response_string: "", IP: net.ParseIP("2001:1458:d00:2c::100:a6"), Response_error: ""}, + lbhost.LBHostTransportResult{Transport: "udp", Response_int: 2, Response_string: "", IP: net.ParseIP("188.184.108.98"), Response_error: ""}, + }) + host2 := lbhost.NewLBHost(c.ClusterConfig, lg) + host2.SetName("lxplus041.cern.ch") + host2.SetTransportPayload([]lbhost.LBHostTransportResult{ + lbhost.LBHostTransportResult{Transport: "udp6", Response_int: 3, Response_string: "", IP: net.ParseIP("2001:1458:d00:32::100:51"), Response_error: ""}, + lbhost.LBHostTransportResult{Transport: "udp", Response_int: 3, Response_string: "", IP: net.ParseIP("188.184.116.81"), Response_error: ""}, + }) + host3 := lbhost.NewLBHost(c.ClusterConfig, lg) + host3.SetName("lxplus130.cern.ch") + host3.SetTransportPayload([]lbhost.LBHostTransportResult{lbhost.LBHostTransportResult{ + Transport: "udp", Response_int: 27, Response_string: "", IP: net.ParseIP("188.184.108.100"), Response_error: "", + }}) + host4 := lbhost.NewLBHost(c.ClusterConfig, lg) + host4.SetName("lxplus133.subdo.cern.ch") + host4.SetTransportPayload([]lbhost.LBHostTransportResult{lbhost.LBHostTransportResult{ + Transport: "udp", Response_int: 27, Response_string: "", IP: net.ParseIP("188.184.108.101"), Response_error: "", + }}) + host5 := lbhost.NewLBHost(c.ClusterConfig, lg) + host5.SetName("monit-kafkax-17be060b0d.cern.ch") + host5.SetTransportPayload([]lbhost.LBHostTransportResult{lbhost.LBHostTransportResult{ + Transport: "udp", Response_int: 100000, Response_string: "monit-kafkax.cern.ch=816,monit-kafka.cern.ch=816,test01.cern.ch=816", IP: net.ParseIP("188.184.108.100"), Response_error: ""}}, + ) + hostsToCheck := map[string]lbhost.Host{ + "lxplus132.cern.ch": host1, + "lxplus041.cern.ch": host2, + "lxplus130.cern.ch": host3, + "lxplus133.subdo.cern.ch": host4, + "monit-kafkax-17be060b0d.cern.ch": host5, } return hostsToCheck } -func getBadHostsToCheck(c lbcluster.LBCluster) map[string]lbhost.LBHost { - badHostsToCheck := map[string]lbhost.LBHost{ - "lxplus132.cern.ch": lbhost.LBHost{Cluster_name: c.Cluster_name, - Host_name: "lxplus132.cern.ch", - Host_transports: []lbhost.LBHostTransportResult{ - lbhost.LBHostTransportResult{Transport: "udp6", Response_int: -2, Response_string: "", IP: net.ParseIP("2001:1458:d00:2c::100:a6"), Response_error: ""}, - lbhost.LBHostTransportResult{Transport: "udp", Response_int: -2, Response_string: "", IP: net.ParseIP("188.184.108.98"), Response_error: ""}, - }, - Loadbalancing_username: c.Loadbalancing_username, - Loadbalancing_password: c.Loadbalancing_password, - LogFile: c.Slog.TofilePath, - Debugflag: c.Slog.Debugflag, - }, - "lxplus041.cern.ch": lbhost.LBHost{Cluster_name: c.Cluster_name, - Host_name: "lxplus041.cern.ch", - Host_transports: []lbhost.LBHostTransportResult{ - lbhost.LBHostTransportResult{Transport: "udp6", Response_int: -3, Response_string: "", IP: net.ParseIP("2001:1458:d00:32::100:51"), Response_error: ""}, - lbhost.LBHostTransportResult{Transport: "udp", Response_int: -3, Response_string: "", IP: net.ParseIP("188.184.116.81"), Response_error: ""}, - }, - Loadbalancing_username: c.Loadbalancing_username, - Loadbalancing_password: c.Loadbalancing_password, - LogFile: c.Slog.TofilePath, - Debugflag: c.Slog.Debugflag, - }, - "lxplus130.cern.ch": lbhost.LBHost{Cluster_name: c.Cluster_name, - Host_name: "lxplus130.cern.ch", - Host_transports: []lbhost.LBHostTransportResult{lbhost.LBHostTransportResult{Transport: "udp", Response_int: -27, Response_string: "", IP: net.ParseIP("188.184.108.100"), Response_error: ""}}, - Loadbalancing_username: c.Loadbalancing_username, - Loadbalancing_password: c.Loadbalancing_password, - LogFile: c.Slog.TofilePath, - Debugflag: c.Slog.Debugflag, - }, - "lxplus133.subdo.cern.ch": lbhost.LBHost{Cluster_name: c.Cluster_name, - Host_name: "lxplus133.subdo.cern.ch", - Host_transports: []lbhost.LBHostTransportResult{lbhost.LBHostTransportResult{Transport: "udp", Response_int: -15, Response_string: "", IP: net.ParseIP("188.184.108.101"), Response_error: ""}}, - Loadbalancing_username: c.Loadbalancing_username, - Loadbalancing_password: c.Loadbalancing_password, - LogFile: c.Slog.TofilePath, - Debugflag: c.Slog.Debugflag, - }, - "monit-kafkax-17be060b0d.cern.ch": lbhost.LBHost{Cluster_name: c.Cluster_name, - Host_name: "monit-kafkax-17be060b0d.cern.ch", - Host_transports: []lbhost.LBHostTransportResult{lbhost.LBHostTransportResult{Transport: "udp", Response_int: 100000, Response_string: "monit-kafkax.cern.ch=816,monit-kafka.cern.ch=816,test01.cern.ch=816", IP: net.ParseIP("188.184.108.100"), Response_error: ""}}, - Loadbalancing_username: c.Loadbalancing_username, - Loadbalancing_password: c.Loadbalancing_password, - LogFile: c.Slog.TofilePath, - Debugflag: c.Slog.Debugflag, - }, +func getBadHostsToCheck(c lbcluster.LBCluster) map[string]lbhost.Host { + lg, _ := logger.NewLoggerFactory("sample.log") + host1 := lbhost.NewLBHost(c.ClusterConfig, lg) + host1.SetName("lxplus132.cern.ch") + host1.SetTransportPayload([]lbhost.LBHostTransportResult{ + lbhost.LBHostTransportResult{Transport: "udp6", Response_int: -2, Response_string: "", IP: net.ParseIP("2001:1458:d00:2c::100:a6"), Response_error: ""}, + lbhost.LBHostTransportResult{Transport: "udp", Response_int: -2, Response_string: "", IP: net.ParseIP("188.184.108.98"), Response_error: ""}, + }) + host2 := lbhost.NewLBHost(c.ClusterConfig, lg) + host2.SetName("lxplus041.cern.ch") + host2.SetTransportPayload([]lbhost.LBHostTransportResult{ + lbhost.LBHostTransportResult{Transport: "udp6", Response_int: -3, Response_string: "", IP: net.ParseIP("2001:1458:d00:32::100:51"), Response_error: ""}, + lbhost.LBHostTransportResult{Transport: "udp", Response_int: -3, Response_string: "", IP: net.ParseIP("188.184.116.81"), Response_error: ""}, + }) + host3 := lbhost.NewLBHost(c.ClusterConfig, lg) + host3.SetName("lxplus130.cern.ch") + host3.SetTransportPayload([]lbhost.LBHostTransportResult{lbhost.LBHostTransportResult{ + Transport: "udp", Response_int: -27, Response_string: "", IP: net.ParseIP("188.184.108.100"), Response_error: "", + }}) + host4 := lbhost.NewLBHost(c.ClusterConfig, lg) + host4.SetName("lxplus133.subdo.cern.ch") + host4.SetTransportPayload([]lbhost.LBHostTransportResult{lbhost.LBHostTransportResult{ + Transport: "udp", Response_int: -15, Response_string: "", IP: net.ParseIP("188.184.108.101"), Response_error: "", + }}) + host5 := lbhost.NewLBHost(c.ClusterConfig, lg) + host5.SetName("monit-kafkax-17be060b0d.cern.ch") + host5.SetTransportPayload([]lbhost.LBHostTransportResult{lbhost.LBHostTransportResult{ + Transport: "udp", Response_int: 100000, Response_string: "monit-kafkax.cern.ch=816,monit-kafka.cern.ch=816,test01.cern.ch=816", IP: net.ParseIP("188.184.108.100"), Response_error: ""}}, + ) + badHostsToCheck := map[string]lbhost.Host{ + "lxplus132.cern.ch": host1, + "lxplus041.cern.ch": host2, + "lxplus130.cern.ch": host3, + "lxplus133.subdo.cern.ch": host4, + "monit-kafkax-17be060b0d.cern.ch": host5, } return badHostsToCheck } -func getHost(hostname string, responseInt int, responseString string) lbhost.LBHost { - - return lbhost.LBHost{Cluster_name: "test01.cern.ch", - Host_name: hostname, - Host_transports: []lbhost.LBHostTransportResult{ - lbhost.LBHostTransportResult{Transport: "udp", Response_int: responseInt, Response_string: responseString, IP: net.ParseIP("188.184.108.98"), Response_error: ""}}, +func getHost(hostname string, responseInt int, responseString string) lbhost.Host { + lg, _ := logger.NewLoggerFactory("sample.log") + clusterConfig := model.ClusterConfig{ + Cluster_name: "test01.cern.ch", Loadbalancing_username: "loadbalancing", Loadbalancing_password: "XXXX", - LogFile: "", - Debugflag: false, } + host1 := lbhost.NewLBHost(clusterConfig, lg) + host1.SetName(hostname) + host1.SetTransportPayload([]lbhost.LBHostTransportResult{ + lbhost.LBHostTransportResult{Transport: "udp", Response_int: responseInt, Response_string: responseString, IP: net.ParseIP("188.184.108.98"), Response_error: ""}}, + ) + return host1 } func TestLoadClusters(t *testing.T) { - lg := lbcluster.Log{SyslogWriter: nil, Stdout: true, Debugflag: false} + lg, _ := logger.NewLoggerFactory("sample.log") + lg.EnableWriteToSTd() - config := lbconfig.Config{Master: "lbdxyz.cern.ch", - HeartbeatFile: "heartbeat", - HeartbeatPath: "/work/go/src/github.com/cernops/golbd", - //HeartbeatMu: sync.Mutex{0, 0}, - TsigKeyPrefix: "abcd-", - TsigInternalKey: "xxx123==", - TsigExternalKey: "yyy123==", - SnmpPassword: "zzz123", - DNSManager: "111.111.0.111:53", - Clusters: map[string][]string{"test01.cern.ch": {"lxplus132.cern.ch", "lxplus041.cern.ch", "lxplus130.cern.ch", "lxplus133.subdo.cern.ch", "monit-kafkax-17be060b0d.cern.ch"}, "test02.test.cern.ch": {"lxplus013.cern.ch", "lxplus038.cern.ch", "lxplus039.test.cern.ch", "lxplus025.cern.ch"}}, - Parameters: map[string]lbcluster.Params{"test01.cern.ch": lbcluster.Params{Behaviour: "mindless", Best_hosts: 2, - External: true, Metric: "cmsfrontier", Polling_interval: 6, Statistics: "long"}, - "test02.test.cern.ch": lbcluster.Params{Behaviour: "mindless", Best_hosts: 10, External: false, Metric: "cmsfrontier", Polling_interval: 6, Statistics: "long"}}} - expected := []lbcluster.LBCluster{getTestCluster("test01.cern.ch"), - getSecondTestCluster()} + config := lbconfig.NewLoadBalancerConfig("", lg) + config.SetMasterHost("lbdxyz.cern.ch") + config.SetHeartBeatFileName("heartbeat") + config.SetHeartBeatDirPath("/work/go/src/github.com/cernops/golbd") + config.SetTSIGKeyPrefix("abcd-") + config.SetTSIGInternalKey("xxx123==") + config.SetTSIGExternalKey("yyy123==") + config.SetDNSManager("111.111.0.111:53") + config.SetSNMPPassword("zzz123") + config.SetClusters(map[string][]string{ + "test01.cern.ch": {"lxplus132.cern.ch", "lxplus041.cern.ch", "lxplus130.cern.ch", "lxplus133.subdo.cern.ch", "monit-kafkax-17be060b0d.cern.ch"}, + "test02.test.cern.ch": {"lxplus013.cern.ch", "lxplus038.cern.ch", "lxplus039.test.cern.ch", "lxplus025.cern.ch"}, + }) + config.SetParameters(map[string]lbcluster.Params{ + "test01.cern.ch": lbcluster.Params{Behaviour: "mindless", Best_hosts: 2, External: true, Metric: "cmsfrontier", Polling_interval: 6, Statistics: "long"}, + "test02.test.cern.ch": lbcluster.Params{Behaviour: "mindless", Best_hosts: 10, External: false, Metric: "cmsfrontier", Polling_interval: 6, Statistics: "long"}, + }) + testCluster1 := getTestCluster("test01.cern.ch") + testCluster1.Slog = lg + testCluster2 := getSecondTestCluster() + testCluster2.Slog = lg + expected := []lbcluster.LBCluster{testCluster1, testCluster2} - lbclusters, _ := lbconfig.LoadClusters(&config, &lg) + lbclusters, err := config.LoadClusters() + if err != nil { + t.Errorf("error while loading clusters. error: %v", err) + return + } // reflect.DeepEqual(lbclusters, expected) occassionally fails as the array order is not always the same // so comparing element par element i := 0 for _, e := range expected { for _, c := range lbclusters { - if c.Cluster_name == e.Cluster_name { + if c.ClusterConfig.Cluster_name == e.ClusterConfig.Cluster_name { if !reflect.DeepEqual(c, e) { t.Errorf("loadClusters: got\n%v\nexpected\n%v", lbclusters, expected) } else { @@ -202,4 +200,9 @@ func TestLoadClusters(t *testing.T) { t.Errorf("loadClusters: wrong number of clusters, got\n%v\nexpected\n%v (and %v", len(lbclusters), len(expected), i) } + err = os.Remove("sample.log") + if err != nil { + t.Fail() + t.Errorf("error deleting file.error %v", err) + } } diff --git a/tests/loadConfig_test.go b/tests/loadConfig_test.go index f8b2c0d..5a358e0 100644 --- a/tests/loadConfig_test.go +++ b/tests/loadConfig_test.go @@ -3,60 +3,120 @@ package main import ( "os" "reflect" + "sync" "testing" + "time" - "gitlab.cern.ch/lb-experts/golbd/lbcluster" - "gitlab.cern.ch/lb-experts/golbd/lbconfig" + "lb-experts/golbd/lbcluster" + "lb-experts/golbd/lbconfig" + "lb-experts/golbd/logger" ) func TestLoadConfig(t *testing.T) { - lg := lbcluster.Log{Stdout: true, Debugflag: false} testFiles := []string{"testloadconfig.yaml", "testloadconfig"} - - //open files + lg, _ := logger.NewLoggerFactory("sample.log") + lg.EnableWriteToSTd() for _, testFile := range testFiles { - loadconfig, err := os.Open(testFile) + configFromFile := lbconfig.NewLoadBalancerConfig(testFile, lg) + _, err := configFromFile.Load() if err != nil { - panic(err) + t.Fail() + t.Errorf("loadConfig Error: %v", err.Error()) + } + expConfig := lbconfig.NewLoadBalancerConfig(testFile, lg) + expConfig.SetMasterHost("lbdxyz.cern.ch") + expConfig.SetHeartBeatFileName("heartbeat") + expConfig.SetHeartBeatDirPath("/work/go/src/github.com/cernops/golbd") + expConfig.SetTSIGKeyPrefix("abcd-") + expConfig.SetTSIGInternalKey("xxx123==") + expConfig.SetTSIGExternalKey("yyy123==") + expConfig.SetDNSManager("137.138.28.176:53") + expConfig.SetSNMPPassword("zzz123") + expConfig.SetClusters(map[string][]string{ + "aiermis.cern.ch": {"ermis19.cern.ch", "ermis20.cern.ch"}, + "uermis.cern.ch": {"ermis21.cern.ch", "ermis22.cern.ch"}, + "permis.cern.ch": {"ermis21.sub.cern.ch", "ermis22.test.cern.ch", "ermis42.cern.ch"}, + "ermis.test.cern.ch": {"ermis23.cern.ch", "ermis24.cern.ch"}, + "ermis2.test.cern.ch": {"ermis23.toto.cern.ch", "ermis24.cern.ch", "ermis25.sub.cern.ch"}, + }) + expConfig.SetParameters(map[string]lbcluster.Params{ + "aiermis.cern.ch": {Behaviour: "mindless", Best_hosts: 1, External: false, Metric: "cmsfrontier", Polling_interval: 300, Statistics: "long", Ttl: 60}, + "uermis.cern.ch": {Behaviour: "mindless", Best_hosts: 1, External: false, Metric: "cmsfrontier", Polling_interval: 300, Statistics: "long", Ttl: 222}, + "permis.cern.ch": {Behaviour: "mindless", Best_hosts: 1, External: false, Metric: "cmsfrontier", Polling_interval: 300, Statistics: "long", Ttl: 222}, + "ermis.test.cern.ch": {Behaviour: "mindless", Best_hosts: 1, External: false, Metric: "cmsfrontier", Polling_interval: 300, Statistics: "long", Ttl: 222}, + "ermis2.test.cern.ch": {Behaviour: "mindless", Best_hosts: 1, External: false, Metric: "cmsfrontier", Polling_interval: 300, Statistics: "long", Ttl: 222}, + }) + + if !reflect.DeepEqual(configFromFile, expConfig) { + t.Errorf("loadConfig: got\n %v expected\n %v", configFromFile, expConfig) } - defer loadconfig.Close() - - // The expected output - expected := - lbconfig.Config{ - Master: "lbdxyz.cern.ch", - HeartbeatFile: "heartbeat", - HeartbeatPath: "/work/go/src/github.com/cernops/golbd", - //HeartbeatMu: sync.Mutex{0, 0}, - TsigKeyPrefix: "abcd-", - TsigInternalKey: "xxx123==", - TsigExternalKey: "yyy123==", - SnmpPassword: "zzz123", - DNSManager: "137.138.28.176:53", - ConfigFile: testFile, - Clusters: map[string][]string{ - "aiermis.cern.ch": {"ermis19.cern.ch", "ermis20.cern.ch"}, - "uermis.cern.ch": {"ermis21.cern.ch", "ermis22.cern.ch"}, - "permis.cern.ch": {"ermis21.sub.cern.ch", "ermis22.test.cern.ch", "ermis42.cern.ch"}, - "ermis.test.cern.ch": {"ermis23.cern.ch", "ermis24.cern.ch"}, - "ermis2.test.cern.ch": {"ermis23.toto.cern.ch", "ermis24.cern.ch", "ermis25.sub.cern.ch"}}, - Parameters: map[string]lbcluster.Params{ - "aiermis.cern.ch": {Behaviour: "mindless", Best_hosts: 1, External: false, Metric: "cmsfrontier", Polling_interval: 300, Statistics: "long", Ttl: 60}, - "uermis.cern.ch": {Behaviour: "mindless", Best_hosts: 1, External: false, Metric: "cmsfrontier", Polling_interval: 300, Statistics: "long", Ttl: 222}, - "permis.cern.ch": {Behaviour: "mindless", Best_hosts: 1, External: false, Metric: "cmsfrontier", Polling_interval: 300, Statistics: "long", Ttl: 222}, - "ermis.test.cern.ch": {Behaviour: "mindless", Best_hosts: 1, External: false, Metric: "cmsfrontier", Polling_interval: 300, Statistics: "long", Ttl: 222}, - "ermis2.test.cern.ch": {Behaviour: "mindless", Best_hosts: 1, External: false, Metric: "cmsfrontier", Polling_interval: 300, Statistics: "long", Ttl: 222}}} - - //retrieving the actual output - configExisting, _, e := lbconfig.LoadConfig(loadconfig.Name(), &lg) - - if e != nil { - t.Errorf("loadConfig Error: %v", e.Error()) - } else { - if !reflect.DeepEqual(configExisting, &expected) { - t.Errorf("loadConfig: got\n %+v \nexpected\n %+v", configExisting, &expected) + + } + os.Remove("sample.log") + +} +func TestWatchConfigFileChanges(t *testing.T) { + lg, _ := logger.NewLoggerFactory("sample.log") + lg.EnableWriteToSTd() + var wg sync.WaitGroup + var controlChan = make(chan bool) + var changeCounter int + sampleConfigFileName := "sampleConfig" + dataSet := []string{ + "data 1", + "data 12", + "data 123", + } + + err := createTestConfigFile(sampleConfigFileName) + if err != nil { + t.Fail() + t.Errorf("error while creating test config file. name: %s", sampleConfigFileName) + } + go func() { + defer close(controlChan) + for _, dataToWrite := range dataSet { + time.Sleep(1 * time.Second) + err = writeDataToFile(sampleConfigFileName, dataToWrite) + if err != nil { + t.Fail() + t.Errorf("error while writting to test config file. filename: %s, data:%s", sampleConfigFileName, dataToWrite) } } + }() + config := lbconfig.NewLoadBalancerConfig(sampleConfigFileName, lg) + fileChangeSignal := config.WatchFileChange(controlChan, wg) + for fileChangeData := range fileChangeSignal { + changeCounter += 1 + t.Log("file change signal", fileChangeData) } + if changeCounter == 0 { + t.Fail() + t.Error("file changes not observed") + } + deleteFile("sample.log") + deleteFile(sampleConfigFileName) +} + +func createTestConfigFile(fileName string) error { + _, err := os.Create(fileName) + if err != nil { + return err + } + return nil +} + +func writeDataToFile(fileName string, data string) error { + fp, err := os.Create(fileName) + if err != nil { + return err + } + _, err = fp.WriteString(data) + fp.Close() + return err +} + +func deleteFile(fileName string) { + os.Remove(fileName) } diff --git a/tests/metric_test.go b/tests/metric_test.go new file mode 100644 index 0000000..4752770 --- /dev/null +++ b/tests/metric_test.go @@ -0,0 +1,59 @@ +package main_test + +import ( + "lb-experts/golbd/metric" + "os" + "strings" + "testing" + "time" +) + +func TestMetricReadWriteRecord(t *testing.T) { + hostName, _ := os.Hostname() + logic := metric.NewLogic("", hostName) + curTime := time.Now() + property := metric.Property{ + RoundTripStartTime: curTime, + RoundTripEndTime: curTime.Add(5 * time.Second), + RoundTripDuration: 5 * time.Second, + } + err := logic.WriteRecord(property) + if err != nil { + t.Fail() + t.Errorf("error while recoding metric. error:%v", err) + return + } + hostMetric, err := logic.ReadHostMetric() + if err != nil { + t.Fail() + t.Errorf("error while reading metric.error:%v", err) + return + } + + if hostMetric.PropertyList == nil || len(hostMetric.PropertyList) == 0 { + t.Fail() + t.Errorf("property list is empty. expected length :%d", 1) + return + } + expectedStarttime := property.RoundTripStartTime.Format(time.RFC3339) + expectedEndtime := property.RoundTripEndTime.Format(time.RFC3339) + actualStartTime := hostMetric.PropertyList[0].RoundTripStartTime.Format(time.RFC3339) + actualEndTime := hostMetric.PropertyList[0].RoundTripEndTime.Format(time.RFC3339) + if !strings.EqualFold(expectedStarttime, actualStartTime) { + t.Fail() + t.Errorf("start time value mismatch expected: %v, actual:%v", expectedStarttime, actualStartTime) + } + if !strings.EqualFold(expectedEndtime, actualEndTime) { + t.Fail() + t.Errorf("end time value mismatch expected: %v, actual:%v", expectedEndtime, actualEndTime) + } + if hostMetric.PropertyList[0].RoundTripDuration != property.RoundTripDuration { + t.Fail() + t.Errorf("duration value mismatch expected: %v, actual:%v", property.RoundTripDuration, hostMetric.PropertyList[0].RoundTripDuration) + } + err = os.Remove(logic.GetFilePath()) + if err != nil { + t.Fail() + t.Errorf("error deleting file.error %v", err) + } +} diff --git a/tests/retry_module_test.go b/tests/retry_module_test.go new file mode 100644 index 0000000..e2b970a --- /dev/null +++ b/tests/retry_module_test.go @@ -0,0 +1,108 @@ +package main_test + +import ( + "fmt" + "lb-experts/golbd/logger" + "os" + "testing" + "time" + + "lb-experts/golbd/lbcluster" +) + +func TestRetryWithNoErrorsShouldExitAfterFirstAttempt(t *testing.T) { + lg, _ := logger.NewLoggerFactory("sample.log") + currentTime := time.Now() + retryModule := lbcluster.NewRetryModule(10*time.Second, lg) + err := retryModule.Execute(func() error { + return nil + }) + if err != nil { + t.Fail() + t.Errorf("error should be nil") + } + if time.Now().Sub(currentTime) > 1*time.Second { + t.Fail() + t.Errorf("should quit after first try") + } + os.Remove("sample.log") +} + +func TestRetryWithErrorShouldQuitAfterMultipleAttempts(t *testing.T) { + lg, _ := logger.NewLoggerFactory("sample.log") + currentTime := time.Now() + counter := 0 + retryModule := lbcluster.NewRetryModule(1*time.Second, lg) + err := retryModule.Execute(func() error { + if counter == 4 { + return nil + } + counter += 1 + return fmt.Errorf("sample error") + }) + if err != nil { + t.Fail() + t.Errorf("error should be nil") + } + if time.Now().Sub(currentTime) > 12*time.Second { + t.Fail() + t.Errorf("should quit after expected: %v, actual:%v", "11 sec", time.Now().Sub(currentTime)) + } + os.Remove("sample.log") +} + +func TestRetryWithErrorShouldQuitAfterMaxCount(t *testing.T) { + lg, _ := logger.NewLoggerFactory("sample.log") + currentTime := time.Now() + counter := 0 + retryModule := lbcluster.NewRetryModule(1*time.Second, lg) + err := retryModule.SetMaxCount(3) + if err != nil { + t.Fail() + t.Errorf("error should be nil") + } + err = retryModule.Execute(func() error { + if counter == 4 { + return nil + } + counter += 1 + return fmt.Errorf("sample error") + }) + if err == nil { + t.Fail() + t.Errorf("error should be nil") + } + if time.Now().Sub(currentTime) > 4*time.Second { + t.Fail() + t.Errorf("should quit after expected: %v, actual:%v", "3 sec", time.Now().Sub(currentTime)) + } + os.Remove("sample.log") +} + +func TestRetryWithErrorShouldQuitAfterMaxDuration(t *testing.T) { + lg, _ := logger.NewLoggerFactory("sample.log") + currentTime := time.Now() + counter := 0 + retryModule := lbcluster.NewRetryModule(1*time.Second, lg) + err := retryModule.SetMaxDuration(4 * time.Second) + if err != nil { + t.Fail() + t.Errorf("error should be nil") + } + err = retryModule.Execute(func() error { + if counter == 4 { + return nil + } + counter += 1 + return fmt.Errorf("sample error") + }) + if err == nil { + t.Fail() + t.Errorf("error should be nil") + } + if time.Now().Sub(currentTime) > 5*time.Second { + t.Fail() + t.Errorf("should quit after expected: %v, actual:%v", "3 sec", time.Now().Sub(currentTime)) + } + os.Remove("sample.log") +}