diff --git a/go.mod b/go.mod index 82b1b91..2e45115 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ go 1.18 require ( github.com/aws/aws-sdk-go v1.43.9 + github.com/aws/aws-sdk-go-v2 v1.17.3 github.com/aws/aws-sdk-go-v2/service/kms v1.20.0 github.com/btcsuite/btcd v0.22.1 github.com/golang/mock v1.6.0 @@ -29,7 +30,6 @@ require ( require ( github.com/VictoriaMetrics/fastcache v1.5.7 // indirect - github.com/aws/aws-sdk-go-v2 v1.17.3 // indirect github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.27 // indirect github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.21 // indirect github.com/aws/smithy-go v1.13.5 // indirect diff --git a/pkg/aws/opts.go b/pkg/aws/opts.go index 3d46b6d..6db35eb 100644 --- a/pkg/aws/opts.go +++ b/pkg/aws/opts.go @@ -12,6 +12,7 @@ import ( type opts struct { keyAliasPrefix string + awsClient awsClient } // NewOpts create new opts populated with environment variable. @@ -34,3 +35,8 @@ type Opts func(opts *opts) func WithKeyAliasPrefix(prefix string) Opts { return func(opts *opts) { opts.keyAliasPrefix = prefix } } + +// WithAWSClient sets custom AWS client. +func WithAWSClient(client awsClient) Opts { + return func(opts *opts) { opts.awsClient = client } +} diff --git a/pkg/aws/service.go b/pkg/aws/service.go index 6210541..b088786 100644 --- a/pkg/aws/service.go +++ b/pkg/aws/service.go @@ -4,11 +4,14 @@ Copyright SecureKey Technologies Inc. All Rights Reserved. SPDX-License-Identifier: Apache-2.0 */ +//go:generate mockgen -destination service_mocks.go -package aws -source=service.go + package aws import ( "context" "crypto/elliptic" + "crypto/rand" "crypto/sha256" "crypto/sha512" "encoding/asn1" @@ -26,7 +29,7 @@ import ( arieskms "github.com/hyperledger/aries-framework-go/pkg/kms" ) -type awsClient interface { //nolint:dupl +type awsClient interface { Sign(ctx context.Context, params *kms.SignInput, optFns ...func(*kms.Options)) (*kms.SignOutput, error) GetPublicKey(ctx context.Context, params *kms.GetPublicKeyInput, optFns ...func(*kms.Options)) (*kms.GetPublicKeyOutput, error) @@ -37,11 +40,17 @@ type awsClient interface { //nolint:dupl optFns ...func(*kms.Options)) (*kms.CreateKeyOutput, error) CreateAlias(ctx context.Context, params *kms.CreateAliasInput, optFns ...func(*kms.Options)) (*kms.CreateAliasOutput, error) + Encrypt(ctx context.Context, params *kms.EncryptInput, optFns ...func(*kms.Options)) (*kms.EncryptOutput, error) + Decrypt(ctx context.Context, params *kms.DecryptInput, optFns ...func(*kms.Options)) (*kms.DecryptOutput, error) } type metricsProvider interface { SignCount() + EncryptCount() + DecryptCount() SignTime(value time.Duration) + EncryptTime(value time.Duration) + DecryptTime(value time.Duration) ExportPublicKeyCount() ExportPublicKeyTime(value time.Duration) VerifyCount() @@ -58,6 +67,8 @@ type Service struct { client awsClient metrics metricsProvider healthCheckKeyID string + encryptionAlgo types.EncryptionAlgorithmSpec + nonceLength int } const ( @@ -79,23 +90,108 @@ var keySpecToCurve = map[types.KeySpec]elliptic.Curve{ types.KeySpecEccSecgP256k1: btcec.S256(), } +const ( + defaultNonceLength = 16 +) + // New return aws service. -func New(awsConfig *aws.Config, metrics metricsProvider, - healthCheckKeyID string, opts ...Opts) *Service { +func New( + awsConfig *aws.Config, + metrics metricsProvider, + healthCheckKeyID string, + opts ...Opts, +) *Service { options := newOpts() for _, opt := range opts { opt(options) } + client := options.awsClient + if client == nil { + client = kms.NewFromConfig(*awsConfig) + } + return &Service{ options: options, - client: kms.NewFromConfig(*awsConfig), + client: client, metrics: metrics, healthCheckKeyID: healthCheckKeyID, + encryptionAlgo: types.EncryptionAlgorithmSpecRsaesOaepSha256, + nonceLength: defaultNonceLength, } } +// Decrypt data. +func (s *Service) Decrypt(_, aad, _ []byte, kh interface{}) ([]byte, error) { + startTime := time.Now() + + defer func() { + if s.metrics != nil { + s.metrics.DecryptTime(time.Since(startTime)) + } + }() + + if s.metrics != nil { + s.metrics.DecryptCount() + } + + keyID, err := s.getKeyID(kh.(string)) + if err != nil { + return nil, err + } + + input := &kms.DecryptInput{ + CiphertextBlob: aad, + EncryptionAlgorithm: s.encryptionAlgo, + KeyId: aws.String(keyID), + } + + resp, err := s.client.Decrypt(context.Background(), input) + if err != nil { + return nil, err + } + + return resp.Plaintext, nil +} + +// Encrypt data. +func (s *Service) Encrypt( + msg []byte, + _ []byte, + kh interface{}, +) ([]byte, []byte, error) { + startTime := time.Now() + + defer func() { + if s.metrics != nil { + s.metrics.EncryptTime(time.Since(startTime)) + } + }() + + if s.metrics != nil { + s.metrics.EncryptCount() + } + + keyID, err := s.getKeyID(kh.(string)) + if err != nil { + return nil, nil, err + } + + input := &kms.EncryptInput{ + KeyId: aws.String(keyID), + Plaintext: msg, + EncryptionAlgorithm: s.encryptionAlgo, + } + + resp, err := s.client.Encrypt(context.Background(), input) + if err != nil { + return nil, nil, err + } + + return resp.CiphertextBlob, generateNonce(s.nonceLength), nil +} + // Sign data. func (s *Service) Sign(msg []byte, kh interface{}) ([]byte, error) { //nolint: funlen startTime := time.Now() @@ -310,6 +406,13 @@ func (s *Service) getKeyID(keyURI string) (string, error) { return r[4], nil } +func generateNonce(length int) []byte { + key := make([]byte, length) + _, _ = rand.Read(key) //nolint: errcheck + + return key +} + func hashMessage(message []byte, algorithm types.SigningAlgorithmSpec) ([]byte, error) { var digest hash.Hash diff --git a/pkg/aws/service_mocks.go b/pkg/aws/service_mocks.go new file mode 100644 index 0000000..ac88063 --- /dev/null +++ b/pkg/aws/service_mocks.go @@ -0,0 +1,340 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: service.go + +// Package aws is a generated GoMock package. +package aws + +import ( + context "context" + reflect "reflect" + time "time" + + kms "github.com/aws/aws-sdk-go-v2/service/kms" + gomock "github.com/golang/mock/gomock" +) + +// MockawsClient is a mock of awsClient interface. +type MockawsClient struct { + ctrl *gomock.Controller + recorder *MockawsClientMockRecorder +} + +// MockawsClientMockRecorder is the mock recorder for MockawsClient. +type MockawsClientMockRecorder struct { + mock *MockawsClient +} + +// NewMockawsClient creates a new mock instance. +func NewMockawsClient(ctrl *gomock.Controller) *MockawsClient { + mock := &MockawsClient{ctrl: ctrl} + mock.recorder = &MockawsClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockawsClient) EXPECT() *MockawsClientMockRecorder { + return m.recorder +} + +// CreateAlias mocks base method. +func (m *MockawsClient) CreateAlias(ctx context.Context, params *kms.CreateAliasInput, optFns ...func(*kms.Options)) (*kms.CreateAliasOutput, error) { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, params} + for _, a := range optFns { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "CreateAlias", varargs...) + ret0, _ := ret[0].(*kms.CreateAliasOutput) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateAlias indicates an expected call of CreateAlias. +func (mr *MockawsClientMockRecorder) CreateAlias(ctx, params interface{}, optFns ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, params}, optFns...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateAlias", reflect.TypeOf((*MockawsClient)(nil).CreateAlias), varargs...) +} + +// CreateKey mocks base method. +func (m *MockawsClient) CreateKey(ctx context.Context, params *kms.CreateKeyInput, optFns ...func(*kms.Options)) (*kms.CreateKeyOutput, error) { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, params} + for _, a := range optFns { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "CreateKey", varargs...) + ret0, _ := ret[0].(*kms.CreateKeyOutput) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateKey indicates an expected call of CreateKey. +func (mr *MockawsClientMockRecorder) CreateKey(ctx, params interface{}, optFns ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, params}, optFns...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateKey", reflect.TypeOf((*MockawsClient)(nil).CreateKey), varargs...) +} + +// Decrypt mocks base method. +func (m *MockawsClient) Decrypt(ctx context.Context, params *kms.DecryptInput, optFns ...func(*kms.Options)) (*kms.DecryptOutput, error) { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, params} + for _, a := range optFns { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Decrypt", varargs...) + ret0, _ := ret[0].(*kms.DecryptOutput) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Decrypt indicates an expected call of Decrypt. +func (mr *MockawsClientMockRecorder) Decrypt(ctx, params interface{}, optFns ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, params}, optFns...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Decrypt", reflect.TypeOf((*MockawsClient)(nil).Decrypt), varargs...) +} + +// DescribeKey mocks base method. +func (m *MockawsClient) DescribeKey(ctx context.Context, params *kms.DescribeKeyInput, optFns ...func(*kms.Options)) (*kms.DescribeKeyOutput, error) { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, params} + for _, a := range optFns { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "DescribeKey", varargs...) + ret0, _ := ret[0].(*kms.DescribeKeyOutput) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DescribeKey indicates an expected call of DescribeKey. +func (mr *MockawsClientMockRecorder) DescribeKey(ctx, params interface{}, optFns ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, params}, optFns...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DescribeKey", reflect.TypeOf((*MockawsClient)(nil).DescribeKey), varargs...) +} + +// Encrypt mocks base method. +func (m *MockawsClient) Encrypt(ctx context.Context, params *kms.EncryptInput, optFns ...func(*kms.Options)) (*kms.EncryptOutput, error) { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, params} + for _, a := range optFns { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Encrypt", varargs...) + ret0, _ := ret[0].(*kms.EncryptOutput) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Encrypt indicates an expected call of Encrypt. +func (mr *MockawsClientMockRecorder) Encrypt(ctx, params interface{}, optFns ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, params}, optFns...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Encrypt", reflect.TypeOf((*MockawsClient)(nil).Encrypt), varargs...) +} + +// GetPublicKey mocks base method. +func (m *MockawsClient) GetPublicKey(ctx context.Context, params *kms.GetPublicKeyInput, optFns ...func(*kms.Options)) (*kms.GetPublicKeyOutput, error) { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, params} + for _, a := range optFns { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "GetPublicKey", varargs...) + ret0, _ := ret[0].(*kms.GetPublicKeyOutput) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetPublicKey indicates an expected call of GetPublicKey. +func (mr *MockawsClientMockRecorder) GetPublicKey(ctx, params interface{}, optFns ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, params}, optFns...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPublicKey", reflect.TypeOf((*MockawsClient)(nil).GetPublicKey), varargs...) +} + +// Sign mocks base method. +func (m *MockawsClient) Sign(ctx context.Context, params *kms.SignInput, optFns ...func(*kms.Options)) (*kms.SignOutput, error) { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, params} + for _, a := range optFns { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Sign", varargs...) + ret0, _ := ret[0].(*kms.SignOutput) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Sign indicates an expected call of Sign. +func (mr *MockawsClientMockRecorder) Sign(ctx, params interface{}, optFns ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, params}, optFns...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Sign", reflect.TypeOf((*MockawsClient)(nil).Sign), varargs...) +} + +// Verify mocks base method. +func (m *MockawsClient) Verify(ctx context.Context, params *kms.VerifyInput, optFns ...func(*kms.Options)) (*kms.VerifyOutput, error) { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, params} + for _, a := range optFns { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Verify", varargs...) + ret0, _ := ret[0].(*kms.VerifyOutput) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Verify indicates an expected call of Verify. +func (mr *MockawsClientMockRecorder) Verify(ctx, params interface{}, optFns ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, params}, optFns...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Verify", reflect.TypeOf((*MockawsClient)(nil).Verify), varargs...) +} + +// MockmetricsProvider is a mock of metricsProvider interface. +type MockmetricsProvider struct { + ctrl *gomock.Controller + recorder *MockmetricsProviderMockRecorder +} + +// MockmetricsProviderMockRecorder is the mock recorder for MockmetricsProvider. +type MockmetricsProviderMockRecorder struct { + mock *MockmetricsProvider +} + +// NewMockmetricsProvider creates a new mock instance. +func NewMockmetricsProvider(ctrl *gomock.Controller) *MockmetricsProvider { + mock := &MockmetricsProvider{ctrl: ctrl} + mock.recorder = &MockmetricsProviderMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockmetricsProvider) EXPECT() *MockmetricsProviderMockRecorder { + return m.recorder +} + +// DecryptCount mocks base method. +func (m *MockmetricsProvider) DecryptCount() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DecryptCount") +} + +// DecryptCount indicates an expected call of DecryptCount. +func (mr *MockmetricsProviderMockRecorder) DecryptCount() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecryptCount", reflect.TypeOf((*MockmetricsProvider)(nil).DecryptCount)) +} + +// DecryptTime mocks base method. +func (m *MockmetricsProvider) DecryptTime(value time.Duration) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DecryptTime", value) +} + +// DecryptTime indicates an expected call of DecryptTime. +func (mr *MockmetricsProviderMockRecorder) DecryptTime(value interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecryptTime", reflect.TypeOf((*MockmetricsProvider)(nil).DecryptTime), value) +} + +// EncryptCount mocks base method. +func (m *MockmetricsProvider) EncryptCount() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "EncryptCount") +} + +// EncryptCount indicates an expected call of EncryptCount. +func (mr *MockmetricsProviderMockRecorder) EncryptCount() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EncryptCount", reflect.TypeOf((*MockmetricsProvider)(nil).EncryptCount)) +} + +// EncryptTime mocks base method. +func (m *MockmetricsProvider) EncryptTime(value time.Duration) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "EncryptTime", value) +} + +// EncryptTime indicates an expected call of EncryptTime. +func (mr *MockmetricsProviderMockRecorder) EncryptTime(value interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EncryptTime", reflect.TypeOf((*MockmetricsProvider)(nil).EncryptTime), value) +} + +// ExportPublicKeyCount mocks base method. +func (m *MockmetricsProvider) ExportPublicKeyCount() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ExportPublicKeyCount") +} + +// ExportPublicKeyCount indicates an expected call of ExportPublicKeyCount. +func (mr *MockmetricsProviderMockRecorder) ExportPublicKeyCount() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExportPublicKeyCount", reflect.TypeOf((*MockmetricsProvider)(nil).ExportPublicKeyCount)) +} + +// ExportPublicKeyTime mocks base method. +func (m *MockmetricsProvider) ExportPublicKeyTime(value time.Duration) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ExportPublicKeyTime", value) +} + +// ExportPublicKeyTime indicates an expected call of ExportPublicKeyTime. +func (mr *MockmetricsProviderMockRecorder) ExportPublicKeyTime(value interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExportPublicKeyTime", reflect.TypeOf((*MockmetricsProvider)(nil).ExportPublicKeyTime), value) +} + +// SignCount mocks base method. +func (m *MockmetricsProvider) SignCount() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SignCount") +} + +// SignCount indicates an expected call of SignCount. +func (mr *MockmetricsProviderMockRecorder) SignCount() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SignCount", reflect.TypeOf((*MockmetricsProvider)(nil).SignCount)) +} + +// SignTime mocks base method. +func (m *MockmetricsProvider) SignTime(value time.Duration) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SignTime", value) +} + +// SignTime indicates an expected call of SignTime. +func (mr *MockmetricsProviderMockRecorder) SignTime(value interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SignTime", reflect.TypeOf((*MockmetricsProvider)(nil).SignTime), value) +} + +// VerifyCount mocks base method. +func (m *MockmetricsProvider) VerifyCount() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "VerifyCount") +} + +// VerifyCount indicates an expected call of VerifyCount. +func (mr *MockmetricsProviderMockRecorder) VerifyCount() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "VerifyCount", reflect.TypeOf((*MockmetricsProvider)(nil).VerifyCount)) +} + +// VerifyTime mocks base method. +func (m *MockmetricsProvider) VerifyTime(value time.Duration) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "VerifyTime", value) +} + +// VerifyTime indicates an expected call of VerifyTime. +func (mr *MockmetricsProviderMockRecorder) VerifyTime(value interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "VerifyTime", reflect.TypeOf((*MockmetricsProvider)(nil).VerifyTime), value) +} diff --git a/pkg/aws/service_test.go b/pkg/aws/service_test.go index 6f4250d..52f95dd 100644 --- a/pkg/aws/service_test.go +++ b/pkg/aws/service_test.go @@ -8,14 +8,16 @@ package aws //nolint:testpackage import ( "context" + "errors" "fmt" "testing" - "time" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/kms" "github.com/aws/aws-sdk-go-v2/service/kms/types" + "github.com/golang/mock/gomock" arieskms "github.com/hyperledger/aries-framework-go/pkg/kms" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -25,22 +27,25 @@ func TestSign(t *testing.T) { } t.Run("success", func(t *testing.T) { - svc := New(awsConfig, &mockMetrics{}, "", []Opts{}...) + metric := NewMockmetricsProvider(gomock.NewController(t)) + metric.EXPECT().SignCount() + metric.EXPECT().SignTime(gomock.Any()) - svc.client = &mockAWSClient{signFunc: func(ctx context.Context, params *kms.SignInput, - optFns ...func(*kms.Options)) (*kms.SignOutput, error) { - return &kms.SignOutput{ + client := NewMockawsClient(gomock.NewController(t)) + client.EXPECT().Sign(gomock.Any(), gomock.Any(), gomock.Any()). + Return(&kms.SignOutput{ Signature: []byte("data"), - }, nil - }, describeKeyFunc: func(ctx context.Context, params *kms.DescribeKeyInput, - optFns ...func(*kms.Options)) (*kms.DescribeKeyOutput, error) { - return &kms.DescribeKeyOutput{ + }, nil) + + client.EXPECT().DescribeKey(gomock.Any(), gomock.Any(), gomock.Any()). + Return(&kms.DescribeKeyOutput{ KeyMetadata: &types.KeyMetadata{ SigningAlgorithms: []types.SigningAlgorithmSpec{types.SigningAlgorithmSpecEcdsaSha256}, KeySpec: types.KeySpecEccNistP256, }, - }, nil - }} + }, nil) + + svc := New(awsConfig, metric, "", WithAWSClient(client)) signature, err := svc.Sign([]byte("msg"), "aws-kms://arn:aws:kms:ca-central-1:111122223333:alias/800d5768-3fd7-4edd-a4b8-4c81c3e4c147") @@ -49,19 +54,21 @@ func TestSign(t *testing.T) { }) t.Run("failed to sign", func(t *testing.T) { - svc := New(awsConfig, &mockMetrics{}, "", []Opts{}...) - - svc.client = &mockAWSClient{signFunc: func(ctx context.Context, params *kms.SignInput, - optFns ...func(*kms.Options)) (*kms.SignOutput, error) { - return nil, fmt.Errorf("failed to sign") - }, describeKeyFunc: func(ctx context.Context, params *kms.DescribeKeyInput, - optFns ...func(*kms.Options)) (*kms.DescribeKeyOutput, error) { - return &kms.DescribeKeyOutput{ + metric := NewMockmetricsProvider(gomock.NewController(t)) + metric.EXPECT().SignCount() + metric.EXPECT().SignTime(gomock.Any()) + + client := NewMockawsClient(gomock.NewController(t)) + client.EXPECT().Sign(gomock.Any(), gomock.Any(), gomock.Any()). + Return(nil, fmt.Errorf("failed to sign")) + client.EXPECT().DescribeKey(gomock.Any(), gomock.Any(), gomock.Any()). + Return(&kms.DescribeKeyOutput{ KeyMetadata: &types.KeyMetadata{ SigningAlgorithms: []types.SigningAlgorithmSpec{types.SigningAlgorithmSpecEcdsaSha256}, }, - }, nil - }} + }, nil) + + svc := New(awsConfig, metric, "", WithAWSClient(client)) _, err := svc.Sign([]byte("msg"), "aws-kms://arn:aws:kms:ca-central-1:111122223333:key/800d5768-3fd7-4edd-a4b8-4c81c3e4c147") @@ -70,8 +77,11 @@ func TestSign(t *testing.T) { }) t.Run("failed to parse key id", func(t *testing.T) { - svc := New(awsConfig, &mockMetrics{}, "", []Opts{}...) + metric := NewMockmetricsProvider(gomock.NewController(t)) + metric.EXPECT().SignCount() + metric.EXPECT().SignTime(gomock.Any()) + svc := New(awsConfig, metric, "", []Opts{}...) _, err := svc.Sign([]byte("msg"), "aws-kms://arn:aws:kms:key1") require.Error(t, err) require.Contains(t, err.Error(), "extracting key id from URI failed") @@ -84,28 +94,29 @@ func TestHealthCheck(t *testing.T) { } t.Run("success", func(t *testing.T) { - svc := New(awsConfig, &mockMetrics{}, - "aws-kms://arn:aws:kms:ca-central-1:111122223333:key/800d5768-3fd7-4edd-a4b8-4c81c3e4c147", - []Opts{}...) + metric := NewMockmetricsProvider(gomock.NewController(t)) + client := NewMockawsClient(gomock.NewController(t)) + client.EXPECT().DescribeKey(gomock.Any(), gomock.Any(), gomock.Any()). + Return(&kms.DescribeKeyOutput{}, nil) - svc.client = &mockAWSClient{describeKeyFunc: func(ctx context.Context, params *kms.DescribeKeyInput, - optFns ...func(*kms.Options)) (*kms.DescribeKeyOutput, error) { - return &kms.DescribeKeyOutput{}, nil - }} + svc := New(awsConfig, metric, + "aws-kms://arn:aws:kms:ca-central-1:111122223333:key/800d5768-3fd7-4edd-a4b8-4c81c3e4c147", + WithAWSClient(client), + ) err := svc.HealthCheck() require.NoError(t, err) }) t.Run("failed to list keys", func(t *testing.T) { - svc := New(awsConfig, &mockMetrics{}, - "aws-kms://arn:aws:kms:ca-central-1:111122223333:key/800d5768-3fd7-4edd-a4b8-4c81c3e4c147", - []Opts{}...) + metric := NewMockmetricsProvider(gomock.NewController(t)) + client := NewMockawsClient(gomock.NewController(t)) + client.EXPECT().DescribeKey(gomock.Any(), gomock.Any(), gomock.Any()). + Return(nil, fmt.Errorf("failed to list keys")) - svc.client = &mockAWSClient{describeKeyFunc: func(ctx context.Context, params *kms.DescribeKeyInput, - optFns ...func(*kms.Options)) (*kms.DescribeKeyOutput, error) { - return nil, fmt.Errorf("failed to list keys") - }} + svc := New(awsConfig, metric, + "aws-kms://arn:aws:kms:ca-central-1:111122223333:key/800d5768-3fd7-4edd-a4b8-4c81c3e4c147", + WithAWSClient(client)) err := svc.HealthCheck() require.Error(t, err) @@ -119,14 +130,14 @@ func TestCreate(t *testing.T) { } t.Run("success", func(t *testing.T) { - svc := New(awsConfig, &mockMetrics{}, "", []Opts{}...) - keyID := "key1" - svc.client = &mockAWSClient{createKeyFunc: func(ctx context.Context, params *kms.CreateKeyInput, - optFns ...func(*kms.Options)) (*kms.CreateKeyOutput, error) { - return &kms.CreateKeyOutput{KeyMetadata: &types.KeyMetadata{KeyId: &keyID}}, nil - }} + metric := NewMockmetricsProvider(gomock.NewController(t)) + client := NewMockawsClient(gomock.NewController(t)) + client.EXPECT().CreateKey(gomock.Any(), gomock.Any(), gomock.Any()). + Return(&kms.CreateKeyOutput{KeyMetadata: &types.KeyMetadata{KeyId: &keyID}}, nil) + + svc := New(awsConfig, metric, "", WithAWSClient(client)) result, _, err := svc.Create(arieskms.ECDSAP256DER) require.NoError(t, err) @@ -134,20 +145,19 @@ func TestCreate(t *testing.T) { }) t.Run("success: with key alias prefix", func(t *testing.T) { - svc := New(awsConfig, &mockMetrics{}, "", WithKeyAliasPrefix("dummyKeyAlias")) - keyID := "key1" - svc.client = &mockAWSClient{ - createKeyFunc: func(ctx context.Context, params *kms.CreateKeyInput, - optFns ...func(*kms.Options)) (*kms.CreateKeyOutput, error) { - return &kms.CreateKeyOutput{KeyMetadata: &types.KeyMetadata{KeyId: &keyID}}, nil - }, - createAliasFunc: func(ctx context.Context, params *kms.CreateAliasInput, - optFns ...func(*kms.Options)) (*kms.CreateAliasOutput, error) { - return &kms.CreateAliasOutput{}, nil - }, - } + metric := NewMockmetricsProvider(gomock.NewController(t)) + client := NewMockawsClient(gomock.NewController(t)) + client.EXPECT().CreateKey(gomock.Any(), gomock.Any(), gomock.Any()). + Return(&kms.CreateKeyOutput{KeyMetadata: &types.KeyMetadata{KeyId: &keyID}}, nil) + client.EXPECT().CreateAlias(gomock.Any(), gomock.Any(), gomock.Any()). + Return(&kms.CreateAliasOutput{}, nil) + + svc := New(awsConfig, metric, "", + WithKeyAliasPrefix("dummyKeyAlias"), + WithAWSClient(client), + ) result, _, err := svc.Create(arieskms.ECDSAP256DER) require.NoError(t, err) @@ -155,7 +165,9 @@ func TestCreate(t *testing.T) { }) t.Run("key not supported", func(t *testing.T) { - svc := New(awsConfig, &mockMetrics{}, "", []Opts{}...) + metric := NewMockmetricsProvider(gomock.NewController(t)) + + svc := New(awsConfig, metric, "", []Opts{}...) _, _, err := svc.Create(arieskms.ED25519) require.Error(t, err) @@ -167,9 +179,10 @@ func TestGet(t *testing.T) { awsConfig := aws.Config{ Region: "ca", } + metric := NewMockmetricsProvider(gomock.NewController(t)) t.Run("success", func(t *testing.T) { - svc := New(&awsConfig, &mockMetrics{}, "", []Opts{}...) + svc := New(&awsConfig, metric, "", []Opts{}...) keyID, err := svc.Get("key1") require.NoError(t, err) @@ -184,22 +197,18 @@ func TestCreateAndPubKeyBytes(t *testing.T) { t.Run("success", func(t *testing.T) { keyID := "aws-kms://arn:aws:kms:ca-central-1:111122223333:key/800d5768-3fd7-4edd-a4b8-4c81c3e4c147" - - svc := New(&awsConfig, &mockMetrics{}, "", []Opts{}...) - - svc.client = &mockAWSClient{ - getPublicKeyFunc: func(ctx context.Context, params *kms.GetPublicKeyInput, - optFns ...func(*kms.Options)) (*kms.GetPublicKeyOutput, error) { - return &kms.GetPublicKeyOutput{ - PublicKey: []byte("publickey"), - SigningAlgorithms: []types.SigningAlgorithmSpec{types.SigningAlgorithmSpecEcdsaSha256}, - }, nil - }, - createKeyFunc: func(ctx context.Context, params *kms.CreateKeyInput, - optFns ...func(*kms.Options)) (*kms.CreateKeyOutput, error) { - return &kms.CreateKeyOutput{KeyMetadata: &types.KeyMetadata{KeyId: &keyID}}, nil - }, - } + metric := NewMockmetricsProvider(gomock.NewController(t)) + metric.EXPECT().ExportPublicKeyCount() + metric.EXPECT().ExportPublicKeyTime(gomock.Any()) + client := NewMockawsClient(gomock.NewController(t)) + client.EXPECT().GetPublicKey(gomock.Any(), gomock.Any(), gomock.Any()). + Return(&kms.GetPublicKeyOutput{ + PublicKey: []byte("publickey"), + SigningAlgorithms: []types.SigningAlgorithmSpec{types.SigningAlgorithmSpecEcdsaSha256}, + }, nil) + client.EXPECT().CreateKey(gomock.Any(), gomock.Any(), gomock.Any()). + Return(&kms.CreateKeyOutput{KeyMetadata: &types.KeyMetadata{KeyId: &keyID}}, nil) + svc := New(&awsConfig, metric, "", WithAWSClient(client)) keyID, publicKey, err := svc.CreateAndExportPubKeyBytes(arieskms.ECDSAP256DER) require.NoError(t, err) @@ -212,8 +221,9 @@ func TestSignMulti(t *testing.T) { awsConfig := aws.Config{ Region: "ca", } + metric := NewMockmetricsProvider(gomock.NewController(t)) - svc := New(&awsConfig, &mockMetrics{}, "", []Opts{}...) + svc := New(&awsConfig, metric, "", []Opts{}...) _, err := svc.SignMulti(nil, nil) require.Error(t, err) @@ -226,15 +236,17 @@ func TestPubKeyBytes(t *testing.T) { } t.Run("success", func(t *testing.T) { - svc := New(awsConfig, &mockMetrics{}, "", []Opts{}...) + metric := NewMockmetricsProvider(gomock.NewController(t)) + metric.EXPECT().ExportPublicKeyCount() + metric.EXPECT().ExportPublicKeyTime(gomock.Any()) - svc.client = &mockAWSClient{getPublicKeyFunc: func(ctx context.Context, params *kms.GetPublicKeyInput, - optFns ...func(*kms.Options)) (*kms.GetPublicKeyOutput, error) { - return &kms.GetPublicKeyOutput{ + client := NewMockawsClient(gomock.NewController(t)) + client.EXPECT().GetPublicKey(gomock.Any(), gomock.Any(), gomock.Any()). + Return(&kms.GetPublicKeyOutput{ PublicKey: []byte("publickey"), SigningAlgorithms: []types.SigningAlgorithmSpec{types.SigningAlgorithmSpecEcdsaSha256}, - }, nil - }} + }, nil) + svc := New(awsConfig, metric, "", WithAWSClient(client)) keyID, keyType, err := svc.ExportPubKeyBytes( "aws-kms://arn:aws:kms:ca-central-1:111122223333:key/800d5768-3fd7-4edd-a4b8-4c81c3e4c147") @@ -244,12 +256,14 @@ func TestPubKeyBytes(t *testing.T) { }) t.Run("failed to export public key", func(t *testing.T) { - svc := New(awsConfig, &mockMetrics{}, "", []Opts{}...) + metric := NewMockmetricsProvider(gomock.NewController(t)) + metric.EXPECT().ExportPublicKeyCount() + metric.EXPECT().ExportPublicKeyTime(gomock.Any()) - svc.client = &mockAWSClient{getPublicKeyFunc: func(ctx context.Context, params *kms.GetPublicKeyInput, - optFns ...func(*kms.Options)) (*kms.GetPublicKeyOutput, error) { - return nil, fmt.Errorf("failed to export public key") - }} + client := NewMockawsClient(gomock.NewController(t)) + client.EXPECT().GetPublicKey(gomock.Any(), gomock.Any(), gomock.Any()). + Return(nil, fmt.Errorf("failed to export public key")) + svc := New(awsConfig, metric, "", WithAWSClient(client)) _, _, err := svc.ExportPubKeyBytes( "aws-kms://arn:aws:kms:ca-central-1:111122223333:key/800d5768-3fd7-4edd-a4b8-4c81c3e4c147") @@ -258,7 +272,11 @@ func TestPubKeyBytes(t *testing.T) { }) t.Run("failed to parse key id", func(t *testing.T) { - svc := New(awsConfig, &mockMetrics{}, "", []Opts{}...) + metric := NewMockmetricsProvider(gomock.NewController(t)) + metric.EXPECT().ExportPublicKeyCount() + metric.EXPECT().ExportPublicKeyTime(gomock.Any()) + + svc := New(awsConfig, metric, "", []Opts{}...) _, _, err := svc.ExportPubKeyBytes("aws-kms://arn:aws:kms:key1") require.Error(t, err) @@ -266,91 +284,172 @@ func TestPubKeyBytes(t *testing.T) { }) } -type mockAWSClient struct { - signFunc func(ctx context.Context, params *kms.SignInput, - optFns ...func(*kms.Options)) (*kms.SignOutput, error) - getPublicKeyFunc func(ctx context.Context, params *kms.GetPublicKeyInput, - optFns ...func(*kms.Options)) (*kms.GetPublicKeyOutput, error) - verifyFunc func(ctx context.Context, params *kms.VerifyInput, - optFns ...func(*kms.Options)) (*kms.VerifyOutput, error) - describeKeyFunc func(ctx context.Context, params *kms.DescribeKeyInput, - optFns ...func(*kms.Options)) (*kms.DescribeKeyOutput, error) - createKeyFunc func(ctx context.Context, params *kms.CreateKeyInput, - optFns ...func(*kms.Options)) (*kms.CreateKeyOutput, error) - createAliasFunc func(ctx context.Context, params *kms.CreateAliasInput, - optFns ...func(*kms.Options)) (*kms.CreateAliasOutput, error) -} - -func (m *mockAWSClient) Sign(ctx context.Context, params *kms.SignInput, - optFns ...func(*kms.Options)) (*kms.SignOutput, error) { - if m.signFunc != nil { - return m.signFunc(ctx, params, optFns...) - } - - return nil, nil //nolint:nilnil -} - -func (m *mockAWSClient) GetPublicKey(ctx context.Context, params *kms.GetPublicKeyInput, - optFns ...func(*kms.Options)) (*kms.GetPublicKeyOutput, error) { - if m.getPublicKeyFunc != nil { - return m.getPublicKeyFunc(ctx, params, optFns...) +func TestEncrypt(t *testing.T) { + awsConfig := &aws.Config{ + Region: "ca", } - return nil, nil //nolint:nilnil -} + t.Run("success", func(t *testing.T) { + metric := NewMockmetricsProvider(gomock.NewController(t)) + metric.EXPECT().EncryptCount() + metric.EXPECT().EncryptTime(gomock.Any()) + + client := NewMockawsClient(gomock.NewController(t)) + + svc := New(awsConfig, metric, "", WithAWSClient(client)) + msg := generateNonce(64) + encrypted := generateNonce(128) + + client.EXPECT().Encrypt(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func( + ctx context.Context, + params *kms.EncryptInput, + optFns ...func(*kms.Options), + ) (*kms.EncryptOutput, error) { + assert.Equal(t, "alias/800d5768-3fd7-4edd-a4b8-4c81c3e4c147", *params.KeyId) + assert.Equal(t, msg, params.Plaintext) + assert.Equal(t, svc.encryptionAlgo, params.EncryptionAlgorithm) + + return &kms.EncryptOutput{ + CiphertextBlob: encrypted, + }, nil + }) -func (m *mockAWSClient) Verify(ctx context.Context, params *kms.VerifyInput, - optFns ...func(*kms.Options)) (*kms.VerifyOutput, error) { - if m.verifyFunc != nil { - return m.verifyFunc(ctx, params, optFns...) - } + encryptedData, nonce, err := svc.Encrypt( + msg, + nil, + "aws-kms://arn:aws:kms:ca-central-1:111122223333:alias/800d5768-3fd7-4edd-a4b8-4c81c3e4c147", + ) - return nil, nil //nolint:nilnil -} + assert.NoError(t, err) + assert.Len(t, nonce, svc.nonceLength) + assert.Equal(t, encrypted, encryptedData) + }) -func (m *mockAWSClient) DescribeKey(ctx context.Context, params *kms.DescribeKeyInput, - optFns ...func(*kms.Options)) (*kms.DescribeKeyOutput, error) { - if m.describeKeyFunc != nil { - return m.describeKeyFunc(ctx, params, optFns...) - } + t.Run("encryption err", func(t *testing.T) { + metric := NewMockmetricsProvider(gomock.NewController(t)) + metric.EXPECT().EncryptCount() + metric.EXPECT().EncryptTime(gomock.Any()) + + client := NewMockawsClient(gomock.NewController(t)) + + svc := New(awsConfig, metric, "", WithAWSClient(client)) + msg := generateNonce(64) + + client.EXPECT().Encrypt(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func( + ctx context.Context, + params *kms.EncryptInput, + optFns ...func(*kms.Options), + ) (*kms.EncryptOutput, error) { + return nil, errors.New("encryption err") + }) + + encryptedData, nonce, err := svc.Encrypt( + msg, + nil, + "aws-kms://arn:aws:kms:ca-central-1:111122223333:alias/800d5768-3fd7-4edd-a4b8-4c81c3e4c147", + ) + + assert.ErrorContains(t, err, "encryption err") + assert.Empty(t, nonce) + assert.Empty(t, encryptedData) + }) - return nil, nil //nolint:nilnil -} + t.Run("failed to parse key id", func(t *testing.T) { + metric := NewMockmetricsProvider(gomock.NewController(t)) + metric.EXPECT().EncryptCount() + metric.EXPECT().EncryptTime(gomock.Any()) -func (m *mockAWSClient) CreateKey(ctx context.Context, params *kms.CreateKeyInput, - optFns ...func(*kms.Options)) (*kms.CreateKeyOutput, error) { - if m.createKeyFunc != nil { - return m.createKeyFunc(ctx, params, optFns...) - } + svc := New(awsConfig, metric, "", []Opts{}...) - return nil, nil //nolint:nilnil + _, _, err := svc.Encrypt(nil, nil, "aws-kms://arn:aws:kms:key1") + require.Error(t, err) + require.Contains(t, err.Error(), "extracting key id from URI failed") + }) } -func (m *mockAWSClient) CreateAlias(ctx context.Context, params *kms.CreateAliasInput, - optFns ...func(*kms.Options)) (*kms.CreateAliasOutput, error) { - if m.createAliasFunc != nil { - return m.createAliasFunc(ctx, params, optFns...) +func TestDecrypt(t *testing.T) { + awsConfig := &aws.Config{ + Region: "ca", } - return nil, nil //nolint:nilnil -} - -type mockMetrics struct{} + t.Run("success", func(t *testing.T) { + metric := NewMockmetricsProvider(gomock.NewController(t)) + metric.EXPECT().DecryptCount() + metric.EXPECT().DecryptTime(gomock.Any()) + + client := NewMockawsClient(gomock.NewController(t)) + + svc := New(awsConfig, metric, "", WithAWSClient(client)) + encrypted := generateNonce(64) + decrypted := generateNonce(128) + + client.EXPECT().Decrypt(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func( + ctx context.Context, + params *kms.DecryptInput, + optFns ...func(*kms.Options), + ) (*kms.DecryptOutput, error) { + assert.Equal(t, "alias/800d5768-3fd7-4edd-a4b8-4c81c3e4c147", *params.KeyId) + assert.Equal(t, encrypted, params.CiphertextBlob) + assert.Equal(t, svc.encryptionAlgo, params.EncryptionAlgorithm) + + return &kms.DecryptOutput{ + Plaintext: decrypted, + }, nil + }) -func (m *mockMetrics) SignCount() { -} + decryptedData, err := svc.Decrypt( + nil, + encrypted, + nil, + "aws-kms://arn:aws:kms:ca-central-1:111122223333:alias/800d5768-3fd7-4edd-a4b8-4c81c3e4c147", + ) -func (m *mockMetrics) SignTime(value time.Duration) { -} + assert.NoError(t, err) + assert.Equal(t, decrypted, decryptedData) + }) -func (m *mockMetrics) ExportPublicKeyCount() { -} + t.Run("decryption err", func(t *testing.T) { + metric := NewMockmetricsProvider(gomock.NewController(t)) + metric.EXPECT().DecryptCount() + metric.EXPECT().DecryptTime(gomock.Any()) + + client := NewMockawsClient(gomock.NewController(t)) + + svc := New(awsConfig, metric, "", WithAWSClient(client)) + msg := generateNonce(64) + + client.EXPECT().Decrypt(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func( + ctx context.Context, + params *kms.DecryptInput, + optFns ...func(*kms.Options), + ) (*kms.DecryptOutput, error) { + return nil, errors.New("encryption err") + }) + + decrypted, err := svc.Decrypt( + msg, + nil, + nil, + "aws-kms://arn:aws:kms:ca-central-1:111122223333:alias/800d5768-3fd7-4edd-a4b8-4c81c3e4c147", + ) + + assert.ErrorContains(t, err, "encryption err") + assert.Empty(t, decrypted) + }) -func (m *mockMetrics) ExportPublicKeyTime(value time.Duration) { -} + t.Run("failed to parse key id", func(t *testing.T) { + metric := NewMockmetricsProvider(gomock.NewController(t)) + metric.EXPECT().DecryptCount() + metric.EXPECT().DecryptTime(gomock.Any()) -func (m *mockMetrics) VerifyCount() { -} + svc := New(awsConfig, metric, "", []Opts{}...) -func (m *mockMetrics) VerifyTime(value time.Duration) { + _, err := svc.Decrypt(nil, nil, nil, "aws-kms://arn:aws:kms:key1") + require.Error(t, err) + require.Contains(t, err.Error(), "extracting key id from URI failed") + }) }