Skip to content

Commit

Permalink
Revert GRPC context changes
Browse files Browse the repository at this point in the history
These changes were made assuming we could use `grpc.NewClient`
everywhere. But we had failing test for that. In a specific case, we
replace it with `grpc.Dial` but that's not the same.

If we can't move it all to `grpc.NewClient` yet, we have to keep the
context setup we do as well, or otherwise `grpc.Dial` could hang
indefinitely even if the caller context is cancelled since it gets
started with the background context.

Signed-off-by: Dirkjan Bussink <[email protected]>
  • Loading branch information
dbussink committed Apr 23, 2024
1 parent 8844aba commit 4736a0e
Show file tree
Hide file tree
Showing 33 changed files with 94 additions and 71 deletions.
2 changes: 1 addition & 1 deletion examples/local/vstream_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func main() {
Filter: "select * from customer",
}},
}
conn, err := vtgateconn.Dial("localhost:15991")
conn, err := vtgateconn.Dial(ctx, "localhost:15991")
if err != nil {
log.Fatal(err)
}
Expand Down
2 changes: 1 addition & 1 deletion go/test/endtoend/cluster/cluster_process.go
Original file line number Diff line number Diff line change
Expand Up @@ -920,7 +920,7 @@ func (cluster *LocalProcessCluster) ExecOnVTGate(ctx context.Context, addr strin
return nil, err
}

conn, err := vtgateconn.Dial(addr)
conn, err := vtgateconn.Dial(ctx, addr)
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions go/test/endtoend/cluster/cluster_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -482,13 +482,13 @@ func WaitForHealthyShard(vtctldclient *VtctldClientProcess, keyspace, shard stri
}

// DialVTGate returns a VTGate grpc connection.
func DialVTGate(name, addr, username, password string) (*vtgateconn.VTGateConn, error) {
func DialVTGate(ctx context.Context, name, addr, username, password string) (*vtgateconn.VTGateConn, error) {
clientCreds := &grpcclient.StaticAuthClientCreds{Username: username, Password: password}
creds := grpc.WithPerRPCCredentials(clientCreds)
dialerFunc := grpcvtgateconn.Dial(creds)
dialerName := name
vtgateconn.RegisterDialer(dialerName, dialerFunc)
return vtgateconn.DialProtocol(dialerName, addr)
return vtgateconn.DialProtocol(ctx, dialerName, addr)
}

// PrintFiles prints the files that are asked for. If no file is specified, all the files are printed.
Expand Down
2 changes: 1 addition & 1 deletion go/test/endtoend/messaging/message_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,7 @@ func VtgateGrpcConn(ctx context.Context, cluster *cluster.LocalProcessCluster) (
stream := new(VTGateStream)
stream.ctx = ctx
stream.host = fmt.Sprintf("%s:%d", cluster.Hostname, cluster.VtgateProcess.GrpcPort)
conn, err := vtgateconn.Dial(stream.host)
conn, err := vtgateconn.Dial(ctx, stream.host)
// init components
stream.respChan = make(chan *sqltypes.Result)
stream.VTGateConn = conn
Expand Down
2 changes: 1 addition & 1 deletion go/test/endtoend/recovery/unshardedrecovery/recovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ func TestRecoveryImpl(t *testing.T) {

// Build vtgate grpc connection
grpcAddress := fmt.Sprintf("%s:%d", localCluster.Hostname, localCluster.VtgateGrpcPort)
vtgateConn, err := vtgateconn.Dial(grpcAddress)
vtgateConn, err := vtgateconn.Dial(context.Background(), grpcAddress)
assert.NoError(t, err)
defer vtgateConn.Close()
session := vtgateConn.Session("@replica", nil)
Expand Down
2 changes: 1 addition & 1 deletion go/test/endtoend/tabletgateway/vtgate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ func TestStreamingRPCStuck(t *testing.T) {
}

// Connect to vtgate and run a streaming query.
vtgateConn, err := cluster.DialVTGate(t.Name(), vtgateGrpcAddress, "test_user", "")
vtgateConn, err := cluster.DialVTGate(ctx, t.Name(), vtgateGrpcAddress, "test_user", "")
require.NoError(t, err)
stream, err := vtgateConn.Session("", &querypb.ExecuteOptions{}).StreamExecute(ctx, "select * from customer", map[string]*querypb.BindVariable{})
require.NoError(t, err)
Expand Down
2 changes: 1 addition & 1 deletion go/test/endtoend/vreplication/vreplication_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,7 @@ func testVStreamCellFlag(t *testing.T) {

for _, tc := range vstreamTestCases {
t.Run("VStreamCellsFlag/"+tc.cells, func(t *testing.T) {
conn, err := vtgateconn.Dial(fmt.Sprintf("localhost:%d", vc.ClusterConfig.vtgateGrpcPort))
conn, err := vtgateconn.Dial(ctx, fmt.Sprintf("localhost:%d", vc.ClusterConfig.vtgateGrpcPort))
require.NoError(t, err)
defer conn.Close()

Expand Down
2 changes: 1 addition & 1 deletion go/test/endtoend/vreplication/vschema_load_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func TestVSchemaChangesUnderLoad(t *testing.T) {
Filter: "select * from customer",
}},
}
conn, err := vtgateconn.Dial(net.JoinHostPort("localhost", strconv.Itoa(vc.ClusterConfig.vtgateGrpcPort)))
conn, err := vtgateconn.Dial(ctx, net.JoinHostPort("localhost", strconv.Itoa(vc.ClusterConfig.vtgateGrpcPort)))
require.NoError(t, err)
defer conn.Close()

Expand Down
8 changes: 4 additions & 4 deletions go/test/endtoend/vreplication/vstream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func testVStreamWithFailover(t *testing.T, failover bool) {
testVStreamFrom(t, vtgate, "product", 2)
})
ctx := context.Background()
vstreamConn, err := vtgateconn.Dial(fmt.Sprintf("%s:%d", vc.ClusterConfig.hostname, vc.ClusterConfig.vtgateGrpcPort))
vstreamConn, err := vtgateconn.Dial(ctx, fmt.Sprintf("%s:%d", vc.ClusterConfig.hostname, vc.ClusterConfig.vtgateGrpcPort))
if err != nil {
log.Fatal(err)
}
Expand Down Expand Up @@ -259,7 +259,7 @@ func testVStreamStopOnReshardFlag(t *testing.T, stopOnReshard bool, baseTabletID
vc.AddKeyspace(t, []*Cell{defaultCell}, "sharded", "-80,80-", vschemaSharded, schemaSharded, defaultReplicas, defaultRdonly, baseTabletID+200, nil)

ctx := context.Background()
vstreamConn, err := vtgateconn.Dial(fmt.Sprintf("%s:%d", vc.ClusterConfig.hostname, vc.ClusterConfig.vtgateGrpcPort))
vstreamConn, err := vtgateconn.Dial(ctx, fmt.Sprintf("%s:%d", vc.ClusterConfig.hostname, vc.ClusterConfig.vtgateGrpcPort))
if err != nil {
log.Fatal(err)
}
Expand Down Expand Up @@ -398,7 +398,7 @@ func testVStreamCopyMultiKeyspaceReshard(t *testing.T, baseTabletID int) numEven
require.NoError(t, err)

ctx := context.Background()
vstreamConn, err := vtgateconn.Dial(fmt.Sprintf("%s:%d", vc.ClusterConfig.hostname, vc.ClusterConfig.vtgateGrpcPort))
vstreamConn, err := vtgateconn.Dial(ctx, fmt.Sprintf("%s:%d", vc.ClusterConfig.hostname, vc.ClusterConfig.vtgateGrpcPort))
if err != nil {
log.Fatal(err)
}
Expand Down Expand Up @@ -550,7 +550,7 @@ func TestMultiVStreamsKeyspaceReshard(t *testing.T) {
defer vtgateConn.Close()
verifyClusterHealth(t, vc)

vstreamConn, err := vtgateconn.Dial(fmt.Sprintf("%s:%d", vc.ClusterConfig.hostname, vc.ClusterConfig.vtgateGrpcPort))
vstreamConn, err := vtgateconn.Dial(ctx, fmt.Sprintf("%s:%d", vc.ClusterConfig.hostname, vc.ClusterConfig.vtgateGrpcPort))
require.NoError(t, err)
defer vstreamConn.Close()

Expand Down
2 changes: 1 addition & 1 deletion go/test/endtoend/vtcombo/recreate/recreate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ func TestMain(m *testing.M) {

func TestDropAndRecreateWithSameShards(t *testing.T) {
ctx := context.Background()
conn, err := vtgateconn.Dial(grpcAddress)
conn, err := vtgateconn.Dial(ctx, grpcAddress)
require.Nil(t, err)
defer conn.Close()

Expand Down
2 changes: 1 addition & 1 deletion go/test/endtoend/vtcombo/vttest_sample_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ func TestStandalone(t *testing.T) {
require.Contains(t, tmp[0], "vtcombo")

ctx := context.Background()
conn, err := vtgateconn.Dial(grpcAddress)
conn, err := vtgateconn.Dial(ctx, grpcAddress)
require.NoError(t, err)
defer conn.Close()

Expand Down
2 changes: 1 addition & 1 deletion go/test/endtoend/vtgate/foreignkey/fk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ func TestUpdateWithFK(t *testing.T) {

// TestVstreamForFKBinLog tests that dml queries with fks are written with child row first approach in the binary logs.
func TestVstreamForFKBinLog(t *testing.T) {
vtgateConn, err := cluster.DialVTGate(t.Name(), vtgateGrpcAddress, "fk_user", "")
vtgateConn, err := cluster.DialVTGate(context.Background(), t.Name(), vtgateGrpcAddress, "fk_user", "")
require.NoError(t, err)
defer vtgateConn.Close()

Expand Down
10 changes: 5 additions & 5 deletions go/test/endtoend/vtgate/grpc_api/acl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func TestEffectiveCallerIDWithAccess(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

vtgateConn, err := cluster.DialVTGate(t.Name(), vtgateGrpcAddress, "some_other_user", "test_password")
vtgateConn, err := cluster.DialVTGate(ctx, t.Name(), vtgateGrpcAddress, "some_other_user", "test_password")
require.NoError(t, err)
defer vtgateConn.Close()

Expand All @@ -48,7 +48,7 @@ func TestEffectiveCallerIDWithNoAccess(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

vtgateConn, err := cluster.DialVTGate(t.Name(), vtgateGrpcAddress, "another_unrelated_user", "test_password")
vtgateConn, err := cluster.DialVTGate(ctx, t.Name(), vtgateGrpcAddress, "another_unrelated_user", "test_password")
require.NoError(t, err)
defer vtgateConn.Close()

Expand All @@ -66,7 +66,7 @@ func TestAuthenticatedUserWithAccess(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

vtgateConn, err := cluster.DialVTGate(t.Name(), vtgateGrpcAddress, "user_with_access", "test_password")
vtgateConn, err := cluster.DialVTGate(ctx, t.Name(), vtgateGrpcAddress, "user_with_access", "test_password")
require.NoError(t, err)
defer vtgateConn.Close()

Expand All @@ -81,7 +81,7 @@ func TestAuthenticatedUserNoAccess(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

vtgateConn, err := cluster.DialVTGate(t.Name(), vtgateGrpcAddress, "user_no_access", "test_password")
vtgateConn, err := cluster.DialVTGate(ctx, t.Name(), vtgateGrpcAddress, "user_no_access", "test_password")
require.NoError(t, err)
defer vtgateConn.Close()

Expand All @@ -98,7 +98,7 @@ func TestUnauthenticatedUser(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

vtgateConn, err := cluster.DialVTGate(t.Name(), vtgateGrpcAddress, "", "")
vtgateConn, err := cluster.DialVTGate(ctx, t.Name(), vtgateGrpcAddress, "", "")
require.NoError(t, err)
defer vtgateConn.Close()

Expand Down
2 changes: 1 addition & 1 deletion go/test/endtoend/vtgate/grpc_api/execute_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func TestTransactionsWithGRPCAPI(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

vtgateConn, err := cluster.DialVTGate(t.Name(), vtgateGrpcAddress, "user_with_access", "test_password")
vtgateConn, err := cluster.DialVTGate(ctx, t.Name(), vtgateGrpcAddress, "user_with_access", "test_password")
require.NoError(t, err)
defer vtgateConn.Close()

Expand Down
4 changes: 2 additions & 2 deletions go/test/endtoend/vtgate/queries/reference/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ func TestMain(m *testing.M) {
go func() {
ctx := context.Background()
vtgateAddr := fmt.Sprintf("%s:%d", clusterInstance.Hostname, clusterInstance.VtgateProcess.GrpcPort)
vtgateConn, err := vtgateconn.Dial(vtgateAddr)
vtgateConn, err := vtgateconn.Dial(ctx, vtgateAddr)
if err != nil {
done <- false
return
Expand Down Expand Up @@ -234,7 +234,7 @@ func TestMain(m *testing.M) {

ctx := context.Background()
vtgateAddr := fmt.Sprintf("%s:%d", clusterInstance.Hostname, clusterInstance.VtgateProcess.GrpcPort)
vtgateConn, err := vtgateconn.Dial(vtgateAddr)
vtgateConn, err := vtgateconn.Dial(ctx, vtgateAddr)
if err != nil {
return 1
}
Expand Down
13 changes: 12 additions & 1 deletion go/vt/grpcclient/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
package grpcclient

import (
"context"
"crypto/tls"
"sync"
"time"
Expand Down Expand Up @@ -96,6 +97,16 @@ func RegisterGRPCDialOptions(grpcDialOptionsFunc func(opts []grpc.DialOption) ([
// failFast is a non-optional parameter because callers are required to specify
// what that should be.
func Dial(target string, failFast FailFast, opts ...grpc.DialOption) (*grpc.ClientConn, error) {
return DialContext(context.Background(), target, failFast, opts...)
}

// DialContext creates a grpc connection to the given target. Setup steps are
// covered by the context deadline, and, if WithBlock is specified in the dial
// options, connection establishment steps are covered by the context as well.
//
// failFast is a non-optional parameter because callers are required to specify
// what that should be.
func DialContext(ctx context.Context, target string, failFast FailFast, opts ...grpc.DialOption) (*grpc.ClientConn, error) {
msgSize := grpccommon.MaxMessageSize()
newopts := []grpc.DialOption{
grpc.WithDefaultCallOptions(
Expand Down Expand Up @@ -138,7 +149,7 @@ func Dial(target string, failFast FailFast, opts ...grpc.DialOption) (*grpc.Clie

newopts = append(newopts, interceptors()...)

return grpc.Dial(target, newopts...)
return grpc.DialContext(ctx, target, newopts...)
}

func interceptors() []grpc.DialOption {
Expand Down
2 changes: 1 addition & 1 deletion go/vt/grpcoptionaltls/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func TestOptionalTLS(t *testing.T) {
testFunc := func(t *testing.T, dialOpt grpc.DialOption) {
ctx, cancel := context.WithTimeout(testCtx, 5*time.Second)
defer cancel()
conn, err := grpc.NewClient(addr, dialOpt)
conn, err := grpc.DialContext(ctx, addr, dialOpt)
if err != nil {
t.Fatalf("failed to connect to the server %v", err)
}
Expand Down
8 changes: 4 additions & 4 deletions go/vt/vitessdriver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,13 +174,13 @@ func (d drv) newConnector(cfg Configuration) (driver.Connector, error) {
}

// Connect implements the database/sql/driver.Connector interface.
func (c *connector) Connect(_ context.Context) (driver.Conn, error) {
func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
conn := &conn{
cfg: c.cfg,
convert: c.convert,
}

if err := conn.dial(); err != nil {
if err := conn.dial(ctx); err != nil {
return nil, err
}

Expand Down Expand Up @@ -267,9 +267,9 @@ type conn struct {
session *vtgateconn.VTGateSession
}

func (c *conn) dial() error {
func (c *conn) dial(ctx context.Context) error {
var err error
c.conn, err = vtgateconn.DialProtocol(c.cfg.Protocol, c.cfg.Address)
c.conn, err = vtgateconn.DialProtocol(ctx, c.cfg.Protocol, c.cfg.Address)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtadmin/grpcserver/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func TestServer(t *testing.T) {
}
close(readyCh)

conn, err := grpc.NewClient(lis.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock())
conn, err := grpc.Dial(lis.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock())
assert.NoError(t, err)

defer conn.Close()
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/endtoend/vstream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ import (
)

func initialize(ctx context.Context, t *testing.T) (*vtgateconn.VTGateConn, *mysql.Conn, *mysql.Conn, func()) {
gconn, err := vtgateconn.Dial(grpcAddress)
gconn, err := vtgateconn.Dial(ctx, grpcAddress)
if err != nil {
t.Fatal(err)
}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/fakerpcvtgateconn/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func RegisterFakeVTGateConnDialer() (*FakeVTGateConn, string) {
impl := &FakeVTGateConn{
execMap: make(map[string]*queryResponse),
}
vtgateconn.RegisterDialer(protocol, func(address string) (vtgateconn.Impl, error) {
vtgateconn.RegisterDialer(protocol, func(ctx context.Context, address string) (vtgateconn.Impl, error) {
return impl, nil
})
return impl, protocol
Expand Down
16 changes: 12 additions & 4 deletions go/vt/vtgate/grpcvtgateconn/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,21 +72,21 @@ type vtgateConn struct {
c vtgateservicepb.VitessClient
}

func dial(addr string) (vtgateconn.Impl, error) {
return Dial()(addr)
func dial(ctx context.Context, addr string) (vtgateconn.Impl, error) {
return Dial()(ctx, addr)
}

// Dial produces a vtgateconn.DialerFunc with custom options.
func Dial(opts ...grpc.DialOption) vtgateconn.DialerFunc {
return func(address string) (vtgateconn.Impl, error) {
return func(ctx context.Context, address string) (vtgateconn.Impl, error) {
opt, err := grpcclient.SecureDialOption(cert, key, ca, crl, name)
if err != nil {
return nil, err
}

opts = append(opts, opt)

cc, err := grpcclient.Dial(address, grpcclient.FailFast(false), opts...)
cc, err := grpcclient.DialContext(ctx, address, grpcclient.FailFast(false), opts...)
if err != nil {
return nil, err
}
Expand All @@ -99,6 +99,14 @@ func Dial(opts ...grpc.DialOption) vtgateconn.DialerFunc {
}
}

// DialWithOpts allows for custom dial options to be set on a vtgateConn.
//
// Deprecated: the context parameter cannot be used by the returned
// vtgateconn.DialerFunc and thus has no effect. Use Dial instead.
func DialWithOpts(_ context.Context, opts ...grpc.DialOption) vtgateconn.DialerFunc {
return Dial(opts...)
}

func (conn *vtgateConn) Execute(ctx context.Context, session *vtgatepb.Session, query string, bindVars map[string]*querypb.BindVariable) (*vtgatepb.Session, *sqltypes.Result, error) {
request := &vtgatepb.ExecuteRequest{
CallerId: callerid.EffectiveCallerIDFromContext(ctx),
Expand Down
Loading

0 comments on commit 4736a0e

Please sign in to comment.