diff --git a/api.go b/api.go index faa7705d4..6d2497d10 100644 --- a/api.go +++ b/api.go @@ -1032,6 +1032,15 @@ type Sender interface { Send(*SearchResult) } +// SenderFunc is an adapter to allow the use of ordinary functions as Sender. +// If f is a function with the appropriate signature, SenderFunc(f) is a Sender +// that calls f. +type SenderFunc func(result *SearchResult) + +func (f SenderFunc) Send(result *SearchResult) { + f(result) +} + // Streamer adds the method StreamSearch to the Searcher interface. type Streamer interface { Searcher diff --git a/cmd/zoekt-webserver/grpc/server/sampling.go b/cmd/zoekt-webserver/grpc/server/sampling.go new file mode 100644 index 000000000..525dc4ebe --- /dev/null +++ b/cmd/zoekt-webserver/grpc/server/sampling.go @@ -0,0 +1,61 @@ +package server + +import ( + "math" + + "github.com/sourcegraph/zoekt" +) + +// newSamplingSender is a zoekt.Sender that samples stats events to avoid +// sending many empty stats events over the wire. +func newSamplingSender(next zoekt.Sender) *samplingSender { + return &samplingSender{next: next} +} + +type samplingSender struct { + next zoekt.Sender + agg zoekt.SearchResult + aggCount int +} + +func (s *samplingSender) Send(event *zoekt.SearchResult) { + // We don't want to send events over the wire if they don't contain file + // matches. Hence, in case we didn't find any results, we aggregate the stats + // and send them out in regular intervals. + if len(event.Files) == 0 { + s.aggCount++ + + s.agg.Stats.Add(event.Stats) + s.agg.Progress = event.Progress + + if s.aggCount%100 == 0 && !s.agg.Stats.Zero() { + s.next.Send(&s.agg) + s.agg = zoekt.SearchResult{} + } + + return + } + + // If we have aggregate stats, we merge them with the new event before sending + // it. We drop agg.Progress, because we assume that event.Progress reflects the + // latest status. + if !s.agg.Stats.Zero() { + event.Stats.Add(s.agg.Stats) + s.agg = zoekt.SearchResult{} + } + + s.next.Send(event) +} + +// Flush sends any aggregated stats that we haven't sent yet +func (s *samplingSender) Flush() { + if !s.agg.Stats.Zero() { + s.next.Send(&zoekt.SearchResult{ + Stats: s.agg.Stats, + Progress: zoekt.Progress{ + Priority: math.Inf(-1), + MaxPendingPriority: math.Inf(-1), + }, + }) + } +} diff --git a/cmd/zoekt-webserver/grpc/server/sampling_test.go b/cmd/zoekt-webserver/grpc/server/sampling_test.go new file mode 100644 index 000000000..3d06c4be6 --- /dev/null +++ b/cmd/zoekt-webserver/grpc/server/sampling_test.go @@ -0,0 +1,72 @@ +package server + +import ( + "testing" + + "github.com/sourcegraph/zoekt" +) + +func TestSamplingStream(t *testing.T) { + nonZeroStats := zoekt.Stats{ + ContentBytesLoaded: 10, + } + filesEvent := &zoekt.SearchResult{ + Files: make([]zoekt.FileMatch, 10), + Stats: nonZeroStats, + } + fileEvents := func(n int) []*zoekt.SearchResult { + res := make([]*zoekt.SearchResult, n) + for i := 0; i < n; i++ { + res[i] = filesEvent + } + return res + } + statsEvent := &zoekt.SearchResult{ + Stats: nonZeroStats, + } + statsEvents := func(n int) []*zoekt.SearchResult { + res := make([]*zoekt.SearchResult, n) + for i := 0; i < n; i++ { + res[i] = statsEvent + } + return res + } + cases := []struct { + events []*zoekt.SearchResult + beforeFlushCount int + afterFlushCount int + }{ + // These test cases assume that the sampler only forwards + // every 100 stats-only event. In case the sampling logic + // changes, these tests are not valuable. + {nil, 0, 0}, + {fileEvents(1), 1, 1}, + {fileEvents(2), 2, 2}, + {fileEvents(200), 200, 200}, + {append(fileEvents(1), statsEvents(1)...), 1, 2}, + {append(fileEvents(1), statsEvents(2)...), 1, 2}, + {append(fileEvents(1), statsEvents(99)...), 1, 2}, + {append(fileEvents(1), statsEvents(100)...), 2, 2}, + {statsEvents(500), 5, 5}, + {statsEvents(501), 5, 6}, + } + + for _, tc := range cases { + count := 0 + ss := newSamplingSender(zoekt.SenderFunc(func(*zoekt.SearchResult) { + count += 1 + })) + + for _, event := range tc.events { + ss.Send(event) + } + if count != tc.beforeFlushCount { + t.Fatalf("expected %d events, got %d", tc.beforeFlushCount, count) + } + ss.Flush() + + if count != tc.afterFlushCount { + t.Fatalf("expected %d events, got %d", tc.afterFlushCount, count) + } + } +} diff --git a/cmd/zoekt-webserver/grpc/server/server.go b/cmd/zoekt-webserver/grpc/server/server.go index 6a1392710..812b8a2f9 100644 --- a/cmd/zoekt-webserver/grpc/server/server.go +++ b/cmd/zoekt-webserver/grpc/server/server.go @@ -11,7 +11,6 @@ import ( "github.com/sourcegraph/zoekt" "github.com/sourcegraph/zoekt/query" - "github.com/sourcegraph/zoekt/stream" ) func NewServer(s zoekt.Streamer) *Server { @@ -48,7 +47,7 @@ func (s *Server) StreamSearch(req *proto.StreamSearchRequest, ss proto.Webserver } sender := gRPCChunkSender(ss) - sampler := stream.NewSamplingSender(sender) + sampler := newSamplingSender(sender) err = s.streamer.StreamSearch(ss.Context(), q, zoekt.SearchOptionsFromProto(request.GetOpts()), sampler) if err == nil { @@ -125,5 +124,5 @@ func gRPCChunkSender(ss proto.WebserverService_StreamSearchServer) zoekt.Sender _ = chunk.SendAll(sendFunc, result.GetFiles()...) } - return stream.SenderFunc(f) + return zoekt.SenderFunc(f) } diff --git a/cmd/zoekt-webserver/main.go b/cmd/zoekt-webserver/main.go index c54afdb42..ead52cad6 100644 --- a/cmd/zoekt-webserver/main.go +++ b/cmd/zoekt-webserver/main.go @@ -57,7 +57,6 @@ import ( "github.com/sourcegraph/zoekt/internal/tracer" "github.com/sourcegraph/zoekt/query" "github.com/sourcegraph/zoekt/shards" - "github.com/sourcegraph/zoekt/stream" "github.com/sourcegraph/zoekt/trace" "github.com/sourcegraph/zoekt/web" @@ -554,7 +553,7 @@ func (s *loggedSearcher) StreamSearch( var stats zoekt.Stats metricSearchRequestsTotal.Inc() - err := s.Streamer.StreamSearch(ctx, q, opts, stream.SenderFunc(func(event *zoekt.SearchResult) { + err := s.Streamer.StreamSearch(ctx, q, opts, zoekt.SenderFunc(func(event *zoekt.SearchResult) { stats.Add(event.Stats) sender.Send(event) })) diff --git a/query/query.go b/query/query.go index d306bce06..092c1c560 100644 --- a/query/query.go +++ b/query/query.go @@ -15,8 +15,6 @@ package query import ( - "bytes" - "encoding/gob" "encoding/json" "fmt" "log" @@ -25,7 +23,6 @@ import ( "sort" "strconv" "strings" - "sync" "github.com/RoaringBitmap/roaring" "github.com/grafana/regexp" @@ -39,17 +36,6 @@ type Q interface { String() string } -// RPCUnwrap processes q to remove RPC specific elements from q. This is -// needed because gob isn't flexible enough for us. This should be called by -// RPC servers at the client/server boundary so that q works with the rest of -// zoekt. -func RPCUnwrap(q Q) Q { - if cache, ok := q.(*GobCache); ok { - return cache.Q - } - return q -} - // RawConfig filters repositories based on their encoded RawConfig map. type RawConfig uint64 @@ -462,56 +448,6 @@ func (q *Regexp) setCase(k string) { } } -// GobCache exists so we only pay the cost of marshalling a query once when we -// aggregate it out over all the replicas. -// -// Our query and eval layer do not support GobCache. Instead, at the gob -// boundaries (RPC and Streaming) we check if the Q is a GobCache and unwrap -// it. -// -// "I wish we could get rid of this code soon enough" - tomas -type GobCache struct { - Q - - once sync.Once - data []byte - err error -} - -// GobEncode implements gob.Encoder. -func (q *GobCache) GobEncode() ([]byte, error) { - q.once.Do(func() { - var buf bytes.Buffer - enc := gob.NewEncoder(&buf) - q.err = enc.Encode(&gobWrapper{ - WrappedQ: q.Q, - }) - q.data = buf.Bytes() - }) - return q.data, q.err -} - -// GobDecode implements gob.Decoder. -func (q *GobCache) GobDecode(data []byte) error { - dec := gob.NewDecoder(bytes.NewBuffer(data)) - var w gobWrapper - err := dec.Decode(&w) - if err != nil { - return err - } - q.Q = w.WrappedQ - return nil -} - -// gobWrapper is needed so the gob decoder works. -type gobWrapper struct { - WrappedQ Q -} - -func (q *GobCache) String() string { - return fmt.Sprintf("GobCache(%s)", q.Q) -} - // Or is matched when any of its children is matched. type Or struct { Children []Q diff --git a/query/query_proto.go b/query/query_proto.go index f02a6fcd9..83116d812 100644 --- a/query/query_proto.go +++ b/query/query_proto.go @@ -50,7 +50,6 @@ func QToProto(q Q) *proto.Q { return &proto.Q{Query: &proto.Q_Boost{Boost: v.ToProto()}} default: // The following nodes do not have a proto representation: - // - GobCache: only needed for Gob encoding // - caseQ: only used internally, not by the RPC layer panic(fmt.Sprintf("unknown query node %T", v)) } diff --git a/rpc/internal/srv/srv.go b/rpc/internal/srv/srv.go deleted file mode 100644 index f3391792c..000000000 --- a/rpc/internal/srv/srv.go +++ /dev/null @@ -1,71 +0,0 @@ -package srv - -import ( - "context" - "time" - - "github.com/sourcegraph/zoekt" - "github.com/sourcegraph/zoekt/query" -) - -// defaultTimeout is the maximum amount of time a search request should -// take. This is the same default used by Sourcegraph. -const defaultTimeout = 20 * time.Second - -type SearchArgs struct { - Q query.Q - Opts *zoekt.SearchOptions -} - -type SearchReply struct { - Result *zoekt.SearchResult -} - -type ListArgs struct { - Q query.Q - Opts *zoekt.ListOptions -} - -type ListReply struct { - List *zoekt.RepoList -} - -type Searcher struct { - Searcher zoekt.Searcher -} - -func (s *Searcher) Search(ctx context.Context, args *SearchArgs, reply *SearchReply) error { - // Set a timeout if the user hasn't specified one. - if args.Opts != nil && args.Opts.MaxWallTime == 0 { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, defaultTimeout) - defer cancel() - } - - if args.Q != nil { - args.Q = query.RPCUnwrap(args.Q) - } - - r, err := s.Searcher.Search(ctx, args.Q, args.Opts) - if err != nil { - return err - } - reply.Result = r - return nil -} - -func (s *Searcher) List(ctx context.Context, args *ListArgs, reply *ListReply) error { - ctx, cancel := context.WithTimeout(ctx, defaultTimeout) - defer cancel() - - if args.Q != nil { - args.Q = query.RPCUnwrap(args.Q) - } - - r, err := s.Searcher.List(ctx, args.Q, args.Opts) - if err != nil { - return err - } - reply.List = r - return nil -} diff --git a/rpc/rpc.go b/rpc/rpc.go deleted file mode 100644 index d38c27a42..000000000 --- a/rpc/rpc.go +++ /dev/null @@ -1,210 +0,0 @@ -// Package rpc provides a zoekt.Searcher over RPC. -package rpc - -import ( - "context" - "encoding/gob" - "fmt" - "net/http" - "reflect" - "strings" - "sync" - "time" - - "github.com/keegancsmith/rpc" - "github.com/sourcegraph/zoekt" - "github.com/sourcegraph/zoekt/query" - "github.com/sourcegraph/zoekt/rpc/internal/srv" -) - -// DefaultRPCPath is the rpc path used by zoekt-webserver -const DefaultRPCPath = "/rpc" - -// Server returns an http.Handler for searcher which is the server side of the -// RPC calls. -func Server(searcher zoekt.Searcher) http.Handler { - RegisterGob() - server := rpc.NewServer() - if err := server.Register(&srv.Searcher{Searcher: searcher}); err != nil { - // this should never fail, so we panic. - panic("unexpected error registering rpc server: " + err.Error()) - } - return server -} - -// Client connects to a Searcher HTTP RPC server at address (host:port) using -// DefaultRPCPath path. -func Client(address string) zoekt.Searcher { - return ClientAtPath(address, DefaultRPCPath) -} - -// ClientAtPath connects to a Searcher HTTP RPC server at address and path -// (http://host:port/path). -func ClientAtPath(address, path string) zoekt.Searcher { - RegisterGob() - return &client{addr: address, path: path} -} - -type client struct { - addr, path string - - mu sync.Mutex // protects client and gen - cl *rpc.Client - gen int // incremented each time we dial -} - -func (c *client) Search(ctx context.Context, q query.Q, opts *zoekt.SearchOptions) (*zoekt.SearchResult, error) { - var reply srv.SearchReply - err := c.call(ctx, "Searcher.Search", &srv.SearchArgs{Q: q, Opts: opts}, &reply) - return reply.Result, err -} - -func (c *client) List(ctx context.Context, q query.Q, opts *zoekt.ListOptions) (*zoekt.RepoList, error) { - var reply srv.ListReply - err := c.call(ctx, "Searcher.List", &srv.ListArgs{Q: q, Opts: opts}, &reply) - return reply.List, err -} - -func (c *client) call(ctx context.Context, serviceMethod string, args interface{}, reply interface{}) error { - // We try twice. If we fail to dial or fail to call the function we try - // again after 100ms. Unrolled to make logic clear - cl, gen, err := c.getRPCClient(ctx, 0) - if err == nil { - err = cl.Call(ctx, serviceMethod, args, reply) - if err != rpc.ErrShutdown { - return err - } - } - - select { - case <-ctx.Done(): - return ctx.Err() - case <-time.After(100 * time.Millisecond): - } - - cl, _, err = c.getRPCClient(ctx, gen) - if err != nil { - return err - } - return cl.Call(ctx, serviceMethod, args, reply) -} - -// getRPCClient gets the rpc client. If gen matches the current generation, we -// redial and increment the generation. This is used to prevent concurrent -// redialing on network failure. -func (c *client) getRPCClient(ctx context.Context, gen int) (*rpc.Client, int, error) { - // coarse lock so we only dial once - c.mu.Lock() - defer c.mu.Unlock() - if gen != c.gen { - return c.cl, c.gen, nil - } - var timeout time.Duration - if deadline, ok := ctx.Deadline(); ok { - timeout = time.Until(deadline) - } - cl, err := rpc.DialHTTPPathTimeout("tcp", c.addr, c.path, timeout) - if err != nil { - return nil, c.gen, err - } - c.cl = cl - c.gen++ - return c.cl, c.gen, nil -} - -func (c *client) Close() { - c.mu.Lock() - defer c.mu.Unlock() - if c.cl != nil { - c.cl.Close() - } -} - -func (c *client) String() string { - return fmt.Sprintf("rpcSearcher(%s/%s)", c.addr, c.path) -} - -var once sync.Once - -// RegisterGob registers various query types with gob. It can be called more than -// once, because calls to gob.Register are protected by a sync.Once. -func RegisterGob() { - once.Do(func() { - gobRegister(&query.And{}) - gobRegister(&query.BranchRepos{}) - gobRegister(&query.BranchesRepos{}) - gobRegister(&query.Branch{}) - gobRegister(&query.Const{}) - gobRegister(&query.FileNameSet{}) - gobRegister(&query.GobCache{}) - gobRegister(&query.Language{}) - gobRegister(&query.Not{}) - gobRegister(&query.Or{}) - gobRegister(&query.Regexp{}) - gobRegister(&query.RepoRegexp{}) - gobRegister(&query.RepoSet{}) - gobRegister(&query.RepoIDs{}) - gobRegister(&query.Repo{}) - gobRegister(&query.Substring{}) - gobRegister(&query.Symbol{}) - gobRegister(&query.Type{}) - gobRegister(query.RawConfig(41)) - }) -} - -// gobRegister exists to keep backwards compatibility around renames of the go -// module. This is to avoid breaking the wire protocol due to refactors. In -// particular in August 2022 we renamed the go module from -// github.com/google/zoekt to github.com/sourcegraph/zoekt which breaks the -// wire protocol. So this function will replace those names so we keep using -// google/zoekt. -func gobRegister(value any) { - name := gobRegister_name(value) - - name = strings.Replace(name, "github.com/sourcegraph/", "github.com/google/", 1) - - gob.RegisterName(name, value) -} - -// gobRegister_name is copy-pasta from the stdlib gob.Register, returning the -// name it picks for gob.RegisterName. -func gobRegister_name(value any) string { - // Default to printed representation for unnamed types - rt := reflect.TypeOf(value) - name := rt.String() - - // But for named types (or pointers to them), qualify with import path (but see inner comment). - // Dereference one pointer looking for a named type. - star := "" - if rt.Name() == "" { - if pt := rt; pt.Kind() == reflect.Pointer { - star = "*" - // NOTE: The following line should be rt = pt.Elem() to implement - // what the comment above claims, but fixing it would break compatibility - // with existing gobs. - // - // Given package p imported as "full/p" with these definitions: - // package p - // type T1 struct { ... } - // this table shows the intended and actual strings used by gob to - // name the types: - // - // Type Correct string Actual string - // - // T1 full/p.T1 full/p.T1 - // *T1 *full/p.T1 *p.T1 - // - // The missing full path cannot be fixed without breaking existing gob decoders. - rt = pt - } - } - if rt.Name() != "" { - if rt.PkgPath() == "" { - name = star + rt.Name() - } else { - name = star + rt.PkgPath() + "." + rt.Name() - } - } - - return name -} diff --git a/rpc/rpc_test.go b/rpc/rpc_test.go deleted file mode 100644 index 2265321e7..000000000 --- a/rpc/rpc_test.go +++ /dev/null @@ -1,81 +0,0 @@ -package rpc_test - -import ( - "context" - "net/http/httptest" - "net/url" - "reflect" - "testing" - - "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" - "github.com/sourcegraph/zoekt" - "github.com/sourcegraph/zoekt/internal/mockSearcher" - "github.com/sourcegraph/zoekt/query" - "github.com/sourcegraph/zoekt/rpc" -) - -func TestClientServer(t *testing.T) { - mock := &mockSearcher.MockSearcher{ - WantSearch: query.NewAnd(mustParse("hello world|universe"), query.NewSingleBranchesRepos("HEAD", 1, 2)), - SearchResult: &zoekt.SearchResult{ - Files: []zoekt.FileMatch{ - {FileName: "bin.go"}, - }, - }, - - WantList: &query.Const{Value: true}, - RepoList: &zoekt.RepoList{ - Repos: []*zoekt.RepoListEntry{ - { - Repository: zoekt.Repository{ - ID: 2, - Name: "foo/bar", - }, - }, - }, - }, - } - - ts := httptest.NewServer(rpc.Server(mock)) - defer ts.Close() - - u, err := url.Parse(ts.URL) - if err != nil { - t.Fatal(err) - } - client := rpc.Client(u.Host) - defer client.Close() - - var cached query.Q = &query.GobCache{ - Q: mock.WantSearch, - } - - r, err := client.Search(context.Background(), cached, &zoekt.SearchOptions{}) - if err != nil { - t.Fatal(err) - } - if !reflect.DeepEqual(r, mock.SearchResult) { - t.Fatalf("got %+v, want %+v", r, mock.SearchResult) - } - - l, err := client.List(context.Background(), mock.WantList, nil) - if err != nil { - t.Fatal(err) - } - if d := cmp.Diff(mock.RepoList, l, cmpopts.IgnoreUnexported(zoekt.Repository{})); d != "" { - t.Fatalf("unexpected RepoList (-want, +got):\n%s", d) - } - - // Test closing a client we never dial. - noopClient := rpc.Client(u.Host) - noopClient.Close() -} - -func mustParse(s string) query.Q { - q, err := query.Parse(s) - if err != nil { - panic(err) - } - return q -} diff --git a/shards/aggregate.go b/shards/aggregate.go index e6faf5de1..a89d99593 100644 --- a/shards/aggregate.go +++ b/shards/aggregate.go @@ -9,7 +9,6 @@ import ( "github.com/prometheus/client_golang/prometheus/promauto" "github.com/sourcegraph/zoekt" - "github.com/sourcegraph/zoekt/stream" ) var metricFinalAggregateSize = promauto.NewHistogramVec(prometheus.HistogramOpts{ @@ -138,7 +137,7 @@ func newFlushCollectSender(opts *zoekt.SearchOptions, sender zoekt.Sender) (zoek stopCollectingAndFlush(zoekt.FlushReasonFinalFlush) } - return stream.SenderFunc(func(event *zoekt.SearchResult) { + return zoekt.SenderFunc(func(event *zoekt.SearchResult) { mu.Lock() if collectSender != nil { collectSender.Send(event) @@ -152,7 +151,7 @@ func newFlushCollectSender(opts *zoekt.SearchOptions, sender zoekt.Sender) (zoek // limitSender wraps a sender and calls cancel once the truncator has finished // truncating. func limitSender(cancel context.CancelFunc, sender zoekt.Sender, truncator zoekt.DisplayTruncator) zoekt.Sender { - return stream.SenderFunc(func(result *zoekt.SearchResult) { + return zoekt.SenderFunc(func(result *zoekt.SearchResult) { var hasMore bool result.Files, hasMore = truncator(result.Files) if !hasMore { @@ -163,7 +162,7 @@ func limitSender(cancel context.CancelFunc, sender zoekt.Sender, truncator zoekt } func copyFileSender(sender zoekt.Sender) zoekt.Sender { - return stream.SenderFunc(func(result *zoekt.SearchResult) { + return zoekt.SenderFunc(func(result *zoekt.SearchResult) { copyFiles(result) sender.Send(result) }) diff --git a/shards/eval.go b/shards/eval.go index 8b7a2da24..caa0e6440 100644 --- a/shards/eval.go +++ b/shards/eval.go @@ -5,7 +5,6 @@ import ( "github.com/sourcegraph/zoekt" "github.com/sourcegraph/zoekt/query" - "github.com/sourcegraph/zoekt/stream" "github.com/sourcegraph/zoekt/trace" ) @@ -59,7 +58,7 @@ func (s *typeRepoSearcher) StreamSearch(ctx context.Context, q query.Q, opts *zo return err } - return s.Streamer.StreamSearch(ctx, q, opts, stream.SenderFunc(func(event *zoekt.SearchResult) { + return s.Streamer.StreamSearch(ctx, q, opts, zoekt.SenderFunc(func(event *zoekt.SearchResult) { stats.Add(event.Stats) sender.Send(event) })) diff --git a/shards/shards_test.go b/shards/shards_test.go index e800dd511..5c4ddc735 100644 --- a/shards/shards_test.go +++ b/shards/shards_test.go @@ -37,7 +37,6 @@ import ( "github.com/sourcegraph/zoekt" "github.com/sourcegraph/zoekt/query" - "github.com/sourcegraph/zoekt/stream" ) type crashSearcher struct{} @@ -258,7 +257,7 @@ func TestShardedSearcher_DocumentRanking(t *testing.T) { } err := ss.StreamSearch(context.Background(), &query.Substring{Pattern: "foo"}, opts, - stream.SenderFunc(func(event *zoekt.SearchResult) { + zoekt.SenderFunc(func(event *zoekt.SearchResult) { results = append(results, event) })) if err != nil { @@ -1129,7 +1128,7 @@ func testShardedStreamSearch(t *testing.T, q query.Q, ib *zoekt.IndexBuilder, us ss.replace(map[string]zoekt.Searcher{"r1": searcher}) var files []zoekt.FileMatch - sender := stream.SenderFunc(func(result *zoekt.SearchResult) { + sender := zoekt.SenderFunc(func(result *zoekt.SearchResult) { files = append(files, result.Files...) }) diff --git a/stream/client.go b/stream/client.go deleted file mode 100644 index 4bc4f6dec..000000000 --- a/stream/client.go +++ /dev/null @@ -1,126 +0,0 @@ -package stream - -import ( - "bytes" - "context" - "encoding/gob" - "fmt" - "net/http" - - "github.com/sourcegraph/zoekt" - "github.com/sourcegraph/zoekt/query" -) - -// Doer implements the minimal surface of *http.Client and http.RoundTripper needed -// by Client. -type Doer interface { - Do(*http.Request) (*http.Response, error) -} - -// NewClient returns a client which implements StreamSearch. If httpClient is -// nil, http.DefaultClient is used. -func NewClient(address string, httpClient Doer) *Client { - registerGob() - if httpClient == nil { - httpClient = http.DefaultClient - } - return &Client{ - address: address, - httpClient: httpClient, - } -} - -// Client is an HTTP client for StreamSearch. Do not create directly, call -// NewClient. -type Client struct { - // HTTP address of zoekt-webserver. Will query against address + "/stream". - address string - - // httpClient when set is used instead of http.DefaultClient - httpClient Doer -} - -// SenderFunc is an adapter to allow the use of ordinary functions as Sender. -// If f is a function with the appropriate signature, SenderFunc(f) is a Sender -// that calls f. -type SenderFunc func(result *zoekt.SearchResult) - -func (f SenderFunc) Send(result *zoekt.SearchResult) { - f(result) -} - -// StreamSearch returns search results as stream by calling streamer.Send(event) -// for each event returned by the server. -// -// Error events returned by the server are returned as error. Context errors are -// recreated and returned on a best-efforts basis. -func (c *Client) StreamSearch(ctx context.Context, q query.Q, opts *zoekt.SearchOptions, streamer zoekt.Sender) error { - // Encode query and opts. - buf := new(bytes.Buffer) - args := &searchArgs{ - q, opts, - } - enc := gob.NewEncoder(buf) - err := enc.Encode(args) - if err != nil { - return fmt.Errorf("error during encoding: %w", err) - } - - // Send request. - req, err := http.NewRequestWithContext(ctx, "POST", c.address+DefaultSSEPath, buf) - if err != nil { - return err - } - req.Header.Set("Accept", "application/x-gob-stream") - req.Header.Set("Cache-Control", "no-cache") - req.Header.Set("Connection", "keep-alive") - req.Header.Set("Transfer-Encoding", "chunked") - - resp, err := c.httpClient.Do(req) - if err != nil { - return err - } - defer resp.Body.Close() - - dec := gob.NewDecoder(resp.Body) - for { - reply := &searchReply{} - err := dec.Decode(reply) - if err != nil { - return fmt.Errorf("error during decoding: %w", err) - } - switch reply.Event { - case eventMatches: - if res, ok := reply.Data.(*zoekt.SearchResult); ok { - streamer.Send(res) - } else { - return fmt.Errorf("event of type %s could not be converted to *zoekt.SearchResult", eventMatches.string()) - } - case eventError: - if errString, ok := reply.Data.(string); ok { - return fmt.Errorf("error received from zoekt: %s", errString) - } else { - return fmt.Errorf("data for event of type %s could not be converted to string", eventError.string()) - } - case eventDone: - return nil - default: - return fmt.Errorf("unknown event type") - } - } -} - -// WithSearcher returns Streamer composed of s and the streaming client. All -// non-streaming calls will go via s, while streaming calls will go via the -// streaming client. -func (c *Client) WithSearcher(s zoekt.Searcher) zoekt.Streamer { - return &streamer{ - Searcher: s, - Client: c, - } -} - -type streamer struct { - zoekt.Searcher - *Client -} diff --git a/stream/stream.go b/stream/stream.go deleted file mode 100644 index a05093a78..000000000 --- a/stream/stream.go +++ /dev/null @@ -1,209 +0,0 @@ -// Package stream provides a client and a server to consume search results as -// stream. -package stream - -import ( - "encoding/gob" - "errors" - "math" - "net/http" - "sync" - - "github.com/sourcegraph/zoekt" - "github.com/sourcegraph/zoekt/query" - "github.com/sourcegraph/zoekt/rpc" -) - -// DefaultSSEPath is the path used by zoekt-webserver. -const DefaultSSEPath = "/stream" - -type eventType int - -const ( - eventMatches eventType = iota - eventError - eventDone -) - -func (e eventType) string() string { - return []string{"eventMatches", "eventError", "eventDone"}[e] -} - -// Server returns an http.Handler which is the server side of StreamSearch. -func Server(searcher zoekt.Streamer) http.Handler { - registerGob() - return &handler{Searcher: searcher} -} - -type searchArgs struct { - Q query.Q - Opts *zoekt.SearchOptions -} - -type searchReply struct { - Event eventType - Data interface{} -} - -type handler struct { - Searcher zoekt.Streamer -} - -func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - - // Decode payload. - args := new(searchArgs) - err := gob.NewDecoder(r.Body).Decode(args) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - - args.Q = query.RPCUnwrap(args.Q) - - eventWriter, err := newEventStreamWriter(w) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - // Always send a done event in the end. - defer func() { - err = eventWriter.event(eventDone, nil) - if err != nil { - _ = eventWriter.event(eventError, err) - } - }() - - send := func(zsr *zoekt.SearchResult) { - err := eventWriter.event(eventMatches, zsr) - if err != nil { - _ = eventWriter.event(eventError, err) - return - } - } - - sampler := NewSamplingSender(SenderFunc(send)) - - err = h.Searcher.StreamSearch(ctx, args.Q, args.Opts, sampler) - - if err == nil { - sampler.Flush() - } - - if err != nil { - _ = eventWriter.event(eventError, err) - return - } -} - -type eventStreamWriter struct { - enc *gob.Encoder - flush func() -} - -func newEventStreamWriter(w http.ResponseWriter) (*eventStreamWriter, error) { - flusher, ok := w.(http.Flusher) - if !ok { - return nil, errors.New("http flushing not supported") - } - - w.Header().Set("Content-Type", "application/x-gob-stream") - w.Header().Set("Cache-Control", "no-cache") - w.Header().Set("Connection", "keep-alive") - w.Header().Set("Transfer-Encoding", "chunked") - - // This informs nginx to not buffer. With buffering search responses will - // be delayed until buffers get full, leading to worst case latency of the - // full time a search takes to complete. - w.Header().Set("X-Accel-Buffering", "no") - - return &eventStreamWriter{ - enc: gob.NewEncoder(w), - flush: flusher.Flush, - }, nil -} - -func (e *eventStreamWriter) event(event eventType, data interface{}) error { - // Because gob does not support serializing errors, we send error.Error() and - // recreate the error on the client-side. - if event == eventError { - if err, isError := data.(error); isError { - data = err.Error() - } - } - err := e.enc.Encode(searchReply{Event: event, Data: data}) - if err != nil { - return err - } - e.flush() - return nil -} - -var once sync.Once - -func registerGob() { - once.Do(func() { - gob.Register(&zoekt.SearchResult{}) - }) - rpc.RegisterGob() -} - -// NewSamplingSender is a zoekt.Sender that samples stats events -// to avoid sending many empty stats events over the wire. -func NewSamplingSender(next zoekt.Sender) *samplingSender { - return &samplingSender{ - next: next, - agg: zoekt.SearchResult{}, - aggCount: 0, - } -} - -type samplingSender struct { - next zoekt.Sender - agg zoekt.SearchResult - aggCount int -} - -func (s *samplingSender) Send(event *zoekt.SearchResult) { - // We don't want to send events over the wire if they don't contain file - // matches. Hence, in case we didn't find any results, we aggregate the stats - // and send them out in regular intervals. - if len(event.Files) == 0 { - s.aggCount++ - - s.agg.Stats.Add(event.Stats) - s.agg.Progress = event.Progress - - if s.aggCount%100 == 0 && !s.agg.Stats.Zero() { - s.next.Send(&s.agg) - s.agg = zoekt.SearchResult{} - } - - return - } - - // If we have aggregate stats, we merge them with the new event before sending - // it. We drop agg.Progress, because we assume that event.Progress reflects the - // latest status. - if !s.agg.Stats.Zero() { - event.Stats.Add(s.agg.Stats) - s.agg = zoekt.SearchResult{} - } - - s.next.Send(event) -} - -// Flush sends any aggregated stats that we haven't sent yet -func (s *samplingSender) Flush() { - if !s.agg.Stats.Zero() { - s.next.Send(&zoekt.SearchResult{ - Stats: s.agg.Stats, - Progress: zoekt.Progress{ - Priority: math.Inf(-1), - MaxPendingPriority: math.Inf(-1), - }, - }) - } -} diff --git a/stream/stream_test.go b/stream/stream_test.go deleted file mode 100644 index ec59c342d..000000000 --- a/stream/stream_test.go +++ /dev/null @@ -1,262 +0,0 @@ -package stream - -import ( - "bytes" - "context" - "encoding/gob" - "fmt" - "net/http" - "net/http/httptest" - "testing" - - "github.com/google/go-cmp/cmp" - - "github.com/sourcegraph/zoekt" - "github.com/sourcegraph/zoekt/internal/mockSearcher" - "github.com/sourcegraph/zoekt/query" -) - -func TestStreamSearch(t *testing.T) { - q := query.NewAnd(mustParse("hello world|universe"), query.NewRepoSet("foo/bar", "baz/bam")) - searcher := &mockSearcher.MockSearcher{ - WantSearch: q, - SearchResult: &zoekt.SearchResult{ - Files: []zoekt.FileMatch{ - {FileName: "bin.go"}, - }, - }, - } - - h := &handler{Searcher: adapter{searcher}} - - s := httptest.NewServer(h) - defer s.Close() - - cl := NewClient(s.URL, nil) - - c := make(chan *zoekt.SearchResult, 100) - - err := cl.StreamSearch(context.Background(), q, nil, streamerChan(c)) - if err != nil { - t.Fatal(err) - } - close(c) - - for res := range c { - if res.Files == nil { - continue - } - if res.Files[0].FileName != "bin.go" { - t.Errorf("got %s, wanted %s", res.Files[0].FileName, "bin.go") - } - } -} - -func TestStreamSearchJustStats(t *testing.T) { - wantStats := zoekt.Stats{ - Crashes: 1, - } - q := query.NewAnd(mustParse("hello world|universe"), query.NewRepoSet("foo/bar", "baz/bam")) - searcher := &mockSearcher.MockSearcher{ - WantSearch: q, - SearchResult: &zoekt.SearchResult{ - Files: []zoekt.FileMatch{}, - Stats: wantStats, - }, - } - - h := &handler{Searcher: adapter{searcher}} - - s := httptest.NewServer(h) - defer s.Close() - - cl := NewClient(s.URL, nil) - - c := make(chan *zoekt.SearchResult, 100) - - err := cl.StreamSearch(context.Background(), q, nil, streamerChan(c)) - if err != nil { - t.Fatal(err) - } - close(c) - - count := 0 - for res := range c { - count += 1 - if count > 1 { - t.Fatal("expected exactly 1 result, got at least 2") - } - if d := cmp.Diff(wantStats, res.Stats); d != "" { - t.Fatalf("zoekt.Stats mismatch (-want +got): %s\n", d) - } - } - if count != 1 { - t.Fatal("expected exactly 1 result, got 0") - } -} - -func TestEventStreamWriter(t *testing.T) { - registerGob() - network := new(bytes.Buffer) - enc := gob.NewEncoder(network) - dec := gob.NewDecoder(network) - - esw := eventStreamWriter{ - enc: enc, - flush: func() {}, - } - - tests := []struct { - event eventType - data interface{} - }{ - { - eventDone, - nil, - }, - { - eventMatches, - &zoekt.SearchResult{ - Files: []zoekt.FileMatch{ - {FileName: "bin.go"}, - }, - }, - }, - { - eventError, - "test error", - }, - } - - for _, tt := range tests { - t.Run(tt.event.string(), func(t *testing.T) { - err := esw.event(tt.event, tt.data) - if err != nil { - t.Fatal(err) - } - reply := new(searchReply) - err = dec.Decode(reply) - if err != nil { - t.Fatal(err) - } - if reply.Event != tt.event { - t.Fatalf("got %s, want %s", reply.Event.string(), tt.event.string()) - } - if d := cmp.Diff(tt.data, reply.Data); d != "" { - t.Fatalf("mismatch for event type %s (-want +got):\n%s", tt.event.string(), d) - } - }) - } -} - -func TestServerError(t *testing.T) { - serverError := fmt.Errorf("zoekt server error") - h := func(w http.ResponseWriter, r *http.Request) { - esw, err := newEventStreamWriter(w) - if err != nil { - t.Fatal(err) - } - err = esw.event(eventError, serverError) - if err != nil { - t.Fatal(err) - } - } - s := httptest.NewServer(http.HandlerFunc(h)) - cl := NewClient(s.URL, nil) - err := cl.StreamSearch(context.Background(), nil, nil, streamerChan(make(chan *zoekt.SearchResult))) - if err == nil { - t.Fatalf("got nil, want %s", serverError) - } -} - -func mustParse(s string) query.Q { - q, err := query.Parse(s) - if err != nil { - panic(err) - } - return q -} - -type streamerChan chan<- *zoekt.SearchResult - -func (c streamerChan) Send(result *zoekt.SearchResult) { - c <- result -} - -type adapter struct { - zoekt.Searcher -} - -func (a adapter) StreamSearch(ctx context.Context, q query.Q, opts *zoekt.SearchOptions, sender zoekt.Sender) (err error) { - sr, err := a.Searcher.Search(ctx, q, opts) - if err != nil { - return err - } - sender.Send(sr) - return nil -} - -func TestSamplingStream(t *testing.T) { - nonZeroStats := zoekt.Stats{ - ContentBytesLoaded: 10, - } - filesEvent := &zoekt.SearchResult{ - Files: make([]zoekt.FileMatch, 10), - Stats: nonZeroStats, - } - fileEvents := func(n int) []*zoekt.SearchResult { - res := make([]*zoekt.SearchResult, n) - for i := 0; i < n; i++ { - res[i] = filesEvent - } - return res - } - statsEvent := &zoekt.SearchResult{ - Stats: nonZeroStats, - } - statsEvents := func(n int) []*zoekt.SearchResult { - res := make([]*zoekt.SearchResult, n) - for i := 0; i < n; i++ { - res[i] = statsEvent - } - return res - } - cases := []struct { - events []*zoekt.SearchResult - beforeFlushCount int - afterFlushCount int - }{ - // These test cases assume that the sampler only forwards - // every 100 stats-only event. In case the sampling logic - // changes, these tests are not valuable. - {nil, 0, 0}, - {fileEvents(1), 1, 1}, - {fileEvents(2), 2, 2}, - {fileEvents(200), 200, 200}, - {append(fileEvents(1), statsEvents(1)...), 1, 2}, - {append(fileEvents(1), statsEvents(2)...), 1, 2}, - {append(fileEvents(1), statsEvents(99)...), 1, 2}, - {append(fileEvents(1), statsEvents(100)...), 2, 2}, - {statsEvents(500), 5, 5}, - {statsEvents(501), 5, 6}, - } - - for _, tc := range cases { - count := 0 - ss := NewSamplingSender(SenderFunc(func(*zoekt.SearchResult) { - count += 1 - })) - - for _, event := range tc.events { - ss.Send(event) - } - if count != tc.beforeFlushCount { - t.Fatalf("expected %d events, got %d", tc.beforeFlushCount, count) - } - ss.Flush() - - if count != tc.afterFlushCount { - t.Fatalf("expected %d events, got %d", tc.afterFlushCount, count) - } - } -} diff --git a/web/e2e_test.go b/web/e2e_test.go index e0be04a71..5cbb63d13 100644 --- a/web/e2e_test.go +++ b/web/e2e_test.go @@ -33,8 +33,6 @@ import ( "github.com/sourcegraph/zoekt" "github.com/sourcegraph/zoekt/query" - "github.com/sourcegraph/zoekt/rpc" - "github.com/sourcegraph/zoekt/stream" ) // TODO(hanwen): cut & paste from ../ . Should create internal test @@ -963,61 +961,6 @@ func TestHealthz(t *testing.T) { } } -func TestRPC(t *testing.T) { - b, err := zoekt.NewIndexBuilder(&zoekt.Repository{ - Name: "name", - URL: "repo-url", - CommitURLTemplate: "{{.Version}}", - FileURLTemplate: "file-url", - LineFragmentTemplate: "#line", - Branches: []zoekt.RepositoryBranch{{Name: "master", Version: "1234"}}, - }) - if err != nil { - t.Fatalf("NewIndexBuilder: %v", err) - } - if err := b.Add(zoekt.Document{ - Name: "f2", - Content: []byte("to carry water in the no later bla"), - // --------------0123456789012345678901234567890123 - // --------------0 1 2 3 - Branches: []string{"master"}, - }); err != nil { - t.Fatalf("Add: %v", err) - } - - s := searcherForTest(t, b) - srv := Server{ - Searcher: s, - RPC: true, - Top: Top, - } - - mux, err := NewMux(&srv) - if err != nil { - t.Fatalf("NewMux: %v", err) - } - - ts := httptest.NewServer(mux) - defer ts.Close() - - endpoint := ts.Listener.Addr().String() - - client := stream.NewClient("http://"+endpoint, nil).WithSearcher(rpc.Client(endpoint)) - - ctx := context.Background() - q := &query.Substring{Pattern: "water"} - opts := &zoekt.SearchOptions{ChunkMatches: true} - opts.SetDefaults() - results, err := client.Search(ctx, q, opts) - if err != nil { - t.Fatal(err) - } - - assertResults(t, results.Files, "f2: to carry water in the no later bla") - - // TODO grpc, List, StreamSearch -} - func assertResults(t *testing.T, files []zoekt.FileMatch, want string) { t.Helper() diff --git a/web/server.go b/web/server.go index 6476ca69b..75b631fd2 100644 --- a/web/server.go +++ b/web/server.go @@ -34,8 +34,6 @@ import ( "github.com/sourcegraph/zoekt" zjson "github.com/sourcegraph/zoekt/json" "github.com/sourcegraph/zoekt/query" - "github.com/sourcegraph/zoekt/rpc" - "github.com/sourcegraph/zoekt/stream" ) var Funcmap = template.FuncMap{ @@ -176,9 +174,7 @@ func NewMux(s *Server) (*http.ServeMux, error) { mux.HandleFunc("/print", s.servePrint) } if s.RPC { - mux.Handle(rpc.DefaultRPCPath, rpc.Server(traceAwareSearcher{s.Searcher})) // /rpc mux.Handle("/api/", http.StripPrefix("/api", zjson.JSONServer(traceAwareSearcher{s.Searcher}))) - mux.Handle(stream.DefaultSSEPath, stream.Server(traceAwareSearcher{s.Searcher})) // /stream } mux.HandleFunc("/healthz", s.serveHealthz)