diff --git a/api/ssh.go b/api/ssh.go index e0e4e01cd..5cfe906da 100644 --- a/api/ssh.go +++ b/api/ssh.go @@ -6,8 +6,11 @@ import ( "encoding/base64" "encoding/json" "net/http" + "net/url" + "strings" "time" + "github.com/google/uuid" "github.com/pkg/errors" "golang.org/x/crypto/ssh" @@ -326,6 +329,7 @@ func SSHSign(w http.ResponseWriter, r *http.Request) { // Enforce the same duration as ssh certificate. signOpts = append(signOpts, &identityModifier{ + Identity: getIdentityURI(cr), NotBefore: time.Unix(int64(cert.ValidAfter), 0), NotAfter: time.Unix(int64(cert.ValidBefore), 0), }) @@ -498,14 +502,42 @@ func SSHBastion(w http.ResponseWriter, r *http.Request) { }) } -// identityModifier is a custom modifier used to force a fixed duration. +// identityModifier is a custom modifier used to force a fixed duration, and set +// the identity URI. type identityModifier struct { + Identity *url.URL NotBefore time.Time NotAfter time.Time } +// Enforce implements the enforcer interface and sets the validity bounds and +// the identity uri to the certificate. func (m *identityModifier) Enforce(cert *x509.Certificate) error { cert.NotBefore = m.NotBefore cert.NotAfter = m.NotAfter + if m.Identity != nil { + var identityURL = m.Identity.String() + for _, u := range cert.URIs { + if u.String() == identityURL { + return nil + } + } + cert.URIs = append(cert.URIs, m.Identity) + } + + return nil +} + +// getIdentityURI returns the first valid UUID URN from the given CSR. +func getIdentityURI(cr *x509.CertificateRequest) *url.URL { + for _, u := range cr.URIs { + s := u.String() + // urn:uuid:xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx + if len(s) == 9+36 && strings.EqualFold(s[:9], "urn:uuid:") { + if _, err := uuid.Parse(s); err == nil { + return u + } + } + } return nil } diff --git a/api/ssh_test.go b/api/ssh_test.go index 2b90dc12e..7d917fa75 100644 --- a/api/ssh_test.go +++ b/api/ssh_test.go @@ -13,18 +13,20 @@ import ( "io" "net/http" "net/http/httptest" + "net/url" "reflect" "strings" "testing" "time" - "golang.org/x/crypto/ssh" - - "github.com/smallstep/assert" + "github.com/google/uuid" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/templates" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" ) var ( @@ -123,9 +125,9 @@ func getSignedHostCertificate() (*ssh.Certificate, error) { func TestSSHCertificate_MarshalJSON(t *testing.T) { user, err := getSignedUserCertificate() - assert.FatalError(t, err) + require.NoError(t, err) host, err := getSignedHostCertificate() - assert.FatalError(t, err) + require.NoError(t, err) userB64 := base64.StdEncoding.EncodeToString(user.Marshal()) hostB64 := base64.StdEncoding.EncodeToString(host.Marshal()) @@ -161,9 +163,9 @@ func TestSSHCertificate_MarshalJSON(t *testing.T) { func TestSSHCertificate_UnmarshalJSON(t *testing.T) { user, err := getSignedUserCertificate() - assert.FatalError(t, err) + require.NoError(t, err) host, err := getSignedHostCertificate() - assert.FatalError(t, err) + require.NoError(t, err) userB64 := base64.StdEncoding.EncodeToString(user.Marshal()) hostB64 := base64.StdEncoding.EncodeToString(host.Marshal()) keyB64 := base64.StdEncoding.EncodeToString(user.Key.Marshal()) @@ -253,9 +255,9 @@ func TestSignSSHRequest_Validate(t *testing.T) { func Test_SSHSign(t *testing.T) { user, err := getSignedUserCertificate() - assert.FatalError(t, err) + require.NoError(t, err) host, err := getSignedHostCertificate() - assert.FatalError(t, err) + require.NoError(t, err) userB64 := base64.StdEncoding.EncodeToString(user.Marshal()) hostB64 := base64.StdEncoding.EncodeToString(host.Marshal()) @@ -264,24 +266,24 @@ func Test_SSHSign(t *testing.T) { PublicKey: user.Key.Marshal(), OTT: "ott", }) - assert.FatalError(t, err) + require.NoError(t, err) hostReq, err := json.Marshal(SSHSignRequest{ PublicKey: host.Key.Marshal(), OTT: "ott", }) - assert.FatalError(t, err) + require.NoError(t, err) userAddReq, err := json.Marshal(SSHSignRequest{ PublicKey: user.Key.Marshal(), OTT: "ott", AddUserPublicKey: user.Key.Marshal(), }) - assert.FatalError(t, err) + require.NoError(t, err) userIdentityReq, err := json.Marshal(SSHSignRequest{ PublicKey: user.Key.Marshal(), OTT: "ott", IdentityCSR: CertificateRequest{parseCertificateRequest(csrPEM)}, }) - assert.FatalError(t, err) + require.NoError(t, err) identityCerts := []*x509.Certificate{ parseCertificate(certPEM), } @@ -355,11 +357,11 @@ func Test_SSHSign(t *testing.T) { func Test_SSHRoots(t *testing.T) { user, err := ssh.NewPublicKey(sshUserKey.Public()) - assert.FatalError(t, err) + require.NoError(t, err) userB64 := base64.StdEncoding.EncodeToString(user.Marshal()) host, err := ssh.NewPublicKey(sshHostKey.Public()) - assert.FatalError(t, err) + require.NoError(t, err) hostB64 := base64.StdEncoding.EncodeToString(host.Marshal()) tests := []struct { @@ -409,11 +411,11 @@ func Test_SSHRoots(t *testing.T) { func Test_SSHFederation(t *testing.T) { user, err := ssh.NewPublicKey(sshUserKey.Public()) - assert.FatalError(t, err) + require.NoError(t, err) userB64 := base64.StdEncoding.EncodeToString(user.Marshal()) host, err := ssh.NewPublicKey(sshHostKey.Public()) - assert.FatalError(t, err) + require.NoError(t, err) hostB64 := base64.StdEncoding.EncodeToString(host.Marshal()) tests := []struct { @@ -471,9 +473,9 @@ func Test_SSHConfig(t *testing.T) { {Name: "ca.tpl", Type: templates.File, Comment: "#", Path: "/etc/ssh/ca.pub", Content: []byte("ecdsa-sha2-nistp256 AAAA...=")}, } userJSON, err := json.Marshal(userOutput) - assert.FatalError(t, err) + require.NoError(t, err) hostJSON, err := json.Marshal(hostOutput) - assert.FatalError(t, err) + require.NoError(t, err) tests := []struct { name string @@ -574,7 +576,7 @@ func Test_SSHGetHosts(t *testing.T) { {HostID: "2", HostTags: []authority.HostTag{{ID: "1", Name: "group", Value: "1"}, {ID: "2", Name: "group", Value: "2"}}, Hostname: "host2"}, } hostsJSON, err := json.Marshal(hosts) - assert.FatalError(t, err) + require.NoError(t, err) tests := []struct { name string @@ -676,7 +678,7 @@ func Test_SSHBastion(t *testing.T) { func TestSSHPublicKey_MarshalJSON(t *testing.T) { key, err := ssh.NewPublicKey(sshUserKey.Public()) - assert.FatalError(t, err) + require.NoError(t, err) keyB64 := base64.StdEncoding.EncodeToString(key.Marshal()) tests := []struct { @@ -705,7 +707,7 @@ func TestSSHPublicKey_MarshalJSON(t *testing.T) { func TestSSHPublicKey_UnmarshalJSON(t *testing.T) { key, err := ssh.NewPublicKey(sshUserKey.Public()) - assert.FatalError(t, err) + require.NoError(t, err) keyB64 := base64.StdEncoding.EncodeToString(key.Marshal()) type args struct { @@ -736,3 +738,98 @@ func TestSSHPublicKey_UnmarshalJSON(t *testing.T) { }) } } + +func Test_identityModifier_Enforce(t *testing.T) { + now := time.Now() + type fields struct { + Identity *url.URL + NotBefore time.Time + NotAfter time.Time + } + type args struct { + cert *x509.Certificate + } + tests := []struct { + name string + fields fields + args args + want *x509.Certificate + assertion assert.ErrorAssertionFunc + }{ + {"ok", fields{&url.URL{Scheme: "urn", Opaque: "uuid:0c4670b2-d9f1-42bb-9045-184836f16733"}, now, now.Add(time.Hour)}, + args{&x509.Certificate{}}, &x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + URIs: []*url.URL{{Scheme: "urn", Opaque: "uuid:0c4670b2-d9f1-42bb-9045-184836f16733"}}, + }, assert.NoError}, + {"ok exists", fields{&url.URL{Scheme: "urn", Opaque: "uuid:0c4670b2-d9f1-42bb-9045-184836f16733"}, now, now.Add(time.Hour)}, + args{&x509.Certificate{ + URIs: []*url.URL{{Scheme: "urn", Opaque: "uuid:0c4670b2-d9f1-42bb-9045-184836f16733"}}, + }}, &x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + URIs: []*url.URL{{Scheme: "urn", Opaque: "uuid:0c4670b2-d9f1-42bb-9045-184836f16733"}}, + }, assert.NoError}, + {"ok append", fields{&url.URL{Scheme: "urn", Opaque: "uuid:0c4670b2-d9f1-42bb-9045-184836f16733"}, now, now.Add(time.Hour)}, + args{&x509.Certificate{ + URIs: []*url.URL{{Scheme: "urn", Opaque: "uuid:27bb66db-e12a-4ff6-9161-aa6b0a98f914"}}, + }}, &x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + URIs: []*url.URL{ + {Scheme: "urn", Opaque: "uuid:27bb66db-e12a-4ff6-9161-aa6b0a98f914"}, + {Scheme: "urn", Opaque: "uuid:0c4670b2-d9f1-42bb-9045-184836f16733"}, + }, + }, assert.NoError}, + {"ok no identity", fields{nil, now, now.Add(time.Hour)}, + args{&x509.Certificate{}}, &x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }, assert.NoError}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m := &identityModifier{ + Identity: tt.fields.Identity, + NotBefore: tt.fields.NotBefore, + NotAfter: tt.fields.NotAfter, + } + tt.assertion(t, m.Enforce(tt.args.cert)) + }) + } +} + +func Test_getIdentityURI(t *testing.T) { + id, err := uuid.Parse("54a2ec9d-a7d9-4b53-8f9a-efcd275e35e1") + require.NoError(t, err) + u, err := url.Parse(id.URN()) + require.NoError(t, err) + + type args struct { + cr *x509.CertificateRequest + } + tests := []struct { + name string + args args + want *url.URL + }{ + {"ok", args{&x509.CertificateRequest{ + URIs: []*url.URL{u}, + }}, &url.URL{Scheme: "urn", Opaque: "uuid:54a2ec9d-a7d9-4b53-8f9a-efcd275e35e1"}}, + {"ok multiple", args{&x509.CertificateRequest{ + URIs: []*url.URL{u, {Scheme: "urn", Opaque: "uuid:f0e74f3a-95fe-4cf6-98e3-68e55b69ba48"}}, + }}, &url.URL{Scheme: "urn", Opaque: "uuid:54a2ec9d-a7d9-4b53-8f9a-efcd275e35e1"}}, + {"ok multiple with invalid", args{&x509.CertificateRequest{ + URIs: []*url.URL{{Scheme: "urn", Opaque: "uuid:f0e74f3a+95fe+4cf6+98e3+68e55b69ba48"}, u}, + }}, &url.URL{Scheme: "urn", Opaque: "uuid:54a2ec9d-a7d9-4b53-8f9a-efcd275e35e1"}}, + {"ok missing", args{&x509.CertificateRequest{ + URIs: []*url.URL{{Scheme: "https", Host: "example.com", Path: "/54a2ec9d-a7d9-4b53-8f9a-efcd275e35e1"}}, + }}, nil}, + {"ok empty", args{&x509.CertificateRequest{}}, nil}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, getIdentityURI(tt.args.cr)) + }) + } +}