diff --git a/src/cli/cli.go b/src/cli/cli.go index 8f7fca7e..1bf5b58e 100644 --- a/src/cli/cli.go +++ b/src/cli/cli.go @@ -37,6 +37,9 @@ type InputHandler interface { type OutputHandler interface { WriteResults(results <-chan string, wg *sync.WaitGroup) error } +type StatusHandler interface { + LogPeriodicUpdates(statusChan <-chan zdns.Status, wg *sync.WaitGroup) error +} // GeneralOptions core options for all ZDNS modules // Order here is the order they'll be printed to the user, so preserve alphabetical order @@ -96,9 +99,11 @@ type InputOutputOptions struct { MetadataFilePath string `long:"metadata-file" description:"where should JSON metadata be saved, defaults to no metadata output. Use '-' for stderr."` MetadataFormat bool `long:"metadata-passthrough" description:"if input records have the form 'name,METADATA', METADATA will be propagated to the output"` OutputFilePath string `short:"o" long:"output-file" default:"-" description:"where should JSON output be saved, defaults to stdout"` + QuietStatusUpdates bool `short:"q" long:"quiet" description:"do not print status updates"` NameOverride string `long:"override-name" description:"name overrides all passed in names. Commonly used with --name-server-mode."` NamePrefix string `long:"prefix" description:"name to be prepended to what's passed in (e.g., www.)"` ResultVerbosity string `long:"result-verbosity" default:"normal" description:"Sets verbosity of each output record. Options: short, normal, long, trace"` + StatusUpdatesFilePath string `short:"u" long:"status-updates-file" default:"-" description:"file to write scan progress to, defaults to stderr"` Verbosity int `long:"verbosity" default:"3" description:"log verbosity: 1 (lowest)--5 (highest)"` } @@ -116,6 +121,7 @@ type CLIConf struct { ClientSubnet *dns.EDNS0_SUBNET InputHandler InputHandler OutputHandler OutputHandler + StatusHandler StatusHandler CLIModule string // the module name as passed in by the user ActiveModuleNames []string // names of modules that are active in this invocation of zdns. Mostly used with MULTIPLE ActiveModules map[string]LookupModule // map of module names to modules diff --git a/src/cli/iohandlers/status_handler.go b/src/cli/iohandlers/status_handler.go new file mode 100644 index 00000000..81719ffe --- /dev/null +++ b/src/cli/iohandlers/status_handler.go @@ -0,0 +1,154 @@ +/* + * ZDNS Copyright 2024 Regents of the University of Michigan + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy + * of the License at http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package iohandlers + +import ( + "fmt" + "os" + "sort" + "strings" + "sync" + "time" + + "github.com/pkg/errors" + log "github.com/sirupsen/logrus" + + "github.com/zmap/zdns/src/internal/util" + "github.com/zmap/zdns/src/zdns" +) + +type StatusHandler struct { + filePath string +} + +type scanStats struct { + scanStartTime time.Time + domainsScanned int + domainsSuccess int // number of domains that returned either NXDOMAIN or NOERROR + statusOccurance map[zdns.Status]int +} + +func NewStatusHandler(filePath string) *StatusHandler { + return &StatusHandler{ + filePath: filePath, + } +} + +// LogPeriodicUpdates prints a per-second update to the user scan progress and per-status statistics +func (h *StatusHandler) LogPeriodicUpdates(statusChan <-chan zdns.Status, wg *sync.WaitGroup) error { + defer wg.Done() + // open file for writing + var f *os.File + if h.filePath == "" || h.filePath == "-" { + f = os.Stderr + } else { + // open file for writing + var err error + f, err = os.OpenFile(h.filePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, util.DefaultFilePermissions) + if err != nil { + return errors.Wrap(err, "unable to open status file") + } + defer func(f *os.File) { + if err := f.Close(); err != nil { + log.Errorf("unable to close status file: %v", err) + } + }(f) + } + if err := h.statusLoop(statusChan, f); err != nil { + return errors.Wrap(err, "error encountered in status loop") + } + return nil +} + +// statusLoop will print a per-second summary of the scan progress and per-status statistics +func (h *StatusHandler) statusLoop(statusChan <-chan zdns.Status, statusFile *os.File) error { + // initialize stats + stats := scanStats{ + statusOccurance: make(map[zdns.Status]int), + scanStartTime: time.Now(), + } + ticker := time.NewTicker(time.Second) +statusLoop: + for { + select { + case <-ticker.C: + // print per-second summary + timeSinceStart := time.Since(stats.scanStartTime) + s := fmt.Sprintf("%02dh:%02dm:%02ds; %d names scanned; %.02f names/sec; %.01f%% success rate; %s\n", + int(timeSinceStart.Hours()), + int(timeSinceStart.Minutes())%60, + int(timeSinceStart.Seconds())%60, + stats.domainsScanned, + float64(stats.domainsScanned)/timeSinceStart.Seconds(), + float64(stats.domainsSuccess*100)/float64(stats.domainsScanned), + getStatusOccuranceString(stats.statusOccurance)) + if _, err := statusFile.WriteString(s); err != nil { + return errors.Wrap(err, "unable to write periodic status update") + } + case status, ok := <-statusChan: + if !ok { + // status chan closed, exiting + break statusLoop + } + stats.domainsScanned += 1 + if status == zdns.StatusNoError || status == zdns.StatusNXDomain { + stats.domainsSuccess += 1 + } + if _, ok = stats.statusOccurance[status]; !ok { + // initialize status if not seen before + stats.statusOccurance[status] = 0 + } + stats.statusOccurance[status] += 1 + } + } + timeSinceStart := time.Since(stats.scanStartTime) + s := fmt.Sprintf("%02dh:%02dm:%02ds; Scan Complete; %d names scanned; %.02f names/sec; %.01f%% success rate; %s\n", + int(timeSinceStart.Hours()), + int(timeSinceStart.Minutes())%60, + int(timeSinceStart.Seconds())%60, + stats.domainsScanned, + float64(stats.domainsScanned)/time.Since(stats.scanStartTime).Seconds(), + float64(stats.domainsSuccess*100)/float64(stats.domainsScanned), + getStatusOccuranceString(stats.statusOccurance)) + if _, err := statusFile.WriteString(s); err != nil { + return errors.Wrap(err, "unable to write final status update") + } + return nil +} + +func getStatusOccuranceString(statusOccurances map[zdns.Status]int) string { + type statusAndOccurance struct { + status zdns.Status + occurance int + } + statusesAndOccurances := make([]statusAndOccurance, 0, len(statusOccurances)) + for status, occurance := range statusOccurances { + statusesAndOccurances = append(statusesAndOccurances, statusAndOccurance{ + status: status, + occurance: occurance, + }) + } + // sort by occurance + sort.Slice(statusesAndOccurances, func(i, j int) bool { + return statusesAndOccurances[i].occurance > statusesAndOccurances[j].occurance + }) + returnStr := "" + for _, statusOccurance := range statusesAndOccurances { + returnStr += fmt.Sprintf("%s: %d, ", statusOccurance.status, statusOccurance.occurance) + } + // remove trailing comma + returnStr = strings.TrimSuffix(returnStr, ", ") + return returnStr +} diff --git a/src/cli/worker_manager.go b/src/cli/worker_manager.go index cd20cac1..5bb0c54c 100644 --- a/src/cli/worker_manager.go +++ b/src/cli/worker_manager.go @@ -162,6 +162,9 @@ func populateCLIConfig(gc *CLIConf) *CLIConf { if gc.OutputHandler == nil { gc.OutputHandler = iohandlers.NewFileOutputHandler(gc.OutputFilePath) } + if gc.StatusHandler == nil { + gc.StatusHandler = iohandlers.NewStatusHandler(gc.StatusUpdatesFilePath) + } return gc } @@ -490,6 +493,7 @@ func Run(gc CLIConf) { inChan := make(chan string) outChan := make(chan string) metaChan := make(chan routineMetadata, gc.Threads) + statusChan := make(chan zdns.Status) var routineWG sync.WaitGroup inHandler := gc.InputHandler @@ -502,20 +506,33 @@ func Run(gc CLIConf) { log.Fatal("Output handler is nil") } + statusHandler := gc.StatusHandler + if statusHandler == nil { + log.Fatal("Status handler is nil") + } + // Use handlers to populate the input and output/results channel go func() { - inErr := inHandler.FeedChannel(inChan, &routineWG) - if inErr != nil { + if inErr := inHandler.FeedChannel(inChan, &routineWG); inErr != nil { log.Fatal(fmt.Sprintf("could not feed input channel: %v", inErr)) } }() + go func() { - outErr := outHandler.WriteResults(outChan, &routineWG) - if outErr != nil { + if outErr := outHandler.WriteResults(outChan, &routineWG); outErr != nil { log.Fatal(fmt.Sprintf("could not write output results from output channel: %v", outErr)) } }() - routineWG.Add(2) + routineWG.Add(2) // input and output handlers + + if !gc.QuietStatusUpdates { + go func() { + if statusErr := statusHandler.LogPeriodicUpdates(statusChan, &routineWG); statusErr != nil { + log.Fatal(fmt.Sprintf("could not log periodic status updates: %v", statusErr)) + } + }() + routineWG.Add(1) // status handler + } // create pool of worker goroutines var lookupWG sync.WaitGroup @@ -525,7 +542,7 @@ func Run(gc CLIConf) { for i := 0; i < gc.Threads; i++ { i := i go func(threadID int) { - initWorkerErr := doLookupWorker(&gc, resolverConfig, inChan, outChan, metaChan, &lookupWG) + initWorkerErr := doLookupWorker(&gc, resolverConfig, inChan, outChan, metaChan, statusChan, &lookupWG) if initWorkerErr != nil { log.Fatalf("could not start lookup worker #%d: %v", i, initWorkerErr) } @@ -534,6 +551,7 @@ func Run(gc CLIConf) { lookupWG.Wait() close(outChan) close(metaChan) + close(statusChan) routineWG.Wait() if gc.MetadataFilePath != "" { // we're done processing data. aggregate all the data from individual routines @@ -580,7 +598,7 @@ func Run(gc CLIConf) { } // doLookupWorker is a single worker thread that processes lookups from the input channel. It calls wg.Done when it is finished. -func doLookupWorker(gc *CLIConf, rc *zdns.ResolverConfig, inputChan <-chan string, output chan<- string, metaChan chan<- routineMetadata, wg *sync.WaitGroup) error { +func doLookupWorker(gc *CLIConf, rc *zdns.ResolverConfig, inputChan <-chan string, outputChan chan<- string, metaChan chan<- routineMetadata, statusChan chan<- zdns.Status, wg *sync.WaitGroup) error { defer wg.Done() resolver, err := zdns.InitResolver(rc) if err != nil { @@ -590,7 +608,7 @@ func doLookupWorker(gc *CLIConf, rc *zdns.ResolverConfig, inputChan <-chan strin metadata.Status = make(map[zdns.Status]int) for line := range inputChan { - handleWorkerInput(gc, rc, line, resolver, &metadata, output) + handleWorkerInput(gc, rc, line, resolver, &metadata, outputChan, statusChan) } // close the resolver, freeing up resources resolver.Close() @@ -598,7 +616,7 @@ func doLookupWorker(gc *CLIConf, rc *zdns.ResolverConfig, inputChan <-chan strin return nil } -func handleWorkerInput(gc *CLIConf, rc *zdns.ResolverConfig, line string, resolver *zdns.Resolver, metadata *routineMetadata, output chan<- string) { +func handleWorkerInput(gc *CLIConf, rc *zdns.ResolverConfig, line string, resolver *zdns.Resolver, metadata *routineMetadata, outputChan chan<- string, statusChan chan<- zdns.Status) { // we'll process each module sequentially, parallelism is per-domain res := zdns.Result{Results: make(map[string]zdns.SingleModuleResult, len(gc.ActiveModules))} // get the fields that won't change for each lookup module @@ -669,6 +687,9 @@ func handleWorkerInput(gc *CLIConf, rc *zdns.ResolverConfig, line string, resolv lookupRes.Error = err.Error() } res.Results[moduleName] = lookupRes + if !gc.QuietStatusUpdates { + statusChan <- status + } } metadata.Status[status]++ metadata.Lookups++ @@ -689,7 +710,7 @@ func handleWorkerInput(gc *CLIConf, rc *zdns.ResolverConfig, line string, resolv if err != nil { log.Fatalf("unable to marshal JSON result: %v", err) } - output <- string(jsonRes) + outputChan <- string(jsonRes) } metadata.Names++ } diff --git a/testing/integration_tests.py b/testing/integration_tests.py index 222d79ee..07ef85a2 100755 --- a/testing/integration_tests.py +++ b/testing/integration_tests.py @@ -40,7 +40,7 @@ def dictSort(d): class Tests(unittest.TestCase): maxDiff = None ZDNS_EXECUTABLE = "./zdns" - ADDITIONAL_FLAGS = " --threads=10" # flags used with every test + ADDITIONAL_FLAGS = " --threads=10 --quiet" # flags used with every test def run_zdns_check_failure(self, flags, name, expected_err, executable=ZDNS_EXECUTABLE): flags = flags + self.ADDITIONAL_FLAGS