diff --git a/shards/shards.go b/shards/shards.go index 74a84e3e..c69d649f 100644 --- a/shards/shards.go +++ b/shards/shards.go @@ -27,6 +27,7 @@ import ( "sync" "time" + "golang.org/x/sync/errgroup" "golang.org/x/sync/semaphore" "github.com/prometheus/client_golang/prometheus" @@ -873,20 +874,19 @@ type shardListResult struct { err error } -func listOneShard(ctx context.Context, s zoekt.Searcher, q query.Q, opts *zoekt.ListOptions, sink chan shardListResult) { +func listOneShard(ctx context.Context, s zoekt.Searcher, q query.Q, opts *zoekt.ListOptions) (result *zoekt.RepoList, _ error) { metricListShardRunning.Inc() defer func() { metricListShardRunning.Dec() + // If we panic, we log the panic and set Crashes (but do not return an + // error). if r := recover(); r != nil { log.Printf("crashed shard: %s: %s, %s", s.String(), r, debug.Stack()) - sink <- shardListResult{ - &zoekt.RepoList{Crashes: 1}, nil, - } + result = &zoekt.RepoList{Crashes: 1} } }() - ms, err := s.List(ctx, q, opts) - sink <- shardListResult{ms, err} + return s.List(ctx, q, opts) } func (ss *shardedSearcher) List(ctx context.Context, r query.Q, opts *zoekt.ListOptions) (rl *zoekt.RepoList, err error) { @@ -945,36 +945,65 @@ func (ss *shardedSearcher) List(ctx context.Context, r query.Q, opts *zoekt.List return &agg, nil } - shardCount := len(shards) - all := make(chan shardListResult, shardCount) - tr.LazyPrintf("shardCount: %d", len(shards)) + // We use an errgroup so that when an error is encountered we can stop + // feeder and report the first error seen. + ctx, cancel := context.WithCancel(ctx) + defer cancel() + g, ctx := errgroup.WithContext(ctx) - feeder := make(chan zoekt.Searcher, len(shards)) - for _, s := range shards { - feeder <- s - } - close(feeder) + // Bound work by number of CPUs. + workers := min(runtime.GOMAXPROCS(0), len(shards)) - for i := 0; i < runtime.GOMAXPROCS(0); i++ { - go func() { + var ( + feeder = make(chan zoekt.Searcher, workers) + all = make(chan *zoekt.RepoList, workers) + ) + + // Send shards to feeder until context is canceled. + g.Go(func() error { + defer close(feeder) + for _, s := range shards { + // If context is canceled we stop consuming from shards and cancel the + // errgroup. + if err := proc.Yield(ctx); err != nil { + return err + } + feeder <- s + } + return nil + }) + + // Start up workers goroutines to consume feeder, do listing of a shard and + // send results down all. If an error is encountered we cancel the errgroup. + for range workers { + g.Go(func() error { for s := range feeder { - listOneShard(ctx, s, r, opts, all) + result, err := listOneShard(ctx, s, r, opts) + if err != nil { + return err + } + all <- result } - }() + return nil + }) } - uniq := map[string]*zoekt.RepoListEntry{} - - for range shards { - r := <-all - if r.err != nil { - return nil, r.err - } + // Once all goroutines in errgroup is done, we know nothing more will be + // sent to all so close it. We rely on this sync point such that workersErr + // will be written to before we are finished reading from all. + var workersErr error + go func() { + workersErr = g.Wait() + close(all) + }() - agg.Crashes += r.rl.Crashes - agg.Stats.Add(&r.rl.Stats) + // Aggregate results from all. + uniq := map[string]*zoekt.RepoListEntry{} + for rl := range all { + agg.Crashes += rl.Crashes + agg.Stats.Add(&rl.Stats) - for _, r := range r.rl.Repos { + for _, r := range rl.Repos { prev, ok := uniq[r.Repository.Name] if !ok { cp := *r // We need to copy because we mutate r.Stats when merging duplicates @@ -984,7 +1013,7 @@ func (ss *shardedSearcher) List(ctx context.Context, r query.Q, opts *zoekt.List } } - for id, r := range r.rl.ReposMap { + for id, r := range rl.ReposMap { _, ok := agg.ReposMap[id] if !ok { agg.ReposMap[id] = r @@ -992,6 +1021,11 @@ func (ss *shardedSearcher) List(ctx context.Context, r query.Q, opts *zoekt.List } } + // workersErr will now be set since all is closed. + if workersErr != nil { + return nil, workersErr + } + agg.Repos = make([]*zoekt.RepoListEntry, 0, len(uniq)) for _, r := range uniq { agg.Repos = append(agg.Repos, r)