Skip to content

Commit

Permalink
shards: respect scheduler and use smarter synchronization for List
Browse files Browse the repository at this point in the history
Previously List would never call proc.Yield, which broke the
co-operative scheduler in the case of slow List calls. Additionally, we
used a naive concurrency (a large buffered channel) which shows up in
the profiler as 30% of CPU spent on chan_ related operations under List.

This commit follows how Search used to respect proc.Yield. See sched.go
in 90ed7bf. We did not copy Search
since it uses a more complicated implementation than we need since it
supports streaming, while List is still batch only.

We needed to use errgroup to ensure we drained all channels in the case
of an error. Previously we did not need to do this since the channels
had a buffer size of len(shards), which gaurenteed nothing would ever
block. Now channels are never larger than the number of workers (<=
GOMAXPROCS).

Test Plan: go test covers the no error cases. In the case of errors we
manually tested by running zoekt-webserver and adding a random context
cancellation. We observed the error being reported and no List
goroutines running.

Co-authored-by: William Bezuidenhout <[email protected]>
  • Loading branch information
keegancsmith and burmudar committed Mar 26, 2024
1 parent 06de7ad commit 753239b
Showing 1 changed file with 63 additions and 29 deletions.
92 changes: 63 additions & 29 deletions shards/shards.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"sync"
"time"

"golang.org/x/sync/errgroup"
"golang.org/x/sync/semaphore"

"github.com/prometheus/client_golang/prometheus"
Expand Down Expand Up @@ -872,20 +873,19 @@ type shardListResult struct {
err error
}

func listOneShard(ctx context.Context, s zoekt.Searcher, q query.Q, opts *zoekt.ListOptions, sink chan shardListResult) {
func listOneShard(ctx context.Context, s zoekt.Searcher, q query.Q, opts *zoekt.ListOptions) (result *zoekt.RepoList, _ error) {
metricListShardRunning.Inc()
defer func() {
metricListShardRunning.Dec()
// If we panic, we log the panic and set Crashes (but do not return an
// error).
if r := recover(); r != nil {
log.Printf("crashed shard: %s: %s, %s", s.String(), r, debug.Stack())
sink <- shardListResult{
&zoekt.RepoList{Crashes: 1}, nil,
}
result = &zoekt.RepoList{Crashes: 1}
}
}()

ms, err := s.List(ctx, q, opts)
sink <- shardListResult{ms, err}
return s.List(ctx, q, opts)
}

func (ss *shardedSearcher) List(ctx context.Context, r query.Q, opts *zoekt.ListOptions) (rl *zoekt.RepoList, err error) {
Expand Down Expand Up @@ -944,36 +944,65 @@ func (ss *shardedSearcher) List(ctx context.Context, r query.Q, opts *zoekt.List
return &agg, nil
}

shardCount := len(shards)
all := make(chan shardListResult, shardCount)
tr.LazyPrintf("shardCount: %d", len(shards))
// We use an errgroup so that when an error is encountered we can stop
// feeder and report the first error seen.
ctx, cancel := context.WithCancel(ctx)
defer cancel()
g, ctx := errgroup.WithContext(ctx)

feeder := make(chan zoekt.Searcher, len(shards))
for _, s := range shards {
feeder <- s
}
close(feeder)
// Bound work by number of CPUs.
workers := min(runtime.GOMAXPROCS(0), len(shards))

for i := 0; i < runtime.GOMAXPROCS(0); i++ {
go func() {
var (
feeder = make(chan zoekt.Searcher, workers)
all = make(chan *zoekt.RepoList, workers)
)

// Send shards to feeder until context is canceled.
g.Go(func() error {
defer close(feeder)
for _, s := range shards {
// If context is canceled we stop consuming from shards and cancel the
// errgroup.
if err := proc.Yield(ctx); err != nil {
return err
}
feeder <- s
}
return nil
})

// Start up workers goroutines to consume feeder, do listing of a shard and
// send results down all. If an error is encountered we cancel the errgroup.
for range workers {
g.Go(func() error {
for s := range feeder {
listOneShard(ctx, s, r, opts, all)
result, err := listOneShard(ctx, s, r, opts)
if err != nil {
return err
}
all <- result
}
}()
return nil
})
}

uniq := map[string]*zoekt.RepoListEntry{}

for range shards {
r := <-all
if r.err != nil {
return nil, r.err
}
// Once all goroutines in errgroup is done, we know nothing more will be
// sent to all so close it. We rely on this sync point such that workersErr
// will be written to before we are finished reading from all.
var workersErr error
go func() {
workersErr = g.Wait()
close(all)
}()

agg.Crashes += r.rl.Crashes
agg.Stats.Add(&r.rl.Stats)
// Aggregate results from all.
uniq := map[string]*zoekt.RepoListEntry{}
for rl := range all {
agg.Crashes += rl.Crashes
agg.Stats.Add(&rl.Stats)

for _, r := range r.rl.Repos {
for _, r := range rl.Repos {
prev, ok := uniq[r.Repository.Name]
if !ok {
cp := *r // We need to copy because we mutate r.Stats when merging duplicates
Expand All @@ -983,14 +1012,19 @@ func (ss *shardedSearcher) List(ctx context.Context, r query.Q, opts *zoekt.List
}
}

for id, r := range r.rl.ReposMap {
for id, r := range rl.ReposMap {
_, ok := agg.ReposMap[id]
if !ok {
agg.ReposMap[id] = r
}
}
}

// workersErr will now be set since all is closed.
if workersErr != nil {
return nil, workersErr
}

agg.Repos = make([]*zoekt.RepoListEntry, 0, len(uniq))
for _, r := range uniq {
agg.Repos = append(agg.Repos, r)
Expand Down

0 comments on commit 753239b

Please sign in to comment.