diff --git a/.github/workflows/manual-release.yml b/.github/workflows/manual-release.yml new file mode 100644 index 0000000..d7a7d3c --- /dev/null +++ b/.github/workflows/manual-release.yml @@ -0,0 +1,88 @@ +name: Manual Release + +on: + # Allows you to run this workflow manually from the Actions tab + workflow_dispatch: + +jobs: + release: + needs: [test] + runs-on: ubuntu-24.04 + outputs: + version: ${{ steps.release.outputs.release }} + steps: + - uses: actions/checkout@v3 + - name: Set release + id: semrel + uses: go-semantic-release/action@v1 + with: + github-token: ${{ secrets.MOMENTO_MACHINE_USER_GITHUB_TOKEN }} + force-bump-patch-version: true + + - name: Output release + id: release + run: echo "release=${{ steps.semrel.outputs.version }}" >> $GITHUB_OUTPUT + + test: + runs-on: ubuntu-24.04 + permissions: + contents: read + pull-requests: read + env: + MOMENTO_API_KEY: ${{ secrets.ALPHA_TEST_AUTH_TOKEN }} + steps: + - name: Setup repo + uses: actions/checkout@v3 + + - name: Install Go + uses: actions/setup-go@v3 + with: + go-version: 1.19.x + + - name: Install devtools + run: make install-devtools + + - name: Lint + run: make lint + + - name: golangci-lint + uses: golangci/golangci-lint-action@v3 + with: + version: latest + only-new-issues: true + skip-pkg-cache: true + + # If there are any diffs from goimports or go mod tidy, fail. + - name: Verify no changes from goimports and go mod tidy + run: | + if [ -n "$(git status --porcelain)" ]; then + git diff + exit 1 + fi + + - name: Build + run: make build + + - name: Run test + run: make test + + publish: + needs: [test, release] + runs-on: ubuntu-24.04 + steps: + - name: Setup repo + uses: actions/checkout@v3 + + - name: Publish package + run: | + set -e + set -x + export MOMENTO_VERSION="${{needs.release.outputs.version}}" + if [ -z "$MOMENTO_VERSION"] + then + echo "Unable to determine version! Exiting!" + exit 1 + fi + echo "MOMENTO_VERSION=${MOMENTO_VERSION}" + GOPROXY=proxy.golang.org go list -m github.com/momentohq/go-aws-sdk-middlewares@v${MOMENTO_VERSION} + shell: bash diff --git a/caching/caching.go b/caching/caching.go index 4842f5d..bd18849 100644 --- a/caching/caching.go +++ b/caching/caching.go @@ -20,17 +20,38 @@ import ( "github.com/momentohq/client-sdk-go/responses" ) +type WritebackType string + +const ( + SYNCHRONOUS WritebackType = "SYNCHRONOUS" + ASYNCHRONOUS WritebackType = "ASYNCHRONOUS" + DISABLED WritebackType = "DISABLED" +) + type cachingMiddleware struct { cacheName string momentoClient momento.CacheClient + writebackType WritebackType +} + +type MiddlewareProps struct { + AwsConfig *aws.Config + CacheName string + MomentoClient momento.CacheClient + WritebackType WritebackType } -func AttachNewCachingMiddleware(cfg *aws.Config, cacheName string, client momento.CacheClient) { - cfg.APIOptions = append(cfg.APIOptions, func(stack *middleware.Stack) error { +func AttachNewCachingMiddleware(props MiddlewareProps) { + if props.WritebackType == "" { + props.WritebackType = SYNCHRONOUS + } + props.MomentoClient.Logger().Debug("attaching Momento caching middleware with writeback type " + string(props.WritebackType)) + props.AwsConfig.APIOptions = append(props.AwsConfig.APIOptions, func(stack *middleware.Stack) error { return stack.Initialize.Add( NewCachingMiddleware(&cachingMiddleware{ - cacheName: cacheName, - momentoClient: client, + cacheName: props.CacheName, + momentoClient: props.MomentoClient, + writebackType: props.WritebackType, }), middleware.Before, ) @@ -70,18 +91,20 @@ func NewCachingMiddleware(mw *cachingMiddleware) middleware.InitializeMiddleware } func (d *cachingMiddleware) handleBatchGetItemCommand(ctx context.Context, input *dynamodb.BatchGetItemInput, in middleware.InitializeInput, next middleware.InitializeHandler) (middleware.InitializeOutput, error) { + if len(input.RequestItems) > 100 { + return middleware.InitializeOutput{}, errors.New("request items exceeded maximum of 100") + } + // we gather all responses from both backends in this variable to return to the user as a DDB response responsesToReturn := make(map[string][]map[string]types.AttributeValue) - // for now any cache miss is considered a miss for the whole batch - gotMiss := false // this holds the query keys for each DDB table in the request and is used to compute the cache key for // cache set operations tableToDdbKeys := make(map[string][]string) // this holds the computed Momento cache keys for each item in the request, per table tableToCacheKeys := make(map[string][]momento.Key) + cacheMissesPerTable := make(map[string]int) // gather cache keys for batch get from Momento cache for tableName, keys := range input.RequestItems { - // TODO: it may be preferable/safer to query the table for the keys gatherKeys := false if tableToDdbKeys[tableName] == nil { tableToDdbKeys[tableName] = []string{} @@ -101,14 +124,17 @@ func (d *cachingMiddleware) handleBatchGetItemCommand(ctx context.Context, input gatherKeys = false } cacheKey, err := ComputeCacheKey(tableName, key) - d.momentoClient.Logger().Debug("computed cache key for batch get retrieval: %s", cacheKey) if err != nil { return middleware.InitializeOutput{}, fmt.Errorf("error getting key for caching: %w", err) } + d.momentoClient.Logger().Debug("computed cache key for batch get retrieval: %s", cacheKey) tableToCacheKeys[tableName] = append(tableToCacheKeys[tableName], momento.String(cacheKey)) } } + gotMiss := false + gotHit := false + // Batch get from Momento cache and gather response data for tableName := range tableToCacheKeys { getResp, err := d.momentoClient.GetBatch(ctx, &momento.GetBatchRequest{ @@ -124,6 +150,7 @@ func (d *cachingMiddleware) handleBatchGetItemCommand(ctx context.Context, input for _, element := range r.Results() { switch e := element.(type) { case *responses.GetHit: + gotHit = true marshalMap, err := GetMarshalMap(e) if err != nil { return middleware.InitializeOutput{}, fmt.Errorf("error with marshal map: %w", err) @@ -131,15 +158,20 @@ func (d *cachingMiddleware) handleBatchGetItemCommand(ctx context.Context, input responsesToReturn[tableName] = append(responsesToReturn[tableName], marshalMap) case *responses.GetMiss: gotMiss = true - } - if gotMiss { - break + if _, ok := cacheMissesPerTable[tableName]; !ok { + cacheMissesPerTable[tableName] = 0 + } + cacheMissesPerTable[tableName]++ + responsesToReturn[tableName] = append(responsesToReturn[tableName], nil) } } } + if cacheMissesPerTable[tableName] > 0 { + d.momentoClient.Logger().Debug(fmt.Sprintf("got %d misses for table '%s'", cacheMissesPerTable[tableName], tableName)) + } } - // If we didn't get a miss, return the responses + // If we didn't get any misses, we return the entire response from the cache if !gotMiss { d.momentoClient.Logger().Debug("returning cached batch get responses") return middleware.InitializeOutput{ @@ -149,79 +181,97 @@ func (d *cachingMiddleware) handleBatchGetItemCommand(ctx context.Context, input }, nil } - d.momentoClient.Logger().Debug("returning DynamoDB response") - // On MISS Let middleware chains continue, so we can get result and try to cache it + // We got some misses, so there's still work for DDB to do + var newDdbRequest *dynamodb.BatchGetItemInput + if !gotHit { + // We didn't get any cache hits, so the new request is the old request + newDdbRequest = input + } else { + // compose a new DDB request with only the cache misses + newDdbRequest = &dynamodb.BatchGetItemInput{ + RequestItems: map[string]types.KeysAndAttributes{}, + } + for tableName, keys := range responsesToReturn { + if _, ok := newDdbRequest.RequestItems[tableName]; !ok { + newDdbRequest.RequestItems[tableName] = types.KeysAndAttributes{ + Keys: make([]map[string]types.AttributeValue, cacheMissesPerTable[tableName]), + } + } + missIdx := 0 + for idx, key := range keys { + if key == nil { + newDdbRequest.RequestItems[tableName].Keys[missIdx] = input.RequestItems[tableName].Keys[idx] + missIdx++ + } + } + } + } + + // re-issue the DDB request with only the cache misses + d.momentoClient.Logger().Debug("requesting items from DynamoDB") + // toReturn will be the final output to return to the user + toReturn := middleware.InitializeOutput{} + // replace the original DDB request with the new one + in.Parameters = newDdbRequest out, _, err := next.HandleInitialize(ctx, in) if err == nil { switch o := out.Result.(type) { case *dynamodb.BatchGetItemOutput: - var itemsToSet []momento.BatchSetItem - // compute and gather keys and JSON encoded items to store in Momento cache - for tableName, items := range o.Responses { - for _, item := range items { - j, err := MarshalToJson(item, d.momentoClient.Logger()) - if err != nil { - return out, err // don't return error + // check DDB responses and stitch them together with the cache responses + if !gotHit { + // if we got all misses, we can just return the DDB response after caching it + toReturn = out + } else { + for tableName, items := range responsesToReturn { + ddbResponseIdx := 0 + for idx, item := range items { + if item == nil { + responsesToReturn[tableName][idx] = o.Responses[tableName][ddbResponseIdx] + ddbResponseIdx++ + } } - - // extract the keys from the item to compute the hash key - itemForKey := map[string]types.AttributeValue{} - for _, key := range tableToDdbKeys[tableName] { - itemForKey[key] = item[key] - } - cacheKey, err := ComputeCacheKey(tableName, itemForKey) - d.momentoClient.Logger().Debug("computed cache key for batch get storage: %s", cacheKey) - if err != nil { - return middleware.InitializeOutput{}, fmt.Errorf("error getting key for caching: %w", err) - } - itemsToSet = append(itemsToSet, momento.BatchSetItem{ - Key: momento.String(cacheKey), - Value: momento.Bytes(j), - }) + } + toReturn = middleware.InitializeOutput{ + Result: &dynamodb.BatchGetItemOutput{ + Responses: responsesToReturn, + }, } } - // set item batch in Momento cache - _, err = d.momentoClient.SetBatch(ctx, &momento.SetBatchRequest{ - CacheName: d.cacheName, - Items: itemsToSet, - }) - if err != nil { - d.momentoClient.Logger().Warn( - fmt.Sprintf("error storing item batch in cache err=%+v", err), - ) - return out, nil // don't return err + + if d.writebackType == SYNCHRONOUS { + d.writeBatchResultsToCache(ctx, o, tableToDdbKeys) + } else if d.writebackType == ASYNCHRONOUS { + go d.writeBatchResultsToCache(ctx, o, tableToDdbKeys) } + } } - // unsupported output just return output and dont do anything - return out, err + return toReturn, err } func (d *cachingMiddleware) handleGetItemCommand(ctx context.Context, input *dynamodb.GetItemInput, in middleware.InitializeInput, next middleware.InitializeHandler) (middleware.InitializeOutput, error) { - - // Derive a cache key from DDB request + if input.ConsistentRead != nil { + return middleware.InitializeOutput{}, errors.New("consistent read not supported with caching middleware") + } if input.TableName == nil { return middleware.InitializeOutput{}, errors.New("error table name not set on get-item request") } + + // Derive a cache key from DDB request cacheKey, err := ComputeCacheKey(*input.TableName, input.Key) if err != nil { return middleware.InitializeOutput{}, fmt.Errorf("error getting key for caching: %w", err) } d.momentoClient.Logger().Debug("computed cache key for item retrieval: %s", cacheKey) - // Make sure we don't cache when trying to do a consistent read - if input.ConsistentRead == nil { - // Try to look up value in momento - rsp, err := d.momentoClient.Get(ctx, &momento.GetRequest{ - CacheName: d.cacheName, - Key: momento.String(cacheKey), - }) - if err != nil { - return middleware.InitializeOutput{}, fmt.Errorf("error looking up item in cache: %w", err) - } - + // Try to look up value in momento + rsp, err := d.momentoClient.Get(ctx, &momento.GetRequest{ + CacheName: d.cacheName, + Key: momento.String(cacheKey), + }) + if err == nil { switch r := rsp.(type) { case *responses.GetHit: // On hit decode value from stored json to DDB attribute map @@ -239,40 +289,92 @@ func (d *cachingMiddleware) handleGetItemCommand(ctx context.Context, input *dyn // Just log on miss d.momentoClient.Logger().Debug("momento lookup did not find key: " + cacheKey) } + } else { + d.momentoClient.Logger().Warn( + fmt.Sprintf("error looking up item in cache err=%+v", err), + ) } d.momentoClient.Logger().Debug("returning DynamoDB response") // On MISS Let middleware chains continue, so we can get result and try to cache it out, _, err := next.HandleInitialize(ctx, in) - if err == nil { + if err == nil && d.writebackType != DISABLED { switch o := out.Result.(type) { case *dynamodb.GetItemOutput: + if d.writebackType == SYNCHRONOUS { + d.writeResultToCache(ctx, o, cacheKey) + } else if d.writebackType == ASYNCHRONOUS { + go d.writeResultToCache(ctx, o, cacheKey) + } + } + } + return out, err +} - // unmarshal raw response object to DDB attribute values map and encode as json - j, err := MarshalToJson(o.Item, d.momentoClient.Logger()) +func (d *cachingMiddleware) writeResultToCache(ctx context.Context, ddbOutput *dynamodb.GetItemOutput, cacheKey string) { + // unmarshal raw response object to DDB attribute values map and encode as json + j, err := MarshalToJson(ddbOutput.Item, d.momentoClient.Logger()) + if err != nil { + d.momentoClient.Logger().Warn(fmt.Sprintf("error marshalling item to json: %+v", err)) + } + + d.momentoClient.Logger().Debug(fmt.Sprintf("caching item with key: %s", cacheKey)) + // set item in momento cache + _, err = d.momentoClient.Set(ctx, &momento.SetRequest{ + CacheName: d.cacheName, + Key: momento.String(cacheKey), + Value: momento.Bytes(j), + }) + if err != nil { + d.momentoClient.Logger().Warn( + fmt.Sprintf("error storing item in cache err=%+v", err), + ) + } +} + +func (d *cachingMiddleware) writeBatchResultsToCache(ctx context.Context, ddbOutput *dynamodb.BatchGetItemOutput, tableToDdbKeys map[string][]string) { + d.momentoClient.Logger().Debug("storing dynamodb items in cache") + var itemsToSet []momento.BatchSetItem + // compute and gather keys and JSON encoded items to store in Momento cache + for tableName, items := range ddbOutput.Responses { + for _, item := range items { + j, err := MarshalToJson(item, d.momentoClient.Logger()) if err != nil { - return out, err // don't return error + d.momentoClient.Logger().Warn(fmt.Sprintf("error marshalling item to json: %+v", err)) + continue } - d.momentoClient.Logger().Debug("caching item with key %s: " + cacheKey) - // set item in momento cache - _, err = d.momentoClient.Set(ctx, &momento.SetRequest{ - CacheName: d.cacheName, - Key: momento.String(cacheKey), - Value: momento.Bytes(j), - }) + // extract the keys from the item to compute the hash key + itemForKey := map[string]types.AttributeValue{} + for _, key := range tableToDdbKeys[tableName] { + itemForKey[key] = item[key] + } + cacheKey, err := ComputeCacheKey(tableName, itemForKey) if err != nil { - d.momentoClient.Logger().Warn( - fmt.Sprintf("error storing item in cache err=%+v", err), - ) - return out, nil // don't return err + d.momentoClient.Logger().Warn(fmt.Sprintf("error getting key for caching: %+v", err)) + continue } + d.momentoClient.Logger().Debug("computed cache key for batch get storage: %s", cacheKey) + + itemsToSet = append(itemsToSet, momento.BatchSetItem{ + Key: momento.String(cacheKey), + Value: momento.Bytes(j), + }) } } + // set item batch in Momento cache + _, err := d.momentoClient.SetBatch(ctx, &momento.SetBatchRequest{ + CacheName: d.cacheName, + Items: itemsToSet, + }) + if err != nil { + d.momentoClient.Logger().Warn( + fmt.Sprintf("error storing item batch in cache err=%+v", err), + ) + } + d.momentoClient.Logger().Debug("stored dynamodb items in cache") - // unsupported output just return output and dont do anything - return out, err } func ComputeCacheKey(tableName string, keys map[string]types.AttributeValue) (string, error) { diff --git a/caching/caching_test.go b/caching/caching_test.go index a37eafa..a1c8e14 100644 --- a/caching/caching_test.go +++ b/caching/caching_test.go @@ -36,22 +36,18 @@ type Movie struct { } var ( - momentoClient momento.CacheClient - ddbClient *dynamodb.Client - tableInfo TableBasics - tableName = "movies" - movie1 = Movie{ - Title: "A Movie Part 1", - Year: 2021, - } - movie2 = Movie{ - Title: "A Movie Part 2", - Year: 2021, - } + momentoClient momento.CacheClient + ddbClient *dynamodb.Client + tableInfo TableBasics + tableName = "movies" + movies []Movie + movie1 Movie + movie2 Movie movie1hash = "1e21f0974977886cb33d2ca173f89cb9c3c1c5e84712ee07d3fab031817751f2" movie2hash = "f334e26f2f40da3172e2dd668a18c58b95b2472a8891ea5a0c63d67ed57c6660" movie1json2022 = "{\"info\":null,\"title\":\"A Movie Part 1\",\"year\":2022}" movie2json2022 = "{\"info\":null,\"title\":\"A Movie Part 2\",\"year\":2022}" + writebackType = SYNCHRONOUS ) func setupTest() func() { @@ -75,25 +71,37 @@ func setupTest() func() { if err != nil { panic(err) } - ddbClient = getDdbClientWithMiddleware(momentoClient) + + // writebackType defaults to synchronous but can be modified before calling `setupTest()` + // you may also instantiate additional clients to test, passing different values for writebackType + // to `getDdbClientWithMiddleware()` + ddbClient = getDdbClientWithMiddleware(momentoClient, &writebackType) amazonConfig := mustGetAWSConfig() ddbControlClient := dynamodb.NewFromConfig(amazonConfig) tableInfo = TableBasics{DynamoDbClient: ddbControlClient, TableName: tableName} + momentoClient.Logger().Debug("Populating DDB with movies") _, err = tableInfo.createTestTable() if err != nil { panic(err) } - for i := 0; i < 20; i++ { - err = tableInfo.addMovie(Movie{ - Title: "A Movie Part " + fmt.Sprint(i), + // insert movies in DDB with year = 2021 + for i := 0; i < 50; i++ { + movie := Movie{ + Title: "A Movie Part " + fmt.Sprint(i+1), Year: 2021, - }) + } + err = tableInfo.addMovie(movie) if err != nil { panic(fmt.Errorf("error adding data: %+v", err)) } + movies = append(movies, movie) } + momentoClient.Logger().Debug("done populating data") + + movie1 = movies[0] + movie2 = movies[1] // teardown function return func() { @@ -105,12 +113,12 @@ func setupTest() func() { panic(err) } momentoClient.Close() + writebackType = SYNCHRONOUS } } -func TestGetItemCacheMiss(t *testing.T) { - defer setupTest()() - +// cache miss tests +func testGetItemCacheMissCommon(t *testing.T) (Movie, responses.GetResponse) { // Execute GetItem Request as you would normally resp, err := ddbClient.GetItem(context.TODO(), &dynamodb.GetItemInput{ TableName: aws.String(tableName), @@ -125,6 +133,7 @@ func TestGetItemCacheMiss(t *testing.T) { t.Errorf("error decoding dynamodb response: %+v", err) } + time.Sleep(1 * time.Second) getResp, err := momentoClient.Get(context.Background(), &momento.GetRequest{ CacheName: tableName, Key: momento.String(movie1hash), @@ -132,6 +141,13 @@ func TestGetItemCacheMiss(t *testing.T) { if err != nil { t.Errorf("error occured calling momento get: %+v", err) } + return movie, getResp +} + +func TestGetItemCacheMiss(t *testing.T) { + defer setupTest()() + + movie, getResp := testGetItemCacheMissCommon(t) switch r := getResp.(type) { case *responses.GetHit: movieInfo, err := getMapFromJsonBytes(r.ValueByte()) @@ -149,7 +165,28 @@ func TestGetItemCacheMiss(t *testing.T) { } } -func TestGetItemHit(t *testing.T) { +func TestGetItemCacheMissAsync(t *testing.T) { + writebackType = ASYNCHRONOUS + TestGetItemCacheMiss(t) +} + +func TestGetItemCacheMissNoWriteback(t *testing.T) { + writebackType = DISABLED + defer setupTest()() + _, getResp := testGetItemCacheMissCommon(t) + switch getResp.(type) { + case *responses.GetHit: + t.Errorf("expected cache miss, got cache hit") + } +} + +// cache hit tests +func TestGetItemCacheHitAsync(t *testing.T) { + writebackType = ASYNCHRONOUS + TestGetItemCacheHit(t) +} + +func TestGetItemCacheHit(t *testing.T) { defer setupTest()() _, err := momentoClient.Set(context.Background(), &momento.SetRequest{ @@ -180,10 +217,11 @@ func TestGetItemHit(t *testing.T) { } } +// cache error test func TestGetItemError(t *testing.T) { defer setupTest()() mmc := &mockMomentoClient{} - ddbClient := getDdbClientWithMiddleware(mmc) + ddbClient := getDdbClientWithMiddleware(mmc, nil) // Execute GetItem Request as you would normally resp, err := ddbClient.GetItem(context.TODO(), &dynamodb.GetItemInput{ @@ -198,11 +236,13 @@ func TestGetItemError(t *testing.T) { if err != nil { t.Errorf("error decoding dynamodb response: %+v", err) } + momentoClient.Logger().Debug(fmt.Sprintf("movie: %+v", movie)) if movie.Year != 2021 { t.Errorf("expected ddb hit year to be 2021: %+v", movie) } } +// batch get tests - hits func TestBatchGetItemAllHits(t *testing.T) { defer setupTest()() @@ -250,7 +290,13 @@ func TestBatchGetItemAllHits(t *testing.T) { } } -func TestBatchGetItemAllMisses(t *testing.T) { +func TestBatchGetItemAllHitsAsync(t *testing.T) { + writebackType = ASYNCHRONOUS + TestBatchGetItemAllHits(t) +} + +// batch get tests - misses +func testBatchGetItemAllMissesCommon(t *testing.T) responses.GetBatchResponse { defer setupTest()() req := &dynamodb.BatchGetItemInput{ @@ -279,6 +325,9 @@ func TestBatchGetItemAllMisses(t *testing.T) { } } + // give the middleware goroutine a little time to finish caching DDB data for the Momento misses + time.Sleep(1 * time.Second) + // make sure results were set in Momento cache getResp, err := momentoClient.GetBatch(context.Background(), &momento.GetBatchRequest{ CacheName: tableName, @@ -290,7 +339,11 @@ func TestBatchGetItemAllMisses(t *testing.T) { if err != nil { t.Errorf("error occured calling momento get: %+v", err) } + return getResp +} +func TestBatchGetItemAllMisses(t *testing.T) { + getResp := testBatchGetItemAllMissesCommon(t) switch r := getResp.(type) { case responses.GetBatchSuccess: for _, element := range r.Results() { @@ -312,9 +365,29 @@ func TestBatchGetItemAllMisses(t *testing.T) { } } -func TestBatchGetItemsMixed(t *testing.T) { - defer setupTest()() +func TestBatchGetItemAllMissesAsync(t *testing.T) { + writebackType = ASYNCHRONOUS + TestBatchGetItemAllMisses(t) +} +func TestBatchGetItemAllMissesNoWriteback(t *testing.T) { + writebackType = DISABLED + getResp := testBatchGetItemAllMissesCommon(t) + switch r := getResp.(type) { + case responses.GetBatchSuccess: + for _, element := range r.Results() { + switch element.(type) { + case *responses.GetHit: + t.Errorf("expected cache hit, got cache miss") + } + } + default: + t.Errorf("unknown get batch response type: %T\n", r) + } +} + +// batch get tests - mixed hits and misses +func testBatchGetItemsMixedCommon(t *testing.T) responses.GetBatchResponse { _, err := momentoClient.Set(context.Background(), &momento.SetRequest{ CacheName: tableName, Key: momento.String(movie1hash), @@ -334,22 +407,30 @@ func TestBatchGetItemsMixed(t *testing.T) { }, }, } + time.Sleep(1 * time.Second) resp, err := ddbClient.BatchGetItem(context.TODO(), req) if err != nil { t.Errorf("error occurred calling batch get item: %+v\n", err) } + for _, items := range resp.Responses { for _, item := range items { movie, err := getMovieFromDdbItem(item) if err != nil { t.Errorf("error decoding dynamodb response: %+v", err) } - if movie.Year != 2021 { + if movie.Title == "A Movie Part 1" && movie.Year != 2022 { + t.Errorf("expected cache hit year to be 2022: %+v", movie) + } + if movie.Title == "A Movie Part 2" && movie.Year != 2021 { t.Errorf("expected ddb hit year to be 2021: %+v", movie) } } } + // give the middleware goroutine a little time to finish caching DDB data for the Momento misses + time.Sleep(500 * time.Millisecond) + // make sure cached versions were overwritten/written getResp, err := momentoClient.GetBatch(context.Background(), &momento.GetBatchRequest{ CacheName: tableName, @@ -361,6 +442,12 @@ func TestBatchGetItemsMixed(t *testing.T) { if err != nil { t.Errorf("error occured calling momento get: %+v", err) } + return getResp +} + +func TestBatchGetItemsMixed(t *testing.T) { + defer setupTest()() + getResp := testBatchGetItemsMixedCommon(t) switch r := getResp.(type) { case responses.GetBatchSuccess: for _, element := range r.Results() { @@ -370,9 +457,12 @@ func TestBatchGetItemsMixed(t *testing.T) { if err != nil { t.Errorf("error decoding cache hit: %+v", err) } - if fmt.Sprint(movieInfo["year"]) != fmt.Sprint(2021) { + if movieInfo["title"] == "A Movie Part 1" && fmt.Sprint(movieInfo["year"]) != fmt.Sprint(2022) { t.Errorf("expected cache hit year to match ddb response: %+v", movieInfo) } + if movieInfo["title"] == "A Movie Part 2" && fmt.Sprint(movieInfo["year"]) != fmt.Sprint(2021) { + t.Errorf("expected ddb hit year: %+v", movieInfo) + } case *responses.GetMiss: t.Errorf("expected cache hit, got cache miss") } @@ -380,10 +470,37 @@ func TestBatchGetItemsMixed(t *testing.T) { } } +func TestBatchGetItemsMixedAsync(t *testing.T) { + writebackType = ASYNCHRONOUS + TestBatchGetItemsMixed(t) +} + +func TestBatchGetItemsMixedNoWriteback(t *testing.T) { + writebackType = DISABLED + defer setupTest()() + getResp := testBatchGetItemsMixedCommon(t) + switch r := getResp.(type) { + case responses.GetBatchSuccess: + for _, element := range r.Results() { + switch e := element.(type) { + case *responses.GetHit: + movieInfo, err := getMapFromJsonBytes(e.ValueByte()) + if err != nil { + t.Errorf("error decoding cache hit: %+v", err) + } + if movieInfo["title"] != "A Movie Part 1" { + t.Errorf("expected cache miss but got: %+v", movieInfo) + } + } + } + } +} + +// batch get test with error func TestBatchGetItemsError(t *testing.T) { defer setupTest()() mmc := &mockMomentoClient{} - ddbClient := getDdbClientWithMiddleware(mmc) + ddbClient := getDdbClientWithMiddleware(mmc, nil) req := &dynamodb.BatchGetItemInput{ RequestItems: map[string]types.KeysAndAttributes{ @@ -469,9 +586,18 @@ func getMovieFromDdbItem(item map[string]types.AttributeValue) (Movie, error) { return movie, nil } -func getDdbClientWithMiddleware(momentoClient momento.CacheClient) *dynamodb.Client { +func getDdbClientWithMiddleware(momentoClient momento.CacheClient, writebackType *WritebackType) *dynamodb.Client { amazonConfiguration := mustGetAWSConfig() - AttachNewCachingMiddleware(&amazonConfiguration, tableName, momentoClient) + var wb WritebackType + if writebackType != nil { + wb = *writebackType + } + AttachNewCachingMiddleware(MiddlewareProps{ + &amazonConfiguration, + tableName, + momentoClient, + wb, + }) return dynamodb.NewFromConfig(amazonConfiguration) }