From 33bb7990e8be4896f0c7129c8f1de7c6b96d5981 Mon Sep 17 00:00:00 2001 From: rodrigozhou Date: Mon, 28 Oct 2024 13:39:20 -0700 Subject: [PATCH] address comments --- internal/internal_workflow_testsuite.go | 86 +++++++++++++++---------- test/nexus_test.go | 23 +++++++ 2 files changed, 74 insertions(+), 35 deletions(-) diff --git a/internal/internal_workflow_testsuite.go b/internal/internal_workflow_testsuite.go index a628f86e8..7cbf77552 100644 --- a/internal/internal_workflow_testsuite.go +++ b/internal/internal_workflow_testsuite.go @@ -2481,25 +2481,8 @@ func (env *testWorkflowEnvironmentImpl) ExecuteNexusOperation( handle.startedCallback(opID, nil) if handle.cancelRequested { handle.cancel() - } else { - completionHandle := env.getNexusAsyncOperationCompletionHandle( - handle.params.client.Service(), - handle.params.operation, - opID, - ) - if completionHandle != nil { - env.deleteNexusAsyncOperationCompletionHandle( - handle.params.client.Service(), - handle.params.operation, - opID, - ) - env.registerDelayedCallback( - func() { - env.resolveNexusOperation(seq, completionHandle.result, completionHandle.err) - }, - completionHandle.delay, - ) - } + } else if handle.isMocked { + env.scheduleNexusAsyncOperationCompletion(handle) } }, true) case *nexuspb.StartOperationResponse_OperationError: @@ -2569,11 +2552,6 @@ func (env *testWorkflowEnvironmentImpl) RegisterNexusAsyncOperationCompletion( } } - if err != nil { - // The handler workflow error needs to wrapped so it can be passed to the caller correctly. - err = NewApplicationError(err.Error(), "", true, err) - } - // Getting the locker to prevent race condition if this function is called while // the test env is already running. env.locker.Lock() @@ -2619,6 +2597,44 @@ func (env *testWorkflowEnvironmentImpl) deleteNexusAsyncOperationCompletionHandl delete(env.nexusAsyncOpHandle, uniqueOpID) } +func (env *testWorkflowEnvironmentImpl) scheduleNexusAsyncOperationCompletion( + handle *testNexusOperationHandle, +) { + completionHandle := env.getNexusAsyncOperationCompletionHandle( + handle.params.client.Service(), + handle.params.operation, + handle.operationID, + ) + if completionHandle == nil { + return + } + env.deleteNexusAsyncOperationCompletionHandle( + handle.params.client.Service(), + handle.params.operation, + handle.operationID, + ) + var nexusErr error + if completionHandle.err != nil { + nexusErr = env.failureConverter.FailureToError(nexusOperationFailure( + handle.params, + handle.operationID, + &failurepb.Failure{ + Message: completionHandle.err.Error(), + FailureInfo: &failurepb.Failure_ApplicationFailureInfo{ + ApplicationFailureInfo: &failurepb.ApplicationFailureInfo{ + NonRetryable: true, + }, + }, + }, + )) + } + env.registerDelayedCallback(func() { + env.postCallback(func() { + handle.completedCallback(completionHandle.result, nexusErr) + }, true) + }, completionHandle.delay) +} + func (env *testWorkflowEnvironmentImpl) resolveNexusOperation(seq int64, result *commonpb.Payload, err error) { env.postCallback(func() { handle, ok := env.getNexusOperationHandle(seq) @@ -2632,17 +2648,6 @@ func (env *testWorkflowEnvironmentImpl) resolveNexusOperation(seq int64, result } else { handle.completedCallback(result, nil) } - if env.onNexusOperationCompletedListener != nil { - env.onNexusOperationCompletedListener( - handle.params.client.Service(), - handle.params.operation, - newEncodedValue( - &commonpb.Payloads{Payloads: []*commonpb.Payload{result}}, - env.GetDataConverter(), - ), - err, - ) - } }, true) } @@ -3114,6 +3119,17 @@ func (h *testNexusOperationHandle) completedCallback(result *commonpb.Payload, e h.done = true h.env.deleteNexusOperationHandle(h.seq) h.onCompleted(result, err) + if h.env.onNexusOperationCompletedListener != nil { + h.env.onNexusOperationCompletedListener( + h.params.client.Service(), + h.params.operation, + newEncodedValue( + &commonpb.Payloads{Payloads: []*commonpb.Payload{result}}, + h.env.GetDataConverter(), + ), + err, + ) + } } // startedCallback is a callback registered to handle operation start. diff --git a/test/nexus_test.go b/test/nexus_test.go index b9d8f3f08..97c51b1d1 100644 --- a/test/nexus_test.go +++ b/test/nexus_test.go @@ -1323,6 +1323,29 @@ func TestWorkflowTestSuite_MockNexusOperation(t *testing.T) { require.Equal(t, "fake result", res) }) + t.Run("mock operation reference existing service", func(t *testing.T) { + suite := testsuite.WorkflowTestSuite{} + env := suite.NewTestWorkflowEnvironment() + env.RegisterNexusService(service) + env.OnNexusOperation( + serviceName, + nexus.NewOperationReference[string, string](dummyOpName), + "Temporal", + mock.Anything, + ).Return( + &nexus.HandlerStartOperationResultSync[string]{ + Value: "fake result", + }, + nil, + ) + env.ExecuteWorkflow(wf, "Temporal") + require.True(t, env.IsWorkflowCompleted()) + require.NoError(t, env.GetWorkflowError()) + var res string + require.NoError(t, env.GetWorkflowResult(&res)) + require.Equal(t, "fake result", res) + }) + t.Run("mock error operation", func(t *testing.T) { suite := testsuite.WorkflowTestSuite{} env := suite.NewTestWorkflowEnvironment()