Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: introduce AWS KMS types.MessageTypeRaw for AWS KMS signing operations #573

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 65 additions & 15 deletions kms/awskms/signer.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,48 @@ import (
"go.step.sm/crypto/pemutil"
)

// AWSOptions implements the crypto.SignerOpts interface, it provides a Raw
// boolean field to indicate to the AWS KMS operation that the MessageType is
// RAW.
//
// Example:
//
// // Sign a raw message with KMS
// client := kms.NewFromConfig(cfg)
// kmsSigner, err := awskms.NewSigner(client, "my-key-id")
// if err != nil {
// // handle error ...
// }
// raw := []byte("my raw message")
// sig, err := kmsSigner.Sign(rand.Reader, raw, &awskms.AWSOptions{
// Raw: true,
// Options: crypto.SHA256,
// })
// if err != nil {
// // handle error ...
// }
type AWSOptions struct {
// Raw specifies to the AWS KMS operation that MessageType is RAW.
Raw bool
Options crypto.SignerOpts
}

// HashFunc implements crypto.SignerOpts.
func (a *AWSOptions) HashFunc() crypto.Hash {
// The GoLang [crypto.SignerOpt] interfaces states that if the [HashFunc]
// returns 0, then it indicates to the [Sign] function that no hashing
// has occurred over the message.
// However, the AWS KMS Sign operation always requires that a
// SigningAlgorithm is specified.
// As such, the AWSOptions HashFunc() must return a valid (non-zero) Hash,
// such that the [getMessageTypeAndSigningAlgorithm] function can return a valid AWS KMS
// [types.SigningAlgorithmSpec]
return a.Options.HashFunc()
}

// compile time check that AWSOptions implements crypto.SignerOpts
var _ crypto.SignerOpts = (*AWSOptions)(nil)

// Signer implements a crypto.Signer using the AWS KMS.
type Signer struct {
client KeyManagementClient
Expand Down Expand Up @@ -63,7 +105,7 @@ func (s *Signer) Public() crypto.PublicKey {

// Sign signs digest with the private key stored in the AWS KMS.
func (s *Signer) Sign(_ io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) {
alg, err := getSigningAlgorithm(s.Public(), opts)
messageType, alg, err := getMessageTypeAndSigningAlgorithm(s.Public(), opts)
if err != nil {
return nil, err
}
Expand All @@ -72,7 +114,7 @@ func (s *Signer) Sign(_ io.Reader, digest []byte, opts crypto.SignerOpts) ([]byt
KeyId: pointer(s.keyID),
SigningAlgorithm: alg,
Message: digest,
MessageType: types.MessageTypeDigest,
MessageType: messageType,
}

ctx, cancel := defaultContext()
Expand All @@ -86,41 +128,49 @@ func (s *Signer) Sign(_ io.Reader, digest []byte, opts crypto.SignerOpts) ([]byt
return resp.Signature, nil
}

func getSigningAlgorithm(key crypto.PublicKey, opts crypto.SignerOpts) (types.SigningAlgorithmSpec, error) {
func getMessageTypeAndSigningAlgorithm(key crypto.PublicKey, opts crypto.SignerOpts) (types.MessageType, types.SigningAlgorithmSpec, error) {
messageType := types.MessageTypeDigest
if awsOpts, ok := opts.(*AWSOptions); ok {
if awsOpts.Raw {
messageType = types.MessageTypeRaw
}
opts = awsOpts.Options
}

switch key.(type) {
case *rsa.PublicKey:
_, isPSS := opts.(*rsa.PSSOptions)
switch h := opts.HashFunc(); h {
case crypto.SHA256:
if isPSS {
return types.SigningAlgorithmSpecRsassaPssSha256, nil
return messageType, types.SigningAlgorithmSpecRsassaPssSha256, nil
}
return types.SigningAlgorithmSpecRsassaPkcs1V15Sha256, nil
return messageType, types.SigningAlgorithmSpecRsassaPkcs1V15Sha256, nil
case crypto.SHA384:
if isPSS {
return types.SigningAlgorithmSpecRsassaPssSha384, nil
return messageType, types.SigningAlgorithmSpecRsassaPssSha384, nil
}
return types.SigningAlgorithmSpecRsassaPkcs1V15Sha384, nil
return messageType, types.SigningAlgorithmSpecRsassaPkcs1V15Sha384, nil
case crypto.SHA512:
if isPSS {
return types.SigningAlgorithmSpecRsassaPssSha512, nil
return messageType, types.SigningAlgorithmSpecRsassaPssSha512, nil
}
return types.SigningAlgorithmSpecRsassaPkcs1V15Sha512, nil
return messageType, types.SigningAlgorithmSpecRsassaPkcs1V15Sha512, nil
default:
return "", errors.Errorf("unsupported hash function %v", h)
return messageType, "", errors.Errorf("unsupported hash function %v", h)
}
case *ecdsa.PublicKey:
switch h := opts.HashFunc(); h {
case crypto.SHA256:
return types.SigningAlgorithmSpecEcdsaSha256, nil
return messageType, types.SigningAlgorithmSpecEcdsaSha256, nil
case crypto.SHA384:
return types.SigningAlgorithmSpecEcdsaSha384, nil
return messageType, types.SigningAlgorithmSpecEcdsaSha384, nil
case crypto.SHA512:
return types.SigningAlgorithmSpecEcdsaSha512, nil
return messageType, types.SigningAlgorithmSpecEcdsaSha512, nil
default:
return "", errors.Errorf("unsupported hash function %v", h)
return messageType, "", errors.Errorf("unsupported hash function %v", h)
}
default:
return "", errors.Errorf("unsupported key type %T", key)
return messageType, "", errors.Errorf("unsupported key type %T", key)
}
}
59 changes: 37 additions & 22 deletions kms/awskms/signer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,9 @@ func TestSigner_Sign(t *testing.T) {
wantErr bool
}{
{"ok", fields{okClient, "be468355-ca7a-40d9-a28b-8ae1c4c7f936", key}, args{rand.Reader, []byte("digest"), crypto.SHA256}, signature, false},
{"fail alg", fields{okClient, "be468355-ca7a-40d9-a28b-8ae1c4c7f936", key}, args{rand.Reader, []byte("digest"), crypto.MD5}, nil, true},
{"(raw) ok", fields{okClient, "be468355-ca7a-40d9-a28b-8ae1c4c7f936", key}, args{rand.Reader, []byte("digest"), &AWSOptions{Raw: true, Options: crypto.SHA256}}, signature, false},
{"fail alg", fields{okClient, "be468355-ca7a-40d9-a28b-8ae1c4c7f936", key}, args{rand.Reader, []byte("digest"), &AWSOptions{Raw: true, Options: crypto.MD5}}, nil, true},
{"(raw) fail alg", fields{okClient, "be468355-ca7a-40d9-a28b-8ae1c4c7f936", key}, args{rand.Reader, []byte("digest"), crypto.MD5}, nil, true},
{"fail key", fields{okClient, "be468355-ca7a-40d9-a28b-8ae1c4c7f936", []byte("key")}, args{rand.Reader, []byte("digest"), crypto.SHA256}, nil, true},
{"fail sign", fields{&MockClient{
sign: func(ctx context.Context, input *kms.SignInput, opts ...func(*kms.Options)) (*kms.SignOutput, error) {
Expand All @@ -152,39 +154,52 @@ func TestSigner_Sign(t *testing.T) {
}
}

func Test_getSigningAlgorithm(t *testing.T) {
func Test_getMessageTypeAndSigningAlgorithm(t *testing.T) {
type args struct {
key crypto.PublicKey
opts crypto.SignerOpts
}
tests := []struct {
name string
args args
want types.SigningAlgorithmSpec
wantErr bool
name string
args args
wantMessageType types.MessageType
wantAlgo types.SigningAlgorithmSpec
wantErr bool
}{
{"rsa+sha256", args{&rsa.PublicKey{}, crypto.SHA256}, "RSASSA_PKCS1_V1_5_SHA_256", false},
{"rsa+sha384", args{&rsa.PublicKey{}, crypto.SHA384}, "RSASSA_PKCS1_V1_5_SHA_384", false},
{"rsa+sha512", args{&rsa.PublicKey{}, crypto.SHA512}, "RSASSA_PKCS1_V1_5_SHA_512", false},
{"pssrsa+sha256", args{&rsa.PublicKey{}, &rsa.PSSOptions{Hash: crypto.SHA256.HashFunc()}}, "RSASSA_PSS_SHA_256", false},
{"pssrsa+sha384", args{&rsa.PublicKey{}, &rsa.PSSOptions{Hash: crypto.SHA384.HashFunc()}}, "RSASSA_PSS_SHA_384", false},
{"pssrsa+sha512", args{&rsa.PublicKey{}, &rsa.PSSOptions{Hash: crypto.SHA512.HashFunc()}}, "RSASSA_PSS_SHA_512", false},
{"P256", args{&ecdsa.PublicKey{}, crypto.SHA256}, "ECDSA_SHA_256", false},
{"P384", args{&ecdsa.PublicKey{}, crypto.SHA384}, "ECDSA_SHA_384", false},
{"P521", args{&ecdsa.PublicKey{}, crypto.SHA512}, "ECDSA_SHA_512", false},
{"fail type", args{[]byte("key"), crypto.SHA256}, "", true},
{"fail rsa alg", args{&rsa.PublicKey{}, crypto.MD5}, "", true},
{"fail ecdsa alg", args{&ecdsa.PublicKey{}, crypto.MD5}, "", true},
{"rsa+sha256", args{&rsa.PublicKey{}, crypto.SHA256}, types.MessageTypeDigest, "RSASSA_PKCS1_V1_5_SHA_256", false},
{"rsa+sha384", args{&rsa.PublicKey{}, crypto.SHA384}, types.MessageTypeDigest, "RSASSA_PKCS1_V1_5_SHA_384", false},
{"rsa+sha512", args{&rsa.PublicKey{}, crypto.SHA512}, types.MessageTypeDigest, "RSASSA_PKCS1_V1_5_SHA_512", false},
{"pssrsa+sha256", args{&rsa.PublicKey{}, &rsa.PSSOptions{Hash: crypto.SHA256.HashFunc()}}, types.MessageTypeDigest, "RSASSA_PSS_SHA_256", false},
{"pssrsa+sha384", args{&rsa.PublicKey{}, &rsa.PSSOptions{Hash: crypto.SHA384.HashFunc()}}, types.MessageTypeDigest, "RSASSA_PSS_SHA_384", false},
{"pssrsa+sha512", args{&rsa.PublicKey{}, &rsa.PSSOptions{Hash: crypto.SHA512.HashFunc()}}, types.MessageTypeDigest, "RSASSA_PSS_SHA_512", false},
{"P256", args{&ecdsa.PublicKey{}, crypto.SHA256}, types.MessageTypeDigest, "ECDSA_SHA_256", false},
{"P384", args{&ecdsa.PublicKey{}, crypto.SHA384}, types.MessageTypeDigest, "ECDSA_SHA_384", false},
{"P521", args{&ecdsa.PublicKey{}, crypto.SHA512}, types.MessageTypeDigest, "ECDSA_SHA_512", false},
{"(raw)rsa+sha256", args{&rsa.PublicKey{}, &AWSOptions{Raw: true, Options: crypto.SHA256}}, types.MessageTypeRaw, "RSASSA_PKCS1_V1_5_SHA_256", false},
{"(raw)rsa+sha384", args{&rsa.PublicKey{}, &AWSOptions{Raw: true, Options: crypto.SHA384}}, types.MessageTypeRaw, "RSASSA_PKCS1_V1_5_SHA_384", false},
{"(raw)rsa+sha512", args{&rsa.PublicKey{}, &AWSOptions{Raw: true, Options: crypto.SHA512}}, types.MessageTypeRaw, "RSASSA_PKCS1_V1_5_SHA_512", false},
{"(raw)pssrsa+sha256", args{&rsa.PublicKey{}, &AWSOptions{Raw: true, Options: &rsa.PSSOptions{Hash: crypto.SHA256.HashFunc()}}}, types.MessageTypeRaw, "RSASSA_PSS_SHA_256", false},
{"(raw)pssrsa+sha384", args{&rsa.PublicKey{}, &AWSOptions{Raw: true, Options: &rsa.PSSOptions{Hash: crypto.SHA384.HashFunc()}}}, types.MessageTypeRaw, "RSASSA_PSS_SHA_384", false},
{"(raw)pssrsa+sha512", args{&rsa.PublicKey{}, &AWSOptions{Raw: true, Options: &rsa.PSSOptions{Hash: crypto.SHA512.HashFunc()}}}, types.MessageTypeRaw, "RSASSA_PSS_SHA_512", false},
{"(raw)P256", args{&ecdsa.PublicKey{}, &AWSOptions{Raw: true, Options: crypto.SHA256}}, types.MessageTypeRaw, "ECDSA_SHA_256", false},
{"(raw)P384", args{&ecdsa.PublicKey{}, &AWSOptions{Raw: true, Options: crypto.SHA384}}, types.MessageTypeRaw, "ECDSA_SHA_384", false},
{"(raw)P521", args{&ecdsa.PublicKey{}, &AWSOptions{Raw: true, Options: crypto.SHA512}}, types.MessageTypeRaw, "ECDSA_SHA_512", false},
{"fail type", args{[]byte("key"), crypto.SHA256}, types.MessageTypeDigest, "", true},
{"fail rsa alg", args{&rsa.PublicKey{}, crypto.MD5}, types.MessageTypeDigest, "", true},
{"fail ecdsa alg", args{&ecdsa.PublicKey{}, crypto.MD5}, types.MessageTypeDigest, "", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := getSigningAlgorithm(tt.args.key, tt.args.opts)
gotMessageType, gotAlgo, err := getMessageTypeAndSigningAlgorithm(tt.args.key, tt.args.opts)
if (err != nil) != tt.wantErr {
t.Errorf("getSigningAlgorithm() error = %v, wantErr %v", err, tt.wantErr)
t.Errorf("getMessageTypeAndSigningAlgorithm() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("getSigningAlgorithm() = %v, want %v", got, tt.want)
if gotMessageType != tt.wantMessageType {
t.Errorf("getMessageTypeAndSigningAlgorithm() (message type) = %v, want %v", gotMessageType, tt.wantMessageType)
}
if gotAlgo != tt.wantAlgo {
t.Errorf("getMessageTypeAndSigningAlgorithm() (algorithm) = %v, want %v", gotAlgo, tt.wantAlgo)
}
})
}
Expand Down