diff --git a/cmd/keymasterd/app.go b/cmd/keymasterd/app.go index 804767a..b86284f 100644 --- a/cmd/keymasterd/app.go +++ b/cmd/keymasterd/app.go @@ -223,6 +223,10 @@ type RuntimeState struct { totpLocalRateLimit map[string]totpRateLimitInfo totpLocalTateLimitMutex sync.Mutex sshCertAuthenticator *sshcertauth.Authenticator + serviceMux *http.ServeMux + serviceAccessLogger *serverlogger.Logger + adminAccessLogger *serverlogger.Logger + adminDashboard *adminDashboardType logger log.DebugLogger } @@ -1827,32 +1831,34 @@ func main() { } logger.Debugf(3, "After load verify") startServerAfterLoad(runtimeState, realLogger) + logger.Debugf(3, "After server initbase") + startListenersAndWaitForUnsealing(runtimeState) } func startServerAfterLoad(runtimeState *RuntimeState, realLogger *serverlogger.Logger) { var err error publicLogs := runtimeState.Config.Base.PublicLogs - adminDashboard := newAdminDashboard(realLogger, publicLogs) + runtimeState.adminDashboard = newAdminDashboard(realLogger, publicLogs) logBufOptions := logbuf.GetStandardOptions() accessLogDirectory := filepath.Join(logBufOptions.Directory, "access") logger.Debugf(1, "accesslogdir=%s\n", accessLogDirectory) - serviceAccessLogger := serverlogger.NewWithOptions("access", + runtimeState.serviceAccessLogger = serverlogger.NewWithOptions("access", logbuf.Options{MaxFileSize: 10 << 20, Quota: 100 << 20, MaxBufferLines: 100, Directory: accessLogDirectory}, stdlog.LstdFlags) adminAccesLogDirectory := filepath.Join(logBufOptions.Directory, "access-admin") - adminAccessLogger := serverlogger.NewWithOptions("access-admin", + runtimeState.adminAccessLogger = serverlogger.NewWithOptions("access-admin", logbuf.Options{MaxFileSize: 10 << 20, Quota: 100 << 20, MaxBufferLines: 100, Directory: adminAccesLogDirectory}, stdlog.LstdFlags) // Expose the registered metrics via HTTP. - http.Handle("/", adminDashboard) + http.Handle("/", runtimeState.adminDashboard) http.Handle("/prometheus_metrics", promhttp.Handler()) //lint:ignore SA1019 TODO: newer prometheus handler http.HandleFunc(secretInjectorPath, runtimeState.secretInjectorHandler) http.HandleFunc(readyzPath, runtimeState.readyzHandler) @@ -1951,12 +1957,19 @@ func startServerAfterLoad(runtimeState *RuntimeState, realLogger *serverlogger.L runtimeState.VerifyAuthTokenHandler) } // TODO: only enable these handlers if sshcertauth is enabled - serviceMux.HandleFunc(sshcertauth.DefaultCreateChallengePath, - runtimeState.sshCertAuthCreateChallengeHandler) - serviceMux.HandleFunc(sshcertauth.DefaultLoginWithChallengePath, - runtimeState.sshCertAuthLoginWithChallengeHandler) - + if runtimeState.isSelfSSHCertAuthenticatorEnabled() { + serviceMux.HandleFunc(sshcertauth.DefaultCreateChallengePath, + runtimeState.sshCertAuthCreateChallengeHandler) + serviceMux.HandleFunc(sshcertauth.DefaultLoginWithChallengePath, + runtimeState.sshCertAuthLoginWithChallengeHandler) + } serviceMux.HandleFunc("/", runtimeState.defaultPathHandler) + runtimeState.serviceMux = serviceMux +} + +func startListenersAndWaitForUnsealing(runtimeState *RuntimeState) { + var err error + publicLogs := runtimeState.Config.Base.PublicLogs cfg := &tls.Config{ ClientCAs: runtimeState.ClientCAPool, @@ -1976,8 +1989,8 @@ func startServerAfterLoad(runtimeState *RuntimeState, realLogger *serverlogger.L } logFilterHandler := NewLogFilterHandler(http.DefaultServeMux, publicLogs, runtimeState) - serviceHTTPLogger := httpLogger{AccessLogger: serviceAccessLogger} - adminHTTPLogger := httpLogger{AccessLogger: adminAccessLogger} + serviceHTTPLogger := httpLogger{AccessLogger: runtimeState.serviceAccessLogger} + adminHTTPLogger := httpLogger{AccessLogger: runtimeState.adminAccessLogger} adminSrv := &http.Server{ Addr: runtimeState.Config.Base.AdminAddress, TLSConfig: cfg, @@ -2050,7 +2063,7 @@ func startServerAfterLoad(runtimeState *RuntimeState, realLogger *serverlogger.L } serviceSrv := &http.Server{ Addr: runtimeState.Config.Base.HttpAddress, - Handler: instrumentedwriter.NewLoggingHandler(serviceMux, serviceHTTPLogger), + Handler: instrumentedwriter.NewLoggingHandler(runtimeState.serviceMux, serviceHTTPLogger), TLSConfig: serviceTLSConfig, ReadTimeout: 5 * time.Second, WriteTimeout: 10 * time.Second, @@ -2061,7 +2074,7 @@ func startServerAfterLoad(runtimeState *RuntimeState, realLogger *serverlogger.L go func() { time.Sleep(time.Millisecond * 10) healthserver.SetReady() - adminDashboard.setReady() + runtimeState.adminDashboard.setReady() }() err = serviceSrv.ListenAndServeTLS("", "") if err != nil { diff --git a/cmd/keymasterd/auth_oauth2_test.go b/cmd/keymasterd/auth_oauth2_test.go index 649d177..6bad957 100644 --- a/cmd/keymasterd/auth_oauth2_test.go +++ b/cmd/keymasterd/auth_oauth2_test.go @@ -67,11 +67,16 @@ func init() { //logger = stdlog.New(os.Stderr, "", stdlog.LstdFlags) slogger := stdlog.New(os.Stderr, "", stdlog.LstdFlags) logger = debuglogger.New(slogger) - http.HandleFunc("/userinfo", userinfoHandler) - http.HandleFunc("/token", tokenHandler) - http.HandleFunc("/", handler) + testMux := http.NewServeMux() + testMux.HandleFunc("/userinfo", userinfoHandler) + testMux.HandleFunc("/token", tokenHandler) + testMux.HandleFunc("/", handler) + testServer := http.Server{ + Handler: testMux, + Addr: "127.0.0.1:12345", + } logger.Printf("about to start server") - go http.ListenAndServe("127.0.0.1:12345", nil) + go testServer.ListenAndServe() time.Sleep(20 * time.Millisecond) _, err := http.Get("http://localhost:12345") if err != nil { diff --git a/cmd/keymasterd/auth_sshcert.go b/cmd/keymasterd/auth_sshcert.go index 532a857..03b561f 100644 --- a/cmd/keymasterd/auth_sshcert.go +++ b/cmd/keymasterd/auth_sshcert.go @@ -25,6 +25,15 @@ func (state *RuntimeState) initialzeSelfSSHCertAuthenticator() error { return state.sshCertAuthenticator.UnsafeUpdateCaKeys(sshTrustedKeys) } +func (state *RuntimeState) isSelfSSHCertAuthenticatorEnabled() bool { + for _, certPref := range state.Config.Base.AllowedAuthBackendsForCerts { + if certPref == proto.AuthTypeSSHCert { + return true + } + } + return false +} + // CreateChallengeHandler is an example of how to write a handler for // the path to create the challenge func (s *RuntimeState) sshCertAuthCreateChallengeHandler(w http.ResponseWriter, r *http.Request) { diff --git a/cmd/keymasterd/auth_sshcert_test.go b/cmd/keymasterd/auth_sshcert_test.go index ef311f9..aa332c2 100644 --- a/cmd/keymasterd/auth_sshcert_test.go +++ b/cmd/keymasterd/auth_sshcert_test.go @@ -5,6 +5,8 @@ import ( "os" "testing" + "github.com/Cloud-Foundations/Dominator/lib/log/serverlogger" + "github.com/Cloud-Foundations/keymaster/lib/webapi/v0/proto" "github.com/cviecco/webauth-sshcert/lib/server/sshcertauth" ) @@ -22,6 +24,17 @@ func TestInitializeSSHAuthenticator(t *testing.T) { } } +func TestIsSelfSSHCertAuthenticatorEnabled(t *testing.T) { + state := RuntimeState{} + if state.isSelfSSHCertAuthenticatorEnabled() { + t.Fatal("it should not be enabled on empty state") + } + state.Config.Base.AllowedAuthBackendsForCerts = append(state.Config.Base.AllowedAuthBackendsForCerts, proto.AuthTypeSSHCert) + if !state.isSelfSSHCertAuthenticatorEnabled() { + t.Fatal("it should be enabled on empty state") + } +} + func TestSshCertAuthCreateChallengeHandlert(t *testing.T) { state, passwdFile, err := setupValidRuntimeStateSigner(t) if err != nil { @@ -53,5 +66,17 @@ func TestSshCertAuthCreateChallengeHandlert(t *testing.T) { if err != nil { t.Fatal(err) } +} + +func TestSshCertAuthLoginWithChallengeHandler(t *testing.T) { + state, passwdFile, err := setupValidRuntimeStateSigner(t) + if err != nil { + t.Fatal(err) + } + defer os.Remove(passwdFile.Name()) // clean up + state.Config.Base.AllowedAuthBackendsForCerts = append(state.Config.Base.AllowedAuthBackendsForCerts, proto.AuthTypeSSHCert) + realLogger := serverlogger.New("") //TODO, we need to find a simulator for this + startServerAfterLoad(state, realLogger) + //TODO: write the actual test, at this point we only have the endpoints initalized }