diff --git a/cmd/keymaster/main.go b/cmd/keymaster/main.go index 2640762..e7c0ea1 100644 --- a/cmd/keymaster/main.go +++ b/cmd/keymaster/main.go @@ -7,6 +7,7 @@ import ( "errors" "flag" "fmt" + "io" "io/ioutil" "net" "net/http" @@ -475,7 +476,7 @@ func getHttpClient(rootCAs *x509.CertPool, logger log.DebugLogger) (*http.Client } if *roundRobinDialer { if rrDialer, err := rrdialer.New(rawDialer, "", logger); err != nil { - logger.Fatalln(err) + return nil, err } else { defer rrDialer.WaitForBackgroundResults(time.Second) dialer = rrDialer @@ -493,33 +494,31 @@ func Usage() { flag.PrintDefaults() } -func main() { - flag.Usage = Usage - flag.Parse() - logger := cmdlogger.New() +// We assume here flags are parsed +func mainWithError(stdout io.Writer, logger log.DebugLogger) error { if *printVersion { - fmt.Println(Version) - return + fmt.Fprintln(stdout, Version) + return nil } rootCAs, err := maybeGetRootCas(*rootCAFilename, logger) if err != nil { - logger.Fatal(err) + return err } client, err := getHttpClient(rootCAs, logger) if err != nil { - logger.Fatal(err) + return err } if *checkDevices { err = u2f.CheckU2FDevices(logger) if err != nil { - logger.Fatal(err) + return err } - return + return nil } computeUserAgent() userName, homeDir, err := util.GetUserNameAndHomeDir() if err != nil { - logger.Fatal(err) + return err } config := loadConfigFile(client, logger) logger.Debugf(3, "loaded Config=%+v", config) @@ -543,7 +542,19 @@ func main() { err = setupCerts(userName, homeDir, config, client, logger) } if err != nil { - logger.Fatal(err) + return err } logger.Printf("Success") + return nil +} + +func main() { + flag.Usage = Usage + flag.Parse() + logger := cmdlogger.New() + err := mainWithError(os.Stdout, logger) + if err != nil { + logger.Fatal(err) + } + } diff --git a/cmd/keymaster/main_test.go b/cmd/keymaster/main_test.go index 30d4437..81b1b90 100644 --- a/cmd/keymaster/main_test.go +++ b/cmd/keymaster/main_test.go @@ -1,6 +1,7 @@ package main import ( + "bytes" "crypto/ed25519" "crypto/rand" "crypto/tls" @@ -20,6 +21,7 @@ import ( "github.com/Cloud-Foundations/golib/pkg/log/testlogger" "github.com/Cloud-Foundations/keymaster/lib/client/config" + "github.com/Cloud-Foundations/keymaster/lib/client/twofa/u2f" "github.com/Cloud-Foundations/keymaster/lib/client/util" "github.com/Cloud-Foundations/keymaster/lib/webapi/v0/proto" ) @@ -275,3 +277,33 @@ func TestInsertSSHCertIntoAgentORWriteToFilesystem(t *testing.T) { // TODO: on linux/macos create agent + unix socket and pass that } + +func TestMainSimple(t *testing.T) { + logger := testlogger.New(t) + var b bytes.Buffer + + // version + *printVersion = true + err := mainWithError(&b, logger) + if err != nil { + t.Fatal(err) + } + t.Logf("versionout='%s'", b.String()) + // TODO: compara out to version string + *printVersion = false + b.Reset() + + // checkDevices + *checkDevices = true + // As of May 2024, no devices returns an error on checkForDevices + // Because this will run inside or outside testing infra, we can + // only check if the error is consistent if any + checkDevRvalue := u2f.CheckU2FDevices(logger) + err = mainWithError(&b, logger) + if err != nil && (err.Error() != checkDevRvalue.Error()) { + t.Fatalf("manual an executed error mismatch mainerr=%s; chdevDerr=%s", err, checkDevRvalue) + } + *checkDevices = false + b.Reset() + +}