diff --git a/internal/internal_workflow_test.go b/internal/internal_workflow_test.go index efa6912d2..94d40bd46 100644 --- a/internal/internal_workflow_test.go +++ b/internal/internal_workflow_test.go @@ -1406,6 +1406,19 @@ type updateCallback struct { accept func() reject func(error) complete func(interface{}, error) + env *TestWorkflowEnvironment + updateID string +} + +// env and updateID are needed to cache update results for deduping purposes +func newUpdateCallback(env *TestWorkflowEnvironment, updateID string, accept func(), reject func(error), complete func(interface{}, error)) *updateCallback { + return &updateCallback{ + accept: accept, + reject: reject, + complete: complete, + env: env, + updateID: updateID, + } } func (uc *updateCallback) Accept() { @@ -1417,6 +1430,11 @@ func (uc *updateCallback) Reject(err error) { } func (uc *updateCallback) Complete(success interface{}, err error) { + // cache update result so we can dedup duplicate update IDs + if uc.env == nil { + panic("env is needed in updateCallback to cache update results for deduping purposes") + } + uc.env.impl.updateMap[uc.env.impl.currentUpdateId] = updateResult{success, err} uc.complete(success, err) } diff --git a/internal/internal_workflow_testsuite.go b/internal/internal_workflow_testsuite.go index 3b9742f73..d1bb780ad 100644 --- a/internal/internal_workflow_testsuite.go +++ b/internal/internal_workflow_testsuite.go @@ -143,6 +143,11 @@ type ( taskQueues map[string]struct{} } + updateResult struct { + success interface{} + err error + } + // testWorkflowEnvironmentShared is the shared data between parent workflow and child workflow test environments testWorkflowEnvironmentShared struct { locker sync.Mutex @@ -208,6 +213,7 @@ type ( signalHandler func(name string, input *commonpb.Payloads, header *commonpb.Header) error queryHandler func(string, *commonpb.Payloads, *commonpb.Header) (*commonpb.Payloads, error) updateHandler func(name string, id string, input *commonpb.Payloads, header *commonpb.Header, resp UpdateCallbacks) + updateMap map[string]updateResult startedHandler func(r WorkflowExecution, e error) isWorkflowCompleted bool @@ -229,6 +235,8 @@ type ( workflowFunctionExecuting bool bufferedUpdateRequests map[string][]func() + + currentUpdateId string } testSessionEnvironmentImpl struct { @@ -2179,6 +2187,7 @@ func (env *testWorkflowEnvironmentImpl) RegisterUpdateHandler( handler func(name string, id string, input *commonpb.Payloads, header *commonpb.Header, resp UpdateCallbacks), ) { env.updateHandler = handler + env.updateMap = make(map[string]updateResult) } func (env *testWorkflowEnvironmentImpl) RegisterQueryHandler( @@ -2732,10 +2741,24 @@ func (env *testWorkflowEnvironmentImpl) updateWorkflow(name string, id string, u if err != nil { panic(err) } - env.postCallback(func() { - // Do not send any headers on test invocations - env.updateHandler(name, id, data, nil, uc) - }, true) + + // check for duplicate update ID + if _, ok := env.updateMap[id]; ok { + // return cached result + env.postCallback(func() { + uc.Accept() + uc.Complete(env.updateMap[id].success, env.updateMap[id].err) + }, false) + } else { + // TODO: This doesn't account for multiple async updates + // would a UC -> ID map work? Would I have to use pointers? + env.currentUpdateId = id + env.postCallback(func() { + // Do not send any headers on test invocations + env.updateHandler(name, id, data, nil, uc) + }, true) + } + } func (env *testWorkflowEnvironmentImpl) updateWorkflowByID(workflowID, name, id string, uc UpdateCallbacks, args ...interface{}) error { @@ -2747,6 +2770,7 @@ func (env *testWorkflowEnvironmentImpl) updateWorkflowByID(workflowID, name, id if err != nil { panic(err) } + // TODO: handle dedup workflowHandle.env.postCallback(func() { workflowHandle.env.updateHandler(name, id, data, nil, uc) }, true) diff --git a/internal/workflow_testsuite_test.go b/internal/workflow_testsuite_test.go index 3fc46146b..482f506f9 100644 --- a/internal/workflow_testsuite_test.go +++ b/internal/workflow_testsuite_test.go @@ -273,14 +273,17 @@ func TestWorkflowIDUpdateWorkflowByID(t *testing.T) { var suite WorkflowTestSuite // Test UpdateWorkflowByID works with custom ID env := suite.NewTestWorkflowEnvironment() + updateID := "id" env.RegisterDelayedCallback(func() { - err := env.UpdateWorkflowByID("my-workflow-id", "update", "id", &updateCallback{ - reject: func(err error) { + err := env.UpdateWorkflowByID("my-workflow-id", "update", updateID, newUpdateCallback( + env, + updateID, + func() {}, + func(err error) { require.Fail(t, "update should not be rejected") }, - accept: func() {}, - complete: func(interface{}, error) {}, - }, "input") + func(interface{}, error) {}, + ), "input") require.NoError(t, err) }, time.Second) @@ -322,6 +325,7 @@ func TestChildWorkflowUpdate(t *testing.T) { require.Fail(t, "update failed", err) } }, + env: env, }, nil) assert.NoError(t, err) }, time.Second*5) @@ -375,6 +379,7 @@ func TestWorkflowUpdateOrder(t *testing.T) { }, accept: func() {}, complete: func(interface{}, error) {}, + env: env, }) }, 0) @@ -415,6 +420,7 @@ func TestWorkflowNotRegisteredRejected(t *testing.T) { require.Fail(t, "update should not be accepted") }, complete: func(interface{}, error) {}, + env: env, }) }, 0) @@ -439,6 +445,7 @@ func TestWorkflowUpdateOrderAcceptReject(t *testing.T) { }, accept: func() {}, complete: func(interface{}, error) {}, + env: env, }) }, 0) @@ -452,6 +459,7 @@ func TestWorkflowUpdateOrderAcceptReject(t *testing.T) { require.Fail(t, "update should not be rejected") }, complete: func(interface{}, error) {}, + env: env, }) }, 0) @@ -462,6 +470,7 @@ func TestWorkflowUpdateOrderAcceptReject(t *testing.T) { }, accept: func() {}, complete: func(interface{}, error) {}, + env: env, }) }, 0) @@ -491,6 +500,65 @@ func TestWorkflowUpdateOrderAcceptReject(t *testing.T) { require.Equal(t, "unknown update bad update. KnownUpdates=[update]", updateRejectionErr.Error()) } +func TestWorkflowDuplicateIDDedup(t *testing.T) { + var suite WorkflowTestSuite + // Test dev server dedups UpdateWorkflow with same ID + env := suite.NewTestWorkflowEnvironment() + env.RegisterDelayedCallback(func() { + env.UpdateWorkflow("update", "id", &updateCallback{ + reject: func(err error) { + require.Fail(t, fmt.Sprintf("update should not be rejected, err: %v", err)) + }, + accept: func() { + }, + complete: func(result interface{}, err error) { + intResult, ok := result.(int) + if !ok { + require.Fail(t, "result should be int") + } else { + require.Equal(t, 0, intResult) + } + }, + env: env, + updateID: "id", + }, 0) + }, 0) + + env.RegisterDelayedCallback(func() { + env.UpdateWorkflow("update", "id", &updateCallback{ + reject: func(err error) { + require.Fail(t, fmt.Sprintf("update should not be rejected, err: %v", err)) + }, + accept: func() { + }, + complete: func(result interface{}, err error) { + intResult, ok := result.(int) + if !ok { + require.Fail(t, "result should be int") + } else { + // if dedup, this be okay, even if we pass in 1 as arg, since it's deduping, + // the result should match the first update's result, 0 + require.Equal(t, 0, intResult) + } + }, + env: env, + }, 1) + + }, 1*time.Millisecond) + + env.ExecuteWorkflow(func(ctx Context) error { + err := SetUpdateHandler(ctx, "update", func(ctx Context, i int) (int, error) { + return i, nil + }, UpdateHandlerOptions{}) + if err != nil { + return err + } + return Sleep(ctx, time.Hour) + }) + require.NoError(t, env.GetWorkflowError()) + require.True(t, false) +} + func TestAllHandlersFinished(t *testing.T) { var suite WorkflowTestSuite env := suite.NewTestWorkflowEnvironment() @@ -502,6 +570,7 @@ func TestAllHandlersFinished(t *testing.T) { }, accept: func() {}, complete: func(interface{}, error) {}, + env: env, }) }, 0) @@ -512,6 +581,7 @@ func TestAllHandlersFinished(t *testing.T) { }, accept: func() {}, complete: func(interface{}, error) {}, + env: env, }) }, time.Minute) @@ -576,6 +646,7 @@ func TestWorkflowAllHandlersFinished(t *testing.T) { }, accept: func() {}, complete: func(interface{}, error) {}, + env: env, }) }, 0) @@ -586,6 +657,7 @@ func TestWorkflowAllHandlersFinished(t *testing.T) { }, accept: func() {}, complete: func(interface{}, error) {}, + env: env, }) }, time.Minute) @@ -596,6 +668,7 @@ func TestWorkflowAllHandlersFinished(t *testing.T) { }, accept: func() {}, complete: func(interface{}, error) {}, + env: env, }) }, 2*time.Minute) @@ -733,6 +806,7 @@ func TestWorkflowUpdateLogger(t *testing.T) { }, accept: func() {}, complete: func(interface{}, error) {}, + env: env, }) }, 0)