From b8b14483744be73d138f7d58fdcb3a2829d5b901 Mon Sep 17 00:00:00 2001 From: Alex Shtin Date: Tue, 4 Feb 2020 17:20:45 -0800 Subject: [PATCH] Fix YARPC context propagation (#116) --- service/frontend/workflowHandler.go | 78 ++++++++++++++---------- service/frontend/workflowHandler_test.go | 7 ++- 2 files changed, 49 insertions(+), 36 deletions(-) diff --git a/service/frontend/workflowHandler.go b/service/frontend/workflowHandler.go index 948017ffb2a..0393ce9d63f 100644 --- a/service/frontend/workflowHandler.go +++ b/service/frontend/workflowHandler.go @@ -397,7 +397,7 @@ func (wh *WorkflowHandler) PollForActivityTask( DomainUUID: common.StringPtr(domainID), PollerID: common.StringPtr(pollerID), PollRequest: pollRequest, - }) + }, versionHeaders(ctx)...) return err } @@ -487,7 +487,7 @@ func (wh *WorkflowHandler) PollForDecisionTask( DomainUUID: common.StringPtr(domainID), PollerID: common.StringPtr(pollerID), PollRequest: pollRequest, - }) + }, versionHeaders(ctx)...) return err } @@ -547,7 +547,7 @@ func (wh *WorkflowHandler) cancelOutstandingPoll(ctx context.Context, err error, TaskListType: common.Int32Ptr(taskListType), TaskList: taskList, PollerID: common.StringPtr(pollerID), - }) + }, versionHeaders(ctx)...) // We can not do much if this call fails. Just log the error and move on if err != nil { wh.GetLogger().Warn("Failed to cancel outstanding poller.", @@ -629,7 +629,7 @@ func (wh *WorkflowHandler) RecordActivityTaskHeartbeat( err = wh.GetHistoryClient().RespondActivityTaskFailed(ctx, &h.RespondActivityTaskFailedRequest{ DomainUUID: common.StringPtr(taskToken.DomainID), FailedRequest: failRequest, - }) + }, versionHeaders(ctx)...) if err != nil { return nil, wh.error(err, scope) } @@ -638,7 +638,7 @@ func (wh *WorkflowHandler) RecordActivityTaskHeartbeat( resp, err = wh.GetHistoryClient().RecordActivityTaskHeartbeat(ctx, &h.RecordActivityTaskHeartbeatRequest{ DomainUUID: common.StringPtr(taskToken.DomainID), HeartbeatRequest: heartbeatRequest, - }) + }, versionHeaders(ctx)...) if err != nil { return nil, wh.error(err, scope) } @@ -730,7 +730,7 @@ func (wh *WorkflowHandler) RecordActivityTaskHeartbeatByID( err = wh.GetHistoryClient().RespondActivityTaskFailed(ctx, &h.RespondActivityTaskFailedRequest{ DomainUUID: common.StringPtr(taskToken.DomainID), FailedRequest: failRequest, - }) + }, versionHeaders(ctx)...) if err != nil { return nil, wh.error(err, scope) } @@ -745,7 +745,7 @@ func (wh *WorkflowHandler) RecordActivityTaskHeartbeatByID( resp, err = wh.GetHistoryClient().RecordActivityTaskHeartbeat(ctx, &h.RecordActivityTaskHeartbeatRequest{ DomainUUID: common.StringPtr(taskToken.DomainID), HeartbeatRequest: req, - }) + }, versionHeaders(ctx)...) if err != nil { return nil, wh.error(err, scope) } @@ -823,7 +823,7 @@ func (wh *WorkflowHandler) RespondActivityTaskCompleted( err = wh.GetHistoryClient().RespondActivityTaskFailed(ctx, &h.RespondActivityTaskFailedRequest{ DomainUUID: common.StringPtr(taskToken.DomainID), FailedRequest: failRequest, - }) + }, versionHeaders(ctx)...) if err != nil { return wh.error(err, scope) } @@ -831,7 +831,7 @@ func (wh *WorkflowHandler) RespondActivityTaskCompleted( err = wh.GetHistoryClient().RespondActivityTaskCompleted(ctx, &h.RespondActivityTaskCompletedRequest{ DomainUUID: common.StringPtr(taskToken.DomainID), CompleteRequest: completeRequest, - }) + }, versionHeaders(ctx)...) if err != nil { return wh.error(err, scope) } @@ -926,7 +926,7 @@ func (wh *WorkflowHandler) RespondActivityTaskCompletedByID( err = wh.GetHistoryClient().RespondActivityTaskFailed(ctx, &h.RespondActivityTaskFailedRequest{ DomainUUID: common.StringPtr(taskToken.DomainID), FailedRequest: failRequest, - }) + }, versionHeaders(ctx)...) if err != nil { return wh.error(err, scope) } @@ -940,7 +940,7 @@ func (wh *WorkflowHandler) RespondActivityTaskCompletedByID( err = wh.GetHistoryClient().RespondActivityTaskCompleted(ctx, &h.RespondActivityTaskCompletedRequest{ DomainUUID: common.StringPtr(taskToken.DomainID), CompleteRequest: req, - }) + }, versionHeaders(ctx)...) if err != nil { return wh.error(err, scope) } @@ -1017,7 +1017,7 @@ func (wh *WorkflowHandler) RespondActivityTaskFailed( err = wh.GetHistoryClient().RespondActivityTaskFailed(ctx, &h.RespondActivityTaskFailedRequest{ DomainUUID: common.StringPtr(taskToken.DomainID), FailedRequest: failedRequest, - }) + }, versionHeaders(ctx)...) if err != nil { return wh.error(err, scope) } @@ -1114,7 +1114,7 @@ func (wh *WorkflowHandler) RespondActivityTaskFailedByID( err = wh.GetHistoryClient().RespondActivityTaskFailed(ctx, &h.RespondActivityTaskFailedRequest{ DomainUUID: common.StringPtr(taskToken.DomainID), FailedRequest: req, - }) + }, versionHeaders(ctx)...) if err != nil { return wh.error(err, scope) } @@ -1191,7 +1191,7 @@ func (wh *WorkflowHandler) RespondActivityTaskCanceled( err = wh.GetHistoryClient().RespondActivityTaskFailed(ctx, &h.RespondActivityTaskFailedRequest{ DomainUUID: common.StringPtr(taskToken.DomainID), FailedRequest: failRequest, - }) + }, versionHeaders(ctx)...) if err != nil { return wh.error(err, scope) } @@ -1199,7 +1199,7 @@ func (wh *WorkflowHandler) RespondActivityTaskCanceled( err = wh.GetHistoryClient().RespondActivityTaskCanceled(ctx, &h.RespondActivityTaskCanceledRequest{ DomainUUID: common.StringPtr(taskToken.DomainID), CancelRequest: cancelRequest, - }) + }, versionHeaders(ctx)...) if err != nil { return wh.error(err, scope) } @@ -1293,7 +1293,7 @@ func (wh *WorkflowHandler) RespondActivityTaskCanceledByID( err = wh.GetHistoryClient().RespondActivityTaskFailed(ctx, &h.RespondActivityTaskFailedRequest{ DomainUUID: common.StringPtr(taskToken.DomainID), FailedRequest: failRequest, - }) + }, versionHeaders(ctx)...) if err != nil { return wh.error(err, scope) } @@ -1307,7 +1307,7 @@ func (wh *WorkflowHandler) RespondActivityTaskCanceledByID( err = wh.GetHistoryClient().RespondActivityTaskCanceled(ctx, &h.RespondActivityTaskCanceledRequest{ DomainUUID: common.StringPtr(taskToken.DomainID), CancelRequest: req, - }) + }, versionHeaders(ctx)...) if err != nil { return wh.error(err, scope) } @@ -1362,6 +1362,7 @@ func (wh *WorkflowHandler) RespondDecisionTaskCompleted( histResp, err := wh.GetHistoryClient().RespondDecisionTaskCompleted(ctx, &h.RespondDecisionTaskCompletedRequest{ DomainUUID: common.StringPtr(taskToken.DomainID), CompleteRequest: completeRequest}, + versionHeaders(ctx)..., ) if err != nil { return nil, wh.error(err, scope) @@ -1464,7 +1465,7 @@ func (wh *WorkflowHandler) RespondDecisionTaskFailed( err = wh.GetHistoryClient().RespondDecisionTaskFailed(ctx, &h.RespondDecisionTaskFailedRequest{ DomainUUID: common.StringPtr(taskToken.DomainID), FailedRequest: failedRequest, - }) + }, versionHeaders(ctx)...) if err != nil { return wh.error(err, scope) } @@ -1547,7 +1548,7 @@ func (wh *WorkflowHandler) RespondQueryTaskCompleted( CompletedRequest: completeRequest, } - err = wh.GetMatchingClient().RespondQueryTaskCompleted(ctx, matchingRequest) + err = wh.GetMatchingClient().RespondQueryTaskCompleted(ctx, matchingRequest, versionHeaders(ctx)...) if err != nil { return wh.error(err, scope) } @@ -1692,7 +1693,7 @@ func (wh *WorkflowHandler) StartWorkflowExecution( } wh.GetLogger().Debug("Start workflow execution request domainID", tag.WorkflowDomainID(domainID)) - resp, err = wh.GetHistoryClient().StartWorkflowExecution(ctx, common.CreateHistoryStartWorkflowRequest(domainID, startRequest)) + resp, err = wh.GetHistoryClient().StartWorkflowExecution(ctx, common.CreateHistoryStartWorkflowRequest(domainID, startRequest), versionHeaders(ctx)...) if err != nil { return nil, wh.error(err, scope) @@ -1763,7 +1764,7 @@ func (wh *WorkflowHandler) GetWorkflowExecutionRawHistory( Execution: execution, ExpectedNextEventId: nil, CurrentBranchToken: currentBranchToken, - }) + }, versionHeaders(ctx)...) if err != nil { return nil, "", 0, err @@ -1919,7 +1920,7 @@ func (wh *WorkflowHandler) GetWorkflowExecutionHistory( Execution: execution, ExpectedNextEventId: common.Int64Ptr(expectedNextEventID), CurrentBranchToken: currentBranchToken, - }) + }, versionHeaders(ctx)...) if err != nil { return nil, "", 0, 0, false, err @@ -2131,7 +2132,7 @@ func (wh *WorkflowHandler) SignalWorkflowExecution( err = wh.GetHistoryClient().SignalWorkflowExecution(ctx, &h.SignalWorkflowExecutionRequest{ DomainUUID: common.StringPtr(domainID), SignalRequest: signalRequest, - }) + }, versionHeaders(ctx)...) if err != nil { return wh.error(err, scope) } @@ -2266,7 +2267,7 @@ func (wh *WorkflowHandler) SignalWithStartWorkflowExecution( resp, err = wh.GetHistoryClient().SignalWithStartWorkflowExecution(ctx, &h.SignalWithStartWorkflowExecutionRequest{ DomainUUID: common.StringPtr(domainID), SignalWithStartRequest: signalWithStartRequest, - }) + }, versionHeaders(ctx)...) return err } @@ -2317,7 +2318,7 @@ func (wh *WorkflowHandler) TerminateWorkflowExecution( err = wh.GetHistoryClient().TerminateWorkflowExecution(ctx, &h.TerminateWorkflowExecutionRequest{ DomainUUID: common.StringPtr(domainID), TerminateRequest: terminateRequest, - }) + }, versionHeaders(ctx)...) if err != nil { return wh.error(err, scope) } @@ -2364,7 +2365,7 @@ func (wh *WorkflowHandler) ResetWorkflowExecution( resp, err = wh.GetHistoryClient().ResetWorkflowExecution(ctx, &h.ResetWorkflowExecutionRequest{ DomainUUID: common.StringPtr(domainID), ResetRequest: resetRequest, - }) + }, versionHeaders(ctx)...) if err != nil { return nil, wh.error(err, scope) } @@ -2410,7 +2411,7 @@ func (wh *WorkflowHandler) RequestCancelWorkflowExecution( err = wh.GetHistoryClient().RequestCancelWorkflowExecution(ctx, &h.RequestCancelWorkflowExecutionRequest{ DomainUUID: common.StringPtr(domainID), CancelRequest: cancelRequest, - }) + }, versionHeaders(ctx)...) if err != nil { return wh.error(err, scope) } @@ -2973,7 +2974,7 @@ func (wh *WorkflowHandler) ResetStickyTaskList( _, err = wh.GetHistoryClient().ResetStickyTaskList(ctx, &h.ResetStickyTaskListRequest{ DomainUUID: common.StringPtr(domainID), Execution: resetRequest.Execution, - }) + }, versionHeaders(ctx)...) if err != nil { return nil, wh.error(err, scope) } @@ -3041,7 +3042,7 @@ func (wh *WorkflowHandler) QueryWorkflow( DomainUUID: common.StringPtr(domainID), Request: queryRequest, } - hResponse, err := wh.GetHistoryClient().QueryWorkflow(ctx, req) + hResponse, err := wh.GetHistoryClient().QueryWorkflow(ctx, req, versionHeaders(ctx)...) if err != nil { return nil, wh.error(err, scope) } @@ -3085,7 +3086,7 @@ func (wh *WorkflowHandler) DescribeWorkflowExecution( response, err := wh.GetHistoryClient().DescribeWorkflowExecution(ctx, &h.DescribeWorkflowExecutionRequest{ DomainUUID: common.StringPtr(domainID), Request: request, - }) + }, versionHeaders(ctx)...) if err != nil { return nil, wh.error(err, scope) @@ -3140,7 +3141,7 @@ func (wh *WorkflowHandler) DescribeTaskList( response, err = wh.GetMatchingClient().DescribeTaskList(ctx, &m.DescribeTaskListRequest{ DomainUUID: common.StringPtr(domainID), DescRequest: request, - }) + }, versionHeaders(ctx)...) return err } @@ -3178,7 +3179,7 @@ func (wh *WorkflowHandler) ListTaskListPartitions(ctx context.Context, request * resp, err := wh.GetMatchingClient().ListTaskListPartitions(ctx, &m.ListTaskListPartitionsRequest{ Domain: request.Domain, TaskList: request.TaskList, - }) + }, versionHeaders(ctx)...) return resp, err } @@ -3651,7 +3652,7 @@ func (wh *WorkflowHandler) historyArchived(ctx context.Context, request *gen.Get DomainUUID: common.StringPtr(domainID), Execution: request.Execution, } - _, err := wh.GetHistoryClient().GetMutableState(ctx, getMutableStateRequest) + _, err := wh.GetHistoryClient().GetMutableState(ctx, getMutableStateRequest, versionHeaders(ctx)...) if err == nil { return false } @@ -3777,3 +3778,14 @@ type domainWrapper struct { func (d domainWrapper) GetDomain() string { return d.domain } + +// TODO: Remove this func after history and matching services gRPC migration is complete +// It sets version headers in YARPC format +func versionHeaders(ctx context.Context) []yarpc.CallOption { + headers := client.GetHeadersValue(ctx, common.LibraryVersionHeaderName, common.FeatureVersionHeaderName, common.ClientImplHeaderName) + return []yarpc.CallOption{ + yarpc.WithHeader(common.LibraryVersionHeaderName, headers[0]), + yarpc.WithHeader(common.FeatureVersionHeaderName, headers[1]), + yarpc.WithHeader(common.ClientImplHeaderName, headers[2]), + } +} diff --git a/service/frontend/workflowHandler_test.go b/service/frontend/workflowHandler_test.go index 15170fc2fc7..ef61032537f 100644 --- a/service/frontend/workflowHandler_test.go +++ b/service/frontend/workflowHandler_test.go @@ -834,7 +834,8 @@ func (s *workflowHandlerSuite) TestHistoryArchived() { } s.False(wh.historyArchived(context.Background(), getHistoryRequest, "test-domain")) - s.mockHistoryClient.EXPECT().GetMutableState(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1) + // TODO: remove last 3 `gomock.Any()` after YARPC migration + s.mockHistoryClient.EXPECT().GetMutableState(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).Times(1) getHistoryRequest = &shared.GetWorkflowExecutionHistoryRequest{ Execution: &shared.WorkflowExecution{ WorkflowId: common.StringPtr(testWorkflowID), @@ -843,7 +844,7 @@ func (s *workflowHandlerSuite) TestHistoryArchived() { } s.False(wh.historyArchived(context.Background(), getHistoryRequest, "test-domain")) - s.mockHistoryClient.EXPECT().GetMutableState(gomock.Any(), gomock.Any()).Return(nil, &shared.EntityNotExistsError{Message: "got archival indication error"}).Times(1) + s.mockHistoryClient.EXPECT().GetMutableState(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, &shared.EntityNotExistsError{Message: "got archival indication error"}).Times(1) getHistoryRequest = &shared.GetWorkflowExecutionHistoryRequest{ Execution: &shared.WorkflowExecution{ WorkflowId: common.StringPtr(testWorkflowID), @@ -852,7 +853,7 @@ func (s *workflowHandlerSuite) TestHistoryArchived() { } s.True(wh.historyArchived(context.Background(), getHistoryRequest, "test-domain")) - s.mockHistoryClient.EXPECT().GetMutableState(gomock.Any(), gomock.Any()).Return(nil, errors.New("got non-archival indication error")).Times(1) + s.mockHistoryClient.EXPECT().GetMutableState(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errors.New("got non-archival indication error")).Times(1) getHistoryRequest = &shared.GetWorkflowExecutionHistoryRequest{ Execution: &shared.WorkflowExecution{ WorkflowId: common.StringPtr(testWorkflowID),