Skip to content
This repository has been archived by the owner on Aug 25, 2023. It is now read-only.

Commit

Permalink
feat: add encrypt, decrypt for aws
Browse files Browse the repository at this point in the history
  • Loading branch information
skynet2 committed Mar 22, 2023
1 parent 8f2dc35 commit c6f21d3
Show file tree
Hide file tree
Showing 2 changed files with 274 additions and 3 deletions.
104 changes: 101 additions & 3 deletions pkg/aws/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@ Copyright SecureKey Technologies Inc. All Rights Reserved.
SPDX-License-Identifier: Apache-2.0
*/

//go:generate mockgen -destination service_mocks.go -package aws -source=service.go

package aws

import (
"context"
"crypto/elliptic"
"crypto/rand"
"crypto/sha256"
"crypto/sha512"
"encoding/asn1"
Expand Down Expand Up @@ -37,11 +40,17 @@ type awsClient interface { //nolint:dupl
optFns ...func(*kms.Options)) (*kms.CreateKeyOutput, error)
CreateAlias(ctx context.Context, params *kms.CreateAliasInput,
optFns ...func(*kms.Options)) (*kms.CreateAliasOutput, error)
Encrypt(ctx context.Context, params *kms.EncryptInput, optFns ...func(*kms.Options)) (*kms.EncryptOutput, error)
Decrypt(ctx context.Context, params *kms.DecryptInput, optFns ...func(*kms.Options)) (*kms.DecryptOutput, error)
}

type metricsProvider interface {
SignCount()
EncryptCount()
DecryptCount()
SignTime(value time.Duration)
EncryptTime(value time.Duration)
DecryptTime(value time.Duration)
ExportPublicKeyCount()
ExportPublicKeyTime(value time.Duration)
VerifyCount()
Expand All @@ -58,6 +67,8 @@ type Service struct {
client awsClient
metrics metricsProvider
healthCheckKeyID string
encryptionAlgo types.EncryptionAlgorithmSpec
nonceLength int
}

const (
Expand All @@ -80,20 +91,100 @@ var keySpecToCurve = map[types.KeySpec]elliptic.Curve{
}

// New return aws service.
func New(awsConfig *aws.Config, metrics metricsProvider,
healthCheckKeyID string, opts ...Opts) *Service {
func New(
awsConfig *aws.Config,
metrics metricsProvider,
healthCheckKeyID string,
opts ...Opts,
) *Service {
options := newOpts()

for _, opt := range opts {
opt(options)
}
client := options.awsClient
if client == nil {
client = kms.NewFromConfig(*awsConfig)
}

return &Service{
options: options,
client: kms.NewFromConfig(*awsConfig),
client: client,
metrics: metrics,
healthCheckKeyID: healthCheckKeyID,
encryptionAlgo: types.EncryptionAlgorithmSpecRsaesOaepSha256,
nonceLength: 16,
}
}

// Decrypt data.
func (s *Service) Decrypt(_, aad, _ []byte, kh interface{}) ([]byte, error) {
startTime := time.Now()

defer func() {
if s.metrics != nil {
s.metrics.DecryptTime(time.Since(startTime))
}
}()

if s.metrics != nil {
s.metrics.DecryptCount()
}

keyID, err := s.getKeyID(kh.(string))
if err != nil {
return nil, err
}

input := &kms.DecryptInput{
CiphertextBlob: aad,
EncryptionAlgorithm: s.encryptionAlgo,
KeyId: aws.String(keyID),
}

resp, err := s.client.Decrypt(context.Background(), input)
if err != nil {
return nil, err
}

return resp.Plaintext, nil
}

// Encrypt data.
func (s *Service) Encrypt(
msg []byte,
_ []byte,
kh interface{},
) ([]byte, []byte, error) {
startTime := time.Now()

defer func() {
if s.metrics != nil {
s.metrics.EncryptTime(time.Since(startTime))
}
}()

if s.metrics != nil {
s.metrics.EncryptCount()
}

keyID, err := s.getKeyID(kh.(string))
if err != nil {
return nil, nil, err
}

input := &kms.EncryptInput{
KeyId: aws.String(keyID),
Plaintext: msg,
EncryptionAlgorithm: s.encryptionAlgo,
}

resp, err := s.client.Encrypt(context.Background(), input)
if err != nil {
return nil, nil, err
}

return resp.CiphertextBlob, generateNonce(s.nonceLength), nil
}

// Sign data.
Expand Down Expand Up @@ -310,6 +401,13 @@ func (s *Service) getKeyID(keyURI string) (string, error) {
return r[4], nil
}

func generateNonce(length int) []byte {
key := make([]byte, length)
_, _ = rand.Read(key)

return key
}

func hashMessage(message []byte, algorithm types.SigningAlgorithmSpec) ([]byte, error) {
var digest hash.Hash

Expand Down
173 changes: 173 additions & 0 deletions pkg/aws/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ SPDX-License-Identifier: Apache-2.0
package aws //nolint:testpackage

import (
"context"
"errors"
"fmt"
"testing"

Expand All @@ -15,6 +17,7 @@ import (
"github.com/aws/aws-sdk-go-v2/service/kms/types"
"github.com/golang/mock/gomock"
arieskms "github.com/hyperledger/aries-framework-go/pkg/kms"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -280,3 +283,173 @@ func TestPubKeyBytes(t *testing.T) {
require.Contains(t, err.Error(), "extracting key id from URI failed")
})
}

func TestEncrypt(t *testing.T) {
awsConfig := &aws.Config{
Region: "ca",
}

t.Run("success", func(t *testing.T) {
metric := NewMockmetricsProvider(gomock.NewController(t))
metric.EXPECT().EncryptCount()
metric.EXPECT().EncryptTime(gomock.Any())

client := NewMockawsClient(gomock.NewController(t))

svc := New(awsConfig, metric, "", WithAWSClient(client))
msg := generateNonce(64)
encrypted := generateNonce(128)

client.EXPECT().Encrypt(gomock.Any(), gomock.Any(), gomock.Any()).
DoAndReturn(func(
ctx context.Context,
params *kms.EncryptInput,
optFns ...func(*kms.Options),
) (*kms.EncryptOutput, error) {
assert.Equal(t, "alias/800d5768-3fd7-4edd-a4b8-4c81c3e4c147", *params.KeyId)
assert.Equal(t, msg, params.Plaintext)
assert.Equal(t, svc.encryptionAlgo, params.EncryptionAlgorithm)

return &kms.EncryptOutput{
CiphertextBlob: encrypted,
}, nil
})

encryptedData, nonce, err := svc.Encrypt(
msg,
nil,
"aws-kms://arn:aws:kms:ca-central-1:111122223333:alias/800d5768-3fd7-4edd-a4b8-4c81c3e4c147",
)

assert.NoError(t, err)
assert.Len(t, nonce, svc.nonceLength)
assert.Equal(t, encrypted, encryptedData)
})

t.Run("encryption err", func(t *testing.T) {
metric := NewMockmetricsProvider(gomock.NewController(t))
metric.EXPECT().EncryptCount()
metric.EXPECT().EncryptTime(gomock.Any())

client := NewMockawsClient(gomock.NewController(t))

svc := New(awsConfig, metric, "", WithAWSClient(client))
msg := generateNonce(64)

client.EXPECT().Encrypt(gomock.Any(), gomock.Any(), gomock.Any()).
DoAndReturn(func(
ctx context.Context,
params *kms.EncryptInput,
optFns ...func(*kms.Options),
) (*kms.EncryptOutput, error) {
return nil, errors.New("encryption err")
})

encryptedData, nonce, err := svc.Encrypt(
msg,
nil,
"aws-kms://arn:aws:kms:ca-central-1:111122223333:alias/800d5768-3fd7-4edd-a4b8-4c81c3e4c147",
)

assert.ErrorContains(t, err, "encryption err")
assert.Empty(t, nonce)
assert.Empty(t, encryptedData)
})

t.Run("failed to parse key id", func(t *testing.T) {
metric := NewMockmetricsProvider(gomock.NewController(t))
metric.EXPECT().EncryptCount()
metric.EXPECT().EncryptTime(gomock.Any())

svc := New(awsConfig, metric, "", []Opts{}...)

_, _, err := svc.Encrypt(nil, nil, "aws-kms://arn:aws:kms:key1")
require.Error(t, err)
require.Contains(t, err.Error(), "extracting key id from URI failed")
})
}

func TestDecrypt(t *testing.T) {
awsConfig := &aws.Config{
Region: "ca",
}

t.Run("success", func(t *testing.T) {
metric := NewMockmetricsProvider(gomock.NewController(t))
metric.EXPECT().DecryptCount()
metric.EXPECT().DecryptTime(gomock.Any())

client := NewMockawsClient(gomock.NewController(t))

svc := New(awsConfig, metric, "", WithAWSClient(client))
encrypted := generateNonce(64)
decrypted := generateNonce(128)

client.EXPECT().Decrypt(gomock.Any(), gomock.Any(), gomock.Any()).
DoAndReturn(func(
ctx context.Context,
params *kms.DecryptInput,
optFns ...func(*kms.Options),
) (*kms.DecryptOutput, error) {
assert.Equal(t, "alias/800d5768-3fd7-4edd-a4b8-4c81c3e4c147", *params.KeyId)
assert.Equal(t, encrypted, params.CiphertextBlob)
assert.Equal(t, svc.encryptionAlgo, params.EncryptionAlgorithm)

return &kms.DecryptOutput{
Plaintext: decrypted,
}, nil
})

decryptedData, err := svc.Decrypt(
nil,
encrypted,
nil,
"aws-kms://arn:aws:kms:ca-central-1:111122223333:alias/800d5768-3fd7-4edd-a4b8-4c81c3e4c147",
)

assert.NoError(t, err)
assert.Equal(t, decrypted, decryptedData)
})

t.Run("decryption err", func(t *testing.T) {
metric := NewMockmetricsProvider(gomock.NewController(t))
metric.EXPECT().DecryptCount()
metric.EXPECT().DecryptTime(gomock.Any())

client := NewMockawsClient(gomock.NewController(t))

svc := New(awsConfig, metric, "", WithAWSClient(client))
msg := generateNonce(64)

client.EXPECT().Decrypt(gomock.Any(), gomock.Any(), gomock.Any()).
DoAndReturn(func(
ctx context.Context,
params *kms.DecryptInput,
optFns ...func(*kms.Options),
) (*kms.DecryptOutput, error) {
return nil, errors.New("encryption err")
})

decrypted, err := svc.Decrypt(
msg,
nil,
nil,
"aws-kms://arn:aws:kms:ca-central-1:111122223333:alias/800d5768-3fd7-4edd-a4b8-4c81c3e4c147",
)

assert.ErrorContains(t, err, "encryption err")
assert.Empty(t, decrypted)
})

t.Run("failed to parse key id", func(t *testing.T) {
metric := NewMockmetricsProvider(gomock.NewController(t))
metric.EXPECT().DecryptCount()
metric.EXPECT().DecryptTime(gomock.Any())

svc := New(awsConfig, metric, "", []Opts{}...)

_, err := svc.Decrypt(nil, nil, nil, "aws-kms://arn:aws:kms:key1")
require.Error(t, err)
require.Contains(t, err.Error(), "extracting key id from URI failed")
})
}

0 comments on commit c6f21d3

Please sign in to comment.