Skip to content

Commit

Permalink
add test for enrollment hash lookup
Browse files Browse the repository at this point in the history
  • Loading branch information
jessepeterson committed Aug 28, 2023
1 parent 609c08a commit 367c587
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 0 deletions.
2 changes: 2 additions & 0 deletions http/mdm/mdm_cert.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,8 @@ func CertWithEnrollmentIDMiddleware(next http.Handler, hasher HashFn, store stor
ctxlog.Logger(r.Context(), logger).Info(
"err", "missing certificate",
)
// we cannot send a 401 to the client as it has MDM protocol semantics
// i.e. the device may unenroll
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusBadRequest)
return
} else {
Expand Down
76 changes: 76 additions & 0 deletions http/mdm/mdm_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package mdm

import (
"bytes"
"context"
"crypto/x509"
"errors"
"net/http"
"net/http/httptest"
"testing"

"github.com/micromdm/nanomdm/log"
)

const (
testHash = "ZZZYYYXXX"
testID = "AAABBBCCC"
)

func testHashCert(_ *x509.Certificate) string {
return testHash
}

type testCertAuthRetriever struct{}

func (c *testCertAuthRetriever) EnrollmentFromHash(ctx context.Context, hash string) (string, error) {
if hash != testHash {
return "", errors.New("invalid test hash")
}
return testID, nil
}

func TestCertWithEnrollmentIDMiddleware(t *testing.T) {
response := []byte("mock response")

// mock handler
var handler http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write(response)
})

handler = CertWithEnrollmentIDMiddleware(handler, testHashCert, &testCertAuthRetriever{}, true, log.NopLogger)

req, err := http.NewRequest("GET", "/test", nil)
if err != nil {
t.Fatal(err)
}

rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)

// we requested enforcement, so make sure we get a BadResponse
if have, want := rr.Code, http.StatusBadRequest; have != want {
t.Errorf("have: %d, want: %d", have, want)
}

req, err = http.NewRequest("GET", "/test", nil)
if err != nil {
t.Fatal(err)
}

req = req.WithContext(context.WithValue(req.Context(), contextKeyCert{}, &x509.Certificate{}))

rr = httptest.NewRecorder()
handler.ServeHTTP(rr, req)

// now that we have a "cert" included, we should get an OK
if have, want := rr.Code, http.StatusOK; have != want {
t.Errorf("have: %d, want: %d", have, want)
}

// verify the actual body, too
if !bytes.Equal(rr.Body.Bytes(), response) {
t.Error("body not equal")
}
}

0 comments on commit 367c587

Please sign in to comment.