Skip to content

Commit

Permalink
Fix Go routine leaks in streaming calls (#15293)
Browse files Browse the repository at this point in the history
Signed-off-by: Manan Gupta <[email protected]>
  • Loading branch information
GuptaManan100 authored Feb 20, 2024
1 parent f1a95e1 commit c1a176c
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 0 deletions.
25 changes: 25 additions & 0 deletions go/vt/vttablet/grpctabletconn/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,10 @@ func (conn *gRPCQueryClient) BeginExecute(ctx context.Context, target *querypb.T

// BeginStreamExecute starts a transaction and runs an Execute.
func (conn *gRPCQueryClient) BeginStreamExecute(ctx context.Context, target *querypb.Target, preQueries []string, query string, bindVars map[string]*querypb.BindVariable, reservedID int64, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) (state queryservice.TransactionState, err error) {
// Please see comments in StreamExecute to see how this works.
ctx, cancel := context.WithCancel(ctx)
defer cancel()

conn.mu.RLock()
defer conn.mu.RUnlock()
if conn.cc == nil {
Expand Down Expand Up @@ -650,6 +654,9 @@ func (conn *gRPCQueryClient) StreamHealth(ctx context.Context, callback func(*qu

// VStream starts a VReplication stream.
func (conn *gRPCQueryClient) VStream(ctx context.Context, request *binlogdatapb.VStreamRequest, send func([]*binlogdatapb.VEvent) error) error {
// Please see comments in StreamExecute to see how this works.
ctx, cancel := context.WithCancel(ctx)
defer cancel()
stream, err := func() (queryservicepb.Query_VStreamClient, error) {
conn.mu.RLock()
defer conn.mu.RUnlock()
Expand Down Expand Up @@ -695,6 +702,9 @@ func (conn *gRPCQueryClient) VStream(ctx context.Context, request *binlogdatapb.

// VStreamRows streams rows of a query from the specified starting point.
func (conn *gRPCQueryClient) VStreamRows(ctx context.Context, request *binlogdatapb.VStreamRowsRequest, send func(*binlogdatapb.VStreamRowsResponse) error) error {
// Please see comments in StreamExecute to see how this works.
ctx, cancel := context.WithCancel(ctx)
defer cancel()
stream, err := func() (queryservicepb.Query_VStreamRowsClient, error) {
conn.mu.RLock()
defer conn.mu.RUnlock()
Expand Down Expand Up @@ -737,6 +747,9 @@ func (conn *gRPCQueryClient) VStreamRows(ctx context.Context, request *binlogdat

// VStreamTables streams rows of a query from the specified starting point.
func (conn *gRPCQueryClient) VStreamTables(ctx context.Context, request *binlogdatapb.VStreamTablesRequest, send func(*binlogdatapb.VStreamTablesResponse) error) error {
// Please see comments in StreamExecute to see how this works.
ctx, cancel := context.WithCancel(ctx)
defer cancel()
stream, err := func() (queryservicepb.Query_VStreamTablesClient, error) {
conn.mu.RLock()
defer conn.mu.RUnlock()
Expand Down Expand Up @@ -777,6 +790,9 @@ func (conn *gRPCQueryClient) VStreamTables(ctx context.Context, request *binlogd

// VStreamResults streams rows of a query from the specified starting point.
func (conn *gRPCQueryClient) VStreamResults(ctx context.Context, target *querypb.Target, query string, send func(*binlogdatapb.VStreamResultsResponse) error) error {
// Please see comments in StreamExecute to see how this works.
ctx, cancel := context.WithCancel(ctx)
defer cancel()
stream, err := func() (queryservicepb.Query_VStreamResultsClient, error) {
conn.mu.RLock()
defer conn.mu.RUnlock()
Expand Down Expand Up @@ -856,6 +872,9 @@ func (conn *gRPCQueryClient) ReserveBeginExecute(ctx context.Context, target *qu

// ReserveBeginStreamExecute implements the queryservice interface
func (conn *gRPCQueryClient) ReserveBeginStreamExecute(ctx context.Context, target *querypb.Target, preQueries []string, postBeginQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) (state queryservice.ReservedTransactionState, err error) {
// Please see comments in StreamExecute to see how this works.
ctx, cancel := context.WithCancel(ctx)
defer cancel()
conn.mu.RLock()
defer conn.mu.RUnlock()
if conn.cc == nil {
Expand Down Expand Up @@ -967,6 +986,9 @@ func (conn *gRPCQueryClient) ReserveExecute(ctx context.Context, target *querypb

// ReserveStreamExecute implements the queryservice interface
func (conn *gRPCQueryClient) ReserveStreamExecute(ctx context.Context, target *querypb.Target, preQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, transactionID int64, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) (state queryservice.ReservedState, err error) {
// Please see comments in StreamExecute to see how this works.
ctx, cancel := context.WithCancel(ctx)
defer cancel()
conn.mu.RLock()
defer conn.mu.RUnlock()
if conn.cc == nil {
Expand Down Expand Up @@ -1060,6 +1082,9 @@ func (conn *gRPCQueryClient) Release(ctx context.Context, target *querypb.Target

// GetSchema implements the queryservice interface
func (conn *gRPCQueryClient) GetSchema(ctx context.Context, target *querypb.Target, tableType querypb.SchemaTableType, tableNames []string, callback func(schemaRes *querypb.GetSchemaResponse) error) error {
// Please see comments in StreamExecute to see how this works.
ctx, cancel := context.WithCancel(ctx)
defer cancel()
conn.mu.RLock()
defer conn.mu.RUnlock()
if conn.cc == nil {
Expand Down
116 changes: 116 additions & 0 deletions go/vt/vttablet/grpctabletconn/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,21 @@ limitations under the License.
package grpctabletconn

import (
"context"
"fmt"
"io"
"net"
"os"
"sync"
"testing"

"github.com/stretchr/testify/require"
"google.golang.org/grpc"

"vitess.io/vitess/go/sqltypes"
binlogdatapb "vitess.io/vitess/go/vt/proto/binlogdata"
querypb "vitess.io/vitess/go/vt/proto/query"
queryservicepb "vitess.io/vitess/go/vt/proto/queryservice"
"vitess.io/vitess/go/vt/servenv"
"vitess.io/vitess/go/vt/vttablet/grpcqueryservice"
"vitess.io/vitess/go/vt/vttablet/tabletconntest"
Expand Down Expand Up @@ -113,3 +121,111 @@ func TestGRPCTabletAuthConn(t *testing.T) {
},
}, service, f)
}

// mockQueryClient is a mock query client that returns an error from Streaming calls,
// but only after storing the context that was passed to the RPC.
type mockQueryClient struct {
lastCallCtx context.Context
queryservicepb.QueryClient
}

func (m *mockQueryClient) StreamExecute(ctx context.Context, in *querypb.StreamExecuteRequest, opts ...grpc.CallOption) (queryservicepb.Query_StreamExecuteClient, error) {
m.lastCallCtx = ctx
return nil, fmt.Errorf("A general error")
}

func (m *mockQueryClient) BeginStreamExecute(ctx context.Context, in *querypb.BeginStreamExecuteRequest, opts ...grpc.CallOption) (queryservicepb.Query_BeginStreamExecuteClient, error) {
m.lastCallCtx = ctx
return nil, fmt.Errorf("A general error")
}

func (m *mockQueryClient) ReserveStreamExecute(ctx context.Context, in *querypb.ReserveStreamExecuteRequest, opts ...grpc.CallOption) (queryservicepb.Query_ReserveStreamExecuteClient, error) {
m.lastCallCtx = ctx
return nil, fmt.Errorf("A general error")
}

func (m *mockQueryClient) ReserveBeginStreamExecute(ctx context.Context, in *querypb.ReserveBeginStreamExecuteRequest, opts ...grpc.CallOption) (queryservicepb.Query_ReserveBeginStreamExecuteClient, error) {
m.lastCallCtx = ctx
return nil, fmt.Errorf("A general error")
}

func (m *mockQueryClient) VStream(ctx context.Context, in *binlogdatapb.VStreamRequest, opts ...grpc.CallOption) (queryservicepb.Query_VStreamClient, error) {
m.lastCallCtx = ctx
return nil, fmt.Errorf("A general error")
}

func (m *mockQueryClient) VStreamRows(ctx context.Context, in *binlogdatapb.VStreamRowsRequest, opts ...grpc.CallOption) (queryservicepb.Query_VStreamRowsClient, error) {
m.lastCallCtx = ctx
return nil, fmt.Errorf("A general error")
}

func (m *mockQueryClient) VStreamTables(ctx context.Context, in *binlogdatapb.VStreamTablesRequest, opts ...grpc.CallOption) (queryservicepb.Query_VStreamTablesClient, error) {
m.lastCallCtx = ctx
return nil, fmt.Errorf("A general error")
}

func (m *mockQueryClient) VStreamResults(ctx context.Context, in *binlogdatapb.VStreamResultsRequest, opts ...grpc.CallOption) (queryservicepb.Query_VStreamResultsClient, error) {
m.lastCallCtx = ctx
return nil, fmt.Errorf("A general error")
}

func (m *mockQueryClient) GetSchema(ctx context.Context, in *querypb.GetSchemaRequest, opts ...grpc.CallOption) (queryservicepb.Query_GetSchemaClient, error) {
m.lastCallCtx = ctx
return nil, fmt.Errorf("A general error")
}

var _ queryservicepb.QueryClient = (*mockQueryClient)(nil)

// TestGoRoutineLeakPrevention tests that after all the RPCs that stream queries, we end up closing the context that was passed to it, to prevent go routines from being leaked.
func TestGoRoutineLeakPrevention(t *testing.T) {
mqc := &mockQueryClient{}
qc := &gRPCQueryClient{
mu: sync.RWMutex{},
cc: &grpc.ClientConn{},
c: mqc,
}
_ = qc.StreamExecute(context.Background(), nil, "", nil, 0, 0, nil, func(result *sqltypes.Result) error {
return nil
})
require.Error(t, mqc.lastCallCtx.Err())

_, _ = qc.BeginStreamExecute(context.Background(), nil, nil, "", nil, 0, nil, func(result *sqltypes.Result) error {
return nil
})
require.Error(t, mqc.lastCallCtx.Err())

_, _ = qc.ReserveBeginStreamExecute(context.Background(), nil, nil, nil, "", nil, nil, func(result *sqltypes.Result) error {
return nil
})
require.Error(t, mqc.lastCallCtx.Err())

_, _ = qc.ReserveStreamExecute(context.Background(), nil, nil, "", nil, 0, nil, func(result *sqltypes.Result) error {
return nil
})
require.Error(t, mqc.lastCallCtx.Err())

_ = qc.VStream(context.Background(), &binlogdatapb.VStreamRequest{}, func(events []*binlogdatapb.VEvent) error {
return nil
})
require.Error(t, mqc.lastCallCtx.Err())

_ = qc.VStreamRows(context.Background(), &binlogdatapb.VStreamRowsRequest{}, func(response *binlogdatapb.VStreamRowsResponse) error {
return nil
})
require.Error(t, mqc.lastCallCtx.Err())

_ = qc.VStreamResults(context.Background(), nil, "", func(response *binlogdatapb.VStreamResultsResponse) error {
return nil
})
require.Error(t, mqc.lastCallCtx.Err())

_ = qc.VStreamTables(context.Background(), &binlogdatapb.VStreamTablesRequest{}, func(response *binlogdatapb.VStreamTablesResponse) error {
return nil
})
require.Error(t, mqc.lastCallCtx.Err())

_ = qc.GetSchema(context.Background(), nil, querypb.SchemaTableType_TABLES, nil, func(schemaRes *querypb.GetSchemaResponse) error {
return nil
})
require.Error(t, mqc.lastCallCtx.Err())
}

0 comments on commit c1a176c

Please sign in to comment.