Skip to content

Commit

Permalink
Merge pull request #692 from smallstep/max/context
Browse files Browse the repository at this point in the history
Context management
  • Loading branch information
dopey authored Nov 17, 2021
2 parents 440616c + df28436 commit de2ce5c
Show file tree
Hide file tree
Showing 18 changed files with 263 additions and 133 deletions.
1 change: 1 addition & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ jobs:
run: V=1 make ci
-
name: Codecov
if: matrix.go == '1.17'
uses: codecov/[email protected]
with:
file: ./coverage.out # optional
Expand Down
6 changes: 3 additions & 3 deletions authority/export.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (

"github.com/pkg/errors"
"github.com/smallstep/certificates/authority/provisioner"
"go.step.sm/cli-utils/config"
"go.step.sm/cli-utils/step"
"go.step.sm/linkedca"
"google.golang.org/protobuf/types/known/structpb"
)
Expand Down Expand Up @@ -245,7 +245,7 @@ func mustReadFileOrURI(fn string, m map[string][]byte) string {
return ""
}

stepPath := filepath.ToSlash(config.StepPath())
stepPath := filepath.ToSlash(step.Path())
if !strings.HasSuffix(stepPath, "/") {
stepPath += "/"
}
Expand All @@ -257,7 +257,7 @@ func mustReadFileOrURI(fn string, m map[string][]byte) string {
panic(err)
}
if ok {
b, err := os.ReadFile(config.StepAbs(fn))
b, err := os.ReadFile(step.Abs(fn))
if err != nil {
panic(errors.Wrapf(err, "error reading %s", fn))
}
Expand Down
10 changes: 7 additions & 3 deletions authority/provisioners.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
"github.com/smallstep/certificates/authority/config"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs"
step "go.step.sm/cli-utils/config"
"go.step.sm/cli-utils/step"
"go.step.sm/cli-utils/ui"
"go.step.sm/crypto/jose"
"go.step.sm/linkedca"
Expand Down Expand Up @@ -238,6 +238,8 @@ func (a *Authority) RemoveProvisioner(ctx context.Context, id string) error {
return nil
}

// CreateFirstProvisioner creates and stores the first provisioner when using
// admin database provisioner storage.
func CreateFirstProvisioner(ctx context.Context, db admin.DB, password string) (*linkedca.Provisioner, error) {
if password == "" {
pass, err := ui.PromptPasswordGenerate("Please enter the password to encrypt your first provisioner, leave empty and we'll generate one")
Expand Down Expand Up @@ -287,6 +289,7 @@ func CreateFirstProvisioner(ctx context.Context, db admin.DB, password string) (
return p, nil
}

// ValidateClaims validates the Claims type.
func ValidateClaims(c *linkedca.Claims) error {
if c == nil {
return nil
Expand All @@ -313,6 +316,7 @@ func ValidateClaims(c *linkedca.Claims) error {
return nil
}

// ValidateDurations validates the Durations type.
func ValidateDurations(d *linkedca.Durations) error {
var (
err error
Expand Down Expand Up @@ -523,7 +527,7 @@ func provisionerOptionsToLinkedca(p *provisioner.Options) (*linkedca.Template, *
if p.X509.Template != "" {
x509Template.Template = []byte(p.SSH.Template)
} else if p.X509.TemplateFile != "" {
filename := step.StepAbs(p.X509.TemplateFile)
filename := step.Abs(p.X509.TemplateFile)
if x509Template.Template, err = os.ReadFile(filename); err != nil {
return nil, nil, errors.Wrap(err, "error reading x509 template")
}
Expand All @@ -539,7 +543,7 @@ func provisionerOptionsToLinkedca(p *provisioner.Options) (*linkedca.Template, *
if p.SSH.Template != "" {
sshTemplate.Template = []byte(p.SSH.Template)
} else if p.SSH.TemplateFile != "" {
filename := step.StepAbs(p.SSH.TemplateFile)
filename := step.Abs(p.SSH.TemplateFile)
if sshTemplate.Template, err = os.ReadFile(filename); err != nil {
return nil, nil, errors.Wrap(err, "error reading ssh template")
}
Expand Down
9 changes: 9 additions & 0 deletions authority/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,15 @@ func (a *Authority) GetSSHConfig(ctx context.Context, typ string, data map[strin
if err != nil {
return nil, err
}

// Backwards compatibility for version of the cli older than v0.18.0.
// Before v0.18.0 we were not passing any value for SSHTemplateVersionKey
// from the cli.
if o.Name == "step_includes.tpl" && data[templates.SSHTemplateVersionKey] == "" {
o.Type = templates.File
o.Path = strings.TrimPrefix(o.Path, "${STEPPATH}/")
}

output = append(output, o)
}
return output, nil
Expand Down
29 changes: 29 additions & 0 deletions authority/ssh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,32 @@ func TestAuthority_GetSSHConfig(t *testing.T) {
{Name: "sshd_config.tpl", Type: templates.File, Comment: "#", Path: "/etc/ssh/sshd_config", Content: []byte("Match all\n\tTrustedUserCAKeys /etc/ssh/ca.pub\n\tHostCertificate /etc/ssh/ssh_host_ecdsa_key-cert.pub\n\tHostKey /etc/ssh/ssh_host_ecdsa_key")},
}

tmplConfigUserIncludes := &templates.Templates{
SSH: &templates.SSHTemplates{
User: []templates.Template{
{Name: "step_includes.tpl", Type: templates.PrependLine, TemplatePath: "./testdata/templates/step_includes.tpl", Path: "${STEPPATH}/ssh/includes", Comment: "#"},
},
},
Data: map[string]interface{}{
"Step": &templates.Step{
SSH: templates.StepSSH{
UserKey: user,
HostKey: host,
},
},
},
}

userOutputEmptyData := []templates.Output{
{Name: "step_includes.tpl", Type: templates.File, Comment: "#", Path: "ssh/includes", Content: []byte("Include \"<no value>/ssh/config\"\n")},
}
userOutputWithoutTemplateVersion := []templates.Output{
{Name: "step_includes.tpl", Type: templates.File, Comment: "#", Path: "ssh/includes", Content: []byte("Include \"/home/user/.step/ssh/config\"\n")},
}
userOutputWithTemplateVersion := []templates.Output{
{Name: "step_includes.tpl", Type: templates.PrependLine, Comment: "#", Path: "${STEPPATH}/ssh/includes", Content: []byte("Include \"/home/user/.step/ssh/config\"\n")},
}

tmplConfigErr := &templates.Templates{
SSH: &templates.SSHTemplates{
User: []templates.Template{
Expand Down Expand Up @@ -542,6 +568,9 @@ func TestAuthority_GetSSHConfig(t *testing.T) {
{"host", fields{tmplConfig, nil, hostSigner}, args{"host", nil}, hostOutput, false},
{"userWithData", fields{tmplConfigWithUserData, userSigner, hostSigner}, args{"user", map[string]string{"StepPath": "/home/user/.step"}}, userOutputWithUserData, false},
{"hostWithData", fields{tmplConfigWithUserData, userSigner, hostSigner}, args{"host", map[string]string{"Certificate": "ssh_host_ecdsa_key-cert.pub", "Key": "ssh_host_ecdsa_key"}}, hostOutputWithUserData, false},
{"userIncludesEmptyData", fields{tmplConfigUserIncludes, userSigner, hostSigner}, args{"user", nil}, userOutputEmptyData, false},
{"userIncludesWithoutTemplateVersion", fields{tmplConfigUserIncludes, userSigner, hostSigner}, args{"user", map[string]string{"StepPath": "/home/user/.step"}}, userOutputWithoutTemplateVersion, false},
{"userIncludesWithTemplateVersion", fields{tmplConfigUserIncludes, userSigner, hostSigner}, args{"user", map[string]string{"StepPath": "/home/user/.step", "StepSSHTemplateVersion": "v2"}}, userOutputWithTemplateVersion, false},
{"disabled", fields{tmplConfig, nil, nil}, args{"host", nil}, nil, true},
{"badType", fields{tmplConfig, userSigner, hostSigner}, args{"bad", nil}, nil, true},
{"userError", fields{tmplConfigErr, userSigner, hostSigner}, args{"user", nil}, nil, true},
Expand Down
1 change: 1 addition & 0 deletions authority/testdata/templates/step_includes.tpl
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{{- if or .User.GOOS "none" | eq "windows" }}Include "{{ .User.StepPath | replace "\\" "/" | trimPrefix "C:" }}/ssh/config"{{- else }}Include "{{.User.StepPath}}/ssh/config"{{- end }}
17 changes: 14 additions & 3 deletions ca/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import (
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/ca/identity"
"github.com/smallstep/certificates/errs"
"go.step.sm/cli-utils/config"
"go.step.sm/cli-utils/step"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/keyutil"
"go.step.sm/crypto/pemutil"
Expand Down Expand Up @@ -225,7 +225,7 @@ func (o *clientOptions) getTransport(endpoint string) (tr http.RoundTripper, err
return tr, nil
}

// WithTransport adds a custom transport to the Client. It will fail if a
// WithTransport adds a custom transport to the Client. It will fail if a
// previous option to create the transport has been configured.
func WithTransport(tr http.RoundTripper) ClientOption {
return func(o *clientOptions) error {
Expand All @@ -237,6 +237,17 @@ func WithTransport(tr http.RoundTripper) ClientOption {
}
}

// WithInsecure adds a insecure transport that bypasses TLS verification.
func WithInsecure() ClientOption {
return func(o *clientOptions) error {
o.transport = &http.Transport{
Proxy: http.ProxyFromEnvironment,
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
return nil
}
}

// WithRootFile will create the transport using the given root certificate. It
// will fail if a previous option to create the transport has been configured.
func WithRootFile(filename string) ClientOption {
Expand Down Expand Up @@ -1294,7 +1305,7 @@ func createCertificateRequest(commonName string, sans []string, key crypto.Priva
// getRootCAPath returns the path where the root CA is stored based on the
// STEPPATH environment variable.
func getRootCAPath() string {
return filepath.Join(config.StepPath(), "certs", "root_ca.crt")
return filepath.Join(step.Path(), "certs", "root_ca.crt")
}

func readJSON(r io.ReadCloser, v interface{}) error {
Expand Down
13 changes: 7 additions & 6 deletions ca/identity/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,22 @@ func (c *Client) ResolveReference(ref *url.URL) *url.URL {
// $STEPPATH/config/defaults.json and the identity defined in
// $STEPPATH/config/identity.json
func LoadClient() (*Client, error) {
b, err := os.ReadFile(DefaultsFile)
defaultsFile := DefaultsFile()
b, err := os.ReadFile(defaultsFile)
if err != nil {
return nil, errors.Wrapf(err, "error reading %s", DefaultsFile)
return nil, errors.Wrapf(err, "error reading %s", defaultsFile)
}

var defaults defaultsConfig
if err := json.Unmarshal(b, &defaults); err != nil {
return nil, errors.Wrapf(err, "error unmarshaling %s", DefaultsFile)
return nil, errors.Wrapf(err, "error unmarshaling %s", defaultsFile)
}
if err := defaults.Validate(); err != nil {
return nil, errors.Wrapf(err, "error validating %s", DefaultsFile)
return nil, errors.Wrapf(err, "error validating %s", defaultsFile)
}
caURL, err := url.Parse(defaults.CaURL)
if err != nil {
return nil, errors.Wrapf(err, "error validating %s", DefaultsFile)
return nil, errors.Wrapf(err, "error validating %s", defaultsFile)
}
if caURL.Scheme == "" {
caURL.Scheme = "https"
Expand All @@ -52,7 +53,7 @@ func LoadClient() (*Client, error) {
return nil, err
}
if err := identity.Validate(); err != nil {
return nil, errors.Wrapf(err, "error validating %s", IdentityFile)
return nil, errors.Wrapf(err, "error validating %s", IdentityFile())
}
if kind := identity.Kind(); kind != MutualTLS {
return nil, errors.Errorf("unsupported identity %s: only mTLS is currently supported", kind)
Expand Down
42 changes: 24 additions & 18 deletions ca/identity/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ import (
"testing"
)

func returnInput(val string) func() string {
return func() string {
return val
}
}

func TestClient(t *testing.T) {
oldIdentityFile := IdentityFile
oldDefaultsFile := DefaultsFile
Expand All @@ -19,8 +25,8 @@ func TestClient(t *testing.T) {
DefaultsFile = oldDefaultsFile
}()

IdentityFile = "testdata/config/identity.json"
DefaultsFile = "testdata/config/defaults.json"
IdentityFile = returnInput("testdata/config/identity.json")
DefaultsFile = returnInput("testdata/config/defaults.json")

client, err := LoadClient()
if err != nil {
Expand Down Expand Up @@ -140,36 +146,36 @@ func TestLoadClient(t *testing.T) {
wantErr bool
}{
{"ok", func() {
IdentityFile = "testdata/config/identity.json"
DefaultsFile = "testdata/config/defaults.json"
IdentityFile = returnInput("testdata/config/identity.json")
DefaultsFile = returnInput("testdata/config/defaults.json")
}, expected, false},
{"fail identity", func() {
IdentityFile = "testdata/config/missing.json"
DefaultsFile = "testdata/config/defaults.json"
IdentityFile = returnInput("testdata/config/missing.json")
DefaultsFile = returnInput("testdata/config/defaults.json")
}, nil, true},
{"fail identity", func() {
IdentityFile = "testdata/config/fail.json"
DefaultsFile = "testdata/config/defaults.json"
IdentityFile = returnInput("testdata/config/fail.json")
DefaultsFile = returnInput("testdata/config/defaults.json")
}, nil, true},
{"fail defaults", func() {
IdentityFile = "testdata/config/identity.json"
DefaultsFile = "testdata/config/missing.json"
IdentityFile = returnInput("testdata/config/identity.json")
DefaultsFile = returnInput("testdata/config/missing.json")
}, nil, true},
{"fail defaults", func() {
IdentityFile = "testdata/config/identity.json"
DefaultsFile = "testdata/config/fail.json"
IdentityFile = returnInput("testdata/config/identity.json")
DefaultsFile = returnInput("testdata/config/fail.json")
}, nil, true},
{"fail ca", func() {
IdentityFile = "testdata/config/identity.json"
DefaultsFile = "testdata/config/badca.json"
IdentityFile = returnInput("testdata/config/identity.json")
DefaultsFile = returnInput("testdata/config/badca.json")
}, nil, true},
{"fail root", func() {
IdentityFile = "testdata/config/identity.json"
DefaultsFile = "testdata/config/badroot.json"
IdentityFile = returnInput("testdata/config/identity.json")
DefaultsFile = returnInput("testdata/config/badroot.json")
}, nil, true},
{"fail type", func() {
IdentityFile = "testdata/config/badIdentity.json"
DefaultsFile = "testdata/config/defaults.json"
IdentityFile = returnInput("testdata/config/badIdentity.json")
DefaultsFile = returnInput("testdata/config/defaults.json")
}, nil, true},
}
for _, tt := range tests {
Expand Down
35 changes: 18 additions & 17 deletions ca/identity/identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (

"github.com/pkg/errors"
"github.com/smallstep/certificates/api"
"go.step.sm/cli-utils/config"
"go.step.sm/cli-utils/step"
"go.step.sm/crypto/pemutil"
)

Expand All @@ -38,11 +38,18 @@ const TunnelTLS Type = "tTLS"
// DefaultLeeway is the duration for matching not before claims.
const DefaultLeeway = 1 * time.Minute

// IdentityFile contains the location of the identity file.
var IdentityFile = filepath.Join(config.StepPath(), "config", "identity.json")
var (
identityDir = step.IdentityPath
configDir = step.ConfigPath

// IdentityFile contains a pointer to a function that outputs the location of
// the identity file.
IdentityFile = step.IdentityFile

// DefaultsFile contains the location of the defaults file.
var DefaultsFile = filepath.Join(config.StepPath(), "config", "defaults.json")
// DefaultsFile contains a prointer a function that outputs the location of the
// defaults configuration file.
DefaultsFile = step.DefaultsFile
)

// Identity represents the identity file that can be used to authenticate with
// the CA.
Expand Down Expand Up @@ -73,23 +80,17 @@ func LoadIdentity(filename string) (*Identity, error) {

// LoadDefaultIdentity loads the default identity.
func LoadDefaultIdentity() (*Identity, error) {
return LoadIdentity(IdentityFile)
return LoadIdentity(IdentityFile())
}

// configDir and identityDir are used in WriteDefaultIdentity for testing
// purposes.
var (
configDir = filepath.Join(config.StepPath(), "config")
identityDir = filepath.Join(config.StepPath(), "identity")
)

// WriteDefaultIdentity writes the given certificates and key and the
// identity.json pointing to the new files.
func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) error {
if err := os.MkdirAll(configDir, 0700); err != nil {
if err := os.MkdirAll(configDir(), 0700); err != nil {
return errors.Wrap(err, "error creating config directory")
}

identityDir := identityDir()
if err := os.MkdirAll(identityDir, 0700); err != nil {
return errors.Wrap(err, "error creating identity directory")
}
Expand Down Expand Up @@ -126,7 +127,7 @@ func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) er
}); err != nil {
return errors.Wrap(err, "error writing identity json")
}
if err := os.WriteFile(IdentityFile, buf.Bytes(), 0600); err != nil {
if err := os.WriteFile(IdentityFile(), buf.Bytes(), 0600); err != nil {
return errors.Wrap(err, "error writing identity certificate")
}

Expand All @@ -135,7 +136,7 @@ func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) er

// WriteIdentityCertificate writes the identity certificate to disk.
func WriteIdentityCertificate(certChain []api.Certificate) error {
filename := filepath.Join(identityDir, "identity.crt")
filename := filepath.Join(identityDir(), "identity.crt")
return writeCertificate(filename, certChain)
}

Expand Down Expand Up @@ -318,7 +319,7 @@ func (i *Identity) Renew(client Renewer) error {
return errors.Wrap(err, "error encoding identity certificate")
}
}
certFilename := filepath.Join(identityDir, "identity.crt")
certFilename := filepath.Join(identityDir(), "identity.crt")
if err := os.WriteFile(certFilename, buf.Bytes(), 0600); err != nil {
return errors.Wrap(err, "error writing identity certificate")
}
Expand Down
Loading

0 comments on commit de2ce5c

Please sign in to comment.