From 819fdc1b51e4f62655864f6012ca5a3b03328cc1 Mon Sep 17 00:00:00 2001 From: Manan Gupta Date: Tue, 20 Feb 2024 14:32:40 +0530 Subject: [PATCH 1/3] feat: create a cancellable context for StreamExecute calls to ensure there are no go-routine leaks Signed-off-by: Manan Gupta --- go/vt/vttablet/grpctabletconn/conn.go | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/go/vt/vttablet/grpctabletconn/conn.go b/go/vt/vttablet/grpctabletconn/conn.go index 775118aee73..c4b3e83b119 100644 --- a/go/vt/vttablet/grpctabletconn/conn.go +++ b/go/vt/vttablet/grpctabletconn/conn.go @@ -473,6 +473,10 @@ func (conn *gRPCQueryClient) BeginExecute(ctx context.Context, target *querypb.T // BeginStreamExecute starts a transaction and runs an Execute. func (conn *gRPCQueryClient) BeginStreamExecute(ctx context.Context, target *querypb.Target, preQueries []string, query string, bindVars map[string]*querypb.BindVariable, reservedID int64, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) (state queryservice.TransactionState, err error) { + // Please see comments in StreamExecute to see how this works. + ctx, cancel := context.WithCancel(ctx) + defer cancel() + conn.mu.RLock() defer conn.mu.RUnlock() if conn.cc == nil { @@ -856,6 +860,9 @@ func (conn *gRPCQueryClient) ReserveBeginExecute(ctx context.Context, target *qu // ReserveBeginStreamExecute implements the queryservice interface func (conn *gRPCQueryClient) ReserveBeginStreamExecute(ctx context.Context, target *querypb.Target, preQueries []string, postBeginQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) (state queryservice.ReservedTransactionState, err error) { + // Please see comments in StreamExecute to see how this works. + ctx, cancel := context.WithCancel(ctx) + defer cancel() conn.mu.RLock() defer conn.mu.RUnlock() if conn.cc == nil { @@ -967,6 +974,9 @@ func (conn *gRPCQueryClient) ReserveExecute(ctx context.Context, target *querypb // ReserveStreamExecute implements the queryservice interface func (conn *gRPCQueryClient) ReserveStreamExecute(ctx context.Context, target *querypb.Target, preQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, transactionID int64, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) (state queryservice.ReservedState, err error) { + // Please see comments in StreamExecute to see how this works. + ctx, cancel := context.WithCancel(ctx) + defer cancel() conn.mu.RLock() defer conn.mu.RUnlock() if conn.cc == nil { From dbd0ed0a2e3167d2e6063aa7a69007869a35f202 Mon Sep 17 00:00:00 2001 From: Manan Gupta Date: Tue, 20 Feb 2024 15:36:47 +0530 Subject: [PATCH 2/3] feat: add unit test for the changes Signed-off-by: Manan Gupta --- go/vt/vttablet/grpctabletconn/conn_test.go | 65 ++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/go/vt/vttablet/grpctabletconn/conn_test.go b/go/vt/vttablet/grpctabletconn/conn_test.go index fb182bfe2e4..0f8db0051d1 100644 --- a/go/vt/vttablet/grpctabletconn/conn_test.go +++ b/go/vt/vttablet/grpctabletconn/conn_test.go @@ -17,13 +17,20 @@ limitations under the License. package grpctabletconn import ( + "context" + "fmt" "io" "net" "os" + "sync" "testing" + "github.com/stretchr/testify/require" "google.golang.org/grpc" + "vitess.io/vitess/go/sqltypes" + querypb "vitess.io/vitess/go/vt/proto/query" + queryservicepb "vitess.io/vitess/go/vt/proto/queryservice" "vitess.io/vitess/go/vt/servenv" "vitess.io/vitess/go/vt/vttablet/grpcqueryservice" "vitess.io/vitess/go/vt/vttablet/tabletconntest" @@ -113,3 +120,61 @@ func TestGRPCTabletAuthConn(t *testing.T) { }, }, service, f) } + +// mockQueryClient is a mock query client that returns an error from Streaming calls, +// but only after storing the context that was passed to the RPC. +type mockQueryClient struct { + lastCallCtx context.Context + queryservicepb.QueryClient +} + +func (m *mockQueryClient) StreamExecute(ctx context.Context, in *querypb.StreamExecuteRequest, opts ...grpc.CallOption) (queryservicepb.Query_StreamExecuteClient, error) { + m.lastCallCtx = ctx + return nil, fmt.Errorf("A general error") +} + +func (m *mockQueryClient) BeginStreamExecute(ctx context.Context, in *querypb.BeginStreamExecuteRequest, opts ...grpc.CallOption) (queryservicepb.Query_BeginStreamExecuteClient, error) { + m.lastCallCtx = ctx + return nil, fmt.Errorf("A general error") +} + +func (m *mockQueryClient) ReserveStreamExecute(ctx context.Context, in *querypb.ReserveStreamExecuteRequest, opts ...grpc.CallOption) (queryservicepb.Query_ReserveStreamExecuteClient, error) { + m.lastCallCtx = ctx + return nil, fmt.Errorf("A general error") +} + +func (m *mockQueryClient) ReserveBeginStreamExecute(ctx context.Context, in *querypb.ReserveBeginStreamExecuteRequest, opts ...grpc.CallOption) (queryservicepb.Query_ReserveBeginStreamExecuteClient, error) { + m.lastCallCtx = ctx + return nil, fmt.Errorf("A general error") +} + +var _ queryservicepb.QueryClient = (*mockQueryClient)(nil) + +// TestGoRoutineLeakPrevention tests that after all the RPCs that stream queries, we end up closing the context that was passed to it, to prevent go routines from being leaked. +func TestGoRoutineLeakPrevention(t *testing.T) { + mqc := &mockQueryClient{} + qc := &gRPCQueryClient{ + mu: sync.RWMutex{}, + cc: &grpc.ClientConn{}, + c: mqc, + } + _ = qc.StreamExecute(context.Background(), nil, "", nil, 0, 0, nil, func(result *sqltypes.Result) error { + return nil + }) + require.Error(t, mqc.lastCallCtx.Err()) + + _, _ = qc.BeginStreamExecute(context.Background(), nil, nil, "", nil, 0, nil, func(result *sqltypes.Result) error { + return nil + }) + require.Error(t, mqc.lastCallCtx.Err()) + + _, _ = qc.ReserveBeginStreamExecute(context.Background(), nil, nil, nil, "", nil, nil, func(result *sqltypes.Result) error { + return nil + }) + require.Error(t, mqc.lastCallCtx.Err()) + + _, _ = qc.ReserveStreamExecute(context.Background(), nil, nil, "", nil, 0, nil, func(result *sqltypes.Result) error { + return nil + }) + require.Error(t, mqc.lastCallCtx.Err()) +} From 316de39481e13e46470cdf17de5ebe6ba206c942 Mon Sep 17 00:00:00 2001 From: Manan Gupta Date: Tue, 20 Feb 2024 16:17:26 +0530 Subject: [PATCH 3/3] feat: fix more RPCs Signed-off-by: Manan Gupta --- go/vt/vttablet/grpctabletconn/conn.go | 15 +++++++ go/vt/vttablet/grpctabletconn/conn_test.go | 51 ++++++++++++++++++++++ 2 files changed, 66 insertions(+) diff --git a/go/vt/vttablet/grpctabletconn/conn.go b/go/vt/vttablet/grpctabletconn/conn.go index c4b3e83b119..8bb8a466b21 100644 --- a/go/vt/vttablet/grpctabletconn/conn.go +++ b/go/vt/vttablet/grpctabletconn/conn.go @@ -654,6 +654,9 @@ func (conn *gRPCQueryClient) StreamHealth(ctx context.Context, callback func(*qu // VStream starts a VReplication stream. func (conn *gRPCQueryClient) VStream(ctx context.Context, request *binlogdatapb.VStreamRequest, send func([]*binlogdatapb.VEvent) error) error { + // Please see comments in StreamExecute to see how this works. + ctx, cancel := context.WithCancel(ctx) + defer cancel() stream, err := func() (queryservicepb.Query_VStreamClient, error) { conn.mu.RLock() defer conn.mu.RUnlock() @@ -699,6 +702,9 @@ func (conn *gRPCQueryClient) VStream(ctx context.Context, request *binlogdatapb. // VStreamRows streams rows of a query from the specified starting point. func (conn *gRPCQueryClient) VStreamRows(ctx context.Context, request *binlogdatapb.VStreamRowsRequest, send func(*binlogdatapb.VStreamRowsResponse) error) error { + // Please see comments in StreamExecute to see how this works. + ctx, cancel := context.WithCancel(ctx) + defer cancel() stream, err := func() (queryservicepb.Query_VStreamRowsClient, error) { conn.mu.RLock() defer conn.mu.RUnlock() @@ -741,6 +747,9 @@ func (conn *gRPCQueryClient) VStreamRows(ctx context.Context, request *binlogdat // VStreamTables streams rows of a query from the specified starting point. func (conn *gRPCQueryClient) VStreamTables(ctx context.Context, request *binlogdatapb.VStreamTablesRequest, send func(*binlogdatapb.VStreamTablesResponse) error) error { + // Please see comments in StreamExecute to see how this works. + ctx, cancel := context.WithCancel(ctx) + defer cancel() stream, err := func() (queryservicepb.Query_VStreamTablesClient, error) { conn.mu.RLock() defer conn.mu.RUnlock() @@ -781,6 +790,9 @@ func (conn *gRPCQueryClient) VStreamTables(ctx context.Context, request *binlogd // VStreamResults streams rows of a query from the specified starting point. func (conn *gRPCQueryClient) VStreamResults(ctx context.Context, target *querypb.Target, query string, send func(*binlogdatapb.VStreamResultsResponse) error) error { + // Please see comments in StreamExecute to see how this works. + ctx, cancel := context.WithCancel(ctx) + defer cancel() stream, err := func() (queryservicepb.Query_VStreamResultsClient, error) { conn.mu.RLock() defer conn.mu.RUnlock() @@ -1070,6 +1082,9 @@ func (conn *gRPCQueryClient) Release(ctx context.Context, target *querypb.Target // GetSchema implements the queryservice interface func (conn *gRPCQueryClient) GetSchema(ctx context.Context, target *querypb.Target, tableType querypb.SchemaTableType, tableNames []string, callback func(schemaRes *querypb.GetSchemaResponse) error) error { + // Please see comments in StreamExecute to see how this works. + ctx, cancel := context.WithCancel(ctx) + defer cancel() conn.mu.RLock() defer conn.mu.RUnlock() if conn.cc == nil { diff --git a/go/vt/vttablet/grpctabletconn/conn_test.go b/go/vt/vttablet/grpctabletconn/conn_test.go index 0f8db0051d1..70e30e337bc 100644 --- a/go/vt/vttablet/grpctabletconn/conn_test.go +++ b/go/vt/vttablet/grpctabletconn/conn_test.go @@ -29,6 +29,7 @@ import ( "google.golang.org/grpc" "vitess.io/vitess/go/sqltypes" + binlogdatapb "vitess.io/vitess/go/vt/proto/binlogdata" querypb "vitess.io/vitess/go/vt/proto/query" queryservicepb "vitess.io/vitess/go/vt/proto/queryservice" "vitess.io/vitess/go/vt/servenv" @@ -148,6 +149,31 @@ func (m *mockQueryClient) ReserveBeginStreamExecute(ctx context.Context, in *que return nil, fmt.Errorf("A general error") } +func (m *mockQueryClient) VStream(ctx context.Context, in *binlogdatapb.VStreamRequest, opts ...grpc.CallOption) (queryservicepb.Query_VStreamClient, error) { + m.lastCallCtx = ctx + return nil, fmt.Errorf("A general error") +} + +func (m *mockQueryClient) VStreamRows(ctx context.Context, in *binlogdatapb.VStreamRowsRequest, opts ...grpc.CallOption) (queryservicepb.Query_VStreamRowsClient, error) { + m.lastCallCtx = ctx + return nil, fmt.Errorf("A general error") +} + +func (m *mockQueryClient) VStreamTables(ctx context.Context, in *binlogdatapb.VStreamTablesRequest, opts ...grpc.CallOption) (queryservicepb.Query_VStreamTablesClient, error) { + m.lastCallCtx = ctx + return nil, fmt.Errorf("A general error") +} + +func (m *mockQueryClient) VStreamResults(ctx context.Context, in *binlogdatapb.VStreamResultsRequest, opts ...grpc.CallOption) (queryservicepb.Query_VStreamResultsClient, error) { + m.lastCallCtx = ctx + return nil, fmt.Errorf("A general error") +} + +func (m *mockQueryClient) GetSchema(ctx context.Context, in *querypb.GetSchemaRequest, opts ...grpc.CallOption) (queryservicepb.Query_GetSchemaClient, error) { + m.lastCallCtx = ctx + return nil, fmt.Errorf("A general error") +} + var _ queryservicepb.QueryClient = (*mockQueryClient)(nil) // TestGoRoutineLeakPrevention tests that after all the RPCs that stream queries, we end up closing the context that was passed to it, to prevent go routines from being leaked. @@ -177,4 +203,29 @@ func TestGoRoutineLeakPrevention(t *testing.T) { return nil }) require.Error(t, mqc.lastCallCtx.Err()) + + _ = qc.VStream(context.Background(), &binlogdatapb.VStreamRequest{}, func(events []*binlogdatapb.VEvent) error { + return nil + }) + require.Error(t, mqc.lastCallCtx.Err()) + + _ = qc.VStreamRows(context.Background(), &binlogdatapb.VStreamRowsRequest{}, func(response *binlogdatapb.VStreamRowsResponse) error { + return nil + }) + require.Error(t, mqc.lastCallCtx.Err()) + + _ = qc.VStreamResults(context.Background(), nil, "", func(response *binlogdatapb.VStreamResultsResponse) error { + return nil + }) + require.Error(t, mqc.lastCallCtx.Err()) + + _ = qc.VStreamTables(context.Background(), &binlogdatapb.VStreamTablesRequest{}, func(response *binlogdatapb.VStreamTablesResponse) error { + return nil + }) + require.Error(t, mqc.lastCallCtx.Err()) + + _ = qc.GetSchema(context.Background(), nil, querypb.SchemaTableType_TABLES, nil, func(schemaRes *querypb.GetSchemaResponse) error { + return nil + }) + require.Error(t, mqc.lastCallCtx.Err()) }