diff --git a/core/mr/mapreduce.go b/core/mr/mapreduce.go index 4d71e89a614e..517c94b893fa 100644 --- a/core/mr/mapreduce.go +++ b/core/mr/mapreduce.go @@ -16,7 +16,10 @@ const ( minWorkers = 1 ) -var ErrCancelWithNil = errors.New("mapreduce cancelled with nil") +var ( + ErrCancelWithNil = errors.New("mapreduce cancelled with nil") + ErrReduceNoOutput = errors.New("reduce not writing value") +) type ( GenerateFunc func(source chan<- interface{}) @@ -93,7 +96,14 @@ func MapReduceWithSource(source <-chan interface{}, mapper MapperFunc, reducer R collector := make(chan interface{}, options.workers) done := syncx.NewDoneChan() writer := newGuardedWriter(output, done.Done()) + var closeOnce sync.Once var retErr errorx.AtomicError + finish := func() { + closeOnce.Do(func() { + done.Close() + close(output) + }) + } cancel := once(func(err error) { if err != nil { retErr.Set(err) @@ -102,14 +112,15 @@ func MapReduceWithSource(source <-chan interface{}, mapper MapperFunc, reducer R } drain(source) - done.Close() - close(output) + finish() }) go func() { defer func() { if r := recover(); r != nil { cancel(fmt.Errorf("%v", r)) + } else { + finish() } }() reducer(collector, writer, cancel) @@ -122,7 +133,7 @@ func MapReduceWithSource(source <-chan interface{}, mapper MapperFunc, reducer R } else if ok { return value, nil } else { - return nil, nil + return nil, ErrReduceNoOutput } } diff --git a/example/mapreduce/deadlock/main.go b/example/mapreduce/deadlock/main.go new file mode 100644 index 000000000000..60d3fd13f2a1 --- /dev/null +++ b/example/mapreduce/deadlock/main.go @@ -0,0 +1,40 @@ +package main + +import ( + "log" + "strconv" + + "github.com/tal-tech/go-zero/core/mr" +) + +type User struct { + Uid int + Name string +} + +func main() { + uids := []int{111, 222, 333} + res, err := mr.MapReduce(func(source chan<- interface{}) { + for _, uid := range uids { + source <- uid + } + }, func(item interface{}, writer mr.Writer, cancel func(error)) { + uid := item.(int) + user := &User{ + Uid: uid, + Name: strconv.Itoa(uid), + } + writer.Write(user) + }, func(pipe <-chan interface{}, writer mr.Writer, cancel func(error)) { + var users []*User + for p := range pipe { + users = append(users, p.(*User)) + } + // missing writer.Write(...), should not panic + }) + if err != nil { + log.Print(err) + return + } + log.Print(len(res.([]*User))) +}