Skip to content

Commit

Permalink
Expand WF context locking to cover WFT responses
Browse files Browse the repository at this point in the history
Failure to hold this lock while responding to a workflow task allows
in-flight tasks to race past each other which has led to history
corruption in the case where responses are not deduplicated. Furthermore
the correctness of resetting the event level while not holding this lock
is unclear at best.
  • Loading branch information
Matt McShane committed Jul 29, 2023
1 parent 746bcf2 commit 685c6a1
Show file tree
Hide file tree
Showing 8 changed files with 402 additions and 98 deletions.
17 changes: 16 additions & 1 deletion internal/internal_public.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,30 @@ type (

// WorkflowTaskHandler represents workflow task handlers.
WorkflowTaskHandler interface {
WorkflowContextManager

// Processes the workflow task
// The response could be:
// - RespondWorkflowTaskCompletedRequest
// - RespondWorkflowTaskFailedRequest
// - RespondQueryTaskCompletedRequest
ProcessWorkflowTask(
task *workflowTask,
ctx *workflowExecutionContextImpl,
f workflowTaskHeartbeatFunc,
) (response interface{}, resetter EventLevelResetter, err error)
) (response interface{}, err error)
}

WorkflowContextManager interface {
// GetOrCreateWorkflowContext finds an existing cached context object
// for the provided task's run ID or creates a new object, adds it to
// cache, and returns it. In all non-error cases the returned context
// object is in a locked state (i.e.
// workflowExecutionContextImpl.Lock() has been called).
GetOrCreateWorkflowContext(
task *workflowservice.PollWorkflowTaskQueueResponse,
historyIterator HistoryIterator,
) (*workflowExecutionContextImpl, error)
}

// ActivityTaskHandler represents activity task handlers.
Expand Down
40 changes: 19 additions & 21 deletions internal/internal_task_handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -474,11 +474,19 @@ func newWorkflowExecutionContext(
return workflowContext
}

// Lock acquires the lock on this context object, use Unlock(error) to release
// the lock.
func (w *workflowExecutionContextImpl) Lock() {
w.mutex.Lock()
}

// Unlock cleans up after the provided error and it's own internal view of the
// workflow error state by clearing itself and removing itself from cache as
// needed. It is an error to call this function without having called the Lock
// function first and the behavior is undefined. Regardless of the error
// handling involved, the context will be unlocked when this call returns.
func (w *workflowExecutionContextImpl) Unlock(err error) {
defer w.mutex.Unlock()
if err != nil || w.err != nil || w.isWorkflowCompleted ||
(w.wth.cache.MaxWorkflowCacheSize() <= 0 && !w.hasPendingLocalActivityWork()) {
// TODO: in case of closed, it asumes the close command always succeed. need server side change to return
Expand All @@ -496,8 +504,6 @@ func (w *workflowExecutionContextImpl) Unlock(err error) {
// exited
w.clearState()
}

w.mutex.Unlock()
}

func (w *workflowExecutionContextImpl) getEventHandler() *workflowExecutionEventHandlerImpl {
Expand Down Expand Up @@ -631,7 +637,7 @@ func (wth *workflowTaskHandlerImpl) createWorkflowContext(task *workflowservice.
return newWorkflowExecutionContext(workflowInfo, wth), nil
}

func (wth *workflowTaskHandlerImpl) getOrCreateWorkflowContext(
func (wth *workflowTaskHandlerImpl) GetOrCreateWorkflowContext(
task *workflowservice.PollWorkflowTaskQueueResponse,
historyIterator HistoryIterator,
) (workflowContext *workflowExecutionContextImpl, err error) {
Expand Down Expand Up @@ -756,10 +762,11 @@ func (w *workflowExecutionContextImpl) resetStateIfDestroyed(task *workflowservi
// ProcessWorkflowTask processes all the events of the workflow task.
func (wth *workflowTaskHandlerImpl) ProcessWorkflowTask(
workflowTask *workflowTask,
workflowContext *workflowExecutionContextImpl,
heartbeatFunc workflowTaskHeartbeatFunc,
) (completeRequest interface{}, resetter EventLevelResetter, errRet error) {
) (completeRequest interface{}, errRet error) {
if workflowTask == nil || workflowTask.task == nil {
return nil, nil, errors.New("nil workflow task provided")
return nil, errors.New("nil workflow task provided")
}
task := workflowTask.task
if task.History == nil || len(task.History.Events) == 0 {
Expand All @@ -768,11 +775,11 @@ func (wth *workflowTaskHandlerImpl) ProcessWorkflowTask(
}
}
if task.Query == nil && len(task.History.Events) == 0 {
return nil, nil, errors.New("nil or empty history")
return nil, errors.New("nil or empty history")
}

if task.Query != nil && len(task.Queries) != 0 {
return nil, nil, errors.New("invalid query workflow task")
return nil, errors.New("invalid query workflow task")
}

runID := task.WorkflowExecution.GetRunId()
Expand All @@ -786,18 +793,12 @@ func (wth *workflowTaskHandlerImpl) ProcessWorkflowTask(
tagPreviousStartedEventID, task.GetPreviousStartedEventId())
})

workflowContext, err := wth.getOrCreateWorkflowContext(task, workflowTask.historyIterator)
if err != nil {
return nil, nil, err
}

defer func() {
workflowContext.Unlock(errRet)
}()

var response interface{}
var (
response interface{}
err error
heartbeatTimer *time.Timer
)

var heartbeatTimer *time.Timer
defer func() {
if heartbeatTimer != nil {
heartbeatTimer.Stop()
Expand Down Expand Up @@ -882,7 +883,6 @@ processWorkflowLoop:
}
errRet = err
completeRequest = response
resetter = workflowContext.SetPreviousStartedEventID
return
}

Expand Down Expand Up @@ -1250,8 +1250,6 @@ func (w *workflowExecutionContextImpl) SetCurrentTask(task *workflowservice.Poll
}

func (w *workflowExecutionContextImpl) SetPreviousStartedEventID(eventID int64) {
w.mutex.Lock() // This call can race against the cache eviction thread - see clearState
defer w.mutex.Unlock()
w.previousStartedEventID = eventID
}

Expand Down
14 changes: 11 additions & 3 deletions internal/internal_task_handlers_interfaces_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,19 @@ type sampleWorkflowTaskHandler struct{}

func (wth sampleWorkflowTaskHandler) ProcessWorkflowTask(
workflowTask *workflowTask,
_ *workflowExecutionContextImpl,
_ workflowTaskHeartbeatFunc,
) (interface{}, EventLevelResetter, error) {
) (interface{}, error) {
return &workflowservice.RespondWorkflowTaskCompletedRequest{
TaskToken: workflowTask.task.TaskToken,
}, nil, nil
}, nil
}

func (wth sampleWorkflowTaskHandler) GetOrCreateWorkflowContext(
task *workflowservice.PollWorkflowTaskQueueResponse,
historyIterator HistoryIterator,
) (*workflowExecutionContextImpl, error) {
return nil, nil
}

func newSampleWorkflowTaskHandler() *sampleWorkflowTaskHandler {
Expand Down Expand Up @@ -115,7 +123,7 @@ func (s *PollLayerInterfacesTestSuite) TestProcessWorkflowTaskInterface() {

// Process task and respond to the service.
taskHandler := newSampleWorkflowTaskHandler()
request, _, err := taskHandler.ProcessWorkflowTask(&workflowTask{task: response}, nil)
request, err := taskHandler.ProcessWorkflowTask(&workflowTask{task: response}, nil, nil)
completionRequest := request.(*workflowservice.RespondWorkflowTaskCompletedRequest)
s.NoError(err)

Expand Down
Loading

0 comments on commit 685c6a1

Please sign in to comment.