Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Go routine leaks in streaming calls #15293

Merged
merged 3 commits into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions go/vt/vttablet/grpctabletconn/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,10 @@ func (conn *gRPCQueryClient) BeginExecute(ctx context.Context, target *querypb.T

// BeginStreamExecute starts a transaction and runs an Execute.
func (conn *gRPCQueryClient) BeginStreamExecute(ctx context.Context, target *querypb.Target, preQueries []string, query string, bindVars map[string]*querypb.BindVariable, reservedID int64, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) (state queryservice.TransactionState, err error) {
// Please see comments in StreamExecute to see how this works.
ctx, cancel := context.WithCancel(ctx)
defer cancel()

conn.mu.RLock()
defer conn.mu.RUnlock()
if conn.cc == nil {
Expand Down Expand Up @@ -650,6 +654,9 @@ func (conn *gRPCQueryClient) StreamHealth(ctx context.Context, callback func(*qu

// VStream starts a VReplication stream.
func (conn *gRPCQueryClient) VStream(ctx context.Context, request *binlogdatapb.VStreamRequest, send func([]*binlogdatapb.VEvent) error) error {
// Please see comments in StreamExecute to see how this works.
ctx, cancel := context.WithCancel(ctx)
defer cancel()
stream, err := func() (queryservicepb.Query_VStreamClient, error) {
conn.mu.RLock()
defer conn.mu.RUnlock()
Expand Down Expand Up @@ -695,6 +702,9 @@ func (conn *gRPCQueryClient) VStream(ctx context.Context, request *binlogdatapb.

// VStreamRows streams rows of a query from the specified starting point.
func (conn *gRPCQueryClient) VStreamRows(ctx context.Context, request *binlogdatapb.VStreamRowsRequest, send func(*binlogdatapb.VStreamRowsResponse) error) error {
// Please see comments in StreamExecute to see how this works.
ctx, cancel := context.WithCancel(ctx)
defer cancel()
stream, err := func() (queryservicepb.Query_VStreamRowsClient, error) {
conn.mu.RLock()
defer conn.mu.RUnlock()
Expand Down Expand Up @@ -737,6 +747,9 @@ func (conn *gRPCQueryClient) VStreamRows(ctx context.Context, request *binlogdat

// VStreamTables streams rows of a query from the specified starting point.
func (conn *gRPCQueryClient) VStreamTables(ctx context.Context, request *binlogdatapb.VStreamTablesRequest, send func(*binlogdatapb.VStreamTablesResponse) error) error {
// Please see comments in StreamExecute to see how this works.
ctx, cancel := context.WithCancel(ctx)
defer cancel()
stream, err := func() (queryservicepb.Query_VStreamTablesClient, error) {
conn.mu.RLock()
defer conn.mu.RUnlock()
Expand Down Expand Up @@ -777,6 +790,9 @@ func (conn *gRPCQueryClient) VStreamTables(ctx context.Context, request *binlogd

// VStreamResults streams rows of a query from the specified starting point.
func (conn *gRPCQueryClient) VStreamResults(ctx context.Context, target *querypb.Target, query string, send func(*binlogdatapb.VStreamResultsResponse) error) error {
// Please see comments in StreamExecute to see how this works.
ctx, cancel := context.WithCancel(ctx)
defer cancel()
stream, err := func() (queryservicepb.Query_VStreamResultsClient, error) {
conn.mu.RLock()
defer conn.mu.RUnlock()
Expand Down Expand Up @@ -856,6 +872,9 @@ func (conn *gRPCQueryClient) ReserveBeginExecute(ctx context.Context, target *qu

// ReserveBeginStreamExecute implements the queryservice interface
func (conn *gRPCQueryClient) ReserveBeginStreamExecute(ctx context.Context, target *querypb.Target, preQueries []string, postBeginQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) (state queryservice.ReservedTransactionState, err error) {
// Please see comments in StreamExecute to see how this works.
ctx, cancel := context.WithCancel(ctx)
defer cancel()
conn.mu.RLock()
defer conn.mu.RUnlock()
if conn.cc == nil {
Expand Down Expand Up @@ -967,6 +986,9 @@ func (conn *gRPCQueryClient) ReserveExecute(ctx context.Context, target *querypb

// ReserveStreamExecute implements the queryservice interface
func (conn *gRPCQueryClient) ReserveStreamExecute(ctx context.Context, target *querypb.Target, preQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, transactionID int64, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) (state queryservice.ReservedState, err error) {
// Please see comments in StreamExecute to see how this works.
ctx, cancel := context.WithCancel(ctx)
defer cancel()
conn.mu.RLock()
defer conn.mu.RUnlock()
if conn.cc == nil {
Expand Down Expand Up @@ -1060,6 +1082,9 @@ func (conn *gRPCQueryClient) Release(ctx context.Context, target *querypb.Target

// GetSchema implements the queryservice interface
func (conn *gRPCQueryClient) GetSchema(ctx context.Context, target *querypb.Target, tableType querypb.SchemaTableType, tableNames []string, callback func(schemaRes *querypb.GetSchemaResponse) error) error {
// Please see comments in StreamExecute to see how this works.
ctx, cancel := context.WithCancel(ctx)
defer cancel()
conn.mu.RLock()
defer conn.mu.RUnlock()
if conn.cc == nil {
Expand Down
116 changes: 116 additions & 0 deletions go/vt/vttablet/grpctabletconn/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,21 @@ limitations under the License.
package grpctabletconn

import (
"context"
"fmt"
"io"
"net"
"os"
"sync"
"testing"

"github.com/stretchr/testify/require"
"google.golang.org/grpc"

"vitess.io/vitess/go/sqltypes"
binlogdatapb "vitess.io/vitess/go/vt/proto/binlogdata"
querypb "vitess.io/vitess/go/vt/proto/query"
queryservicepb "vitess.io/vitess/go/vt/proto/queryservice"
"vitess.io/vitess/go/vt/servenv"
"vitess.io/vitess/go/vt/vttablet/grpcqueryservice"
"vitess.io/vitess/go/vt/vttablet/tabletconntest"
Expand Down Expand Up @@ -113,3 +121,111 @@ func TestGRPCTabletAuthConn(t *testing.T) {
},
}, service, f)
}

// mockQueryClient is a mock query client that returns an error from Streaming calls,
// but only after storing the context that was passed to the RPC.
type mockQueryClient struct {
lastCallCtx context.Context
queryservicepb.QueryClient
}

func (m *mockQueryClient) StreamExecute(ctx context.Context, in *querypb.StreamExecuteRequest, opts ...grpc.CallOption) (queryservicepb.Query_StreamExecuteClient, error) {
m.lastCallCtx = ctx
return nil, fmt.Errorf("A general error")
}

func (m *mockQueryClient) BeginStreamExecute(ctx context.Context, in *querypb.BeginStreamExecuteRequest, opts ...grpc.CallOption) (queryservicepb.Query_BeginStreamExecuteClient, error) {
m.lastCallCtx = ctx
return nil, fmt.Errorf("A general error")
}

func (m *mockQueryClient) ReserveStreamExecute(ctx context.Context, in *querypb.ReserveStreamExecuteRequest, opts ...grpc.CallOption) (queryservicepb.Query_ReserveStreamExecuteClient, error) {
m.lastCallCtx = ctx
return nil, fmt.Errorf("A general error")
}

func (m *mockQueryClient) ReserveBeginStreamExecute(ctx context.Context, in *querypb.ReserveBeginStreamExecuteRequest, opts ...grpc.CallOption) (queryservicepb.Query_ReserveBeginStreamExecuteClient, error) {
m.lastCallCtx = ctx
return nil, fmt.Errorf("A general error")
}

func (m *mockQueryClient) VStream(ctx context.Context, in *binlogdatapb.VStreamRequest, opts ...grpc.CallOption) (queryservicepb.Query_VStreamClient, error) {
m.lastCallCtx = ctx
return nil, fmt.Errorf("A general error")
}

func (m *mockQueryClient) VStreamRows(ctx context.Context, in *binlogdatapb.VStreamRowsRequest, opts ...grpc.CallOption) (queryservicepb.Query_VStreamRowsClient, error) {
m.lastCallCtx = ctx
return nil, fmt.Errorf("A general error")
}

func (m *mockQueryClient) VStreamTables(ctx context.Context, in *binlogdatapb.VStreamTablesRequest, opts ...grpc.CallOption) (queryservicepb.Query_VStreamTablesClient, error) {
m.lastCallCtx = ctx
return nil, fmt.Errorf("A general error")
}

func (m *mockQueryClient) VStreamResults(ctx context.Context, in *binlogdatapb.VStreamResultsRequest, opts ...grpc.CallOption) (queryservicepb.Query_VStreamResultsClient, error) {
m.lastCallCtx = ctx
return nil, fmt.Errorf("A general error")
}

func (m *mockQueryClient) GetSchema(ctx context.Context, in *querypb.GetSchemaRequest, opts ...grpc.CallOption) (queryservicepb.Query_GetSchemaClient, error) {
m.lastCallCtx = ctx
return nil, fmt.Errorf("A general error")
}

var _ queryservicepb.QueryClient = (*mockQueryClient)(nil)

// TestGoRoutineLeakPrevention tests that after all the RPCs that stream queries, we end up closing the context that was passed to it, to prevent go routines from being leaked.
func TestGoRoutineLeakPrevention(t *testing.T) {
mqc := &mockQueryClient{}
qc := &gRPCQueryClient{
mu: sync.RWMutex{},
cc: &grpc.ClientConn{},
c: mqc,
}
_ = qc.StreamExecute(context.Background(), nil, "", nil, 0, 0, nil, func(result *sqltypes.Result) error {
return nil
})
require.Error(t, mqc.lastCallCtx.Err())

_, _ = qc.BeginStreamExecute(context.Background(), nil, nil, "", nil, 0, nil, func(result *sqltypes.Result) error {
return nil
})
require.Error(t, mqc.lastCallCtx.Err())

_, _ = qc.ReserveBeginStreamExecute(context.Background(), nil, nil, nil, "", nil, nil, func(result *sqltypes.Result) error {
return nil
})
require.Error(t, mqc.lastCallCtx.Err())

_, _ = qc.ReserveStreamExecute(context.Background(), nil, nil, "", nil, 0, nil, func(result *sqltypes.Result) error {
return nil
})
require.Error(t, mqc.lastCallCtx.Err())

_ = qc.VStream(context.Background(), &binlogdatapb.VStreamRequest{}, func(events []*binlogdatapb.VEvent) error {
return nil
})
require.Error(t, mqc.lastCallCtx.Err())

_ = qc.VStreamRows(context.Background(), &binlogdatapb.VStreamRowsRequest{}, func(response *binlogdatapb.VStreamRowsResponse) error {
return nil
})
require.Error(t, mqc.lastCallCtx.Err())

_ = qc.VStreamResults(context.Background(), nil, "", func(response *binlogdatapb.VStreamResultsResponse) error {
return nil
})
require.Error(t, mqc.lastCallCtx.Err())

_ = qc.VStreamTables(context.Background(), &binlogdatapb.VStreamTablesRequest{}, func(response *binlogdatapb.VStreamTablesResponse) error {
return nil
})
require.Error(t, mqc.lastCallCtx.Err())

_ = qc.GetSchema(context.Background(), nil, querypb.SchemaTableType_TABLES, nil, func(schemaRes *querypb.GetSchemaResponse) error {
return nil
})
require.Error(t, mqc.lastCallCtx.Err())
}
Loading