Skip to content

Commit

Permalink
feat: make serializer used configurable
Browse files Browse the repository at this point in the history
  • Loading branch information
eaddingtonwhite committed Sep 30, 2024
1 parent 2ae48c3 commit 85a241b
Show file tree
Hide file tree
Showing 6 changed files with 356 additions and 108 deletions.
103 changes: 41 additions & 62 deletions caching/caching.go
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
package caching

import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"sort"

"github.com/momentohq/go-aws-sdk-middlewares/internal/serializer"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue"
"github.com/aws/aws-sdk-go-v2/service/dynamodb"
"github.com/aws/aws-sdk-go-v2/service/dynamodb/types"
"github.com/aws/smithy-go/middleware"
"github.com/momentohq/client-sdk-go/config/logger"
"github.com/momentohq/client-sdk-go/momento"
"github.com/momentohq/client-sdk-go/responses"
)
Expand All @@ -28,41 +27,60 @@ const (
DISABLED WritebackType = "DISABLED"
)

type MsgPackSerializer = serializer.MsgPackSerializer
type JSONSerializer = serializer.JSONSerializer

type cachingMiddleware struct {
cacheName string
momentoClient momento.CacheClient
writebackType WritebackType
asyncWriteChan chan *momento.SetBatchRequest
serializer Serializer
}

// Serializer defines the methods for serializing and deserializing data.
type Serializer interface {
Name() string
Serialize(item map[string]types.AttributeValue) ([]byte, error)
Deserialize(data []byte) (map[string]types.AttributeValue, error)
}

type MiddlewareProps struct {
AwsConfig *aws.Config
CacheName string
MomentoClient momento.CacheClient
WritebackType WritebackType
Serializer Serializer
}

func AttachNewCachingMiddleware(props MiddlewareProps) {
if props.WritebackType == "" {
props.WritebackType = SYNCHRONOUS
}

if props.Serializer == nil {
props.Serializer = serializer.JSONSerializer{}
}

props.MomentoClient.Logger().Debug("attaching Momento caching middleware with writeback type " + string(props.WritebackType))
props.AwsConfig.APIOptions = append(props.AwsConfig.APIOptions, func(stack *middleware.Stack) error {
return stack.Initialize.Add(
NewCachingMiddleware(&cachingMiddleware{
newCachingMiddleware(&cachingMiddleware{
cacheName: props.CacheName,
momentoClient: props.MomentoClient,
writebackType: props.WritebackType,
serializer: props.Serializer,
}),
middleware.Before,
)
})
}

func NewCachingMiddleware(mw *cachingMiddleware) middleware.InitializeMiddleware {
func newCachingMiddleware(mw *cachingMiddleware) middleware.InitializeMiddleware {
if mw.writebackType == ASYNCHRONOUS {
mw.startAsyncBatchWriter()
}

return middleware.InitializeMiddlewareFunc("CachingMiddleware", func(
ctx context.Context, in middleware.InitializeInput, next middleware.InitializeHandler,
) (out middleware.InitializeOutput, metadata middleware.Metadata, err error) {
Expand Down Expand Up @@ -167,7 +185,7 @@ func (d *cachingMiddleware) handleBatchGetItemCommand(ctx context.Context, input
}
gatherKeys = false
}
cacheKey, err := ComputeCacheKey(tableName, key)
cacheKey, err := ComputeCacheKey(tableName, key, d.serializer)
if err != nil {
return middleware.InitializeOutput{}, fmt.Errorf("error getting key for caching: %w", err)
}
Expand Down Expand Up @@ -195,11 +213,11 @@ func (d *cachingMiddleware) handleBatchGetItemCommand(ctx context.Context, input
switch e := element.(type) {
case *responses.GetHit:
gotHit = true
marshalMap, err := GetMarshalMap(e)
deserializedMap, err := d.serializer.Deserialize(e.ValueByte())
if err != nil {
return middleware.InitializeOutput{}, fmt.Errorf("error with marshal map: %w", err)
return middleware.InitializeOutput{}, fmt.Errorf("error with desrializing map: %w", err)
}
responsesToReturn[tableName] = append(responsesToReturn[tableName], marshalMap)
responsesToReturn[tableName] = append(responsesToReturn[tableName], deserializedMap)
case *responses.GetMiss:
gotMiss = true
if _, ok := cacheMissesPerTable[tableName]; !ok {
Expand Down Expand Up @@ -315,7 +333,7 @@ func (d *cachingMiddleware) handleGetItemCommand(ctx context.Context, input *dyn
}

// Derive a cache key from DDB request
cacheKey, err := ComputeCacheKey(*input.TableName, input.Key)
cacheKey, err := ComputeCacheKey(*input.TableName, input.Key, d.serializer)
if err != nil {
return middleware.InitializeOutput{}, fmt.Errorf("error getting key for caching: %w", err)
}
Expand All @@ -329,15 +347,15 @@ func (d *cachingMiddleware) handleGetItemCommand(ctx context.Context, input *dyn
if err == nil {
switch r := rsp.(type) {
case *responses.GetHit:
// On hit decode value from stored json to DDB attribute map
marshalMap, err := GetMarshalMap(r)
// On hit decode value from value stored in cache to a DDB attribute map
deserializedMap, err := d.serializer.Deserialize(r.ValueByte())
if err != nil {
return middleware.InitializeOutput{}, fmt.Errorf("error with marshal map: %w", err)
}
d.momentoClient.Logger().Debug("returning cached item")
// Return user spoofed dynamodb.GetItemOutput.Item w/ cached value
return struct{ Result interface{} }{Result: &dynamodb.GetItemOutput{
Item: marshalMap,
Item: deserializedMap,
}}, nil

case *responses.GetMiss:
Expand Down Expand Up @@ -368,18 +386,17 @@ func (d *cachingMiddleware) handleGetItemCommand(ctx context.Context, input *dyn
}

func (d *cachingMiddleware) writeResultToCache(ctx context.Context, ddbOutput *dynamodb.GetItemOutput, cacheKey string) {
// unmarshal raw response object to DDB attribute values map and encode as json
j, err := MarshalToJson(ddbOutput.Item, d.momentoClient.Logger())
b, err := d.serializer.Serialize(ddbOutput.Item)
if err != nil {
d.momentoClient.Logger().Warn(fmt.Sprintf("error marshalling item to json: %+v", err))
d.momentoClient.Logger().Warn(fmt.Sprintf("error serializing item: %+v", err))
}

d.momentoClient.Logger().Debug(fmt.Sprintf("caching item with key: %s", cacheKey))
// set item in momento cache
_, err = d.momentoClient.Set(ctx, &momento.SetRequest{
CacheName: d.cacheName,
Key: momento.String(cacheKey),
Value: momento.Bytes(j),
Value: momento.Bytes(b),
})
if err != nil {
d.momentoClient.Logger().Warn(
Expand Down Expand Up @@ -410,9 +427,9 @@ func (d *cachingMiddleware) prepareMomentoBatchGetRequest(ddbOutput *dynamodb.Ba
// compute and gather keys and JSON encoded items to store in Momento cache
for tableName, items := range ddbOutput.Responses {
for _, item := range items {
j, err := MarshalToJson(item, d.momentoClient.Logger())
b, err := d.serializer.Serialize(item)
if err != nil {
d.momentoClient.Logger().Warn(fmt.Sprintf("error marshalling item to json: %+v", err))
d.momentoClient.Logger().Warn(fmt.Sprintf("error seralizing item: %+v", err))
continue
}

Expand All @@ -421,7 +438,7 @@ func (d *cachingMiddleware) prepareMomentoBatchGetRequest(ddbOutput *dynamodb.Ba
for _, key := range tableToDdbKeys[tableName] {
itemForKey[key] = item[key]
}
cacheKey, err := ComputeCacheKey(tableName, itemForKey)
cacheKey, err := ComputeCacheKey(tableName, itemForKey, d.serializer)
if err != nil {
d.momentoClient.Logger().Warn(fmt.Sprintf("error getting key for caching: %+v", err))
continue
Expand All @@ -430,7 +447,7 @@ func (d *cachingMiddleware) prepareMomentoBatchGetRequest(ddbOutput *dynamodb.Ba

itemsToSet = append(itemsToSet, momento.BatchSetItem{
Key: momento.String(cacheKey),
Value: momento.Bytes(j),
Value: momento.Bytes(b),
})
}
}
Expand All @@ -440,7 +457,7 @@ func (d *cachingMiddleware) prepareMomentoBatchGetRequest(ddbOutput *dynamodb.Ba
}
}

func ComputeCacheKey(tableName string, keys map[string]types.AttributeValue) (string, error) {
func ComputeCacheKey(tableName string, keys map[string]types.AttributeValue, serializer Serializer) (string, error) {
// Marshal to attribute map
var t map[string]interface{}
err := attributevalue.UnmarshalMap(keys, &t)
Expand All @@ -461,50 +478,12 @@ func ComputeCacheKey(tableName string, keys map[string]types.AttributeValue) (st
out += fieldToValue[k]
}

// prefix key w/ table name and convert to fixed length hash
// prefix key w/ table name + serializer used and convert to fixed length hash
hash := sha256.New()
hash.Write([]byte(tableName + out))
hash.Write([]byte(tableName + serializer.Name() + out))
return hex.EncodeToString(hash.Sum(nil)), nil
}

func GetMarshalMap(r *responses.GetHit) (map[string]types.AttributeValue, error) {
// On hit decode value from stored json to DDB attribute map
var t map[string]interface{}
err := json.NewDecoder(bytes.NewReader(r.ValueByte())).Decode(&t)
if err != nil {
return nil, fmt.Errorf("error decoding json item in cache to return: %w", err)
}

// Marshal from attribute map to dynamodb.GetItemOutput.Item
marshalMap, err := attributevalue.MarshalMap(t)
if err != nil {
return nil, fmt.Errorf("error encoding item in cache to ddbItem to return: %w", err)
}
return marshalMap, nil
}

func MarshalToJson(item map[string]types.AttributeValue, logger logger.MomentoLogger) ([]byte, error) {
// unmarshal raw response object to DDB attribute values map
var t map[string]interface{}
err := attributevalue.UnmarshalMap(item, &t)
if err != nil {
logger.Warn(
fmt.Sprintf("error decoding output item to store in cache err=%+v", err),
)
return nil, fmt.Errorf("error decoding output item to store in cache err=%+v", err)
}

// Marshal to JSON to store in cache
j, err := json.Marshal(t)
if err != nil {
logger.Warn(
fmt.Sprintf("error json encoding new item to store in cache err=%+v", err),
)
return nil, fmt.Errorf("error json encoding new item to store in cache err=%+v", err)
}
return j, nil
}

// safeGetDDItemFromResponseSlice safely checks the slice of DDB item responses to avoid a panic
func safeGetDDItemFromResponseSlice(slice []map[string]types.AttributeValue, index int) (map[string]types.AttributeValue, bool) {
if index >= 0 && index < len(slice) {
Expand Down
Loading

0 comments on commit 85a241b

Please sign in to comment.