diff --git a/authority/authorize_test.go b/authority/authorize_test.go index b83cb2519..f7287e7a5 100644 --- a/authority/authorize_test.go +++ b/authority/authorize_test.go @@ -89,6 +89,39 @@ func generateToken(sub, iss, aud string, sans []string, iat time.Time, jwk *jose return jose.Signed(sig).Claims(claims).CompactSerialize() } +func generateCustomToken(sub, iss, aud string, jwk *jose.JSONWebKey, extraHeaders, extraClaims map[string]any) (string, error) { + so := new(jose.SignerOptions) + so.WithType("JWT") + so.WithHeader("kid", jwk.KeyID) + + for k, v := range extraHeaders { + so.WithHeader(jose.HeaderKey(k), v) + } + + sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, so) + if err != nil { + return "", err + } + + id, err := randutil.ASCII(64) + if err != nil { + return "", err + } + + iat := time.Now() + claims := jose.Claims{ + ID: id, + Subject: sub, + Issuer: iss, + IssuedAt: jose.NewNumericDate(iat), + NotBefore: jose.NewNumericDate(iat), + Expiry: jose.NewNumericDate(iat.Add(5 * time.Minute)), + Audience: []string{aud}, + } + + return jose.Signed(sig).Claims(claims).Claims(extraClaims).CompactSerialize() +} + func TestAuthority_authorizeToken(t *testing.T) { a := testAuthority(t) @@ -510,7 +543,7 @@ func TestAuthority_authorizeSign(t *testing.T) { } } else { if assert.Nil(t, tc.err) { - assert.Equals(t, 10, len(got)) // number of provisioner.SignOptions returned + assert.Equals(t, 11, len(got)) // number of provisioner.SignOptions returned } } }) diff --git a/authority/provisioner/gcp.go b/authority/provisioner/gcp.go index 8f6211d3a..43763e2f1 100644 --- a/authority/provisioner/gcp.go +++ b/authority/provisioner/gcp.go @@ -493,8 +493,8 @@ func (p *GCP) genHostOptions(_ context.Context, claims *gcpPayload) (SignSSHOpti return SignSSHOptions{CertType: SSHHostCert}, keyID, principals, sshutil.HostCert, sshutil.DefaultIIDTemplate } -func FormatServiceAccountUsername(serviceAccountId string) string { - return fmt.Sprintf("sa_%v", serviceAccountId) +func FormatServiceAccountUsername(serviceAccountID string) string { + return fmt.Sprintf("sa_%v", serviceAccountID) } func (p *GCP) genUserOptions(_ context.Context, claims *gcpPayload) (SignSSHOptions, string, []string, sshutil.CertType, string) { diff --git a/authority/provisioner/jwk.go b/authority/provisioner/jwk.go index 13e8bd485..ed481877a 100644 --- a/authority/provisioner/jwk.go +++ b/authority/provisioner/jwk.go @@ -19,8 +19,9 @@ import ( // jwtPayload extends jwt.Claims with step attributes. type jwtPayload struct { jose.Claims - SANs []string `json:"sans,omitempty"` - Step *stepPayload `json:"step,omitempty"` + SANs []string `json:"sans,omitempty"` + Step *stepPayload `json:"step,omitempty"` + Confirmation *cnfPayload `json:"cnf,omitempty"` } type stepPayload struct { @@ -28,6 +29,10 @@ type stepPayload struct { RA *RAInfo `json:"ra,omitempty"` } +type cnfPayload struct { + Fingerprint string `json:"x5rt#S256,omitempty"` +} + // JWK is the default provisioner, an entity that can sign tokens necessary for // signature requests. type JWK struct { @@ -183,6 +188,12 @@ func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er } } + // Check the fingerprint of the certificate request if given. + var fingerprint string + if claims.Confirmation != nil { + fingerprint = claims.Confirmation.Fingerprint + } + return []SignOption{ self, templateOptions, @@ -190,6 +201,7 @@ func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er newProvisionerExtensionOption(TypeJWK, p.Name, p.Key.KeyID).WithControllerOptions(p.ctl), profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()), // validators + csrFingerprintValidator(fingerprint), commonNameSliceValidator(append([]string{claims.Subject}, claims.SANs...)), defaultPublicKeyValidator{}, newDefaultSANsValidator(ctx, claims.SANs), diff --git a/authority/provisioner/jwk_test.go b/authority/provisioner/jwk_test.go index 794fe1eaa..68fb7f47a 100644 --- a/authority/provisioner/jwk_test.go +++ b/authority/provisioner/jwk_test.go @@ -13,7 +13,9 @@ import ( "testing" "time" + "go.step.sm/crypto/fingerprint" "go.step.sm/crypto/jose" + "golang.org/x/crypto/ssh" "github.com/smallstep/assert" "github.com/smallstep/certificates/api/render" @@ -247,6 +249,9 @@ func TestJWK_AuthorizeSign(t *testing.T) { t2, err := generateToken("subject", p1.Name, testAudiences.Sign[0], "name@smallstep.com", []string{}, time.Now(), key1) assert.FatalError(t, err) + t3, err := generateCustomToken("subject", p1.Name, testAudiences.Sign[0], key1, nil, map[string]any{"cnf": map[string]any{"x5rt#S256": "fingerprint"}}) + assert.FatalError(t, err) + // invalid signature failSig := t1[0 : len(t1)-2] @@ -254,12 +259,13 @@ func TestJWK_AuthorizeSign(t *testing.T) { token string } tests := []struct { - name string - prov *JWK - args args - code int - err error - sans []string + name string + prov *JWK + args args + code int + err error + sans []string + fingerprint string }{ { name: "fail-signature", @@ -284,6 +290,15 @@ func TestJWK_AuthorizeSign(t *testing.T) { err: nil, sans: []string{"subject"}, }, + { + name: "ok-cnf", + prov: p1, + args: args{t3}, + code: http.StatusOK, + err: nil, + sans: []string{"subject"}, + fingerprint: "fingerprint", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -297,7 +312,7 @@ func TestJWK_AuthorizeSign(t *testing.T) { } } else { if assert.NotNil(t, got) { - assert.Equals(t, 10, len(got)) + assert.Equals(t, 11, len(got)) for _, o := range got { switch v := o.(type) { case *JWK: @@ -321,6 +336,8 @@ func TestJWK_AuthorizeSign(t *testing.T) { case *x509NamePolicyValidator: assert.Equals(t, nil, v.policyEngine) case *WebhookController: + case csrFingerprintValidator: + assert.Equals(t, tt.fingerprint, string(v)) default: assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) } @@ -393,17 +410,6 @@ func TestJWK_AuthorizeSSHSign(t *testing.T) { jwk, err := decryptJSONWebKey(p1.EncryptedKey) assert.FatalError(t, err) - iss, aud := p1.Name, testAudiences.SSHSign[0] - - t1, err := generateSimpleSSHUserToken(iss, aud, jwk) - assert.FatalError(t, err) - - t2, err := generateSimpleSSHHostToken(iss, aud, jwk) - assert.FatalError(t, err) - - // invalid signature - failSig := t1[0 : len(t1)-2] - key, err := generateJSONWebKey() assert.FatalError(t, err) @@ -417,6 +423,39 @@ func TestJWK_AuthorizeSSHSign(t *testing.T) { rsa1024, err := rsa.GenerateKey(rand.Reader, 1024) assert.FatalError(t, err) + // Calculate fingerprint + sshPub, err := ssh.NewPublicKey(pub) + assert.FatalError(t, err) + fp, err := fingerprint.New(sshPub.Marshal(), crypto.SHA256, fingerprint.Base64RawURLFingerprint) + assert.FatalError(t, err) + + iss, aud := p1.Name, testAudiences.SSHSign[0] + + t1, err := generateSimpleSSHUserToken(iss, aud, jwk) + assert.FatalError(t, err) + + t2, err := generateSimpleSSHHostToken(iss, aud, jwk) + assert.FatalError(t, err) + + t3, err := generateCustomToken("sub", iss, aud, jwk, nil, map[string]any{ + "step": map[string]any{ + "ssh": map[string]any{"certType": "host", "principals": []string{"smallstep.com"}}, + }, + "cnf": map[string]any{"kid": fp}, + }) + assert.FatalError(t, err) + + t4, err := generateCustomToken("sub", iss, aud, jwk, nil, map[string]any{ + "step": map[string]any{ + "ssh": map[string]any{"certType": "host", "principals": []string{"smallstep.com"}}, + }, + "cnf": map[string]any{"kid": "bad-fingerprint"}, + }) + assert.FatalError(t, err) + + // invalid signature + failSig := t1[0 : len(t1)-2] + userDuration := p1.ctl.Claimer.DefaultUserSSHCertDuration() hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration() expectedUserOptions := &SignSSHOptions{ @@ -451,9 +490,11 @@ func TestJWK_AuthorizeSSHSign(t *testing.T) { {"host-type", p1, args{t2, SignSSHOptions{CertType: "host"}, pub}, expectedHostOptions, http.StatusOK, false, false}, {"host-principals", p1, args{t2, SignSSHOptions{Principals: []string{"smallstep.com"}}, pub}, expectedHostOptions, http.StatusOK, false, false}, {"host-options", p1, args{t2, SignSSHOptions{CertType: "host", Principals: []string{"smallstep.com"}}, pub}, expectedHostOptions, http.StatusOK, false, false}, + {"host-cnf", p1, args{t3, SignSSHOptions{CertType: "host", Principals: []string{"smallstep.com"}}, pub}, expectedHostOptions, http.StatusOK, false, false}, + {"ignore-bad-cnf", p1, args{t4, SignSSHOptions{CertType: "host", Principals: []string{"smallstep.com"}}, pub}, expectedHostOptions, http.StatusOK, false, false}, {"fail-sshCA-disabled", p2, args{"foo", SignSSHOptions{}, pub}, expectedUserOptions, http.StatusUnauthorized, true, false}, {"fail-signature", p1, args{failSig, SignSSHOptions{}, pub}, nil, http.StatusUnauthorized, true, false}, - {"rail-rsa1024", p1, args{t1, SignSSHOptions{}, rsa1024.Public()}, expectedUserOptions, http.StatusOK, false, true}, + {"fail-rsa1024", p1, args{t1, SignSSHOptions{}, rsa1024.Public()}, expectedUserOptions, http.StatusOK, false, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/authority/provisioner/sign_options.go b/authority/provisioner/sign_options.go index a243645e2..b7cf0dbca 100644 --- a/authority/provisioner/sign_options.go +++ b/authority/provisioner/sign_options.go @@ -5,7 +5,10 @@ import ( "crypto/ecdsa" "crypto/ed25519" "crypto/rsa" + "crypto/sha256" + "crypto/subtle" "crypto/x509" + "encoding/base64" "encoding/json" "net" "net/http" @@ -503,3 +506,21 @@ func (o *provisionerExtensionOption) Modify(cert *x509.Certificate, _ SignOption cert.ExtraExtensions = append(cert.ExtraExtensions, ext) return nil } + +// csrFingerprintValidator is a CertificateRequestValidator that checks the +// fingerprint of the certificate request with the provided one. +type csrFingerprintValidator string + +func (s csrFingerprintValidator) Valid(cr *x509.CertificateRequest) error { + if s != "" { + expected, err := base64.RawURLEncoding.DecodeString(string(s)) + if err != nil { + return errs.ForbiddenErr(err, "error decoding fingerprint") + } + sum := sha256.Sum256(cr.Raw) + if subtle.ConstantTimeCompare(expected, sum[:]) != 1 { + return errs.Forbidden("certificate request fingerprint does not match %q", s) + } + } + return nil +} diff --git a/authority/provisioner/sign_ssh_options.go b/authority/provisioner/sign_ssh_options.go index 404b16ae7..512a8f0e0 100644 --- a/authority/provisioner/sign_ssh_options.go +++ b/authority/provisioner/sign_ssh_options.go @@ -44,6 +44,13 @@ type SSHCertOptionsValidator interface { Valid(got SignSSHOptions) error } +// SSHPublicKeyValidator is the interface used to validate the public key of an +// SSH certificate. +type SSHPublicKeyValidator interface { + SignOption + Valid(got ssh.PublicKey) error +} + // SignSSHOptions contains the options that can be passed to the SignSSH method. type SignSSHOptions struct { CertType string `json:"certType"` diff --git a/authority/provisioner/utils_test.go b/authority/provisioner/utils_test.go index 975192b20..e455f5f50 100644 --- a/authority/provisioner/utils_test.go +++ b/authority/provisioner/utils_test.go @@ -767,6 +767,37 @@ func generateToken(sub, iss, aud, email string, sans []string, iat time.Time, jw return jose.Signed(sig).Claims(claims).CompactSerialize() } +func generateCustomToken(sub, iss, aud string, jwk *jose.JSONWebKey, extraHeaders, extraClaims map[string]any) (string, error) { + so := new(jose.SignerOptions) + so.WithType("JWT") + so.WithHeader("kid", jwk.KeyID) + + for k, v := range extraHeaders { + so.WithHeader(jose.HeaderKey(k), v) + } + + sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, so) + if err != nil { + return "", err + } + + id, err := randutil.ASCII(64) + if err != nil { + return "", err + } + iat := time.Now() + claims := jose.Claims{ + ID: id, + Subject: sub, + Issuer: iss, + IssuedAt: jose.NewNumericDate(iat), + NotBefore: jose.NewNumericDate(iat), + Expiry: jose.NewNumericDate(iat.Add(5 * time.Minute)), + Audience: []string{aud}, + } + return jose.Signed(sig).Claims(claims).Claims(extraClaims).CompactSerialize() +} + func generateOIDCToken(sub, iss, aud, email, preferredUsername string, iat time.Time, jwk *jose.JSONWebKey, tokOpts ...tokOption) (string, error) { so := new(jose.SignerOptions) so.WithType("JWT") diff --git a/authority/provisioner/x5c.go b/authority/provisioner/x5c.go index 9b1f2b086..fd77fe75e 100644 --- a/authority/provisioner/x5c.go +++ b/authority/provisioner/x5c.go @@ -21,9 +21,10 @@ import ( // x5cPayload extends jwt.Claims with step attributes. type x5cPayload struct { jose.Claims - SANs []string `json:"sans,omitempty"` - Step *stepPayload `json:"step,omitempty"` - chains [][]*x509.Certificate + SANs []string `json:"sans,omitempty"` + Step *stepPayload `json:"step,omitempty"` + Confirmation *cnfPayload `json:"cnf,omitempty"` + chains [][]*x509.Certificate } // X5C is the default provisioner, an entity that can sign tokens necessary for @@ -233,6 +234,12 @@ func (p *X5C) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er } } + // Check the fingerprint of the certificate request if given. + var fingerprint string + if claims.Confirmation != nil { + fingerprint = claims.Confirmation.Fingerprint + } + return []SignOption{ self, templateOptions, @@ -243,6 +250,7 @@ func (p *X5C) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er x5cLeaf.NotBefore, x5cLeaf.NotAfter, }, // validators + csrFingerprintValidator(fingerprint), commonNameValidator(claims.Subject), newDefaultSANsValidator(ctx, claims.SANs), defaultPublicKeyValidator{}, diff --git a/authority/provisioner/x5c_test.go b/authority/provisioner/x5c_test.go index 0493d64aa..99d10b68c 100644 --- a/authority/provisioner/x5c_test.go +++ b/authority/provisioner/x5c_test.go @@ -3,9 +3,11 @@ package provisioner import ( "context" "crypto/x509" + "encoding/base64" "errors" "fmt" "net/http" + "strings" "testing" "time" @@ -14,13 +16,19 @@ import ( "go.step.sm/crypto/randutil" "go.step.sm/linkedca" - "github.com/smallstep/assert" "github.com/smallstep/certificates/api/render" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) +func assertHasPrefix(t *testing.T, s, p string) bool { + t.Helper() + return assert.True(t, strings.HasPrefix(s, p), "%q is not a prefix of %q", p, s) +} + func TestX5C_Getters(t *testing.T) { p, err := generateX5C(nil) - assert.FatalError(t, err) + require.NoError(t, err) id := "x5c/" + p.Name if got := p.GetID(); got != id { t.Errorf("X5C.GetID() = %v, want %v:%v", got, p.Name, id) @@ -79,7 +87,7 @@ func TestX5C_Init(t *testing.T) { }, "fail/invalid-duration": func(t *testing.T) ProvisionerValidateTest { p, err := generateX5C(nil) - assert.FatalError(t, err) + require.NoError(t, err) p.Claims = &Claims{DefaultTLSDur: &Duration{0}} return ProvisionerValidateTest{ p: p, @@ -88,7 +96,7 @@ func TestX5C_Init(t *testing.T) { }, "ok": func(t *testing.T) ProvisionerValidateTest { p, err := generateX5C(nil) - assert.FatalError(t, err) + require.NoError(t, err) return ProvisionerValidateTest{ p: p, } @@ -117,7 +125,7 @@ VR0RBA0wC4IJcm9vdC10ZXN0MAoGCCqGSM49BAMCA0kAMEYCIQC2vgqwla0u8LHH 1MHob14qvS5o76HautbIBW7fcHzz5gIhAIx5A2+wkJYX4026kqaZCk/1sAwTxSGY M46l92gdOozT -----END CERTIFICATE-----`)) - assert.FatalError(t, err) + require.NoError(t, err) return ProvisionerValidateTest{ p: p, extraValid: func(p *X5C) error { @@ -143,11 +151,11 @@ M46l92gdOozT err := tc.p.Init(config) if err != nil { if assert.NotNil(t, tc.err) { - assert.Equals(t, tc.err.Error(), err.Error()) + assert.EqualError(t, tc.err, err.Error()) } } else { if assert.Nil(t, tc.err) { - assert.Equals(t, *tc.p.ctl.Audiences, config.Audiences.WithFragment(tc.p.GetID())) + assert.Equal(t, *tc.p.ctl.Audiences, config.Audiences.WithFragment(tc.p.GetID())) if tc.extraValid != nil { assert.Nil(t, tc.extraValid(tc.p)) } @@ -159,9 +167,9 @@ M46l92gdOozT func TestX5C_authorizeToken(t *testing.T) { x5cCerts, err := pemutil.ReadCertificateBundle("./testdata/certs/x5c-leaf.crt") - assert.FatalError(t, err) + require.NoError(t, err) x5cJWK, err := jose.ReadKey("./testdata/secrets/x5c-leaf.key") - assert.FatalError(t, err) + require.NoError(t, err) type test struct { p *X5C @@ -172,7 +180,7 @@ func TestX5C_authorizeToken(t *testing.T) { tests := map[string]func(*testing.T) test{ "fail/bad-token": func(t *testing.T) test { p, err := generateX5C(nil) - assert.FatalError(t, err) + require.NoError(t, err) return test{ p: p, token: "foo", @@ -192,15 +200,15 @@ DgYDVR0PAQH/BAQDAgEGMBIGA1UdEwEB/wQIMAYBAf8CAQAwHQYDVR0OBBYEFNLJ P9K7MAoGCCqGSM49BAMCA0gAMEUCIQC5c1ldDcesDb31GlO5cEJvOcRrIrNtkk8m a5wpg+9s6QIgHIW6L60F8klQX+EO3o0SBqLeNcaskA4oSZsKjEdpSGo= -----END CERTIFICATE-----`)) - assert.FatalError(t, err) + require.NoError(t, err) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) - assert.FatalError(t, err) + require.NoError(t, err) p, err := generateX5C(nil) - assert.FatalError(t, err) + require.NoError(t, err) tok, err := generateToken("", p.Name, testAudiences.Sign[0], "", []string{"test.smallstep.com"}, time.Now(), jwk, withX5CHdr(certs)) - assert.FatalError(t, err) + require.NoError(t, err) return test{ p: p, token: tok, @@ -231,15 +239,15 @@ BgNVHREECTAHggVsZWFmMjAKBggqhkjOPQQDAgNIADBFAiB7gMRy3t81HpcnoRAS ELZmDFaEnoLCsVfbmanFykazQQIhAI0sZjoE9t6gvzQp7XQp6CoxzCc3Jv3FwZ8G EXAHTA9L -----END CERTIFICATE-----`)) - assert.FatalError(t, err) + require.NoError(t, err) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) - assert.FatalError(t, err) + require.NoError(t, err) p, err := generateX5C(nil) - assert.FatalError(t, err) + require.NoError(t, err) tok, err := generateToken("", p.Name, testAudiences.Sign[0], "", []string{"test.smallstep.com"}, time.Now(), jwk, withX5CHdr(certs)) - assert.FatalError(t, err) + require.NoError(t, err) return test{ p: p, token: tok, @@ -272,16 +280,16 @@ E4IRaW50ZXJtZWRpYXRlLXRlc3QwCgYIKoZIzj0EAwIDSAAwRQIgII8XpQ8ezDO1 2xdq3hShf155C5X/5jO8qr0VyEJgzlkCIQCTqph1Gwu/dmuf6dYLCfQqJyb371LC lgsqsR63is+0YQ== -----END CERTIFICATE-----`)) - assert.FatalError(t, err) + require.NoError(t, err) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) - assert.FatalError(t, err) + require.NoError(t, err) p, err := generateX5C(nil) - assert.FatalError(t, err) + require.NoError(t, err) tok, err := generateToken("", p.Name, testAudiences.Sign[0], "", []string{"test.smallstep.com"}, time.Now(), jwk, withX5CHdr(certs)) - assert.FatalError(t, err) + require.NoError(t, err) return test{ p: p, token: tok, @@ -314,15 +322,15 @@ E4IRaW50ZXJtZWRpYXRlLXRlc3QwCgYIKoZIzj0EAwIDSAAwRQIgII8XpQ8ezDO1 2xdq3hShf155C5X/5jO8qr0VyEJgzlkCIQCTqph1Gwu/dmuf6dYLCfQqJyb371LC lgsqsR63is+0YQ== -----END CERTIFICATE-----`)) - assert.FatalError(t, err) + require.NoError(t, err) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) - assert.FatalError(t, err) + require.NoError(t, err) p, err := generateX5C(nil) - assert.FatalError(t, err) + require.NoError(t, err) tok, err := generateToken("", "foobar", testAudiences.Sign[0], "", []string{"test.smallstep.com"}, time.Now(), jwk, withX5CHdr(certs)) - assert.FatalError(t, err) + require.NoError(t, err) return test{ p: p, token: tok, @@ -332,11 +340,11 @@ lgsqsR63is+0YQ== }, "fail/invalid-issuer": func(t *testing.T) test { p, err := generateX5C(nil) - assert.FatalError(t, err) + require.NoError(t, err) tok, err := generateToken("", "foobar", testAudiences.Sign[0], "", []string{"test.smallstep.com"}, time.Now(), x5cJWK, withX5CHdr(x5cCerts)) - assert.FatalError(t, err) + require.NoError(t, err) return test{ p: p, token: tok, @@ -346,11 +354,11 @@ lgsqsR63is+0YQ== }, "fail/invalid-audience": func(t *testing.T) test { p, err := generateX5C(nil) - assert.FatalError(t, err) + require.NoError(t, err) tok, err := generateToken("", p.GetName(), "foobar", "", []string{"test.smallstep.com"}, time.Now(), x5cJWK, withX5CHdr(x5cCerts)) - assert.FatalError(t, err) + require.NoError(t, err) return test{ p: p, token: tok, @@ -360,11 +368,11 @@ lgsqsR63is+0YQ== }, "fail/empty-subject": func(t *testing.T) test { p, err := generateX5C(nil) - assert.FatalError(t, err) + require.NoError(t, err) tok, err := generateToken("", p.GetName(), testAudiences.Sign[0], "", []string{"test.smallstep.com"}, time.Now(), x5cJWK, withX5CHdr(x5cCerts)) - assert.FatalError(t, err) + require.NoError(t, err) return test{ p: p, token: tok, @@ -374,11 +382,11 @@ lgsqsR63is+0YQ== }, "ok": func(t *testing.T) test { p, err := generateX5C(nil) - assert.FatalError(t, err) + require.NoError(t, err) tok, err := generateToken("foo", p.GetName(), testAudiences.Sign[0], "", []string{"test.smallstep.com"}, time.Now(), x5cJWK, withX5CHdr(x5cCerts)) - assert.FatalError(t, err) + require.NoError(t, err) return test{ p: p, token: tok, @@ -392,12 +400,12 @@ lgsqsR63is+0YQ== if assert.NotNil(t, tc.err) { var sc render.StatusCodedError if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { - assert.Equals(t, sc.StatusCode(), tc.code) + assert.Equal(t, tc.code, sc.StatusCode()) } - assert.HasPrefix(t, err.Error(), tc.err.Error()) + assertHasPrefix(t, err.Error(), tc.err.Error()) } } else { - if assert.Nil(t, tc.err) { + if assert.NoError(t, tc.err) { assert.NotNil(t, claims) assert.NotNil(t, claims.chains) } @@ -408,21 +416,22 @@ lgsqsR63is+0YQ== func TestX5C_AuthorizeSign(t *testing.T) { certs, err := pemutil.ReadCertificateBundle("./testdata/certs/x5c-leaf.crt") - assert.FatalError(t, err) + require.NoError(t, err) jwk, err := jose.ReadKey("./testdata/secrets/x5c-leaf.key") - assert.FatalError(t, err) + require.NoError(t, err) type test struct { - p *X5C - token string - code int - err error - sans []string + p *X5C + token string + code int + err error + sans []string + fingerprint string } tests := map[string]func(*testing.T) test{ "fail/invalid-token": func(t *testing.T) test { p, err := generateX5C(nil) - assert.FatalError(t, err) + require.NoError(t, err) return test{ p: p, token: "foo", @@ -432,11 +441,11 @@ func TestX5C_AuthorizeSign(t *testing.T) { }, "ok/empty-sans": func(t *testing.T) test { p, err := generateX5C(nil) - assert.FatalError(t, err) + require.NoError(t, err) tok, err := generateToken("foo", p.GetName(), testAudiences.Sign[0], "", []string{}, time.Now(), jwk, withX5CHdr(certs)) - assert.FatalError(t, err) + require.NoError(t, err) return test{ p: p, token: tok, @@ -445,65 +454,90 @@ func TestX5C_AuthorizeSign(t *testing.T) { }, "ok/multi-sans": func(t *testing.T) test { p, err := generateX5C(nil) - assert.FatalError(t, err) + require.NoError(t, err) tok, err := generateToken("foo", p.GetName(), testAudiences.Sign[0], "", []string{"127.0.0.1", "foo", "max@smallstep.com"}, time.Now(), jwk, withX5CHdr(certs)) - assert.FatalError(t, err) + require.NoError(t, err) return test{ p: p, token: tok, sans: []string{"127.0.0.1", "foo", "max@smallstep.com"}, } }, + "ok/cnf": func(t *testing.T) test { + p, err := generateX5C(nil) + require.NoError(t, err) + + x5c := make([]string, len(certs)) + for i, cert := range certs { + x5c[i] = base64.StdEncoding.EncodeToString(cert.Raw) + } + extraHeaders := map[string]any{"x5c": x5c} + extraClaims := map[string]any{ + "sans": []string{"127.0.0.1", "foo", "max@smallstep.com"}, + "cnf": map[string]any{"x5rt#S256": "fingerprint"}, + } + + tok, err := generateCustomToken("foo", p.GetName(), testAudiences.Sign[0], jwk, extraHeaders, extraClaims) + require.NoError(t, err) + return test{ + p: p, + token: tok, + sans: []string{"127.0.0.1", "foo", "max@smallstep.com"}, + fingerprint: "fingerprint", + } + }, } for name, tt := range tests { t.Run(name, func(t *testing.T) { tc := tt(t) ctx := NewContextWithMethod(context.Background(), SignIdentityMethod) if opts, err := tc.p.AuthorizeSign(ctx, tc.token); err != nil { - if assert.NotNil(t, tc.err) { + if assert.NotNil(t, tc.err, err.Error()) { var sc render.StatusCodedError if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { - assert.Equals(t, sc.StatusCode(), tc.code) + assert.Equal(t, tc.code, sc.StatusCode()) } - assert.HasPrefix(t, err.Error(), tc.err.Error()) + assertHasPrefix(t, err.Error(), tc.err.Error()) } } else { if assert.Nil(t, tc.err) { if assert.NotNil(t, opts) { - assert.Equals(t, 10, len(opts)) + assert.Len(t, opts, 11) for _, o := range opts { switch v := o.(type) { case *X5C: case certificateOptionsFunc: case *provisionerExtensionOption: - assert.Equals(t, v.Type, TypeX5C) - assert.Equals(t, v.Name, tc.p.GetName()) - assert.Equals(t, v.CredentialID, "") - assert.Len(t, 0, v.KeyValuePairs) + assert.Equal(t, TypeX5C, v.Type) + assert.Equal(t, tc.p.GetName(), v.Name) + assert.Equal(t, "", v.CredentialID) + assert.Len(t, v.KeyValuePairs, 0) case profileLimitDuration: - assert.Equals(t, v.def, tc.p.ctl.Claimer.DefaultTLSCertDuration()) + assert.Equal(t, tc.p.ctl.Claimer.DefaultTLSCertDuration(), v.def) claims, err := tc.p.authorizeToken(tc.token, tc.p.ctl.Audiences.Sign) - assert.FatalError(t, err) - assert.Equals(t, v.notAfter, claims.chains[0][0].NotAfter) + require.NoError(t, err) + assert.Equal(t, claims.chains[0][0].NotAfter, v.notAfter) case commonNameValidator: - assert.Equals(t, string(v), "foo") + assert.Equal(t, "foo", string(v)) case defaultPublicKeyValidator: case *defaultSANsValidator: - assert.Equals(t, v.sans, tc.sans) - assert.Equals(t, MethodFromContext(v.ctx), SignIdentityMethod) + assert.Equal(t, tc.sans, v.sans) + assert.Equal(t, SignIdentityMethod, MethodFromContext(v.ctx)) case *validityValidator: - assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration()) - assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration()) + assert.Equal(t, tc.p.ctl.Claimer.MinTLSCertDuration(), v.min) + assert.Equal(t, tc.p.ctl.Claimer.MaxTLSCertDuration(), v.max) case *x509NamePolicyValidator: - assert.Equals(t, nil, v.policyEngine) + assert.Equal(t, nil, v.policyEngine) case *WebhookController: - assert.Len(t, 0, v.webhooks) - assert.Equals(t, linkedca.Webhook_X509, v.certType) - assert.Len(t, 2, v.options) + assert.Len(t, v.webhooks, 0) + assert.Equal(t, linkedca.Webhook_X509, v.certType) + assert.Len(t, v.options, 2) + case csrFingerprintValidator: + assert.Equal(t, tc.fingerprint, string(v)) default: - assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) + require.NoError(t, fmt.Errorf("unexpected sign option of type %T", v)) } } } @@ -523,7 +557,7 @@ func TestX5C_AuthorizeRevoke(t *testing.T) { tests := map[string]func(*testing.T) test{ "fail/invalid-token": func(t *testing.T) test { p, err := generateX5C(nil) - assert.FatalError(t, err) + require.NoError(t, err) return test{ p: p, token: "foo", @@ -533,16 +567,16 @@ func TestX5C_AuthorizeRevoke(t *testing.T) { }, "ok": func(t *testing.T) test { certs, err := pemutil.ReadCertificateBundle("./testdata/certs/x5c-leaf.crt") - assert.FatalError(t, err) + require.NoError(t, err) jwk, err := jose.ReadKey("./testdata/secrets/x5c-leaf.key") - assert.FatalError(t, err) + require.NoError(t, err) p, err := generateX5C(nil) - assert.FatalError(t, err) + require.NoError(t, err) tok, err := generateToken("foo", p.GetName(), testAudiences.Revoke[0], "", []string{"test.smallstep.com"}, time.Now(), jwk, withX5CHdr(certs)) - assert.FatalError(t, err) + require.NoError(t, err) return test{ p: p, token: tok, @@ -556,9 +590,9 @@ func TestX5C_AuthorizeRevoke(t *testing.T) { if assert.NotNil(t, tc.err) { var sc render.StatusCodedError if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { - assert.Equals(t, sc.StatusCode(), tc.code) + assert.Equal(t, tc.code, sc.StatusCode()) } - assert.HasPrefix(t, err.Error(), tc.err.Error()) + assertHasPrefix(t, err.Error(), tc.err.Error()) } } else { assert.Nil(t, tc.err) @@ -577,12 +611,12 @@ func TestX5C_AuthorizeRenew(t *testing.T) { tests := map[string]func(*testing.T) test{ "fail/renew-disabled": func(t *testing.T) test { p, err := generateX5C(nil) - assert.FatalError(t, err) + require.NoError(t, err) // disable renewal disable := true p.Claims = &Claims{DisableRenewal: &disable} p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) - assert.FatalError(t, err) + require.NoError(t, err) return test{ p: p, code: http.StatusUnauthorized, @@ -591,7 +625,7 @@ func TestX5C_AuthorizeRenew(t *testing.T) { }, "ok": func(t *testing.T) test { p, err := generateX5C(nil) - assert.FatalError(t, err) + require.NoError(t, err) return test{ p: p, } @@ -607,9 +641,9 @@ func TestX5C_AuthorizeRenew(t *testing.T) { if assert.NotNil(t, tc.err) { var sc render.StatusCodedError if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { - assert.Equals(t, sc.StatusCode(), tc.code) + assert.Equal(t, tc.code, sc.StatusCode()) } - assert.HasPrefix(t, err.Error(), tc.err.Error()) + assertHasPrefix(t, err.Error(), tc.err.Error()) } } else { assert.Nil(t, tc.err) @@ -620,28 +654,30 @@ func TestX5C_AuthorizeRenew(t *testing.T) { func TestX5C_AuthorizeSSHSign(t *testing.T) { x5cCerts, err := pemutil.ReadCertificateBundle("./testdata/certs/x5c-leaf.crt") - assert.FatalError(t, err) + require.NoError(t, err) x5cJWK, err := jose.ReadKey("./testdata/secrets/x5c-leaf.key") - assert.FatalError(t, err) + require.NoError(t, err) _, fn := mockNow() defer fn() type test struct { - p *X5C - token string - claims *x5cPayload - code int - err error + p *X5C + token string + claims *x5cPayload + fingerprint string + count int + code int + err error } tests := map[string]func(*testing.T) test{ "fail/sshCA-disabled": func(t *testing.T) test { p, err := generateX5C(nil) - assert.FatalError(t, err) + require.NoError(t, err) // disable sshCA enable := false p.Claims = &Claims{EnableSSHCA: &enable} p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) - assert.FatalError(t, err) + require.NoError(t, err) return test{ p: p, token: "foo", @@ -651,7 +687,7 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) { }, "fail/invalid-token": func(t *testing.T) test { p, err := generateX5C(nil) - assert.FatalError(t, err) + require.NoError(t, err) return test{ p: p, token: "foo", @@ -661,11 +697,11 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) { }, "fail/no-Step-claim": func(t *testing.T) test { p, err := generateX5C(nil) - assert.FatalError(t, err) + require.NoError(t, err) tok, err := generateToken("foo", p.GetName(), testAudiences.SSHSign[0], "", []string{"test.smallstep.com"}, time.Now(), x5cJWK, withX5CHdr(x5cCerts)) - assert.FatalError(t, err) + require.NoError(t, err) return test{ p: p, token: tok, @@ -675,10 +711,10 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) { }, "fail/no-SSH-subattribute-in-claims": func(t *testing.T) test { p, err := generateX5C(nil) - assert.FatalError(t, err) + require.NoError(t, err) id, err := randutil.ASCII(64) - assert.FatalError(t, err) + require.NoError(t, err) now := time.Now() claims := &x5cPayload{ Claims: jose.Claims{ @@ -693,7 +729,7 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) { Step: &stepPayload{}, } tok, err := generateX5CSSHToken(x5cJWK, claims, withX5CHdr(x5cCerts)) - assert.FatalError(t, err) + require.NoError(t, err) return test{ p: p, token: tok, @@ -703,10 +739,10 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) { }, "ok/with-claims": func(t *testing.T) test { p, err := generateX5C(nil) - assert.FatalError(t, err) + require.NoError(t, err) id, err := randutil.ASCII(64) - assert.FatalError(t, err) + require.NoError(t, err) now := time.Now() claims := &x5cPayload{ Claims: jose.Claims{ @@ -719,7 +755,7 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) { Audience: []string{testAudiences.SSHSign[0]}, }, Step: &stepPayload{SSH: &SignSSHOptions{ - CertType: SSHHostCert, + CertType: SSHUserCert, KeyID: "foo", Principals: []string{"max", "mariano", "alan"}, ValidAfter: TimeDuration{d: 5 * time.Minute}, @@ -727,19 +763,20 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) { }}, } tok, err := generateX5CSSHToken(x5cJWK, claims, withX5CHdr(x5cCerts)) - assert.FatalError(t, err) + require.NoError(t, err) return test{ p: p, claims: claims, token: tok, + count: 12, } }, "ok/without-claims": func(t *testing.T) test { p, err := generateX5C(nil) - assert.FatalError(t, err) + require.NoError(t, err) id, err := randutil.ASCII(64) - assert.FatalError(t, err) + require.NoError(t, err) now := time.Now() claims := &x5cPayload{ Claims: jose.Claims{ @@ -754,11 +791,47 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) { Step: &stepPayload{SSH: &SignSSHOptions{}}, } tok, err := generateX5CSSHToken(x5cJWK, claims, withX5CHdr(x5cCerts)) - assert.FatalError(t, err) + require.NoError(t, err) return test{ p: p, claims: claims, token: tok, + count: 10, + } + }, + "ok/cnf": func(t *testing.T) test { + p, err := generateX5C(nil) + require.NoError(t, err) + + id, err := randutil.ASCII(64) + require.NoError(t, err) + now := time.Now() + claims := &x5cPayload{ + Claims: jose.Claims{ + ID: id, + Subject: "foo", + Issuer: p.GetName(), + IssuedAt: jose.NewNumericDate(now), + NotBefore: jose.NewNumericDate(now), + Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), + Audience: []string{testAudiences.SSHSign[0]}, + }, + Step: &stepPayload{SSH: &SignSSHOptions{ + CertType: SSHHostCert, + Principals: []string{"host.smallstep.com"}, + }}, + Confirmation: &cnfPayload{ + Fingerprint: "fingerprint", + }, + } + tok, err := generateX5CSSHToken(x5cJWK, claims, withX5CHdr(x5cCerts)) + require.NoError(t, err) + return test{ + p: p, + claims: claims, + token: tok, + fingerprint: "fingerprint", + count: 10, } }, } @@ -769,9 +842,9 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) { if assert.NotNil(t, tc.err) { var sc render.StatusCodedError if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { - assert.Equals(t, sc.StatusCode(), tc.code) + assert.Equal(t, tc.code, sc.StatusCode()) } - assert.HasPrefix(t, err.Error(), tc.err.Error()) + assertHasPrefix(t, err.Error(), tc.err.Error()) } } else { if assert.Nil(t, tc.err) { @@ -786,38 +859,34 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) { tc.claims.Step.SSH.ValidAfter.t = time.Time{} tc.claims.Step.SSH.ValidBefore.t = time.Time{} if firstValidator { - assert.Equals(t, SignSSHOptions(v), *tc.claims.Step.SSH) + assert.Equal(t, *tc.claims.Step.SSH, SignSSHOptions(v)) } else { - assert.Equals(t, SignSSHOptions(v), SignSSHOptions{KeyID: tc.claims.Subject}) + assert.Equal(t, SignSSHOptions{KeyID: tc.claims.Subject}, SignSSHOptions(v)) } firstValidator = false case sshCertValidAfterModifier: - assert.Equals(t, int64(v), tc.claims.Step.SSH.ValidAfter.RelativeTime(nw).Unix()) + assert.Equal(t, tc.claims.Step.SSH.ValidAfter.RelativeTime(nw).Unix(), int64(v)) case sshCertValidBeforeModifier: - assert.Equals(t, int64(v), tc.claims.Step.SSH.ValidBefore.RelativeTime(nw).Unix()) + assert.Equal(t, tc.claims.Step.SSH.ValidBefore.RelativeTime(nw).Unix(), int64(v)) case *sshLimitDuration: - assert.Equals(t, v.Claimer, tc.p.ctl.Claimer) - assert.Equals(t, v.NotAfter, x5cCerts[0].NotAfter) + assert.Equal(t, tc.p.ctl.Claimer, v.Claimer) + assert.Equal(t, x5cCerts[0].NotAfter, v.NotAfter) case *sshCertValidityValidator: - assert.Equals(t, v.Claimer, tc.p.ctl.Claimer) + assert.Equal(t, tc.p.ctl.Claimer, v.Claimer) case *sshNamePolicyValidator: - assert.Equals(t, nil, v.userPolicyEngine) - assert.Equals(t, nil, v.hostPolicyEngine) + assert.Nil(t, v.userPolicyEngine) + assert.Nil(t, v.hostPolicyEngine) case *sshDefaultPublicKeyValidator, *sshCertDefaultValidator, sshCertificateOptionsFunc: case *WebhookController: - assert.Len(t, 0, v.webhooks) - assert.Equals(t, linkedca.Webhook_SSH, v.certType) - assert.Len(t, 2, v.options) + assert.Len(t, v.webhooks, 0) + assert.Equal(t, linkedca.Webhook_SSH, v.certType) + assert.Len(t, v.options, 2) default: - assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) + require.NoError(t, fmt.Errorf("unexpected sign option of type %T", v)) } tot++ } - if tc.claims.Step.SSH.CertType != "" { - assert.Equals(t, tot, 12) - } else { - assert.Equals(t, tot, 10) - } + assert.Equal(t, tc.count, tot) } } } diff --git a/authority/ssh.go b/authority/ssh.go index 55f4f4a21..30e4bfc7b 100644 --- a/authority/ssh.go +++ b/authority/ssh.go @@ -154,12 +154,16 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi func (a *Authority) signSSH(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, provisioner.Interface, error) { var ( - certOptions []sshutil.Option - mods []provisioner.SSHCertModifier - validators []provisioner.SSHCertValidator + certOptions []sshutil.Option + mods []provisioner.SSHCertModifier + validators []provisioner.SSHCertValidator + keyValidators []provisioner.SSHPublicKeyValidator ) - // Validate given options. + // Validate given key and options + if key == nil { + return nil, nil, errs.BadRequest("ssh public key cannot be nil") + } if err := opts.Validate(); err != nil { return nil, nil, err } @@ -183,6 +187,10 @@ func (a *Authority) signSSH(ctx context.Context, key ssh.PublicKey, opts provisi case provisioner.SSHCertModifier: mods = append(mods, o) + // validate the ssh public key + case provisioner.SSHPublicKeyValidator: + keyValidators = append(keyValidators, o) + // validate the ssh.Certificate case provisioner.SSHCertValidator: validators = append(validators, o) @@ -202,6 +210,16 @@ func (a *Authority) signSSH(ctx context.Context, key ssh.PublicKey, opts provisi } } + // Validate public key + for _, v := range keyValidators { + if err := v.Valid(key); err != nil { + return nil, nil, errs.ApplyOptions( + errs.ForbiddenErr(err, err.Error()), + errs.WithKeyVal("signOptions", signOpts), + ) + } + } + // Simulated certificate request with request options. cr := sshutil.CertificateRequest{ Type: opts.CertType, diff --git a/authority/tls_test.go b/authority/tls_test.go index 4965fe46e..c7bd6f10d 100644 --- a/authority/tls_test.go +++ b/authority/tls_test.go @@ -19,6 +19,7 @@ import ( "testing" "time" + "go.step.sm/crypto/fingerprint" "go.step.sm/crypto/jose" "go.step.sm/crypto/keyutil" "go.step.sm/crypto/minica" @@ -224,15 +225,6 @@ func generateSubjectKeyID(pub crypto.PublicKey) ([]byte, error) { return hash[:], nil } -func assertHasPrefix(t *testing.T, s, p string) bool { - if strings.HasPrefix(s, p) { - return true - } - t.Helper() - t.Errorf("%q is not a prefix of %q", p, s) - return false -} - type basicConstraints struct { IsCA bool `asn1:"optional"` MaxPathLen int `asn1:"optional,default:-1"` @@ -249,6 +241,11 @@ func (e *testEnforcer) Enforce(cert *x509.Certificate) error { return nil } +func assertHasPrefix(t *testing.T, s, p string) bool { + t.Helper() + return assert.True(t, strings.HasPrefix(s, p), "%q is not a prefix of %q", p, s) +} + func TestAuthority_SignWithContext(t *testing.T) { pub, priv, err := keyutil.GenerateDefaultKeyPair() require.NoError(t, err) @@ -605,6 +602,43 @@ ZYtQ9Ot36qc= code: http.StatusForbidden, } }, + "fail with cnf": func(t *testing.T) *signTest { + csr := getCSR(t, priv) + + auth := testAuthority(t) + auth.config.AuthorityConfig.Template = a.config.AuthorityConfig.Template + auth.db = &db.MockAuthDB{ + MUseToken: func(id, tok string) (bool, error) { + return true, nil + }, + MStoreCertificate: func(crt *x509.Certificate) error { + assert.Equal(t, crt.Subject.CommonName, "smallstep test") + assert.Equal(t, crt.DNSNames, []string{"test.smallstep.com"}) + return nil + }, + } + + // Create a token with cnf + tok, err := generateCustomToken("smallstep test", "step-cli", testAudiences.Sign[0], key, nil, map[string]any{ + "sans": []string{"test.smallstep.com"}, + "cnf": map[string]any{"x5rt#S256": "bad-fingerprint"}, + }) + require.NoError(t, err) + + opts, err := auth.Authorize(ctx, tok) + require.NoError(t, err) + + return &signTest{ + auth: auth, + csr: csr, + extraOpts: opts, + signOpts: signOpts, + notBefore: signOpts.NotBefore.Time().Truncate(time.Second), + notAfter: signOpts.NotAfter.Time().Truncate(time.Second), + err: errors.New(`certificate request fingerprint does not match "bad-fingerprint"`), + code: http.StatusForbidden, + } + }, "ok": func(t *testing.T) *signTest { csr := getCSR(t, priv) _a := testAuthority(t) @@ -852,6 +886,44 @@ ZYtQ9Ot36qc= extensionsCount: 6, } }, + "ok with cnf": func(t *testing.T) *signTest { + csr := getCSR(t, priv) + fingerprint, err := fingerprint.New(csr.Raw, crypto.SHA256, fingerprint.Base64RawURLFingerprint) + require.NoError(t, err) + + auth := testAuthority(t) + auth.config.AuthorityConfig.Template = a.config.AuthorityConfig.Template + auth.db = &db.MockAuthDB{ + MUseToken: func(id, tok string) (bool, error) { + return true, nil + }, + MStoreCertificate: func(crt *x509.Certificate) error { + assert.Equal(t, crt.Subject.CommonName, "smallstep test") + assert.Equal(t, crt.DNSNames, []string{"test.smallstep.com"}) + return nil + }, + } + + // Create a token with cnf + tok, err := generateCustomToken("smallstep test", "step-cli", testAudiences.Sign[0], key, nil, map[string]any{ + "sans": []string{"test.smallstep.com"}, + "cnf": map[string]any{"x5rt#S256": fingerprint}, + }) + require.NoError(t, err) + + opts, err := auth.Authorize(ctx, tok) + require.NoError(t, err) + + return &signTest{ + auth: auth, + csr: csr, + extraOpts: opts, + signOpts: signOpts, + notBefore: signOpts.NotBefore.Time().Truncate(time.Second), + notAfter: signOpts.NotAfter.Time().Truncate(time.Second), + extensionsCount: 6, + } + }, } for name, genTestCase := range tests { diff --git a/ca/ca_test.go b/ca/ca_test.go index a8c173c4d..30a3fbbb7 100644 --- a/ca/ca_test.go +++ b/ca/ca_test.go @@ -625,7 +625,7 @@ func TestCARenew(t *testing.T) { cert, err := x509util.NewCertificate(cr) assert.FatalError(t, err) crt := cert.GetCertificate() - crt.NotBefore = time.Now() + crt.NotBefore = now crt.NotAfter = leafExpiry crt, err = x509util.CreateCertificate(crt, intermediateCert, pub, intermediateKey.(crypto.Signer)) assert.FatalError(t, err)