Skip to content

Commit

Permalink
enhance testing
Browse files Browse the repository at this point in the history
  • Loading branch information
cviecco committed Oct 21, 2024
1 parent e480b72 commit e8dc19a
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 17 deletions.
39 changes: 26 additions & 13 deletions cmd/keymasterd/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,10 @@ type RuntimeState struct {
totpLocalRateLimit map[string]totpRateLimitInfo
totpLocalTateLimitMutex sync.Mutex
sshCertAuthenticator *sshcertauth.Authenticator
serviceMux *http.ServeMux
serviceAccessLogger *serverlogger.Logger
adminAccessLogger *serverlogger.Logger
adminDashboard *adminDashboardType
logger log.DebugLogger
}

Expand Down Expand Up @@ -1827,32 +1831,34 @@ func main() {
}
logger.Debugf(3, "After load verify")
startServerAfterLoad(runtimeState, realLogger)
logger.Debugf(3, "After server initbase")
startListenersAndWaitForUnsealing(runtimeState)
}

func startServerAfterLoad(runtimeState *RuntimeState, realLogger *serverlogger.Logger) {
var err error

publicLogs := runtimeState.Config.Base.PublicLogs
adminDashboard := newAdminDashboard(realLogger, publicLogs)
runtimeState.adminDashboard = newAdminDashboard(realLogger, publicLogs)

logBufOptions := logbuf.GetStandardOptions()
accessLogDirectory := filepath.Join(logBufOptions.Directory, "access")
logger.Debugf(1, "accesslogdir=%s\n", accessLogDirectory)
serviceAccessLogger := serverlogger.NewWithOptions("access",
runtimeState.serviceAccessLogger = serverlogger.NewWithOptions("access",
logbuf.Options{MaxFileSize: 10 << 20,
Quota: 100 << 20, MaxBufferLines: 100,
Directory: accessLogDirectory},
stdlog.LstdFlags)

adminAccesLogDirectory := filepath.Join(logBufOptions.Directory, "access-admin")
adminAccessLogger := serverlogger.NewWithOptions("access-admin",
runtimeState.adminAccessLogger = serverlogger.NewWithOptions("access-admin",
logbuf.Options{MaxFileSize: 10 << 20,
Quota: 100 << 20, MaxBufferLines: 100,
Directory: adminAccesLogDirectory},
stdlog.LstdFlags)

// Expose the registered metrics via HTTP.
http.Handle("/", adminDashboard)
http.Handle("/", runtimeState.adminDashboard)
http.Handle("/prometheus_metrics", promhttp.Handler()) //lint:ignore SA1019 TODO: newer prometheus handler
http.HandleFunc(secretInjectorPath, runtimeState.secretInjectorHandler)
http.HandleFunc(readyzPath, runtimeState.readyzHandler)
Expand Down Expand Up @@ -1951,12 +1957,19 @@ func startServerAfterLoad(runtimeState *RuntimeState, realLogger *serverlogger.L
runtimeState.VerifyAuthTokenHandler)
}
// TODO: only enable these handlers if sshcertauth is enabled
serviceMux.HandleFunc(sshcertauth.DefaultCreateChallengePath,
runtimeState.sshCertAuthCreateChallengeHandler)
serviceMux.HandleFunc(sshcertauth.DefaultLoginWithChallengePath,
runtimeState.sshCertAuthLoginWithChallengeHandler)

if runtimeState.isSelfSSHCertAuthenticatorEnabled() {
serviceMux.HandleFunc(sshcertauth.DefaultCreateChallengePath,
runtimeState.sshCertAuthCreateChallengeHandler)
serviceMux.HandleFunc(sshcertauth.DefaultLoginWithChallengePath,
runtimeState.sshCertAuthLoginWithChallengeHandler)
}
serviceMux.HandleFunc("/", runtimeState.defaultPathHandler)
runtimeState.serviceMux = serviceMux
}

func startListenersAndWaitForUnsealing(runtimeState *RuntimeState) {
var err error
publicLogs := runtimeState.Config.Base.PublicLogs

cfg := &tls.Config{
ClientCAs: runtimeState.ClientCAPool,
Expand All @@ -1976,8 +1989,8 @@ func startServerAfterLoad(runtimeState *RuntimeState, realLogger *serverlogger.L
}
logFilterHandler := NewLogFilterHandler(http.DefaultServeMux, publicLogs,
runtimeState)
serviceHTTPLogger := httpLogger{AccessLogger: serviceAccessLogger}
adminHTTPLogger := httpLogger{AccessLogger: adminAccessLogger}
serviceHTTPLogger := httpLogger{AccessLogger: runtimeState.serviceAccessLogger}
adminHTTPLogger := httpLogger{AccessLogger: runtimeState.adminAccessLogger}
adminSrv := &http.Server{
Addr: runtimeState.Config.Base.AdminAddress,
TLSConfig: cfg,
Expand Down Expand Up @@ -2050,7 +2063,7 @@ func startServerAfterLoad(runtimeState *RuntimeState, realLogger *serverlogger.L
}
serviceSrv := &http.Server{
Addr: runtimeState.Config.Base.HttpAddress,
Handler: instrumentedwriter.NewLoggingHandler(serviceMux, serviceHTTPLogger),
Handler: instrumentedwriter.NewLoggingHandler(runtimeState.serviceMux, serviceHTTPLogger),
TLSConfig: serviceTLSConfig,
ReadTimeout: 5 * time.Second,
WriteTimeout: 10 * time.Second,
Expand All @@ -2061,7 +2074,7 @@ func startServerAfterLoad(runtimeState *RuntimeState, realLogger *serverlogger.L
go func() {
time.Sleep(time.Millisecond * 10)
healthserver.SetReady()
adminDashboard.setReady()
runtimeState.adminDashboard.setReady()
}()
err = serviceSrv.ListenAndServeTLS("", "")
if err != nil {
Expand Down
13 changes: 9 additions & 4 deletions cmd/keymasterd/auth_oauth2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,16 @@ func init() {
//logger = stdlog.New(os.Stderr, "", stdlog.LstdFlags)
slogger := stdlog.New(os.Stderr, "", stdlog.LstdFlags)
logger = debuglogger.New(slogger)
http.HandleFunc("/userinfo", userinfoHandler)
http.HandleFunc("/token", tokenHandler)
http.HandleFunc("/", handler)
testMux := http.NewServeMux()
testMux.HandleFunc("/userinfo", userinfoHandler)
testMux.HandleFunc("/token", tokenHandler)
testMux.HandleFunc("/", handler)
testServer := http.Server{
Handler: testMux,
Addr: "127.0.0.1:12345",
}
logger.Printf("about to start server")
go http.ListenAndServe("127.0.0.1:12345", nil)
go testServer.ListenAndServe()
time.Sleep(20 * time.Millisecond)
_, err := http.Get("http://localhost:12345")
if err != nil {
Expand Down
9 changes: 9 additions & 0 deletions cmd/keymasterd/auth_sshcert.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,15 @@ func (state *RuntimeState) initialzeSelfSSHCertAuthenticator() error {
return state.sshCertAuthenticator.UnsafeUpdateCaKeys(sshTrustedKeys)
}

func (state *RuntimeState) isSelfSSHCertAuthenticatorEnabled() bool {
for _, certPref := range state.Config.Base.AllowedAuthBackendsForCerts {
if certPref == proto.AuthTypeSSHCert {
return true
}
}
return false
}

// CreateChallengeHandler is an example of how to write a handler for
// the path to create the challenge
func (s *RuntimeState) sshCertAuthCreateChallengeHandler(w http.ResponseWriter, r *http.Request) {
Expand Down
25 changes: 25 additions & 0 deletions cmd/keymasterd/auth_sshcert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"os"
"testing"

"github.com/Cloud-Foundations/Dominator/lib/log/serverlogger"
"github.com/Cloud-Foundations/keymaster/lib/webapi/v0/proto"
"github.com/cviecco/webauth-sshcert/lib/server/sshcertauth"
)

Expand All @@ -22,6 +24,17 @@ func TestInitializeSSHAuthenticator(t *testing.T) {
}
}

func TestIsSelfSSHCertAuthenticatorEnabled(t *testing.T) {
state := RuntimeState{}
if state.isSelfSSHCertAuthenticatorEnabled() {
t.Fatal("it should not be enabled on empty state")
}
state.Config.Base.AllowedAuthBackendsForCerts = append(state.Config.Base.AllowedAuthBackendsForCerts, proto.AuthTypeSSHCert)
if !state.isSelfSSHCertAuthenticatorEnabled() {
t.Fatal("it should be enabled on empty state")
}
}

func TestSshCertAuthCreateChallengeHandlert(t *testing.T) {
state, passwdFile, err := setupValidRuntimeStateSigner(t)
if err != nil {
Expand Down Expand Up @@ -53,5 +66,17 @@ func TestSshCertAuthCreateChallengeHandlert(t *testing.T) {
if err != nil {
t.Fatal(err)
}
}

func TestSshCertAuthLoginWithChallengeHandler(t *testing.T) {
state, passwdFile, err := setupValidRuntimeStateSigner(t)
if err != nil {
t.Fatal(err)
}
defer os.Remove(passwdFile.Name()) // clean up
state.Config.Base.AllowedAuthBackendsForCerts = append(state.Config.Base.AllowedAuthBackendsForCerts, proto.AuthTypeSSHCert)
realLogger := serverlogger.New("") //TODO, we need to find a simulator for this
startServerAfterLoad(state, realLogger)

//TODO: write the actual test, at this point we only have the endpoints initalized
}

0 comments on commit e8dc19a

Please sign in to comment.