-
Notifications
You must be signed in to change notification settings - Fork 21
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add ability to sign with vault transit secret engine
- Loading branch information
1 parent
319d50a
commit aecd110
Showing
7 changed files
with
490 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,236 @@ | ||
package hashivault | ||
|
||
import ( | ||
"context" | ||
"crypto" | ||
"encoding/base64" | ||
"encoding/json" | ||
"errors" | ||
"fmt" | ||
"io" | ||
"os" | ||
"regexp" | ||
"strconv" | ||
|
||
vault "github.com/hashicorp/vault-client-go" | ||
"github.com/hashicorp/vault-client-go/schema" | ||
"github.com/in-toto/go-witness/cryptoutil" | ||
"github.com/in-toto/go-witness/signer/kms" | ||
) | ||
|
||
func init() { | ||
kms.AddProvider(ReferenceScheme, &clientOptions{}, func(ctx context.Context, ksp *kms.KMSSignerProvider) (cryptoutil.Signer, error) { | ||
return LoadSignerVerifier(ctx, ksp) | ||
}) | ||
} | ||
|
||
const ( | ||
ReferenceScheme = "hashivault://" | ||
providerName = "kms-hashivault" | ||
) | ||
|
||
var ( | ||
errReference = errors.New("kms specification should be in the format hashivault://<key>") | ||
referenceRegex = regexp.MustCompile(`^hashivault://(?P<path>\w(([\w-.]+)?\w)?)$`) | ||
) | ||
|
||
func ValidReference(ref string) error { | ||
if !referenceRegex.MatchString(ref) { | ||
return errReference | ||
} | ||
|
||
return nil | ||
} | ||
|
||
type client struct { | ||
client *vault.Client | ||
keyPath string | ||
transitSecretsEnginePath string | ||
keyVersion int32 | ||
} | ||
|
||
func newClient(opts *clientOptions) (*client, error) { | ||
vaultOpts := []vault.ClientOption{vault.WithEnvironment()} | ||
if len(opts.addr) > 0 { | ||
vaultOpts = append(vaultOpts, vault.WithAddress(opts.addr)) | ||
} | ||
|
||
vaultClient, err := vault.New(vaultOpts...) | ||
if err != nil { | ||
return nil, fmt.Errorf("could not create vault client: %w", err) | ||
} | ||
|
||
token := "" | ||
if len(opts.tokenFile) > 0 { | ||
tokenBytes, err := os.ReadFile(opts.tokenFile) | ||
if err != nil { | ||
return nil, fmt.Errorf("could not read vault token file: %w", err) | ||
} | ||
|
||
token = string(tokenBytes) | ||
} | ||
|
||
if len(token) > 0 { | ||
if err := vaultClient.SetToken(token); err != nil { | ||
return nil, fmt.Errorf("invalid vault token") | ||
} | ||
} | ||
|
||
return &client{ | ||
client: vaultClient, | ||
keyPath: opts.keyPath, | ||
transitSecretsEnginePath: opts.transitSecretEnginePath, | ||
keyVersion: opts.keyVersion, | ||
}, nil | ||
} | ||
|
||
func (c *client) sign(ctx context.Context, digest []byte, hashFunc crypto.Hash) ([]byte, error) { | ||
hashStr, ok := supportedHashesToString[hashFunc] | ||
if !ok { | ||
return nil, fmt.Errorf("unsupported hash algorithm: %v", hashFunc.String()) | ||
} | ||
|
||
resp, err := c.client.Secrets.TransitSignWithAlgorithm( | ||
ctx, | ||
c.keyPath, | ||
hashStr, | ||
schema.TransitSignWithAlgorithmRequest{ | ||
SignatureAlgorithm: "pkcs1v15", | ||
HashAlgorithm: hashStr, | ||
KeyVersion: c.keyVersion, | ||
Prehashed: true, | ||
Input: base64.StdEncoding.Strict().EncodeToString(digest), | ||
}, | ||
c.requestOptions()..., | ||
) | ||
|
||
if err != nil { | ||
return nil, fmt.Errorf("could not sign: %w", err) | ||
} | ||
|
||
signature, ok := resp.Data["signature"] | ||
if !ok { | ||
return nil, fmt.Errorf("no signature in response: %w", err) | ||
} | ||
|
||
sigStr, ok := signature.(string) | ||
if !ok { | ||
return nil, fmt.Errorf("invalid signature in response") | ||
} | ||
|
||
return []byte(sigStr), nil | ||
} | ||
|
||
func (c *client) verify(ctx context.Context, r io.Reader, sig []byte, hashFunc crypto.Hash) error { | ||
hashStr, ok := supportedHashesToString[hashFunc] | ||
if !ok { | ||
return fmt.Errorf("unsupported hash algorithm: %v", hashFunc.String()) | ||
} | ||
|
||
digest, err := cryptoutil.Digest(r, hashFunc) | ||
if err != nil { | ||
return fmt.Errorf("could not calculate digest: %w", err) | ||
} | ||
|
||
resp, err := c.client.Secrets.TransitVerifyWithAlgorithm( | ||
ctx, | ||
c.keyPath, | ||
hashStr, | ||
schema.TransitVerifyWithAlgorithmRequest{ | ||
SignatureAlgorithm: "pkcs1v15", | ||
HashAlgorithm: hashStr, | ||
Prehashed: true, | ||
Signature: string(sig), | ||
Input: base64.StdEncoding.Strict().EncodeToString(digest), | ||
}, | ||
c.requestOptions()..., | ||
) | ||
|
||
if err != nil { | ||
return fmt.Errorf("could not verify: %w", err) | ||
} | ||
|
||
valid, ok := resp.Data["valid"] | ||
if !ok { | ||
return fmt.Errorf("invalid response") | ||
} | ||
|
||
validBool, ok := valid.(bool) | ||
if !ok { | ||
return fmt.Errorf("expected valid to be bool but is %T", valid) | ||
} | ||
|
||
if !validBool { | ||
return fmt.Errorf("failed verification") | ||
} | ||
|
||
return nil | ||
} | ||
|
||
func (c *client) getPublicKeyBytes(ctx context.Context) ([]byte, error) { | ||
resp, err := c.client.Secrets.TransitReadKey( | ||
ctx, | ||
c.keyPath, | ||
c.requestOptions()..., | ||
) | ||
|
||
if err != nil { | ||
return nil, fmt.Errorf("could not read key: %w", err) | ||
} | ||
|
||
keyVersion := strconv.FormatInt(int64(c.keyVersion), 10) | ||
if keyVersion == "0" { | ||
latestVersion, ok := resp.Data["lastest_version"] | ||
if !ok { | ||
return nil, fmt.Errorf("latest key version not in response") | ||
} | ||
|
||
latestVersionNum, ok := latestVersion.(json.Number) | ||
if !ok { | ||
return nil, fmt.Errorf("latest version not a number") | ||
} | ||
|
||
keyVersion = latestVersionNum.String() | ||
} | ||
|
||
keys, ok := resp.Data["keys"] | ||
if !ok { | ||
return nil, fmt.Errorf("no keys in response") | ||
} | ||
|
||
keysMap, ok := keys.(map[string]interface{}) | ||
if !ok { | ||
return nil, fmt.Errorf("unexpected keys value in response") | ||
} | ||
|
||
keyInfo, ok := keysMap[keyVersion] | ||
if !ok { | ||
return nil, fmt.Errorf("could not find key with version %v", keyVersion) | ||
} | ||
|
||
keyMap, ok := keyInfo.(map[string]interface{}) | ||
if !ok { | ||
return nil, fmt.Errorf("unexpected key data format in response") | ||
} | ||
|
||
publicKey, ok := keyMap["public_key"] | ||
if !ok { | ||
return nil, fmt.Errorf("public key not in key data") | ||
} | ||
|
||
publicKeyStr, ok := publicKey.(string) | ||
if !ok { | ||
return nil, fmt.Errorf("unexpected public key data in response") | ||
} | ||
|
||
return []byte(publicKeyStr), nil | ||
} | ||
|
||
func (c *client) requestOptions() []vault.RequestOption { | ||
opts := []vault.RequestOption{} | ||
if len(c.transitSecretsEnginePath) > 0 { | ||
opts = append(opts, vault.WithMountPath(c.transitSecretsEnginePath)) | ||
} | ||
|
||
return opts | ||
} |
Oops, something went wrong.