diff --git a/go/vt/vtgate/grpcvtgateservice/server.go b/go/vt/vtgate/grpcvtgateservice/server.go index bf00db4ea1c..248e4d21d0b 100644 --- a/go/vt/vtgate/grpcvtgateservice/server.go +++ b/go/vt/vtgate/grpcvtgateservice/server.go @@ -19,6 +19,7 @@ package grpcvtgateservice import ( "context" + "sync" "github.com/spf13/pflag" "google.golang.org/grpc" @@ -176,6 +177,59 @@ func (vtg *VTGate) ExecuteBatch(ctx context.Context, request *vtgatepb.ExecuteBa }, nil } +// streamSender is used to send messages on a stream. It is required to ensure that all the Send messages are sent from the same go-routine +// to ensure that we don't break the gRPC contract. +type streamSender struct { + ch chan *streamResponseWrapper +} + +// streamResponseWrapper is a used to wrap the response and the error together in a single element to ensure we pair the response +// with the error message we receive on sending it. +type streamResponseWrapper struct { + resp *vtgatepb.StreamExecuteResponse + err error + wg sync.WaitGroup +} + +// newStreamSender creates a new streamSender. +func newStreamSender() *streamSender { + return &streamSender{ + ch: make(chan *streamResponseWrapper), + } +} + +// start spins the go routine meant to be used to send all the messages. +func (ss *streamSender) start(stream vtgateservicepb.Vitess_StreamExecuteServer) { + go func() { + // Keep reading from the channel until it has been closed. + for srw := range ss.ch { + // Send the response on the stream and mark the wait group completed, once the message has been sent. + srw.err = stream.Send(srw.resp) + srw.wg.Done() + } + }() +} + +// close closes the stream sender. +func (ss *streamSender) close() { + close(ss.ch) +} + +// sendMessage sends a message using the stream sender. +func (ss *streamSender) sendMessage(resp *vtgatepb.StreamExecuteResponse) error { + // create a new stream wrapper + sh := &streamResponseWrapper{ + resp: resp, + wg: sync.WaitGroup{}, + } + // Add to the wait group and send the message on the channel. + sh.wg.Add(1) + ss.ch <- sh + // We now wait for the message to be sent and the error field populated. + sh.wg.Wait() + return sh.err +} + // StreamExecute is the RPC version of vtgateservice.VTGateService method func (vtg *VTGate) StreamExecute(request *vtgatepb.StreamExecuteRequest, stream vtgateservicepb.Vitess_StreamExecuteServer) (err error) { defer vtg.server.HandlePanic(&err) @@ -187,12 +241,15 @@ func (vtg *VTGate) StreamExecute(request *vtgatepb.StreamExecuteRequest, stream session = &vtgatepb.Session{Autocommit: true} } + ss := newStreamSender() + ss.start(stream) + defer ss.close() + session, vtgErr := vtg.server.StreamExecute(ctx, nil, session, request.Query.Sql, request.Query.BindVariables, func(value *sqltypes.Result) error { - // Send is not safe to call concurrently, but vtgate - // guarantees that it's not. - return stream.Send(&vtgatepb.StreamExecuteResponse{ + resp := &vtgatepb.StreamExecuteResponse{ Result: sqltypes.ResultToProto3(value), - }) + } + return ss.sendMessage(resp) }) var errs []error @@ -203,7 +260,7 @@ func (vtg *VTGate) StreamExecute(request *vtgatepb.StreamExecuteRequest, stream if sendSessionInStreaming { // even if there is an error, session could have been modified. // So, this needs to be sent back to the client. Session is sent in the last stream response. - lastErr := stream.Send(&vtgatepb.StreamExecuteResponse{ + lastErr := ss.sendMessage(&vtgatepb.StreamExecuteResponse{ Session: session, }) if lastErr != nil { diff --git a/go/vt/vtgate/grpcvtgateservice/server_test.go b/go/vt/vtgate/grpcvtgateservice/server_test.go new file mode 100644 index 00000000000..357ca1f5d04 --- /dev/null +++ b/go/vt/vtgate/grpcvtgateservice/server_test.go @@ -0,0 +1,186 @@ +/* +Copyright 2024 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package grpcvtgateservice + +import ( + "context" + "fmt" + "sync" + "testing" + + "github.com/onsi/gomega/gleak/goroutine" + "github.com/stretchr/testify/require" + "google.golang.org/grpc/metadata" + + "vitess.io/vitess/go/sqltypes" + binlogdatapb "vitess.io/vitess/go/vt/proto/binlogdata" + querypb "vitess.io/vitess/go/vt/proto/query" + topodatapb "vitess.io/vitess/go/vt/proto/topodata" + vtgatepb "vitess.io/vitess/go/vt/proto/vtgate" + vtgateservicepb "vitess.io/vitess/go/vt/proto/vtgateservice" + vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/vtgate/vtgateservice" +) + +type mockVtgateService struct{} + +func (m *mockVtgateService) Execute(ctx context.Context, mysqlCtx vtgateservice.MySQLConnection, session *vtgatepb.Session, sql string, bindVariables map[string]*querypb.BindVariable) (*vtgatepb.Session, *sqltypes.Result, error) { + //TODO implement me + panic("implement me") +} + +func (m *mockVtgateService) ExecuteBatch(ctx context.Context, session *vtgatepb.Session, sqlList []string, bindVariablesList []map[string]*querypb.BindVariable) (*vtgatepb.Session, []sqltypes.QueryResponse, error) { + //TODO implement me + panic("implement me") +} + +// StreamExecute in mockVtgateService calls the callback from two different go routines. +func (m *mockVtgateService) StreamExecute(ctx context.Context, mysqlCtx vtgateservice.MySQLConnection, session *vtgatepb.Session, sql string, bindVariables map[string]*querypb.BindVariable, callback func(*sqltypes.Result) error) (*vtgatepb.Session, error) { + resOne := sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col", "int64|int64"), "1|1", "2|1") + resTwo := sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col", "int64|int64"), "1|1", "2|2") + var errOne, errTwo error + wg := sync.WaitGroup{} + wg.Add(2) + go func() { + errOne = callback(resOne) + wg.Done() + }() + go func() { + errTwo = callback(resTwo) + wg.Done() + }() + wg.Wait() + if errOne != nil { + return session, errOne + } + return session, errTwo +} + +func (m *mockVtgateService) Prepare(ctx context.Context, session *vtgatepb.Session, sql string, bindVariables map[string]*querypb.BindVariable) (*vtgatepb.Session, []*querypb.Field, error) { + //TODO implement me + panic("implement me") +} + +func (m *mockVtgateService) CloseSession(ctx context.Context, session *vtgatepb.Session) error { + //TODO implement me + panic("implement me") +} + +func (m *mockVtgateService) ResolveTransaction(ctx context.Context, dtid string) error { + //TODO implement me + panic("implement me") +} + +func (m *mockVtgateService) VStream(ctx context.Context, tabletType topodatapb.TabletType, vgtid *binlogdatapb.VGtid, filter *binlogdatapb.Filter, flags *vtgatepb.VStreamFlags, send func([]*binlogdatapb.VEvent) error) error { + //TODO implement me + panic("implement me") +} + +func (m *mockVtgateService) HandlePanic(err *error) {} + +var _ vtgateservice.VTGateService = (*mockVtgateService)(nil) + +type mockStreamExecuteServer struct { + mu sync.Mutex + goRoutineId uint64 +} + +// Send in mockStreamExecuteServer stores the go routine ID that it is called from. +// If Send is called from 2 different go-routines, then it throws an error. +func (m *mockStreamExecuteServer) Send(response *vtgatepb.StreamExecuteResponse) error { + m.mu.Lock() + defer m.mu.Unlock() + gr := goroutine.Current() + if m.goRoutineId == 0 { + m.goRoutineId = gr.ID + } + if gr.ID != m.goRoutineId { + return fmt.Errorf("two go routines are calling Send - %v and %v", gr.ID, m.goRoutineId) + } + return nil +} + +func (m *mockStreamExecuteServer) SetHeader(md metadata.MD) error { + //TODO implement me + panic("implement me") +} + +func (m *mockStreamExecuteServer) SendHeader(md metadata.MD) error { + //TODO implement me + panic("implement me") +} + +func (m *mockStreamExecuteServer) SetTrailer(md metadata.MD) { + //TODO implement me + panic("implement me") +} + +func (m *mockStreamExecuteServer) Context() context.Context { + return context.Background() +} + +func (m *mockStreamExecuteServer) SendMsg(msg any) error { + //TODO implement me + panic("implement me") +} + +func (m *mockStreamExecuteServer) RecvMsg(msg any) error { + //TODO implement me + panic("implement me") +} + +var _ vtgateservicepb.Vitess_StreamExecuteServer = (*mockStreamExecuteServer)(nil) + +// TestVTGateStreamExecuteConcurrency tests that calling StreamExecute with a mock executor that calls +// Send from 2 different go routines is safe. +func TestVTGateStreamExecuteConcurrency(t *testing.T) { + testcases := []struct { + name string + sendSession bool + }{ + { + name: "send session", + sendSession: true, + }, + { + name: "dont send session", + sendSession: false, + }, + } + for _, tt := range testcases { + t.Run(tt.name, func(t *testing.T) { + oldVal := sendSessionInStreaming + defer func() { + sendSessionInStreaming = oldVal + }() + sendSessionInStreaming = tt.sendSession + vtg := &VTGate{ + UnimplementedVitessServer: vtgateservicepb.UnimplementedVitessServer{}, + server: &mockVtgateService{}, + } + err := vtg.StreamExecute(&vtgatepb.StreamExecuteRequest{ + CallerId: &vtrpcpb.CallerID{}, + Query: &querypb.BoundQuery{}, + }, &mockStreamExecuteServer{ + mu: sync.Mutex{}, + goRoutineId: 0, + }) + require.NoError(t, err) + }) + } + +}