Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev/jgough/8503 keyvault keys #9

Merged
merged 3 commits into from
Sep 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
260 changes: 260 additions & 0 deletions azkeys/keyvault.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,260 @@
package azkeys

/**
* KeyVault implements the azure keyvault API:
*
* https://learn.microsoft.com/en-us/rest/api/keyvault/
*/

import (
"context"
"encoding/base64"
"fmt"
"strings"

"github.com/Azure/azure-sdk-for-go/services/keyvault/2016-10-01/keyvault"
"github.com/Azure/go-autorest/autorest"
"github.com/rkvst/go-rkvstcommon/logger"
)

// KeyVault is the azure keyvault client for interacting with keyvault keys
type KeyVault struct {
Name string
Authorizer autorest.Authorizer // optional, nil for production
}

// NewKeyVault creates a new keyvault client
func NewKeyVault(keyvaultURL string) *KeyVault {
kv := KeyVault{
Name: keyvaultURL,
}

return &kv
}

// GetKeyByKID gets the key by its KID
func (kv *KeyVault) GetKeyByKID(
ctx context.Context, kid string,
) (keyvault.KeyBundle, error) {

log := logger.Sugar.FromContext(ctx)
defer log.Close()

log.Infof("GetLatestKey: %s %s", kv.Name, kid)

kvClient, err := NewKvClient(kv.Authorizer)
if err != nil {
return keyvault.KeyBundle{}, err
}

keyName := GetKeyName(kid)
keyVersion := GetKeyVersion(kid)

key, err := kvClient.GetKey(ctx, kv.Name, keyName, keyVersion)
if err != nil {
return keyvault.KeyBundle{}, fmt.Errorf("failed to read key: %w", err)
}

return key, nil

}

// GetLatestKey returns the latest version of the identified key
func (kv *KeyVault) GetLatestKey(
ctx context.Context, keyID string,
) (keyvault.KeyBundle, error) {

log := logger.Sugar.FromContext(ctx)
defer log.Close()

log.Infof("GetLatestKey: %s %s", kv.Name, keyID)

kvClient, err := NewKvClient(kv.Authorizer)
if err != nil {
return keyvault.KeyBundle{}, err
}

key, err := kvClient.GetKey(ctx, kv.Name, keyID, "")
if err != nil {
return keyvault.KeyBundle{}, fmt.Errorf("failed to read key: %w", err)
}

return key, nil
}

// GetKeyVersionsKeys returns all the keys, for all the versions of the identified key.
//
// The keys returned are the public half of the asymetric keys
func (kv *KeyVault) GetKeyVersionsKeys(
ctx context.Context, keyID string,
) ([]keyvault.KeyBundle, error) {

log := logger.Sugar.FromContext(ctx)
defer log.Close()

log.Infof("GetKeyVersions: %s %s", kv.Name, keyID)

kvClient, err := NewKvClient(kv.Authorizer)
if err != nil {
return []keyvault.KeyBundle{}, err
}

pageLimit := int32(1)
keyVersions, err := kvClient.GetKeyVersions(ctx, kv.Name, keyID, &pageLimit)
if err != nil {
return []keyvault.KeyBundle{}, fmt.Errorf("failed to read key: %w", err)
}

keyVersionValues := keyVersions.Values()

keys, err := kv.getKeysFromVersions(ctx, keyVersionValues)
if err != nil {
log.Infof("failed to get key versions keys: %v", err)
return []keyvault.KeyBundle{}, err
}

for keyVersions.NotDone() {
err := keyVersions.NextWithContext(ctx)
if err != nil {
log.Infof("failed to get key versions: %v", err)
return []keyvault.KeyBundle{}, err
}

keyVersionValues = keyVersions.Values()

nextKeys, err := kv.getKeysFromVersions(ctx, keyVersionValues)
if err != nil {
log.Infof("failed to get next key versions keys: %v", err)
return []keyvault.KeyBundle{}, err
}

keys = append(keys, nextKeys...)
}

return keys, nil
}

// getKeysFromVersions gets the keys from the given key versions
func (kv *KeyVault) getKeysFromVersions(ctx context.Context, keyVersions []keyvault.KeyItem) ([]keyvault.KeyBundle, error) {

log := logger.Sugar.FromContext(ctx)
defer log.Close()

log.Infof("getKeysFromVersions")

keys := []keyvault.KeyBundle{}

for _, keyVersionValue := range keyVersions {

// if we don't have a kid we can't find the key
if keyVersionValue.Kid == nil {
continue
}

key, err := kv.GetKeyByKID(ctx, *keyVersionValue.Kid)
if err != nil {
return []keyvault.KeyBundle{}, fmt.Errorf("failed get key version: %w", err)
}

keys = append(keys, key)
}

return keys, nil
}

// GetKeyVersion gets the version of the given key
func GetKeyVersion(kid string) string {

// the kid is comprised of the {name}/{version}
kidParts := strings.Split(kid, "/")

// get the version part
return kidParts[len(kidParts)-1]

}

// GetKeyName gets the name of the given key
func GetKeyName(kid string) string {

// the kid is comprised of the {name}/{version}
kidParts := strings.Split(kid, "/")

// get the name part
return kidParts[len(kidParts)-2]
}

// Sign signs a given payload
func (kv *KeyVault) Sign(
ctx context.Context,
payload []byte,
keyID string,
keyVersion string,
algorithm keyvault.JSONWebKeySignatureAlgorithm,
) ([]byte, error) {

log := logger.Sugar.FromContext(ctx)
defer log.Close()

log.Infof("Sign: %s %s", kv.Name, keyID)

kvClient, err := NewKvClient(kv.Authorizer)
if err != nil {
return []byte{}, err
}

payloadStr := base64.URLEncoding.EncodeToString(payload)

logger.Sugar.Infof("Payload Str: %v", payloadStr)

params := keyvault.KeySignParameters{
Algorithm: algorithm,
Value: &payloadStr,
}

signatureb64, err := kvClient.Sign(ctx, kv.Name, keyID, keyVersion, params)
if err != nil {
return []byte{}, fmt.Errorf("failed toado sign payl: %w", err)
}

logger.Sugar.Infof("SignatureB64: %v", *signatureb64.Result)
signature, err := base64.URLEncoding.DecodeString(*signatureb64.Result)
return signature, err

}

// Verify verifies a given payload
func (kv *KeyVault) Verify(
ctx context.Context,
signature []byte,
digest []byte,
keyID string,
keyVersion string,
algorithm keyvault.JSONWebKeySignatureAlgorithm,
) (bool, error) {

log := logger.Sugar.FromContext(ctx)
defer log.Close()

log.Infof("Verify: %s %s", kv.Name, keyID)

kvClient, err := NewKvClient(kv.Authorizer)
if err != nil {
return false, err
}

signatureStr := base64.URLEncoding.EncodeToString(signature)
digestStr := base64.URLEncoding.EncodeToString(digest)

params := keyvault.KeyVerifyParameters{
Algorithm: algorithm,
Signature: &signatureStr,
Digest: &digestStr,
}

result, err := kvClient.Verify(ctx, kv.Name, keyID, keyVersion, params)
if err != nil {
return false, fmt.Errorf("failed to verify payload: %w", err)
}
return *result.Value, err

}
67 changes: 67 additions & 0 deletions azkeys/keyvault_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package azkeys

import (
"testing"

"github.com/stretchr/testify/assert"
)

// TestGetKeyVersion tests:
//
// 1. with a valid keyvault KID we get the key version back successfully
func TestGetKeyVersion(t *testing.T) {
type args struct {
kid string
}
tests := []struct {
name string
args args
expected string
}{
{
name: "positive",
args: args{
kid: "https://example.vault.azure.net/keys/my-key/6eee6743b34e4291807565af6b756bac",
},
expected: "6eee6743b34e4291807565af6b756bac",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {

actual := GetKeyVersion(test.args.kid)

assert.Equal(t, test.expected, actual)
})
}
}

// TestGetKeyName tests:
//
// 1. with a valid keyvault KID we get the key name back successfully
func TestGetKeyName(t *testing.T) {
type args struct {
kid string
}
tests := []struct {
name string
args args
expected string
}{
{
name: "positive",
args: args{
kid: "https://example.vault.azure.net/keys/my-key/6eee6743b34e4291807565af6b756bac",
},
expected: "my-key",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {

actual := GetKeyName(test.args.kid)

assert.Equal(t, test.expected, actual)
})
}
}
Loading