Skip to content

Commit

Permalink
Fix YARPC context propagation (#116)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexshtin authored Feb 5, 2020
1 parent 05d04dd commit b8b1448
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 36 deletions.
78 changes: 45 additions & 33 deletions service/frontend/workflowHandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ func (wh *WorkflowHandler) PollForActivityTask(
DomainUUID: common.StringPtr(domainID),
PollerID: common.StringPtr(pollerID),
PollRequest: pollRequest,
})
}, versionHeaders(ctx)...)
return err
}

Expand Down Expand Up @@ -487,7 +487,7 @@ func (wh *WorkflowHandler) PollForDecisionTask(
DomainUUID: common.StringPtr(domainID),
PollerID: common.StringPtr(pollerID),
PollRequest: pollRequest,
})
}, versionHeaders(ctx)...)
return err
}

Expand Down Expand Up @@ -547,7 +547,7 @@ func (wh *WorkflowHandler) cancelOutstandingPoll(ctx context.Context, err error,
TaskListType: common.Int32Ptr(taskListType),
TaskList: taskList,
PollerID: common.StringPtr(pollerID),
})
}, versionHeaders(ctx)...)
// We can not do much if this call fails. Just log the error and move on
if err != nil {
wh.GetLogger().Warn("Failed to cancel outstanding poller.",
Expand Down Expand Up @@ -629,7 +629,7 @@ func (wh *WorkflowHandler) RecordActivityTaskHeartbeat(
err = wh.GetHistoryClient().RespondActivityTaskFailed(ctx, &h.RespondActivityTaskFailedRequest{
DomainUUID: common.StringPtr(taskToken.DomainID),
FailedRequest: failRequest,
})
}, versionHeaders(ctx)...)
if err != nil {
return nil, wh.error(err, scope)
}
Expand All @@ -638,7 +638,7 @@ func (wh *WorkflowHandler) RecordActivityTaskHeartbeat(
resp, err = wh.GetHistoryClient().RecordActivityTaskHeartbeat(ctx, &h.RecordActivityTaskHeartbeatRequest{
DomainUUID: common.StringPtr(taskToken.DomainID),
HeartbeatRequest: heartbeatRequest,
})
}, versionHeaders(ctx)...)
if err != nil {
return nil, wh.error(err, scope)
}
Expand Down Expand Up @@ -730,7 +730,7 @@ func (wh *WorkflowHandler) RecordActivityTaskHeartbeatByID(
err = wh.GetHistoryClient().RespondActivityTaskFailed(ctx, &h.RespondActivityTaskFailedRequest{
DomainUUID: common.StringPtr(taskToken.DomainID),
FailedRequest: failRequest,
})
}, versionHeaders(ctx)...)
if err != nil {
return nil, wh.error(err, scope)
}
Expand All @@ -745,7 +745,7 @@ func (wh *WorkflowHandler) RecordActivityTaskHeartbeatByID(
resp, err = wh.GetHistoryClient().RecordActivityTaskHeartbeat(ctx, &h.RecordActivityTaskHeartbeatRequest{
DomainUUID: common.StringPtr(taskToken.DomainID),
HeartbeatRequest: req,
})
}, versionHeaders(ctx)...)
if err != nil {
return nil, wh.error(err, scope)
}
Expand Down Expand Up @@ -823,15 +823,15 @@ func (wh *WorkflowHandler) RespondActivityTaskCompleted(
err = wh.GetHistoryClient().RespondActivityTaskFailed(ctx, &h.RespondActivityTaskFailedRequest{
DomainUUID: common.StringPtr(taskToken.DomainID),
FailedRequest: failRequest,
})
}, versionHeaders(ctx)...)
if err != nil {
return wh.error(err, scope)
}
} else {
err = wh.GetHistoryClient().RespondActivityTaskCompleted(ctx, &h.RespondActivityTaskCompletedRequest{
DomainUUID: common.StringPtr(taskToken.DomainID),
CompleteRequest: completeRequest,
})
}, versionHeaders(ctx)...)
if err != nil {
return wh.error(err, scope)
}
Expand Down Expand Up @@ -926,7 +926,7 @@ func (wh *WorkflowHandler) RespondActivityTaskCompletedByID(
err = wh.GetHistoryClient().RespondActivityTaskFailed(ctx, &h.RespondActivityTaskFailedRequest{
DomainUUID: common.StringPtr(taskToken.DomainID),
FailedRequest: failRequest,
})
}, versionHeaders(ctx)...)
if err != nil {
return wh.error(err, scope)
}
Expand All @@ -940,7 +940,7 @@ func (wh *WorkflowHandler) RespondActivityTaskCompletedByID(
err = wh.GetHistoryClient().RespondActivityTaskCompleted(ctx, &h.RespondActivityTaskCompletedRequest{
DomainUUID: common.StringPtr(taskToken.DomainID),
CompleteRequest: req,
})
}, versionHeaders(ctx)...)
if err != nil {
return wh.error(err, scope)
}
Expand Down Expand Up @@ -1017,7 +1017,7 @@ func (wh *WorkflowHandler) RespondActivityTaskFailed(
err = wh.GetHistoryClient().RespondActivityTaskFailed(ctx, &h.RespondActivityTaskFailedRequest{
DomainUUID: common.StringPtr(taskToken.DomainID),
FailedRequest: failedRequest,
})
}, versionHeaders(ctx)...)
if err != nil {
return wh.error(err, scope)
}
Expand Down Expand Up @@ -1114,7 +1114,7 @@ func (wh *WorkflowHandler) RespondActivityTaskFailedByID(
err = wh.GetHistoryClient().RespondActivityTaskFailed(ctx, &h.RespondActivityTaskFailedRequest{
DomainUUID: common.StringPtr(taskToken.DomainID),
FailedRequest: req,
})
}, versionHeaders(ctx)...)
if err != nil {
return wh.error(err, scope)
}
Expand Down Expand Up @@ -1191,15 +1191,15 @@ func (wh *WorkflowHandler) RespondActivityTaskCanceled(
err = wh.GetHistoryClient().RespondActivityTaskFailed(ctx, &h.RespondActivityTaskFailedRequest{
DomainUUID: common.StringPtr(taskToken.DomainID),
FailedRequest: failRequest,
})
}, versionHeaders(ctx)...)
if err != nil {
return wh.error(err, scope)
}
} else {
err = wh.GetHistoryClient().RespondActivityTaskCanceled(ctx, &h.RespondActivityTaskCanceledRequest{
DomainUUID: common.StringPtr(taskToken.DomainID),
CancelRequest: cancelRequest,
})
}, versionHeaders(ctx)...)
if err != nil {
return wh.error(err, scope)
}
Expand Down Expand Up @@ -1293,7 +1293,7 @@ func (wh *WorkflowHandler) RespondActivityTaskCanceledByID(
err = wh.GetHistoryClient().RespondActivityTaskFailed(ctx, &h.RespondActivityTaskFailedRequest{
DomainUUID: common.StringPtr(taskToken.DomainID),
FailedRequest: failRequest,
})
}, versionHeaders(ctx)...)
if err != nil {
return wh.error(err, scope)
}
Expand All @@ -1307,7 +1307,7 @@ func (wh *WorkflowHandler) RespondActivityTaskCanceledByID(
err = wh.GetHistoryClient().RespondActivityTaskCanceled(ctx, &h.RespondActivityTaskCanceledRequest{
DomainUUID: common.StringPtr(taskToken.DomainID),
CancelRequest: req,
})
}, versionHeaders(ctx)...)
if err != nil {
return wh.error(err, scope)
}
Expand Down Expand Up @@ -1362,6 +1362,7 @@ func (wh *WorkflowHandler) RespondDecisionTaskCompleted(
histResp, err := wh.GetHistoryClient().RespondDecisionTaskCompleted(ctx, &h.RespondDecisionTaskCompletedRequest{
DomainUUID: common.StringPtr(taskToken.DomainID),
CompleteRequest: completeRequest},
versionHeaders(ctx)...,
)
if err != nil {
return nil, wh.error(err, scope)
Expand Down Expand Up @@ -1464,7 +1465,7 @@ func (wh *WorkflowHandler) RespondDecisionTaskFailed(
err = wh.GetHistoryClient().RespondDecisionTaskFailed(ctx, &h.RespondDecisionTaskFailedRequest{
DomainUUID: common.StringPtr(taskToken.DomainID),
FailedRequest: failedRequest,
})
}, versionHeaders(ctx)...)
if err != nil {
return wh.error(err, scope)
}
Expand Down Expand Up @@ -1547,7 +1548,7 @@ func (wh *WorkflowHandler) RespondQueryTaskCompleted(
CompletedRequest: completeRequest,
}

err = wh.GetMatchingClient().RespondQueryTaskCompleted(ctx, matchingRequest)
err = wh.GetMatchingClient().RespondQueryTaskCompleted(ctx, matchingRequest, versionHeaders(ctx)...)
if err != nil {
return wh.error(err, scope)
}
Expand Down Expand Up @@ -1692,7 +1693,7 @@ func (wh *WorkflowHandler) StartWorkflowExecution(
}

wh.GetLogger().Debug("Start workflow execution request domainID", tag.WorkflowDomainID(domainID))
resp, err = wh.GetHistoryClient().StartWorkflowExecution(ctx, common.CreateHistoryStartWorkflowRequest(domainID, startRequest))
resp, err = wh.GetHistoryClient().StartWorkflowExecution(ctx, common.CreateHistoryStartWorkflowRequest(domainID, startRequest), versionHeaders(ctx)...)

if err != nil {
return nil, wh.error(err, scope)
Expand Down Expand Up @@ -1763,7 +1764,7 @@ func (wh *WorkflowHandler) GetWorkflowExecutionRawHistory(
Execution: execution,
ExpectedNextEventId: nil,
CurrentBranchToken: currentBranchToken,
})
}, versionHeaders(ctx)...)

if err != nil {
return nil, "", 0, err
Expand Down Expand Up @@ -1919,7 +1920,7 @@ func (wh *WorkflowHandler) GetWorkflowExecutionHistory(
Execution: execution,
ExpectedNextEventId: common.Int64Ptr(expectedNextEventID),
CurrentBranchToken: currentBranchToken,
})
}, versionHeaders(ctx)...)

if err != nil {
return nil, "", 0, 0, false, err
Expand Down Expand Up @@ -2131,7 +2132,7 @@ func (wh *WorkflowHandler) SignalWorkflowExecution(
err = wh.GetHistoryClient().SignalWorkflowExecution(ctx, &h.SignalWorkflowExecutionRequest{
DomainUUID: common.StringPtr(domainID),
SignalRequest: signalRequest,
})
}, versionHeaders(ctx)...)
if err != nil {
return wh.error(err, scope)
}
Expand Down Expand Up @@ -2266,7 +2267,7 @@ func (wh *WorkflowHandler) SignalWithStartWorkflowExecution(
resp, err = wh.GetHistoryClient().SignalWithStartWorkflowExecution(ctx, &h.SignalWithStartWorkflowExecutionRequest{
DomainUUID: common.StringPtr(domainID),
SignalWithStartRequest: signalWithStartRequest,
})
}, versionHeaders(ctx)...)
return err
}

Expand Down Expand Up @@ -2317,7 +2318,7 @@ func (wh *WorkflowHandler) TerminateWorkflowExecution(
err = wh.GetHistoryClient().TerminateWorkflowExecution(ctx, &h.TerminateWorkflowExecutionRequest{
DomainUUID: common.StringPtr(domainID),
TerminateRequest: terminateRequest,
})
}, versionHeaders(ctx)...)
if err != nil {
return wh.error(err, scope)
}
Expand Down Expand Up @@ -2364,7 +2365,7 @@ func (wh *WorkflowHandler) ResetWorkflowExecution(
resp, err = wh.GetHistoryClient().ResetWorkflowExecution(ctx, &h.ResetWorkflowExecutionRequest{
DomainUUID: common.StringPtr(domainID),
ResetRequest: resetRequest,
})
}, versionHeaders(ctx)...)
if err != nil {
return nil, wh.error(err, scope)
}
Expand Down Expand Up @@ -2410,7 +2411,7 @@ func (wh *WorkflowHandler) RequestCancelWorkflowExecution(
err = wh.GetHistoryClient().RequestCancelWorkflowExecution(ctx, &h.RequestCancelWorkflowExecutionRequest{
DomainUUID: common.StringPtr(domainID),
CancelRequest: cancelRequest,
})
}, versionHeaders(ctx)...)
if err != nil {
return wh.error(err, scope)
}
Expand Down Expand Up @@ -2973,7 +2974,7 @@ func (wh *WorkflowHandler) ResetStickyTaskList(
_, err = wh.GetHistoryClient().ResetStickyTaskList(ctx, &h.ResetStickyTaskListRequest{
DomainUUID: common.StringPtr(domainID),
Execution: resetRequest.Execution,
})
}, versionHeaders(ctx)...)
if err != nil {
return nil, wh.error(err, scope)
}
Expand Down Expand Up @@ -3041,7 +3042,7 @@ func (wh *WorkflowHandler) QueryWorkflow(
DomainUUID: common.StringPtr(domainID),
Request: queryRequest,
}
hResponse, err := wh.GetHistoryClient().QueryWorkflow(ctx, req)
hResponse, err := wh.GetHistoryClient().QueryWorkflow(ctx, req, versionHeaders(ctx)...)
if err != nil {
return nil, wh.error(err, scope)
}
Expand Down Expand Up @@ -3085,7 +3086,7 @@ func (wh *WorkflowHandler) DescribeWorkflowExecution(
response, err := wh.GetHistoryClient().DescribeWorkflowExecution(ctx, &h.DescribeWorkflowExecutionRequest{
DomainUUID: common.StringPtr(domainID),
Request: request,
})
}, versionHeaders(ctx)...)

if err != nil {
return nil, wh.error(err, scope)
Expand Down Expand Up @@ -3140,7 +3141,7 @@ func (wh *WorkflowHandler) DescribeTaskList(
response, err = wh.GetMatchingClient().DescribeTaskList(ctx, &m.DescribeTaskListRequest{
DomainUUID: common.StringPtr(domainID),
DescRequest: request,
})
}, versionHeaders(ctx)...)
return err
}

Expand Down Expand Up @@ -3178,7 +3179,7 @@ func (wh *WorkflowHandler) ListTaskListPartitions(ctx context.Context, request *
resp, err := wh.GetMatchingClient().ListTaskListPartitions(ctx, &m.ListTaskListPartitionsRequest{
Domain: request.Domain,
TaskList: request.TaskList,
})
}, versionHeaders(ctx)...)
return resp, err
}

Expand Down Expand Up @@ -3651,7 +3652,7 @@ func (wh *WorkflowHandler) historyArchived(ctx context.Context, request *gen.Get
DomainUUID: common.StringPtr(domainID),
Execution: request.Execution,
}
_, err := wh.GetHistoryClient().GetMutableState(ctx, getMutableStateRequest)
_, err := wh.GetHistoryClient().GetMutableState(ctx, getMutableStateRequest, versionHeaders(ctx)...)
if err == nil {
return false
}
Expand Down Expand Up @@ -3777,3 +3778,14 @@ type domainWrapper struct {
func (d domainWrapper) GetDomain() string {
return d.domain
}

// TODO: Remove this func after history and matching services gRPC migration is complete
// It sets version headers in YARPC format
func versionHeaders(ctx context.Context) []yarpc.CallOption {
headers := client.GetHeadersValue(ctx, common.LibraryVersionHeaderName, common.FeatureVersionHeaderName, common.ClientImplHeaderName)
return []yarpc.CallOption{
yarpc.WithHeader(common.LibraryVersionHeaderName, headers[0]),
yarpc.WithHeader(common.FeatureVersionHeaderName, headers[1]),
yarpc.WithHeader(common.ClientImplHeaderName, headers[2]),
}
}
7 changes: 4 additions & 3 deletions service/frontend/workflowHandler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -834,7 +834,8 @@ func (s *workflowHandlerSuite) TestHistoryArchived() {
}
s.False(wh.historyArchived(context.Background(), getHistoryRequest, "test-domain"))

s.mockHistoryClient.EXPECT().GetMutableState(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1)
// TODO: remove last 3 `gomock.Any()` after YARPC migration
s.mockHistoryClient.EXPECT().GetMutableState(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).Times(1)
getHistoryRequest = &shared.GetWorkflowExecutionHistoryRequest{
Execution: &shared.WorkflowExecution{
WorkflowId: common.StringPtr(testWorkflowID),
Expand All @@ -843,7 +844,7 @@ func (s *workflowHandlerSuite) TestHistoryArchived() {
}
s.False(wh.historyArchived(context.Background(), getHistoryRequest, "test-domain"))

s.mockHistoryClient.EXPECT().GetMutableState(gomock.Any(), gomock.Any()).Return(nil, &shared.EntityNotExistsError{Message: "got archival indication error"}).Times(1)
s.mockHistoryClient.EXPECT().GetMutableState(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, &shared.EntityNotExistsError{Message: "got archival indication error"}).Times(1)
getHistoryRequest = &shared.GetWorkflowExecutionHistoryRequest{
Execution: &shared.WorkflowExecution{
WorkflowId: common.StringPtr(testWorkflowID),
Expand All @@ -852,7 +853,7 @@ func (s *workflowHandlerSuite) TestHistoryArchived() {
}
s.True(wh.historyArchived(context.Background(), getHistoryRequest, "test-domain"))

s.mockHistoryClient.EXPECT().GetMutableState(gomock.Any(), gomock.Any()).Return(nil, errors.New("got non-archival indication error")).Times(1)
s.mockHistoryClient.EXPECT().GetMutableState(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errors.New("got non-archival indication error")).Times(1)
getHistoryRequest = &shared.GetWorkflowExecutionHistoryRequest{
Execution: &shared.WorkflowExecution{
WorkflowId: common.StringPtr(testWorkflowID),
Expand Down

0 comments on commit b8b1448

Please sign in to comment.