From cbcc576836d771636c95ac6e1a4fc0690accb24f Mon Sep 17 00:00:00 2001 From: Stas D Date: Wed, 22 Mar 2023 10:48:58 +0100 Subject: [PATCH] refactor: use gomock Signed-off-by: Stas D --- go.mod | 2 +- pkg/aws/opts.go | 6 + pkg/aws/service_mocks.go | 340 +++++++++++++++++++++++++++++++++++++++ pkg/aws/service_test.go | 308 +++++++++++++++-------------------- 4 files changed, 480 insertions(+), 176 deletions(-) create mode 100644 pkg/aws/service_mocks.go diff --git a/go.mod b/go.mod index 82b1b91..2e45115 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ go 1.18 require ( github.com/aws/aws-sdk-go v1.43.9 + github.com/aws/aws-sdk-go-v2 v1.17.3 github.com/aws/aws-sdk-go-v2/service/kms v1.20.0 github.com/btcsuite/btcd v0.22.1 github.com/golang/mock v1.6.0 @@ -29,7 +30,6 @@ require ( require ( github.com/VictoriaMetrics/fastcache v1.5.7 // indirect - github.com/aws/aws-sdk-go-v2 v1.17.3 // indirect github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.27 // indirect github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.21 // indirect github.com/aws/smithy-go v1.13.5 // indirect diff --git a/pkg/aws/opts.go b/pkg/aws/opts.go index 3d46b6d..ebf474d 100644 --- a/pkg/aws/opts.go +++ b/pkg/aws/opts.go @@ -12,6 +12,7 @@ import ( type opts struct { keyAliasPrefix string + awsClient awsClient } // NewOpts create new opts populated with environment variable. @@ -34,3 +35,8 @@ type Opts func(opts *opts) func WithKeyAliasPrefix(prefix string) Opts { return func(opts *opts) { opts.keyAliasPrefix = prefix } } + +// WithAWSClient sets custom aws client +func WithAWSClient(client awsClient) Opts { + return func(opts *opts) { opts.awsClient = client } +} diff --git a/pkg/aws/service_mocks.go b/pkg/aws/service_mocks.go new file mode 100644 index 0000000..ac88063 --- /dev/null +++ b/pkg/aws/service_mocks.go @@ -0,0 +1,340 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: service.go + +// Package aws is a generated GoMock package. +package aws + +import ( + context "context" + reflect "reflect" + time "time" + + kms "github.com/aws/aws-sdk-go-v2/service/kms" + gomock "github.com/golang/mock/gomock" +) + +// MockawsClient is a mock of awsClient interface. +type MockawsClient struct { + ctrl *gomock.Controller + recorder *MockawsClientMockRecorder +} + +// MockawsClientMockRecorder is the mock recorder for MockawsClient. +type MockawsClientMockRecorder struct { + mock *MockawsClient +} + +// NewMockawsClient creates a new mock instance. +func NewMockawsClient(ctrl *gomock.Controller) *MockawsClient { + mock := &MockawsClient{ctrl: ctrl} + mock.recorder = &MockawsClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockawsClient) EXPECT() *MockawsClientMockRecorder { + return m.recorder +} + +// CreateAlias mocks base method. +func (m *MockawsClient) CreateAlias(ctx context.Context, params *kms.CreateAliasInput, optFns ...func(*kms.Options)) (*kms.CreateAliasOutput, error) { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, params} + for _, a := range optFns { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "CreateAlias", varargs...) + ret0, _ := ret[0].(*kms.CreateAliasOutput) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateAlias indicates an expected call of CreateAlias. +func (mr *MockawsClientMockRecorder) CreateAlias(ctx, params interface{}, optFns ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, params}, optFns...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateAlias", reflect.TypeOf((*MockawsClient)(nil).CreateAlias), varargs...) +} + +// CreateKey mocks base method. +func (m *MockawsClient) CreateKey(ctx context.Context, params *kms.CreateKeyInput, optFns ...func(*kms.Options)) (*kms.CreateKeyOutput, error) { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, params} + for _, a := range optFns { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "CreateKey", varargs...) + ret0, _ := ret[0].(*kms.CreateKeyOutput) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateKey indicates an expected call of CreateKey. +func (mr *MockawsClientMockRecorder) CreateKey(ctx, params interface{}, optFns ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, params}, optFns...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateKey", reflect.TypeOf((*MockawsClient)(nil).CreateKey), varargs...) +} + +// Decrypt mocks base method. +func (m *MockawsClient) Decrypt(ctx context.Context, params *kms.DecryptInput, optFns ...func(*kms.Options)) (*kms.DecryptOutput, error) { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, params} + for _, a := range optFns { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Decrypt", varargs...) + ret0, _ := ret[0].(*kms.DecryptOutput) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Decrypt indicates an expected call of Decrypt. +func (mr *MockawsClientMockRecorder) Decrypt(ctx, params interface{}, optFns ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, params}, optFns...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Decrypt", reflect.TypeOf((*MockawsClient)(nil).Decrypt), varargs...) +} + +// DescribeKey mocks base method. +func (m *MockawsClient) DescribeKey(ctx context.Context, params *kms.DescribeKeyInput, optFns ...func(*kms.Options)) (*kms.DescribeKeyOutput, error) { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, params} + for _, a := range optFns { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "DescribeKey", varargs...) + ret0, _ := ret[0].(*kms.DescribeKeyOutput) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DescribeKey indicates an expected call of DescribeKey. +func (mr *MockawsClientMockRecorder) DescribeKey(ctx, params interface{}, optFns ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, params}, optFns...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DescribeKey", reflect.TypeOf((*MockawsClient)(nil).DescribeKey), varargs...) +} + +// Encrypt mocks base method. +func (m *MockawsClient) Encrypt(ctx context.Context, params *kms.EncryptInput, optFns ...func(*kms.Options)) (*kms.EncryptOutput, error) { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, params} + for _, a := range optFns { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Encrypt", varargs...) + ret0, _ := ret[0].(*kms.EncryptOutput) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Encrypt indicates an expected call of Encrypt. +func (mr *MockawsClientMockRecorder) Encrypt(ctx, params interface{}, optFns ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, params}, optFns...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Encrypt", reflect.TypeOf((*MockawsClient)(nil).Encrypt), varargs...) +} + +// GetPublicKey mocks base method. +func (m *MockawsClient) GetPublicKey(ctx context.Context, params *kms.GetPublicKeyInput, optFns ...func(*kms.Options)) (*kms.GetPublicKeyOutput, error) { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, params} + for _, a := range optFns { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "GetPublicKey", varargs...) + ret0, _ := ret[0].(*kms.GetPublicKeyOutput) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetPublicKey indicates an expected call of GetPublicKey. +func (mr *MockawsClientMockRecorder) GetPublicKey(ctx, params interface{}, optFns ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, params}, optFns...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPublicKey", reflect.TypeOf((*MockawsClient)(nil).GetPublicKey), varargs...) +} + +// Sign mocks base method. +func (m *MockawsClient) Sign(ctx context.Context, params *kms.SignInput, optFns ...func(*kms.Options)) (*kms.SignOutput, error) { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, params} + for _, a := range optFns { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Sign", varargs...) + ret0, _ := ret[0].(*kms.SignOutput) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Sign indicates an expected call of Sign. +func (mr *MockawsClientMockRecorder) Sign(ctx, params interface{}, optFns ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, params}, optFns...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Sign", reflect.TypeOf((*MockawsClient)(nil).Sign), varargs...) +} + +// Verify mocks base method. +func (m *MockawsClient) Verify(ctx context.Context, params *kms.VerifyInput, optFns ...func(*kms.Options)) (*kms.VerifyOutput, error) { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, params} + for _, a := range optFns { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Verify", varargs...) + ret0, _ := ret[0].(*kms.VerifyOutput) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Verify indicates an expected call of Verify. +func (mr *MockawsClientMockRecorder) Verify(ctx, params interface{}, optFns ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, params}, optFns...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Verify", reflect.TypeOf((*MockawsClient)(nil).Verify), varargs...) +} + +// MockmetricsProvider is a mock of metricsProvider interface. +type MockmetricsProvider struct { + ctrl *gomock.Controller + recorder *MockmetricsProviderMockRecorder +} + +// MockmetricsProviderMockRecorder is the mock recorder for MockmetricsProvider. +type MockmetricsProviderMockRecorder struct { + mock *MockmetricsProvider +} + +// NewMockmetricsProvider creates a new mock instance. +func NewMockmetricsProvider(ctrl *gomock.Controller) *MockmetricsProvider { + mock := &MockmetricsProvider{ctrl: ctrl} + mock.recorder = &MockmetricsProviderMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockmetricsProvider) EXPECT() *MockmetricsProviderMockRecorder { + return m.recorder +} + +// DecryptCount mocks base method. +func (m *MockmetricsProvider) DecryptCount() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DecryptCount") +} + +// DecryptCount indicates an expected call of DecryptCount. +func (mr *MockmetricsProviderMockRecorder) DecryptCount() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecryptCount", reflect.TypeOf((*MockmetricsProvider)(nil).DecryptCount)) +} + +// DecryptTime mocks base method. +func (m *MockmetricsProvider) DecryptTime(value time.Duration) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DecryptTime", value) +} + +// DecryptTime indicates an expected call of DecryptTime. +func (mr *MockmetricsProviderMockRecorder) DecryptTime(value interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecryptTime", reflect.TypeOf((*MockmetricsProvider)(nil).DecryptTime), value) +} + +// EncryptCount mocks base method. +func (m *MockmetricsProvider) EncryptCount() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "EncryptCount") +} + +// EncryptCount indicates an expected call of EncryptCount. +func (mr *MockmetricsProviderMockRecorder) EncryptCount() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EncryptCount", reflect.TypeOf((*MockmetricsProvider)(nil).EncryptCount)) +} + +// EncryptTime mocks base method. +func (m *MockmetricsProvider) EncryptTime(value time.Duration) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "EncryptTime", value) +} + +// EncryptTime indicates an expected call of EncryptTime. +func (mr *MockmetricsProviderMockRecorder) EncryptTime(value interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EncryptTime", reflect.TypeOf((*MockmetricsProvider)(nil).EncryptTime), value) +} + +// ExportPublicKeyCount mocks base method. +func (m *MockmetricsProvider) ExportPublicKeyCount() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ExportPublicKeyCount") +} + +// ExportPublicKeyCount indicates an expected call of ExportPublicKeyCount. +func (mr *MockmetricsProviderMockRecorder) ExportPublicKeyCount() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExportPublicKeyCount", reflect.TypeOf((*MockmetricsProvider)(nil).ExportPublicKeyCount)) +} + +// ExportPublicKeyTime mocks base method. +func (m *MockmetricsProvider) ExportPublicKeyTime(value time.Duration) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ExportPublicKeyTime", value) +} + +// ExportPublicKeyTime indicates an expected call of ExportPublicKeyTime. +func (mr *MockmetricsProviderMockRecorder) ExportPublicKeyTime(value interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExportPublicKeyTime", reflect.TypeOf((*MockmetricsProvider)(nil).ExportPublicKeyTime), value) +} + +// SignCount mocks base method. +func (m *MockmetricsProvider) SignCount() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SignCount") +} + +// SignCount indicates an expected call of SignCount. +func (mr *MockmetricsProviderMockRecorder) SignCount() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SignCount", reflect.TypeOf((*MockmetricsProvider)(nil).SignCount)) +} + +// SignTime mocks base method. +func (m *MockmetricsProvider) SignTime(value time.Duration) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SignTime", value) +} + +// SignTime indicates an expected call of SignTime. +func (mr *MockmetricsProviderMockRecorder) SignTime(value interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SignTime", reflect.TypeOf((*MockmetricsProvider)(nil).SignTime), value) +} + +// VerifyCount mocks base method. +func (m *MockmetricsProvider) VerifyCount() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "VerifyCount") +} + +// VerifyCount indicates an expected call of VerifyCount. +func (mr *MockmetricsProviderMockRecorder) VerifyCount() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "VerifyCount", reflect.TypeOf((*MockmetricsProvider)(nil).VerifyCount)) +} + +// VerifyTime mocks base method. +func (m *MockmetricsProvider) VerifyTime(value time.Duration) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "VerifyTime", value) +} + +// VerifyTime indicates an expected call of VerifyTime. +func (mr *MockmetricsProviderMockRecorder) VerifyTime(value interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "VerifyTime", reflect.TypeOf((*MockmetricsProvider)(nil).VerifyTime), value) +} diff --git a/pkg/aws/service_test.go b/pkg/aws/service_test.go index 6f4250d..21b0c01 100644 --- a/pkg/aws/service_test.go +++ b/pkg/aws/service_test.go @@ -7,14 +7,13 @@ SPDX-License-Identifier: Apache-2.0 package aws //nolint:testpackage import ( - "context" "fmt" "testing" - "time" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/kms" "github.com/aws/aws-sdk-go-v2/service/kms/types" + "github.com/golang/mock/gomock" arieskms "github.com/hyperledger/aries-framework-go/pkg/kms" "github.com/stretchr/testify/require" ) @@ -25,22 +24,25 @@ func TestSign(t *testing.T) { } t.Run("success", func(t *testing.T) { - svc := New(awsConfig, &mockMetrics{}, "", []Opts{}...) + metric := NewMockmetricsProvider(gomock.NewController(t)) + metric.EXPECT().SignCount() + metric.EXPECT().SignTime(gomock.Any()) - svc.client = &mockAWSClient{signFunc: func(ctx context.Context, params *kms.SignInput, - optFns ...func(*kms.Options)) (*kms.SignOutput, error) { - return &kms.SignOutput{ + client := NewMockawsClient(gomock.NewController(t)) + client.EXPECT().Sign(gomock.Any(), gomock.Any(), gomock.Any()). + Return(&kms.SignOutput{ Signature: []byte("data"), - }, nil - }, describeKeyFunc: func(ctx context.Context, params *kms.DescribeKeyInput, - optFns ...func(*kms.Options)) (*kms.DescribeKeyOutput, error) { - return &kms.DescribeKeyOutput{ + }, nil) + + client.EXPECT().DescribeKey(gomock.Any(), gomock.Any(), gomock.Any()). + Return(&kms.DescribeKeyOutput{ KeyMetadata: &types.KeyMetadata{ SigningAlgorithms: []types.SigningAlgorithmSpec{types.SigningAlgorithmSpecEcdsaSha256}, KeySpec: types.KeySpecEccNistP256, }, - }, nil - }} + }, nil) + + svc := New(awsConfig, metric, "", WithAWSClient(client)) signature, err := svc.Sign([]byte("msg"), "aws-kms://arn:aws:kms:ca-central-1:111122223333:alias/800d5768-3fd7-4edd-a4b8-4c81c3e4c147") @@ -49,19 +51,21 @@ func TestSign(t *testing.T) { }) t.Run("failed to sign", func(t *testing.T) { - svc := New(awsConfig, &mockMetrics{}, "", []Opts{}...) - - svc.client = &mockAWSClient{signFunc: func(ctx context.Context, params *kms.SignInput, - optFns ...func(*kms.Options)) (*kms.SignOutput, error) { - return nil, fmt.Errorf("failed to sign") - }, describeKeyFunc: func(ctx context.Context, params *kms.DescribeKeyInput, - optFns ...func(*kms.Options)) (*kms.DescribeKeyOutput, error) { - return &kms.DescribeKeyOutput{ + metric := NewMockmetricsProvider(gomock.NewController(t)) + metric.EXPECT().SignCount() + metric.EXPECT().SignTime(gomock.Any()) + + client := NewMockawsClient(gomock.NewController(t)) + client.EXPECT().Sign(gomock.Any(), gomock.Any(), gomock.Any()). + Return(nil, fmt.Errorf("failed to sign")) + client.EXPECT().DescribeKey(gomock.Any(), gomock.Any(), gomock.Any()). + Return(&kms.DescribeKeyOutput{ KeyMetadata: &types.KeyMetadata{ SigningAlgorithms: []types.SigningAlgorithmSpec{types.SigningAlgorithmSpecEcdsaSha256}, }, - }, nil - }} + }, nil) + + svc := New(awsConfig, metric, "", WithAWSClient(client)) _, err := svc.Sign([]byte("msg"), "aws-kms://arn:aws:kms:ca-central-1:111122223333:key/800d5768-3fd7-4edd-a4b8-4c81c3e4c147") @@ -70,8 +74,11 @@ func TestSign(t *testing.T) { }) t.Run("failed to parse key id", func(t *testing.T) { - svc := New(awsConfig, &mockMetrics{}, "", []Opts{}...) + metric := NewMockmetricsProvider(gomock.NewController(t)) + metric.EXPECT().SignCount() + metric.EXPECT().SignTime(gomock.Any()) + svc := New(awsConfig, metric, "", []Opts{}...) _, err := svc.Sign([]byte("msg"), "aws-kms://arn:aws:kms:key1") require.Error(t, err) require.Contains(t, err.Error(), "extracting key id from URI failed") @@ -84,28 +91,29 @@ func TestHealthCheck(t *testing.T) { } t.Run("success", func(t *testing.T) { - svc := New(awsConfig, &mockMetrics{}, - "aws-kms://arn:aws:kms:ca-central-1:111122223333:key/800d5768-3fd7-4edd-a4b8-4c81c3e4c147", - []Opts{}...) + metric := NewMockmetricsProvider(gomock.NewController(t)) + client := NewMockawsClient(gomock.NewController(t)) + client.EXPECT().DescribeKey(gomock.Any(), gomock.Any(), gomock.Any()). + Return(&kms.DescribeKeyOutput{}, nil) - svc.client = &mockAWSClient{describeKeyFunc: func(ctx context.Context, params *kms.DescribeKeyInput, - optFns ...func(*kms.Options)) (*kms.DescribeKeyOutput, error) { - return &kms.DescribeKeyOutput{}, nil - }} + svc := New(awsConfig, metric, + "aws-kms://arn:aws:kms:ca-central-1:111122223333:key/800d5768-3fd7-4edd-a4b8-4c81c3e4c147", + WithAWSClient(client), + ) err := svc.HealthCheck() require.NoError(t, err) }) t.Run("failed to list keys", func(t *testing.T) { - svc := New(awsConfig, &mockMetrics{}, - "aws-kms://arn:aws:kms:ca-central-1:111122223333:key/800d5768-3fd7-4edd-a4b8-4c81c3e4c147", - []Opts{}...) + metric := NewMockmetricsProvider(gomock.NewController(t)) + client := NewMockawsClient(gomock.NewController(t)) + client.EXPECT().DescribeKey(gomock.Any(), gomock.Any(), gomock.Any()). + Return(nil, fmt.Errorf("failed to list keys")) - svc.client = &mockAWSClient{describeKeyFunc: func(ctx context.Context, params *kms.DescribeKeyInput, - optFns ...func(*kms.Options)) (*kms.DescribeKeyOutput, error) { - return nil, fmt.Errorf("failed to list keys") - }} + svc := New(awsConfig, metric, + "aws-kms://arn:aws:kms:ca-central-1:111122223333:key/800d5768-3fd7-4edd-a4b8-4c81c3e4c147", + WithAWSClient(client)) err := svc.HealthCheck() require.Error(t, err) @@ -119,14 +127,19 @@ func TestCreate(t *testing.T) { } t.Run("success", func(t *testing.T) { - svc := New(awsConfig, &mockMetrics{}, "", []Opts{}...) - keyID := "key1" - svc.client = &mockAWSClient{createKeyFunc: func(ctx context.Context, params *kms.CreateKeyInput, - optFns ...func(*kms.Options)) (*kms.CreateKeyOutput, error) { - return &kms.CreateKeyOutput{KeyMetadata: &types.KeyMetadata{KeyId: &keyID}}, nil - }} + metric := NewMockmetricsProvider(gomock.NewController(t)) + client := NewMockawsClient(gomock.NewController(t)) + client.EXPECT().CreateKey(gomock.Any(), gomock.Any(), gomock.Any()). + Return(&kms.CreateKeyOutput{KeyMetadata: &types.KeyMetadata{KeyId: &keyID}}, nil) + + svc := New(awsConfig, metric, "", WithAWSClient(client)) + + //svc.client = &mockAWSClient{createKeyFunc: func(ctx context.Context, params *kms.CreateKeyInput, + // optFns ...func(*kms.Options)) (*kms.CreateKeyOutput, error) { + // return &kms.CreateKeyOutput{KeyMetadata: &types.KeyMetadata{KeyId: &keyID}}, nil + //}} result, _, err := svc.Create(arieskms.ECDSAP256DER) require.NoError(t, err) @@ -134,20 +147,19 @@ func TestCreate(t *testing.T) { }) t.Run("success: with key alias prefix", func(t *testing.T) { - svc := New(awsConfig, &mockMetrics{}, "", WithKeyAliasPrefix("dummyKeyAlias")) - keyID := "key1" - svc.client = &mockAWSClient{ - createKeyFunc: func(ctx context.Context, params *kms.CreateKeyInput, - optFns ...func(*kms.Options)) (*kms.CreateKeyOutput, error) { - return &kms.CreateKeyOutput{KeyMetadata: &types.KeyMetadata{KeyId: &keyID}}, nil - }, - createAliasFunc: func(ctx context.Context, params *kms.CreateAliasInput, - optFns ...func(*kms.Options)) (*kms.CreateAliasOutput, error) { - return &kms.CreateAliasOutput{}, nil - }, - } + metric := NewMockmetricsProvider(gomock.NewController(t)) + client := NewMockawsClient(gomock.NewController(t)) + client.EXPECT().CreateKey(gomock.Any(), gomock.Any(), gomock.Any()). + Return(&kms.CreateKeyOutput{KeyMetadata: &types.KeyMetadata{KeyId: &keyID}}, nil) + client.EXPECT().CreateAlias(gomock.Any(), gomock.Any(), gomock.Any()). + Return(&kms.CreateAliasOutput{}, nil) + + svc := New(awsConfig, metric, "", + WithKeyAliasPrefix("dummyKeyAlias"), + WithAWSClient(client), + ) result, _, err := svc.Create(arieskms.ECDSAP256DER) require.NoError(t, err) @@ -155,7 +167,9 @@ func TestCreate(t *testing.T) { }) t.Run("key not supported", func(t *testing.T) { - svc := New(awsConfig, &mockMetrics{}, "", []Opts{}...) + metric := NewMockmetricsProvider(gomock.NewController(t)) + + svc := New(awsConfig, metric, "", []Opts{}...) _, _, err := svc.Create(arieskms.ED25519) require.Error(t, err) @@ -167,9 +181,10 @@ func TestGet(t *testing.T) { awsConfig := aws.Config{ Region: "ca", } + metric := NewMockmetricsProvider(gomock.NewController(t)) t.Run("success", func(t *testing.T) { - svc := New(&awsConfig, &mockMetrics{}, "", []Opts{}...) + svc := New(&awsConfig, metric, "", []Opts{}...) keyID, err := svc.Get("key1") require.NoError(t, err) @@ -184,22 +199,32 @@ func TestCreateAndPubKeyBytes(t *testing.T) { t.Run("success", func(t *testing.T) { keyID := "aws-kms://arn:aws:kms:ca-central-1:111122223333:key/800d5768-3fd7-4edd-a4b8-4c81c3e4c147" - - svc := New(&awsConfig, &mockMetrics{}, "", []Opts{}...) - - svc.client = &mockAWSClient{ - getPublicKeyFunc: func(ctx context.Context, params *kms.GetPublicKeyInput, - optFns ...func(*kms.Options)) (*kms.GetPublicKeyOutput, error) { - return &kms.GetPublicKeyOutput{ - PublicKey: []byte("publickey"), - SigningAlgorithms: []types.SigningAlgorithmSpec{types.SigningAlgorithmSpecEcdsaSha256}, - }, nil - }, - createKeyFunc: func(ctx context.Context, params *kms.CreateKeyInput, - optFns ...func(*kms.Options)) (*kms.CreateKeyOutput, error) { - return &kms.CreateKeyOutput{KeyMetadata: &types.KeyMetadata{KeyId: &keyID}}, nil - }, - } + metric := NewMockmetricsProvider(gomock.NewController(t)) + metric.EXPECT().ExportPublicKeyCount() + metric.EXPECT().ExportPublicKeyTime(gomock.Any()) + client := NewMockawsClient(gomock.NewController(t)) + client.EXPECT().GetPublicKey(gomock.Any(), gomock.Any(), gomock.Any()). + Return(&kms.GetPublicKeyOutput{ + PublicKey: []byte("publickey"), + SigningAlgorithms: []types.SigningAlgorithmSpec{types.SigningAlgorithmSpecEcdsaSha256}, + }, nil) + client.EXPECT().CreateKey(gomock.Any(), gomock.Any(), gomock.Any()). + Return(&kms.CreateKeyOutput{KeyMetadata: &types.KeyMetadata{KeyId: &keyID}}, nil) + svc := New(&awsConfig, metric, "", WithAWSClient(client)) + + //svc.client = &mockAWSClient{ + // getPublicKeyFunc: func(ctx context.Context, params *kms.GetPublicKeyInput, + // optFns ...func(*kms.Options)) (*kms.GetPublicKeyOutput, error) { + // return &kms.GetPublicKeyOutput{ + // PublicKey: []byte("publickey"), + // SigningAlgorithms: []types.SigningAlgorithmSpec{types.SigningAlgorithmSpecEcdsaSha256}, + // }, nil + // }, + // createKeyFunc: func(ctx context.Context, params *kms.CreateKeyInput, + // optFns ...func(*kms.Options)) (*kms.CreateKeyOutput, error) { + // return &kms.CreateKeyOutput{KeyMetadata: &types.KeyMetadata{KeyId: &keyID}}, nil + // }, + //} keyID, publicKey, err := svc.CreateAndExportPubKeyBytes(arieskms.ECDSAP256DER) require.NoError(t, err) @@ -212,8 +237,9 @@ func TestSignMulti(t *testing.T) { awsConfig := aws.Config{ Region: "ca", } + metric := NewMockmetricsProvider(gomock.NewController(t)) - svc := New(&awsConfig, &mockMetrics{}, "", []Opts{}...) + svc := New(&awsConfig, metric, "", []Opts{}...) _, err := svc.SignMulti(nil, nil) require.Error(t, err) @@ -226,15 +252,25 @@ func TestPubKeyBytes(t *testing.T) { } t.Run("success", func(t *testing.T) { - svc := New(awsConfig, &mockMetrics{}, "", []Opts{}...) + metric := NewMockmetricsProvider(gomock.NewController(t)) + metric.EXPECT().ExportPublicKeyCount() + metric.EXPECT().ExportPublicKeyTime(gomock.Any()) - svc.client = &mockAWSClient{getPublicKeyFunc: func(ctx context.Context, params *kms.GetPublicKeyInput, - optFns ...func(*kms.Options)) (*kms.GetPublicKeyOutput, error) { - return &kms.GetPublicKeyOutput{ + client := NewMockawsClient(gomock.NewController(t)) + client.EXPECT().GetPublicKey(gomock.Any(), gomock.Any(), gomock.Any()). + Return(&kms.GetPublicKeyOutput{ PublicKey: []byte("publickey"), SigningAlgorithms: []types.SigningAlgorithmSpec{types.SigningAlgorithmSpecEcdsaSha256}, - }, nil - }} + }, nil) + svc := New(awsConfig, metric, "", WithAWSClient(client)) + + //svc.client = &mockAWSClient{getPublicKeyFunc: func(ctx context.Context, params *kms.GetPublicKeyInput, + // optFns ...func(*kms.Options)) (*kms.GetPublicKeyOutput, error) { + // return &kms.GetPublicKeyOutput{ + // PublicKey: []byte("publickey"), + // SigningAlgorithms: []types.SigningAlgorithmSpec{types.SigningAlgorithmSpecEcdsaSha256}, + // }, nil + //}} keyID, keyType, err := svc.ExportPubKeyBytes( "aws-kms://arn:aws:kms:ca-central-1:111122223333:key/800d5768-3fd7-4edd-a4b8-4c81c3e4c147") @@ -244,12 +280,19 @@ func TestPubKeyBytes(t *testing.T) { }) t.Run("failed to export public key", func(t *testing.T) { - svc := New(awsConfig, &mockMetrics{}, "", []Opts{}...) + metric := NewMockmetricsProvider(gomock.NewController(t)) + metric.EXPECT().ExportPublicKeyCount() + metric.EXPECT().ExportPublicKeyTime(gomock.Any()) - svc.client = &mockAWSClient{getPublicKeyFunc: func(ctx context.Context, params *kms.GetPublicKeyInput, - optFns ...func(*kms.Options)) (*kms.GetPublicKeyOutput, error) { - return nil, fmt.Errorf("failed to export public key") - }} + client := NewMockawsClient(gomock.NewController(t)) + client.EXPECT().GetPublicKey(gomock.Any(), gomock.Any(), gomock.Any()). + Return(nil, fmt.Errorf("failed to export public key")) + svc := New(awsConfig, metric, "", WithAWSClient(client)) + + //svc.client = &mockAWSClient{getPublicKeyFunc: func(ctx context.Context, params *kms.GetPublicKeyInput, + // optFns ...func(*kms.Options)) (*kms.GetPublicKeyOutput, error) { + // return nil, fmt.Errorf("failed to export public key") + //}} _, _, err := svc.ExportPubKeyBytes( "aws-kms://arn:aws:kms:ca-central-1:111122223333:key/800d5768-3fd7-4edd-a4b8-4c81c3e4c147") @@ -258,99 +301,14 @@ func TestPubKeyBytes(t *testing.T) { }) t.Run("failed to parse key id", func(t *testing.T) { - svc := New(awsConfig, &mockMetrics{}, "", []Opts{}...) + metric := NewMockmetricsProvider(gomock.NewController(t)) + metric.EXPECT().ExportPublicKeyCount() + metric.EXPECT().ExportPublicKeyTime(gomock.Any()) + + svc := New(awsConfig, metric, "", []Opts{}...) _, _, err := svc.ExportPubKeyBytes("aws-kms://arn:aws:kms:key1") require.Error(t, err) require.Contains(t, err.Error(), "extracting key id from URI failed") }) } - -type mockAWSClient struct { - signFunc func(ctx context.Context, params *kms.SignInput, - optFns ...func(*kms.Options)) (*kms.SignOutput, error) - getPublicKeyFunc func(ctx context.Context, params *kms.GetPublicKeyInput, - optFns ...func(*kms.Options)) (*kms.GetPublicKeyOutput, error) - verifyFunc func(ctx context.Context, params *kms.VerifyInput, - optFns ...func(*kms.Options)) (*kms.VerifyOutput, error) - describeKeyFunc func(ctx context.Context, params *kms.DescribeKeyInput, - optFns ...func(*kms.Options)) (*kms.DescribeKeyOutput, error) - createKeyFunc func(ctx context.Context, params *kms.CreateKeyInput, - optFns ...func(*kms.Options)) (*kms.CreateKeyOutput, error) - createAliasFunc func(ctx context.Context, params *kms.CreateAliasInput, - optFns ...func(*kms.Options)) (*kms.CreateAliasOutput, error) -} - -func (m *mockAWSClient) Sign(ctx context.Context, params *kms.SignInput, - optFns ...func(*kms.Options)) (*kms.SignOutput, error) { - if m.signFunc != nil { - return m.signFunc(ctx, params, optFns...) - } - - return nil, nil //nolint:nilnil -} - -func (m *mockAWSClient) GetPublicKey(ctx context.Context, params *kms.GetPublicKeyInput, - optFns ...func(*kms.Options)) (*kms.GetPublicKeyOutput, error) { - if m.getPublicKeyFunc != nil { - return m.getPublicKeyFunc(ctx, params, optFns...) - } - - return nil, nil //nolint:nilnil -} - -func (m *mockAWSClient) Verify(ctx context.Context, params *kms.VerifyInput, - optFns ...func(*kms.Options)) (*kms.VerifyOutput, error) { - if m.verifyFunc != nil { - return m.verifyFunc(ctx, params, optFns...) - } - - return nil, nil //nolint:nilnil -} - -func (m *mockAWSClient) DescribeKey(ctx context.Context, params *kms.DescribeKeyInput, - optFns ...func(*kms.Options)) (*kms.DescribeKeyOutput, error) { - if m.describeKeyFunc != nil { - return m.describeKeyFunc(ctx, params, optFns...) - } - - return nil, nil //nolint:nilnil -} - -func (m *mockAWSClient) CreateKey(ctx context.Context, params *kms.CreateKeyInput, - optFns ...func(*kms.Options)) (*kms.CreateKeyOutput, error) { - if m.createKeyFunc != nil { - return m.createKeyFunc(ctx, params, optFns...) - } - - return nil, nil //nolint:nilnil -} - -func (m *mockAWSClient) CreateAlias(ctx context.Context, params *kms.CreateAliasInput, - optFns ...func(*kms.Options)) (*kms.CreateAliasOutput, error) { - if m.createAliasFunc != nil { - return m.createAliasFunc(ctx, params, optFns...) - } - - return nil, nil //nolint:nilnil -} - -type mockMetrics struct{} - -func (m *mockMetrics) SignCount() { -} - -func (m *mockMetrics) SignTime(value time.Duration) { -} - -func (m *mockMetrics) ExportPublicKeyCount() { -} - -func (m *mockMetrics) ExportPublicKeyTime(value time.Duration) { -} - -func (m *mockMetrics) VerifyCount() { -} - -func (m *mockMetrics) VerifyTime(value time.Duration) { -}