Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create identity uri on any provisioner #1922

Merged
merged 2 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion api/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@ import (
"encoding/base64"
"encoding/json"
"net/http"
"net/url"
"strings"
"time"

"github.com/google/uuid"
"github.com/pkg/errors"
"golang.org/x/crypto/ssh"

Expand Down Expand Up @@ -326,6 +329,7 @@ func SSHSign(w http.ResponseWriter, r *http.Request) {

// Enforce the same duration as ssh certificate.
signOpts = append(signOpts, &identityModifier{
Identity: getIdentityURI(cr),
NotBefore: time.Unix(int64(cert.ValidAfter), 0),
NotAfter: time.Unix(int64(cert.ValidBefore), 0),
})
Expand Down Expand Up @@ -498,14 +502,42 @@ func SSHBastion(w http.ResponseWriter, r *http.Request) {
})
}

// identityModifier is a custom modifier used to force a fixed duration.
// identityModifier is a custom modifier used to force a fixed duration. and set
maraino marked this conversation as resolved.
Show resolved Hide resolved
// the identity URI.
type identityModifier struct {
Identity *url.URL
NotBefore time.Time
NotAfter time.Time
}

// Enforce implements the enforcer interface and sets the validity bounds and
// the identity uri to the certificate.
func (m *identityModifier) Enforce(cert *x509.Certificate) error {
cert.NotBefore = m.NotBefore
cert.NotAfter = m.NotAfter
if m.Identity != nil {
var identityURL = m.Identity.String()
for _, u := range cert.URIs {
if u.String() == identityURL {
return nil
}
}
cert.URIs = append(cert.URIs, m.Identity)
}

return nil
}

// getIdentityURI returns the first valid UUID URN from the given CSR.
func getIdentityURI(cr *x509.CertificateRequest) *url.URL {
for _, u := range cr.URIs {
s := u.String()
// urn:uuid:xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx
if len(s) == 9+36 && strings.EqualFold(s[:9], "urn:uuid:") {
if _, err := uuid.Parse(s); err == nil {
return u
}
}
}
return nil
}
141 changes: 119 additions & 22 deletions api/ssh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,20 @@ import (
"io"
"net/http"
"net/http/httptest"
"net/url"
"reflect"
"strings"
"testing"
"time"

"golang.org/x/crypto/ssh"

"github.com/smallstep/assert"
"github.com/google/uuid"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/logging"
"github.com/smallstep/certificates/templates"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"
)

var (
Expand Down Expand Up @@ -123,9 +125,9 @@ func getSignedHostCertificate() (*ssh.Certificate, error) {

func TestSSHCertificate_MarshalJSON(t *testing.T) {
user, err := getSignedUserCertificate()
assert.FatalError(t, err)
require.NoError(t, err)
host, err := getSignedHostCertificate()
assert.FatalError(t, err)
require.NoError(t, err)
userB64 := base64.StdEncoding.EncodeToString(user.Marshal())
hostB64 := base64.StdEncoding.EncodeToString(host.Marshal())

Expand Down Expand Up @@ -161,9 +163,9 @@ func TestSSHCertificate_MarshalJSON(t *testing.T) {

func TestSSHCertificate_UnmarshalJSON(t *testing.T) {
user, err := getSignedUserCertificate()
assert.FatalError(t, err)
require.NoError(t, err)
host, err := getSignedHostCertificate()
assert.FatalError(t, err)
require.NoError(t, err)
userB64 := base64.StdEncoding.EncodeToString(user.Marshal())
hostB64 := base64.StdEncoding.EncodeToString(host.Marshal())
keyB64 := base64.StdEncoding.EncodeToString(user.Key.Marshal())
Expand Down Expand Up @@ -253,9 +255,9 @@ func TestSignSSHRequest_Validate(t *testing.T) {

func Test_SSHSign(t *testing.T) {
user, err := getSignedUserCertificate()
assert.FatalError(t, err)
require.NoError(t, err)
host, err := getSignedHostCertificate()
assert.FatalError(t, err)
require.NoError(t, err)

userB64 := base64.StdEncoding.EncodeToString(user.Marshal())
hostB64 := base64.StdEncoding.EncodeToString(host.Marshal())
Expand All @@ -264,24 +266,24 @@ func Test_SSHSign(t *testing.T) {
PublicKey: user.Key.Marshal(),
OTT: "ott",
})
assert.FatalError(t, err)
require.NoError(t, err)
hostReq, err := json.Marshal(SSHSignRequest{
PublicKey: host.Key.Marshal(),
OTT: "ott",
})
assert.FatalError(t, err)
require.NoError(t, err)
userAddReq, err := json.Marshal(SSHSignRequest{
PublicKey: user.Key.Marshal(),
OTT: "ott",
AddUserPublicKey: user.Key.Marshal(),
})
assert.FatalError(t, err)
require.NoError(t, err)
userIdentityReq, err := json.Marshal(SSHSignRequest{
PublicKey: user.Key.Marshal(),
OTT: "ott",
IdentityCSR: CertificateRequest{parseCertificateRequest(csrPEM)},
})
assert.FatalError(t, err)
require.NoError(t, err)
identityCerts := []*x509.Certificate{
parseCertificate(certPEM),
}
Expand Down Expand Up @@ -355,11 +357,11 @@ func Test_SSHSign(t *testing.T) {

func Test_SSHRoots(t *testing.T) {
user, err := ssh.NewPublicKey(sshUserKey.Public())
assert.FatalError(t, err)
require.NoError(t, err)
userB64 := base64.StdEncoding.EncodeToString(user.Marshal())

host, err := ssh.NewPublicKey(sshHostKey.Public())
assert.FatalError(t, err)
require.NoError(t, err)
hostB64 := base64.StdEncoding.EncodeToString(host.Marshal())

tests := []struct {
Expand Down Expand Up @@ -409,11 +411,11 @@ func Test_SSHRoots(t *testing.T) {

func Test_SSHFederation(t *testing.T) {
user, err := ssh.NewPublicKey(sshUserKey.Public())
assert.FatalError(t, err)
require.NoError(t, err)
userB64 := base64.StdEncoding.EncodeToString(user.Marshal())

host, err := ssh.NewPublicKey(sshHostKey.Public())
assert.FatalError(t, err)
require.NoError(t, err)
hostB64 := base64.StdEncoding.EncodeToString(host.Marshal())

tests := []struct {
Expand Down Expand Up @@ -471,9 +473,9 @@ func Test_SSHConfig(t *testing.T) {
{Name: "ca.tpl", Type: templates.File, Comment: "#", Path: "/etc/ssh/ca.pub", Content: []byte("ecdsa-sha2-nistp256 AAAA...=")},
}
userJSON, err := json.Marshal(userOutput)
assert.FatalError(t, err)
require.NoError(t, err)
hostJSON, err := json.Marshal(hostOutput)
assert.FatalError(t, err)
require.NoError(t, err)

tests := []struct {
name string
Expand Down Expand Up @@ -574,7 +576,7 @@ func Test_SSHGetHosts(t *testing.T) {
{HostID: "2", HostTags: []authority.HostTag{{ID: "1", Name: "group", Value: "1"}, {ID: "2", Name: "group", Value: "2"}}, Hostname: "host2"},
}
hostsJSON, err := json.Marshal(hosts)
assert.FatalError(t, err)
require.NoError(t, err)

tests := []struct {
name string
Expand Down Expand Up @@ -676,7 +678,7 @@ func Test_SSHBastion(t *testing.T) {

func TestSSHPublicKey_MarshalJSON(t *testing.T) {
key, err := ssh.NewPublicKey(sshUserKey.Public())
assert.FatalError(t, err)
require.NoError(t, err)
keyB64 := base64.StdEncoding.EncodeToString(key.Marshal())

tests := []struct {
Expand Down Expand Up @@ -705,7 +707,7 @@ func TestSSHPublicKey_MarshalJSON(t *testing.T) {

func TestSSHPublicKey_UnmarshalJSON(t *testing.T) {
key, err := ssh.NewPublicKey(sshUserKey.Public())
assert.FatalError(t, err)
require.NoError(t, err)
keyB64 := base64.StdEncoding.EncodeToString(key.Marshal())

type args struct {
Expand Down Expand Up @@ -736,3 +738,98 @@ func TestSSHPublicKey_UnmarshalJSON(t *testing.T) {
})
}
}

func Test_identityModifier_Enforce(t *testing.T) {
now := time.Now()
type fields struct {
Identity *url.URL
NotBefore time.Time
NotAfter time.Time
}
type args struct {
cert *x509.Certificate
}
tests := []struct {
name string
fields fields
args args
want *x509.Certificate
assertion assert.ErrorAssertionFunc
}{
{"ok", fields{&url.URL{Scheme: "urn", Opaque: "uuid:0c4670b2-d9f1-42bb-9045-184836f16733"}, now, now.Add(time.Hour)},
args{&x509.Certificate{}}, &x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
URIs: []*url.URL{{Scheme: "urn", Opaque: "uuid:0c4670b2-d9f1-42bb-9045-184836f16733"}},
}, assert.NoError},
{"ok exists", fields{&url.URL{Scheme: "urn", Opaque: "uuid:0c4670b2-d9f1-42bb-9045-184836f16733"}, now, now.Add(time.Hour)},
args{&x509.Certificate{
URIs: []*url.URL{{Scheme: "urn", Opaque: "uuid:0c4670b2-d9f1-42bb-9045-184836f16733"}},
}}, &x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
URIs: []*url.URL{{Scheme: "urn", Opaque: "uuid:0c4670b2-d9f1-42bb-9045-184836f16733"}},
}, assert.NoError},
{"ok append", fields{&url.URL{Scheme: "urn", Opaque: "uuid:0c4670b2-d9f1-42bb-9045-184836f16733"}, now, now.Add(time.Hour)},
args{&x509.Certificate{
URIs: []*url.URL{{Scheme: "urn", Opaque: "uuid:27bb66db-e12a-4ff6-9161-aa6b0a98f914"}},
}}, &x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
URIs: []*url.URL{
{Scheme: "urn", Opaque: "uuid:27bb66db-e12a-4ff6-9161-aa6b0a98f914"},
{Scheme: "urn", Opaque: "uuid:0c4670b2-d9f1-42bb-9045-184836f16733"},
},
}, assert.NoError},
{"ok no identity", fields{nil, now, now.Add(time.Hour)},
args{&x509.Certificate{}}, &x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
}, assert.NoError},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
m := &identityModifier{
Identity: tt.fields.Identity,
NotBefore: tt.fields.NotBefore,
NotAfter: tt.fields.NotAfter,
}
tt.assertion(t, m.Enforce(tt.args.cert))
})
}
}

func Test_getIdentityURI(t *testing.T) {
id, err := uuid.Parse("54a2ec9d-a7d9-4b53-8f9a-efcd275e35e1")
require.NoError(t, err)
u, err := url.Parse(id.URN())
require.NoError(t, err)

type args struct {
cr *x509.CertificateRequest
}
tests := []struct {
name string
args args
want *url.URL
}{
{"ok", args{&x509.CertificateRequest{
URIs: []*url.URL{u},
}}, &url.URL{Scheme: "urn", Opaque: "uuid:54a2ec9d-a7d9-4b53-8f9a-efcd275e35e1"}},
{"ok multiple", args{&x509.CertificateRequest{
URIs: []*url.URL{u, {Scheme: "urn", Opaque: "uuid:f0e74f3a-95fe-4cf6-98e3-68e55b69ba48"}},
}}, &url.URL{Scheme: "urn", Opaque: "uuid:54a2ec9d-a7d9-4b53-8f9a-efcd275e35e1"}},
{"ok multiple with invalid", args{&x509.CertificateRequest{
URIs: []*url.URL{{Scheme: "urn", Opaque: "uuid:f0e74f3a+95fe+4cf6+98e3+68e55b69ba48"}, u},
}}, &url.URL{Scheme: "urn", Opaque: "uuid:54a2ec9d-a7d9-4b53-8f9a-efcd275e35e1"}},
{"ok missing", args{&x509.CertificateRequest{
URIs: []*url.URL{{Scheme: "https", Host: "example.com", Path: "/54a2ec9d-a7d9-4b53-8f9a-efcd275e35e1"}},
}}, nil},
{"ok empty", args{&x509.CertificateRequest{}}, nil},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.want, getIdentityURI(tt.args.cr))
})
}
}
Loading