Skip to content

Commit

Permalink
add fallback verifier and make verifiers context aware (#85)
Browse files Browse the repository at this point in the history
  • Loading branch information
jessepeterson authored Nov 29, 2023
1 parent 012ad61 commit e9fcd9b
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 4 deletions.
41 changes: 41 additions & 0 deletions certverify/fallback.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package certverify

import (
"context"
"crypto/x509"
"errors"
"fmt"
"strings"
)

// CertVerifier is a simple interface for verifying a certificate.
type CertVerifier interface {
Verify(context.Context, *x509.Certificate) error
}

// FallbackVerifier verfies certificate validity using multiple verifiers.
type FallbackVerifier struct {
verifiers []CertVerifier
}

// NewFallbackVerifier creates a new verifier using other verifiers.
func NewFallbackVerifier(verifiers ...CertVerifier) *FallbackVerifier {
return &FallbackVerifier{verifiers: verifiers}
}

// Verify performs certificate verification.
// Any verifier returning nil ("passes") will pass (return nil) and not
// check any other verifier.
// If all verifiers return non-nil ("fail") then an error for all
// verifiers will be returned.
func (v *FallbackVerifier) Verify(ctx context.Context, cert *x509.Certificate) error {
var errs []string
for i, verifier := range v.verifiers {
err := verifier.Verify(ctx, cert)
if err == nil {
return nil
}
errs = append(errs, fmt.Sprintf("fallback error (%d): %v", i, err))
}
return errors.New(strings.Join(errs, "; "))
}
50 changes: 50 additions & 0 deletions certverify/fallback_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package certverify

import (
"context"
"crypto/x509"
"errors"
"testing"
)

type errVerifier struct{ err error }

func (v *errVerifier) Verify(_ context.Context, _ *x509.Certificate) error {
return v.err
}

var nilErroringVerifier = &errVerifier{}
var errErroringVerifier = &errVerifier{err: errors.New("verifier error")}

func TestFallbackVerifier(t *testing.T) {
v := NewFallbackVerifier(nilErroringVerifier)
err := v.Verify(nil, nil)
if err != nil {
t.Errorf("should not have errored: %v", err)
}

v = NewFallbackVerifier(nilErroringVerifier, nilErroringVerifier)
if err = v.Verify(nil, nil); err != nil {
t.Errorf("should not have errored: %v", err)
}

v = NewFallbackVerifier(errErroringVerifier)
if err = v.Verify(nil, nil); err == nil {
t.Error("should have errored")
}

v = NewFallbackVerifier(errErroringVerifier, nilErroringVerifier)
if err = v.Verify(nil, nil); err != nil {
t.Errorf("should not have errored: %v", err)
}

v = NewFallbackVerifier(nilErroringVerifier, errErroringVerifier)
if err = v.Verify(nil, nil); err != nil {
t.Errorf("should not have errored: %v", err)
}

v = NewFallbackVerifier(errErroringVerifier, errErroringVerifier)
if err = v.Verify(nil, nil); err == nil {
t.Error("should have errored")
}
}
3 changes: 2 additions & 1 deletion certverify/pool.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package certverify

import (
"context"
"crypto/x509"
"errors"
)
Expand Down Expand Up @@ -31,7 +32,7 @@ func NewPoolVerifier(rootsPEM []byte, intsPEM []byte, keyUsages ...x509.ExtKeyUs
}

// Verify performs certificate verification
func (v *PoolVerifier) Verify(cert *x509.Certificate) error {
func (v *PoolVerifier) Verify(_ context.Context, cert *x509.Certificate) error {
if cert == nil {
return errors.New("missing MDM certificate")
}
Expand Down
3 changes: 2 additions & 1 deletion certverify/signature.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package certverify

import (
"context"
"crypto/x509"
"errors"

Expand All @@ -25,7 +26,7 @@ func NewSignatureVerifier(rootPEM []byte) (*SignatureVerifier, error) {
}

// Verify checks only the signature of the certificate against the CA
func (v *SignatureVerifier) Verify(cert *x509.Certificate) error {
func (v *SignatureVerifier) Verify(_ context.Context, cert *x509.Certificate) error {
if cert == nil {
return errors.New("missing MDM certificate")
}
Expand Down
4 changes: 2 additions & 2 deletions http/mdm/mdm_cert.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ func GetCert(ctx context.Context) *x509.Certificate {

// CertVerifier is a simple interface for verifying a certificate.
type CertVerifier interface {
Verify(*x509.Certificate) error
Verify(context.Context, *x509.Certificate) error
}

// CertVerifyMiddleware checks the MDM certificate against verifier and
Expand All @@ -120,7 +120,7 @@ type CertVerifier interface {
// MDM unenrollments in the case of bugs or something going wrong.
func CertVerifyMiddleware(next http.Handler, verifier CertVerifier, logger log.Logger) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if err := verifier.Verify(GetCert(r.Context())); err != nil {
if err := verifier.Verify(r.Context(), GetCert(r.Context())); err != nil {
ctxlog.Logger(r.Context(), logger).Info(
"msg", "error verifying MDM certificate",
"err", err,
Expand Down

0 comments on commit e9fcd9b

Please sign in to comment.