From d644497c8d16a3aaed6d2d7b14d69fd6b9005230 Mon Sep 17 00:00:00 2001 From: Dusan Klinec Date: Tue, 28 May 2024 10:00:33 +0200 Subject: [PATCH] [TECH] sync with upstream (#17) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add flavour, version command, fix version source (#229) - make makefile single source of truth for version - trigger the flow in the tests * minor tests enhancements (#232) --------- Co-authored-by: DuĊĦan Klinec Co-authored-by: cviecco --- Makefile | 2 +- cmd/keymaster/main.go | 38 +++++++++++++++++++++++++------------- cmd/keymaster/main_test.go | 31 +++++++++++++++++++++++++++++++ 3 files changed, 57 insertions(+), 14 deletions(-) diff --git a/Makefile b/Makefile index 9155ec9..31e8de4 100644 --- a/Makefile +++ b/Makefile @@ -10,7 +10,7 @@ endif BINARY=keymaster # These are the values we want to pass for Version and BuildTime -VERSION?=1.15.12 +VERSION?=1.15.13 DEFAULT_HOST?= VERSION_FLAVOUR?= EXTRA_LDFLAGS?= diff --git a/cmd/keymaster/main.go b/cmd/keymaster/main.go index 1874130..a669e9d 100644 --- a/cmd/keymaster/main.go +++ b/cmd/keymaster/main.go @@ -7,6 +7,7 @@ import ( "errors" "flag" "fmt" + "io" "io/ioutil" "net" "net/http" @@ -475,7 +476,7 @@ func getHttpClient(rootCAs *x509.CertPool, logger log.DebugLogger) (*http.Client } if *roundRobinDialer { if rrDialer, err := rrdialer.New(rawDialer, "", logger); err != nil { - logger.Fatalln(err) + return nil, err } else { defer rrDialer.WaitForBackgroundResults(time.Second) dialer = rrDialer @@ -488,38 +489,37 @@ func getHttpClient(rootCAs *x509.CertPool, logger log.DebugLogger) (*http.Client } func Usage() { + computeUserAgent() fmt.Fprintf(os.Stderr, "Usage: %s [flags...] [aws-role-cert]\n", os.Args[0]) fmt.Fprintf(os.Stderr, "Version: %s\n", userAgentString) flag.PrintDefaults() } -func main() { +// We assume here flags are parsed +func mainWithError(stdout io.Writer, logger log.DebugLogger) error { computeUserAgent() - flag.Usage = Usage - flag.Parse() - logger := cmdlogger.New() if *printVersion { - fmt.Println(Version) - return + fmt.Fprintln(stdout, Version) + return nil } rootCAs, err := maybeGetRootCas(*rootCAFilename, logger) if err != nil { - logger.Fatal(err) + return err } client, err := getHttpClient(rootCAs, logger) if err != nil { - logger.Fatal(err) + return err } if *checkDevices { err = u2f.CheckU2FDevices(logger) if err != nil { - logger.Fatal(err) + return err } - return + return nil } userName, homeDir, err := util.GetUserNameAndHomeDir() if err != nil { - logger.Fatal(err) + return err } config := loadConfigFile(client, logger) logger.Debugf(3, "loaded Config=%+v", config) @@ -543,7 +543,19 @@ func main() { err = setupCerts(userName, homeDir, config, client, logger) } if err != nil { - logger.Fatal(err) + return err } logger.Printf("Success") + return nil +} + +func main() { + flag.Usage = Usage + flag.Parse() + logger := cmdlogger.New() + err := mainWithError(os.Stdout, logger) + if err != nil { + logger.Fatal(err) + } + } diff --git a/cmd/keymaster/main_test.go b/cmd/keymaster/main_test.go index 345dbd5..326bbbc 100644 --- a/cmd/keymaster/main_test.go +++ b/cmd/keymaster/main_test.go @@ -1,6 +1,7 @@ package main import ( + "bytes" "crypto/ed25519" "crypto/rand" "crypto/tls" @@ -20,6 +21,7 @@ import ( "github.com/Cloud-Foundations/golib/pkg/log/testlogger" "github.com/Cloud-Foundations/keymaster/lib/client/config" + "github.com/Cloud-Foundations/keymaster/lib/client/twofa/u2f" "github.com/Cloud-Foundations/keymaster/lib/client/util" "github.com/Cloud-Foundations/keymaster/lib/webapi/v0/proto" ) @@ -285,3 +287,32 @@ func TestMainPrintVersion(t *testing.T) { }() <-done } + +func TestMainSimple(t *testing.T) { + logger := testlogger.New(t) + var b bytes.Buffer + + // version + *printVersion = true + err := mainWithError(&b, logger) + if err != nil { + t.Fatal(err) + } + t.Logf("versionout='%s'", b.String()) + // TODO: compara out to version string + *printVersion = false + b.Reset() + + // checkDevices + *checkDevices = true + // As of May 2024, no devices returns an error on checkForDevices + // Because this will run inside or outside testing infra, we can + // only check if the error is consistent if any + checkDevRvalue := u2f.CheckU2FDevices(logger) + err = mainWithError(&b, logger) + if err != nil && (err.Error() != checkDevRvalue.Error()) { + t.Fatalf("manual an executed error mismatch mainerr=%s; chdevDerr=%s", err, checkDevRvalue) + } + *checkDevices = false + b.Reset() +}