Skip to content

Commit

Permalink
feature(sdk): Support custom key splits
Browse files Browse the repository at this point in the history
Implementation of opentdf/spec#32

This is a proposal to allow customizing how a client shares key data across multiple KASes. With a split, you can copy the same share to multiple providers, allowing for robustness if a given KAS is unavailable - or if a decrypting user or application does not have authorization with that KAS.
  • Loading branch information
dmihalcik-virtru committed Jun 28, 2024
1 parent 3db97fc commit 645cf23
Show file tree
Hide file tree
Showing 8 changed files with 683 additions and 299 deletions.
75 changes: 60 additions & 15 deletions sdk/kas_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func newKASClient(dialOptions []grpc.DialOption, accessTokenSource auth.AccessTo
}

// there is no connection caching as of now
func (k *KASClient) makeRewrapRequest(keyAccess KeyAccess, policy string) (*kas.RewrapResponse, error) {
func (k *KASClient) makeRewrapRequest(ctx context.Context, keyAccess KeyAccess, policy string) (*kas.RewrapResponse, error) {
rewrapRequest, err := k.getRewrapRequest(keyAccess, policy)
if err != nil {
return nil, err
Expand All @@ -83,7 +83,6 @@ func (k *KASClient) makeRewrapRequest(keyAccess KeyAccess, policy string) (*kas.
}
defer conn.Close()

ctx := context.Background()
serviceClient := kas.NewAccessServiceClient(conn)

response, err := serviceClient.Rewrap(ctx, rewrapRequest)
Expand All @@ -94,8 +93,8 @@ func (k *KASClient) makeRewrapRequest(keyAccess KeyAccess, policy string) (*kas.
return response, nil
}

func (k *KASClient) unwrap(keyAccess KeyAccess, policy string) ([]byte, error) {
response, err := k.makeRewrapRequest(keyAccess, policy)
func (k *KASClient) unwrap(ctx context.Context, keyAccess KeyAccess, policy string) ([]byte, error) {
response, err := k.makeRewrapRequest(ctx, keyAccess, policy)
if err != nil {
return nil, fmt.Errorf("error making request to kas: %w", err)
}
Expand Down Expand Up @@ -156,7 +155,7 @@ func (k *KASClient) getNanoTDFRewrapRequest(header string, kasURL string, pubKey
return &rewrapRequest, nil
}

func (k *KASClient) makeNanoTDFRewrapRequest(header string, kasURL string, pubKey string) (*kas.RewrapResponse, error) {
func (k *KASClient) makeNanoTDFRewrapRequest(ctx context.Context, header string, kasURL string, pubKey string) (*kas.RewrapResponse, error) {
rewrapRequest, err := k.getNanoTDFRewrapRequest(header, kasURL, pubKey)
if err != nil {
return nil, err
Expand All @@ -172,7 +171,6 @@ func (k *KASClient) makeNanoTDFRewrapRequest(header string, kasURL string, pubKe
}
defer conn.Close()

ctx := context.Background()
serviceClient := kas.NewAccessServiceClient(conn)

response, err := serviceClient.Rewrap(ctx, rewrapRequest)
Expand All @@ -183,7 +181,7 @@ func (k *KASClient) makeNanoTDFRewrapRequest(header string, kasURL string, pubKe
return response, nil
}

func (k *KASClient) unwrapNanoTDF(header string, kasURL string) ([]byte, error) {
func (k *KASClient) unwrapNanoTDF(ctx context.Context, header string, kasURL string) ([]byte, error) {
keypair, err := ocrypto.NewECKeyPair(ocrypto.ECCModeSecp256r1)
if err != nil {
return nil, fmt.Errorf("ocrypto.NewECKeyPair failed :%w", err)
Expand All @@ -199,7 +197,7 @@ func (k *KASClient) unwrapNanoTDF(header string, kasURL string) ([]byte, error)
return nil, fmt.Errorf("ocrypto.NewECKeyPair.PrivateKeyInPemFormat failed :%w", err)
}

response, err := k.makeNanoTDFRewrapRequest(header, kasURL, publicKeyAsPem)
response, err := k.makeNanoTDFRewrapRequest(ctx, header, kasURL, publicKeyAsPem)
if err != nil {
return nil, fmt.Errorf("error making request to kas: %w", err)
}
Expand Down Expand Up @@ -287,26 +285,67 @@ func (k *KASClient) getRewrapRequest(keyAccess KeyAccess, policy string) (*kas.R
return &rewrapRequest, nil
}

type publicKeyWithID struct {
publicKey, kid string
type kasKeyRequest struct {
url, algorithm string
}

type timeStampedKASInfo struct {
KASInfo
time.Time
}

// Caches the most recent key info for a given KAS URL and algorithm
type kasKeyCache struct {
c map[kasKeyRequest]timeStampedKASInfo
}

func newKasKeyCache() *kasKeyCache {
return &kasKeyCache{make(map[kasKeyRequest]timeStampedKASInfo)}
}

func (c *kasKeyCache) clear() {
c.c = make(map[kasKeyRequest]timeStampedKASInfo)
}

func (c *kasKeyCache) get(url, algorithm string) *KASInfo {
cacheKey := kasKeyRequest{url, algorithm}
now := time.Now()
cv, ok := c.c[cacheKey]
if !ok {
return nil
}
ago := now.Add(-1 * time.Microsecond)
if ago.After(cv.Time) {
delete(c.c, cacheKey)
return nil
}
return &cv.KASInfo
}

func (s SDK) getPublicKey(kasInfo KASInfo) (*publicKeyWithID, error) {
grpcAddress, err := getGRPCAddress(kasInfo.URL)
func (c *kasKeyCache) store(ki KASInfo) {

Check failure on line 325 in sdk/kas_client.go

View workflow job for this annotation

GitHub Actions / go (sdk)

func `(*kasKeyCache).store` is unused (unused)
cacheKey := kasKeyRequest{ki.URL, ki.Algorithm}
c.c[cacheKey] = timeStampedKASInfo{ki, time.Now()}
}

func (s SDK) getPublicKey(url, algorithm string) (*KASInfo, error) {
if cachedValue := s.kasKeyCache.get(url, algorithm); nil != cachedValue {
return cachedValue, nil
}
grpcAddress, err := getGRPCAddress(url)
if err != nil {
return nil, err
}
conn, err := grpc.Dial(grpcAddress, s.dialOptions...)
if err != nil {
return nil, fmt.Errorf("error connecting to grpc service at %s: %w", kasInfo.URL, err)
return nil, fmt.Errorf("error connecting to grpc service at %s: %w", url, err)
}
defer conn.Close()

ctx := context.Background()
serviceClient := kas.NewAccessServiceClient(conn)

req := kas.PublicKeyRequest{
Algorithm: "rsa:2048",
Algorithm: algorithm,
}
if s.config.tdfFeatures.noKID {
req.V = "1"
Expand All @@ -322,5 +361,11 @@ func (s SDK) getPublicKey(kasInfo KASInfo) (*publicKeyWithID, error) {
kid = ""
}

return &publicKeyWithID{resp.GetPublicKey(), kid}, nil
a := KASInfo{
URL: url,
Algorithm: algorithm,
KID: kid,
PublicKey: resp.GetPublicKey(),
}
return &a, nil
}
1 change: 1 addition & 0 deletions sdk/manifest.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ type KeyAccess struct {
PolicyBinding string `json:"policyBinding"`
EncryptedMetadata string `json:"encryptedMetadata,omitempty"`
KID string `json:"kid,omitempty"`
SplitID string `json:"sid,omitempty"`
}

type Method struct {
Expand Down
7 changes: 6 additions & 1 deletion sdk/nanotdf.go
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,11 @@ func (s SDK) CreateNanoTDF(writer io.Writer, reader io.Reader, config NanoTDFCon

// ReadNanoTDF - read the nano tdf and return the decrypted data from it
func (s SDK) ReadNanoTDF(writer io.Writer, reader io.ReadSeeker) (uint32, error) {
return s.ReadNanoTDFContext(context.Background(), writer, reader)
}

// ReadNanoTDFContext - allows cancelling the reader
func (s SDK) ReadNanoTDFContext(ctx context.Context, writer io.Writer, reader io.ReadSeeker) (uint32, error) {
header, headerSize, err := NewNanoTDFHeaderFromReader(reader)
if err != nil {
return 0, err
Expand Down Expand Up @@ -782,7 +787,7 @@ func (s SDK) ReadNanoTDF(writer io.Writer, reader io.ReadSeeker) (uint32, error)
return 0, fmt.Errorf("newKASClient failed: %w", err)
}

symmetricKey, err := client.unwrapNanoTDF(string(encodedHeader), kasURL)
symmetricKey, err := client.unwrapNanoTDF(ctx, string(encodedHeader), kasURL)
if err != nil {
return 0, fmt.Errorf("readSeeker.Seek failed: %w", err)
}
Expand Down
4 changes: 4 additions & 0 deletions sdk/sdk.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ import (
)

const (
// Failure while connecting to a service.
// Check your configuration and/or retry.
ErrGrpcDialFailed = Error("failed to dial grpc endpoint")
ErrShutdownFailed = Error("failed to shutdown sdk")
ErrPlatformConfigFailed = Error("failed to retrieve platform configuration")
Expand All @@ -43,6 +45,7 @@ func (c Error) Error() string {

type SDK struct {
config
*kasKeyCache
conn *grpc.ClientConn
dialOptions []grpc.DialOption
tokenSource auth.AccessTokenSource
Expand Down Expand Up @@ -146,6 +149,7 @@ func New(platformEndpoint string, opts ...Option) (*SDK, error) {

return &SDK{
config: *cfg,
kasKeyCache: newKasKeyCache(),
conn: defaultConn,
dialOptions: dialOptions,
tokenSource: accessTokenSource,
Expand Down
Loading

0 comments on commit 645cf23

Please sign in to comment.