Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check cnf claim with CSR fingerprint #1660

Merged
merged 8 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion authority/authorize_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,39 @@ func generateToken(sub, iss, aud string, sans []string, iat time.Time, jwk *jose
return jose.Signed(sig).Claims(claims).CompactSerialize()
}

func generateCustomToken(sub, iss, aud string, jwk *jose.JSONWebKey, extraHeaders, extraClaims map[string]any) (string, error) {
so := new(jose.SignerOptions)
so.WithType("JWT")
so.WithHeader("kid", jwk.KeyID)

for k, v := range extraHeaders {
so.WithHeader(jose.HeaderKey(k), v)
}

sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, so)
if err != nil {
return "", err
}

id, err := randutil.ASCII(64)
if err != nil {
return "", err
}

iat := time.Now()
claims := jose.Claims{
ID: id,
Subject: sub,
Issuer: iss,
IssuedAt: jose.NewNumericDate(iat),
NotBefore: jose.NewNumericDate(iat),
Expiry: jose.NewNumericDate(iat.Add(5 * time.Minute)),
Audience: []string{aud},
}

return jose.Signed(sig).Claims(claims).Claims(extraClaims).CompactSerialize()
}

func TestAuthority_authorizeToken(t *testing.T) {
a := testAuthority(t)

Expand Down Expand Up @@ -491,7 +524,7 @@ func TestAuthority_authorizeSign(t *testing.T) {
}
} else {
if assert.Nil(t, tc.err) {
assert.Equals(t, 10, len(got)) // number of provisioner.SignOptions returned
assert.Equals(t, 11, len(got)) // number of provisioner.SignOptions returned
}
}
})
Expand Down
21 changes: 19 additions & 2 deletions authority/provisioner/jwk.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,20 @@ import (
// jwtPayload extends jwt.Claims with step attributes.
type jwtPayload struct {
jose.Claims
SANs []string `json:"sans,omitempty"`
Step *stepPayload `json:"step,omitempty"`
SANs []string `json:"sans,omitempty"`
Step *stepPayload `json:"step,omitempty"`
Confirmation *cnfPayload `json:"cnf,omitempty"`
}

type stepPayload struct {
SSH *SignSSHOptions `json:"ssh,omitempty"`
RA *RAInfo `json:"ra,omitempty"`
}

type cnfPayload struct {
Kid string `json:"kid,omitempty"`
}

// JWK is the default provisioner, an entity that can sign tokens necessary for
// signature requests.
type JWK struct {
Expand Down Expand Up @@ -183,13 +188,20 @@ func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
}
}

// Check the fingerprint of the certificate request if given.
var fingerprint string
if claims.Confirmation != nil {
fingerprint = claims.Confirmation.Kid
}

return []SignOption{
self,
templateOptions,
// modifiers / withOptions
newProvisionerExtensionOption(TypeJWK, p.Name, p.Key.KeyID).WithControllerOptions(p.ctl),
profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()),
// validators
fingerprintValidator(fingerprint),
commonNameSliceValidator(append([]string{claims.Subject}, claims.SANs...)),
defaultPublicKeyValidator{},
newDefaultSANsValidator(ctx, claims.SANs),
Expand Down Expand Up @@ -229,6 +241,11 @@ func (p *JWK) AuthorizeSSHSign(_ context.Context, token string) ([]SignOption, e
sshCertOptionsValidator(SignSSHOptions{KeyID: claims.Subject}),
}

// Check the fingerprint of the certificate request if given.
if claims.Confirmation != nil && claims.Confirmation.Kid != "" {
signOptions = append(signOptions, sshFingerprintValidator(claims.Confirmation.Kid))
}

// Default template attributes.
certType := sshutil.UserCert
keyID := claims.Subject
Expand Down
79 changes: 60 additions & 19 deletions authority/provisioner/jwk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ import (
"testing"
"time"

"go.step.sm/crypto/fingerprint"
"go.step.sm/crypto/jose"
"golang.org/x/crypto/ssh"

"github.com/smallstep/assert"
"github.com/smallstep/certificates/api/render"
Expand Down Expand Up @@ -247,19 +249,23 @@ func TestJWK_AuthorizeSign(t *testing.T) {
t2, err := generateToken("subject", p1.Name, testAudiences.Sign[0], "[email protected]", []string{}, time.Now(), key1)
assert.FatalError(t, err)

t3, err := generateCustomToken("subject", p1.Name, testAudiences.Sign[0], key1, nil, map[string]any{"cnf": map[string]any{"kid": "fingerprint"}})
assert.FatalError(t, err)

// invalid signature
failSig := t1[0 : len(t1)-2]

type args struct {
token string
}
tests := []struct {
name string
prov *JWK
args args
code int
err error
sans []string
name string
prov *JWK
args args
code int
err error
sans []string
fingerprint string
}{
{
name: "fail-signature",
Expand All @@ -284,6 +290,15 @@ func TestJWK_AuthorizeSign(t *testing.T) {
err: nil,
sans: []string{"subject"},
},
{
name: "ok-cnf",
prov: p1,
args: args{t3},
code: http.StatusOK,
err: nil,
sans: []string{"subject"},
fingerprint: "fingerprint",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand All @@ -297,7 +312,7 @@ func TestJWK_AuthorizeSign(t *testing.T) {
}
} else {
if assert.NotNil(t, got) {
assert.Equals(t, 10, len(got))
assert.Equals(t, 11, len(got))
for _, o := range got {
switch v := o.(type) {
case *JWK:
Expand All @@ -321,6 +336,8 @@ func TestJWK_AuthorizeSign(t *testing.T) {
case *x509NamePolicyValidator:
assert.Equals(t, nil, v.policyEngine)
case *WebhookController:
case fingerprintValidator:
assert.Equals(t, tt.fingerprint, string(v))
default:
assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
}
Expand Down Expand Up @@ -393,17 +410,6 @@ func TestJWK_AuthorizeSSHSign(t *testing.T) {
jwk, err := decryptJSONWebKey(p1.EncryptedKey)
assert.FatalError(t, err)

iss, aud := p1.Name, testAudiences.SSHSign[0]

t1, err := generateSimpleSSHUserToken(iss, aud, jwk)
assert.FatalError(t, err)

t2, err := generateSimpleSSHHostToken(iss, aud, jwk)
assert.FatalError(t, err)

// invalid signature
failSig := t1[0 : len(t1)-2]

key, err := generateJSONWebKey()
assert.FatalError(t, err)

Expand All @@ -417,6 +423,39 @@ func TestJWK_AuthorizeSSHSign(t *testing.T) {
rsa1024, err := rsa.GenerateKey(rand.Reader, 1024)
assert.FatalError(t, err)

// Calculate fingerprint
sshPub, err := ssh.NewPublicKey(pub)
assert.FatalError(t, err)
fp, err := fingerprint.New(sshPub.Marshal(), crypto.SHA256, fingerprint.Base64RawURLFingerprint)
assert.FatalError(t, err)

iss, aud := p1.Name, testAudiences.SSHSign[0]

t1, err := generateSimpleSSHUserToken(iss, aud, jwk)
assert.FatalError(t, err)

t2, err := generateSimpleSSHHostToken(iss, aud, jwk)
assert.FatalError(t, err)

t3, err := generateCustomToken("sub", iss, aud, jwk, nil, map[string]any{
"step": map[string]any{
"ssh": map[string]any{"certType": "host", "principals": []string{"smallstep.com"}},
},
"cnf": map[string]any{"kid": fp},
})
assert.FatalError(t, err)

t4, err := generateCustomToken("sub", iss, aud, jwk, nil, map[string]any{
"step": map[string]any{
"ssh": map[string]any{"certType": "host", "principals": []string{"smallstep.com"}},
},
"cnf": map[string]any{"kid": "bad-fingerprint"},
})
assert.FatalError(t, err)

// invalid signature
failSig := t1[0 : len(t1)-2]

userDuration := p1.ctl.Claimer.DefaultUserSSHCertDuration()
hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration()
expectedUserOptions := &SignSSHOptions{
Expand Down Expand Up @@ -451,9 +490,11 @@ func TestJWK_AuthorizeSSHSign(t *testing.T) {
{"host-type", p1, args{t2, SignSSHOptions{CertType: "host"}, pub}, expectedHostOptions, http.StatusOK, false, false},
{"host-principals", p1, args{t2, SignSSHOptions{Principals: []string{"smallstep.com"}}, pub}, expectedHostOptions, http.StatusOK, false, false},
{"host-options", p1, args{t2, SignSSHOptions{CertType: "host", Principals: []string{"smallstep.com"}}, pub}, expectedHostOptions, http.StatusOK, false, false},
{"host-cnf", p1, args{t3, SignSSHOptions{CertType: "host", Principals: []string{"smallstep.com"}}, pub}, expectedHostOptions, http.StatusOK, false, false},
{"fail-sshCA-disabled", p2, args{"foo", SignSSHOptions{}, pub}, expectedUserOptions, http.StatusUnauthorized, true, false},
{"fail-signature", p1, args{failSig, SignSSHOptions{}, pub}, nil, http.StatusUnauthorized, true, false},
{"rail-rsa1024", p1, args{t1, SignSSHOptions{}, rsa1024.Public()}, expectedUserOptions, http.StatusOK, false, true},
{"fail-rsa1024", p1, args{t1, SignSSHOptions{}, rsa1024.Public()}, expectedUserOptions, http.StatusOK, false, true},
{"fail-cnf", p1, args{t4, SignSSHOptions{CertType: "host", Principals: []string{"smallstep.com"}}, pub}, expectedHostOptions, http.StatusUnauthorized, false, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down
21 changes: 21 additions & 0 deletions authority/provisioner/sign_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ import (
"crypto/ecdsa"
"crypto/ed25519"
"crypto/rsa"
"crypto/sha256"
"crypto/subtle"
"crypto/x509"
"encoding/base64"
"encoding/json"
"net"
"net/http"
Expand Down Expand Up @@ -492,3 +495,21 @@ func (o *provisionerExtensionOption) Modify(cert *x509.Certificate, _ SignOption
cert.ExtraExtensions = append(cert.ExtraExtensions, ext)
return nil
}

// fingerprintValidator is a CertificateRequestValidator that checks the
// fingerprint of the certificate with the provided one.
type fingerprintValidator string
hslatman marked this conversation as resolved.
Show resolved Hide resolved

func (s fingerprintValidator) Valid(cr *x509.CertificateRequest) error {
if s != "" {
expected, err := base64.RawURLEncoding.DecodeString(string(s))
if err != nil {
return errs.ForbiddenErr(err, "error decoding fingerprint")
}
sum := sha256.Sum256(cr.Raw)
if subtle.ConstantTimeCompare(expected, sum[:]) != 1 {
return errs.Forbidden("certificate request fingerprint does not match %q", s)
}
}
return nil
}
28 changes: 28 additions & 0 deletions authority/provisioner/sign_ssh_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ package provisioner

import (
"crypto/rsa"
"crypto/sha256"
"crypto/subtle"
"encoding/base64"
"encoding/binary"
"encoding/json"
"fmt"
Expand Down Expand Up @@ -44,6 +47,13 @@ type SSHCertOptionsValidator interface {
Valid(got SignSSHOptions) error
}

// SSHPublicKeyValidator is the interface used to validate the public key of an
// SSH certificate.
type SSHPublicKeyValidator interface {
SignOption
Valid(got ssh.PublicKey) error
}

// SignSSHOptions contains the options that can be passed to the SignSSH method.
type SignSSHOptions struct {
CertType string `json:"certType"`
Expand Down Expand Up @@ -419,6 +429,24 @@ func (v *sshNamePolicyValidator) Valid(cert *ssh.Certificate, _ SignSSHOptions)
}
}

// sshFingerprintValidator is a SSHPublicKeyValidator that checks the
// fingerprint of the public key with the provided one.
type sshFingerprintValidator string

func (s sshFingerprintValidator) Valid(key ssh.PublicKey) error {
if s != "" {
maraino marked this conversation as resolved.
Show resolved Hide resolved
expected, err := base64.RawURLEncoding.DecodeString(string(s))
if err != nil {
return errs.ForbiddenErr(err, "error decoding fingerprint")
}
sum := sha256.Sum256(key.Marshal())
if subtle.ConstantTimeCompare(expected, sum[:]) != 1 {
return errs.Forbidden("ssh public key fingerprint does not match %q", s)
}
}
return nil
}

// sshCertTypeUInt32
func sshCertTypeUInt32(ct string) uint32 {
switch ct {
Expand Down
9 changes: 9 additions & 0 deletions authority/provisioner/ssh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ func signSSHCertificate(key crypto.PublicKey, opts SignSSHOptions, signOpts []Si
var mods []SSHCertModifier
var certOptions []sshutil.Option
var validators []SSHCertValidator
var keyValidators []SSHPublicKeyValidator

for _, op := range signOpts {
switch o := op.(type) {
Expand All @@ -71,11 +72,19 @@ func signSSHCertificate(key crypto.PublicKey, opts SignSSHOptions, signOpts []Si
}
// call webhooks
case *WebhookController:
case sshFingerprintValidator:
keyValidators = append(keyValidators, o)
default:
return nil, fmt.Errorf("signSSH: invalid extra option type %T", o)
}
}

for _, v := range keyValidators {
if err := v.Valid(pub); err != nil {
return nil, err
}
}

// Simulated certificate request with request options.
cr := sshutil.CertificateRequest{
Type: opts.CertType,
Expand Down
31 changes: 31 additions & 0 deletions authority/provisioner/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -765,6 +765,37 @@ func generateToken(sub, iss, aud, email string, sans []string, iat time.Time, jw
return jose.Signed(sig).Claims(claims).CompactSerialize()
}

func generateCustomToken(sub, iss, aud string, jwk *jose.JSONWebKey, extraHeaders, extraClaims map[string]any) (string, error) {
so := new(jose.SignerOptions)
so.WithType("JWT")
so.WithHeader("kid", jwk.KeyID)

for k, v := range extraHeaders {
so.WithHeader(jose.HeaderKey(k), v)
}

sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, so)
if err != nil {
return "", err
}

id, err := randutil.ASCII(64)
if err != nil {
return "", err
}
iat := time.Now()
claims := jose.Claims{
ID: id,
Subject: sub,
Issuer: iss,
IssuedAt: jose.NewNumericDate(iat),
NotBefore: jose.NewNumericDate(iat),
Expiry: jose.NewNumericDate(iat.Add(5 * time.Minute)),
Audience: []string{aud},
}
return jose.Signed(sig).Claims(claims).Claims(extraClaims).CompactSerialize()
}

func generateOIDCToken(sub, iss, aud, email, preferredUsername string, iat time.Time, jwk *jose.JSONWebKey, tokOpts ...tokOption) (string, error) {
so := new(jose.SignerOptions)
so.WithType("JWT")
Expand Down
Loading
Loading