From b35593cd2785abd131dfa04709c03fdf1acf1fab Mon Sep 17 00:00:00 2001 From: Matt Lord Date: Mon, 21 Oct 2024 13:53:16 -0400 Subject: [PATCH 1/5] Setup new fake server if it has gone away And plug unit test goroutine leaks Signed-off-by: Matt Lord --- go/mysql/auth_server_clientcert_test.go | 70 ++-- go/mysql/auth_server_static_test.go | 24 +- go/mysql/handshake_test.go | 39 +- go/mysql/replication_test.go | 3 + go/mysql/server_test.go | 470 ++++++++++++++---------- 5 files changed, 357 insertions(+), 249 deletions(-) diff --git a/go/mysql/auth_server_clientcert_test.go b/go/mysql/auth_server_clientcert_test.go index eff92053d94..95f7125a6b3 100644 --- a/go/mysql/auth_server_clientcert_test.go +++ b/go/mysql/auth_server_clientcert_test.go @@ -17,7 +17,6 @@ limitations under the License. package mysql import ( - "context" "crypto/tls" "net" "path" @@ -27,6 +26,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "vitess.io/vitess/go/test/utils" "vitess.io/vitess/go/vt/tlstest" "vitess.io/vitess/go/vt/vttls" ) @@ -34,6 +34,7 @@ import ( const clientCertUsername = "Client Cert" func TestValidCert(t *testing.T) { + ctx := utils.LeakCheckContext(t) th := &testHandler{} authServer := newAuthServerClientCert(string(MysqlClearPassword)) @@ -52,21 +53,6 @@ func TestValidCert(t *testing.T) { tlstest.CreateSignedCert(root, tlstest.CA, "02", "client", clientCertUsername) tlstest.CreateCRL(root, tlstest.CA) - // Create the server with TLS config. - serverConfig, err := vttls.ServerConfig( - path.Join(root, "server-cert.pem"), - path.Join(root, "server-key.pem"), - path.Join(root, "ca-cert.pem"), - path.Join(root, "ca-crl.pem"), - "", - tls.VersionTLS12) - require.NoError(t, err, "TLSServerConfig failed: %v", err) - - l.TLSConfig.Store(serverConfig) - go func() { - l.Accept() - }() - // Setup the right parameters. params := &ConnParams{ Host: host, @@ -81,7 +67,25 @@ func TestValidCert(t *testing.T) { ServerName: "server.example.com", } - ctx := context.Background() + // Create the server with TLS config. + serverConfig, err := vttls.ServerConfig( + path.Join(root, "server-cert.pem"), + path.Join(root, "server-key.pem"), + path.Join(root, "ca-cert.pem"), + path.Join(root, "ca-crl.pem"), + "", + tls.VersionTLS12) + require.NoError(t, err, "TLSServerConfig failed: %v", err) + + l.TLSConfig.Store(serverConfig) + go l.Accept() + defer func() { + l.Close() + // The accept loop actually only ends on a connection error, which will + // occur when trying to connect after the listener has been closed. + _, _ = Connect(ctx, params) + }() + conn, err := Connect(ctx, params) require.NoError(t, err, "Connect failed: %v", err) @@ -103,6 +107,7 @@ func TestValidCert(t *testing.T) { } func TestNoCert(t *testing.T) { + ctx := utils.LeakCheckContext(t) th := &testHandler{} authServer := newAuthServerClientCert(string(MysqlClearPassword)) @@ -120,6 +125,17 @@ func TestNoCert(t *testing.T) { tlstest.CreateSignedCert(root, tlstest.CA, "01", "server", "server.example.com") tlstest.CreateCRL(root, tlstest.CA) + // Setup the right parameters. + params := &ConnParams{ + Host: host, + Port: port, + Uname: "user1", + Pass: "", + SslMode: vttls.VerifyIdentity, + SslCa: path.Join(root, "ca-cert.pem"), + ServerName: "server.example.com", + } + // Create the server with TLS config. serverConfig, err := vttls.ServerConfig( path.Join(root, "server-cert.pem"), @@ -131,22 +147,14 @@ func TestNoCert(t *testing.T) { require.NoError(t, err, "TLSServerConfig failed: %v", err) l.TLSConfig.Store(serverConfig) - go func() { - l.Accept() + go l.Accept() + defer func() { + l.Close() + // The accept loop actually only ends on a connection error, which will + // occur when trying to connect after the listener has been closed. + _, _ = Connect(ctx, params) }() - // Setup the right parameters. - params := &ConnParams{ - Host: host, - Port: port, - Uname: "user1", - Pass: "", - SslMode: vttls.VerifyIdentity, - SslCa: path.Join(root, "ca-cert.pem"), - ServerName: "server.example.com", - } - - ctx := context.Background() conn, err := Connect(ctx, params) assert.Error(t, err, "Connect() should have errored due to no client cert") diff --git a/go/mysql/auth_server_static_test.go b/go/mysql/auth_server_static_test.go index 12ae74e0d60..a5957cccfed 100644 --- a/go/mysql/auth_server_static_test.go +++ b/go/mysql/auth_server_static_test.go @@ -25,6 +25,8 @@ import ( "time" "github.com/stretchr/testify/require" + + "vitess.io/vitess/go/test/utils" ) // getEntries is a test-only method for AuthServerStatic. @@ -105,9 +107,9 @@ func TestHostMatcher(t *testing.T) { } func TestStaticConfigHUP(t *testing.T) { + _ = utils.LeakCheckContext(t) tmpFile, err := os.CreateTemp("", "mysql_auth_server_static_file.json") require.NoError(t, err, "couldn't create temp file: %v", err) - defer os.Remove(tmpFile.Name()) oldStr := "str5" @@ -125,14 +127,18 @@ func TestStaticConfigHUP(t *testing.T) { mu.Lock() defer mu.Unlock() - // delete registered Auth server - clear(authServers) + // Delete registered Auth servers. + for k, v := range authServers { + if s, ok := v.(*AuthServerStatic); ok { + s.close() + } + delete(authServers, k) + } } func TestStaticConfigHUPWithRotation(t *testing.T) { tmpFile, err := os.CreateTemp("", "mysql_auth_server_static_file.json") require.NoError(t, err, "couldn't create temp file: %v", err) - defer os.Remove(tmpFile.Name()) oldStr := "str1" @@ -147,6 +153,16 @@ func TestStaticConfigHUPWithRotation(t *testing.T) { hupTestWithRotation(t, aStatic, tmpFile, oldStr, "str4") hupTestWithRotation(t, aStatic, tmpFile, "str4", "str5") + + mu.Lock() + defer mu.Unlock() + // Delete registered Auth servers. + for k, v := range authServers { + if s, ok := v.(*AuthServerStatic); ok { + s.close() + } + delete(authServers, k) + } } func hupTest(t *testing.T, aStatic *AuthServerStatic, tmpFile *os.File, oldStr, newStr string) { diff --git a/go/mysql/handshake_test.go b/go/mysql/handshake_test.go index 284189c30e8..b71976e3064 100644 --- a/go/mysql/handshake_test.go +++ b/go/mysql/handshake_test.go @@ -37,6 +37,7 @@ import ( // This file tests the handshake scenarios between our client and our server. func TestClearTextClientAuth(t *testing.T) { + ctx := utils.LeakCheckContext(t) th := &testHandler{} authServer := NewAuthServerStaticWithAuthMethodDescription("", "", 0, MysqlClearPassword) @@ -51,10 +52,6 @@ func TestClearTextClientAuth(t *testing.T) { defer l.Close() host := l.Addr().(*net.TCPAddr).IP.String() port := l.Addr().(*net.TCPAddr).Port - go func() { - l.Accept() - }() - // Setup the right parameters. params := &ConnParams{ Host: host, @@ -63,9 +60,15 @@ func TestClearTextClientAuth(t *testing.T) { Pass: "password1", SslMode: vttls.Disabled, } + go l.Accept() + defer func() { + l.Close() + // The accept loop actually only ends on a connection error, which will + // occur when trying to connect after the listener has been closed. + _, _ = Connect(ctx, params) + }() // Connection should fail, as server requires SSL for clear text auth. - ctx := context.Background() _, err = Connect(ctx, params) if err == nil || !strings.Contains(err.Error(), "Cannot use clear text authentication over non-SSL connections") { t.Fatalf("unexpected connection error: %v", err) @@ -92,6 +95,7 @@ func TestClearTextClientAuth(t *testing.T) { // TestSSLConnection creates a server with TLS support, a client that // also has SSL support, and connects them. func TestSSLConnection(t *testing.T) { + ctx := utils.LeakCheckContext(t) th := &testHandler{} authServer := NewAuthServerStaticWithAuthMethodDescription("", "", 0, MysqlClearPassword) @@ -103,7 +107,6 @@ func TestSSLConnection(t *testing.T) { // Create the listener, so we can get its host. l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0) require.NoError(t, err, "NewListener failed: %v", err) - defer l.Close() host := l.Addr().(*net.TCPAddr).IP.String() port := l.Addr().(*net.TCPAddr).Port @@ -122,12 +125,6 @@ func TestSSLConnection(t *testing.T) { "", tls.VersionTLS12) require.NoError(t, err, "TLSServerConfig failed: %v", err) - - l.TLSConfig.Store(serverConfig) - go func() { - l.Accept() - }() - // Setup the right parameters. params := &ConnParams{ Host: host, @@ -141,20 +138,27 @@ func TestSSLConnection(t *testing.T) { SslKey: path.Join(root, "client-key.pem"), ServerName: "server.example.com", } + l.TLSConfig.Store(serverConfig) + go l.Accept() + defer func() { + l.Close() + // The accept loop actually only ends on a connection error, which will + // occur when trying to connect after the listener has been closed. + _, _ = Connect(ctx, params) + }() t.Run("Basics", func(t *testing.T) { - testSSLConnectionBasics(t, params) + testSSLConnectionBasics(t, ctx, params) }) // Make sure clear text auth works over SSL. t.Run("ClearText", func(t *testing.T) { - testSSLConnectionClearText(t, params) + testSSLConnectionClearText(t, ctx, params) }) } -func testSSLConnectionClearText(t *testing.T, params *ConnParams) { +func testSSLConnectionClearText(t *testing.T, ctx context.Context, params *ConnParams) { // Create a client connection, connect. - ctx := context.Background() conn, err := Connect(ctx, params) require.NoError(t, err, "Connect failed: %v", err) @@ -170,9 +174,8 @@ func testSSLConnectionClearText(t *testing.T, params *ConnParams) { conn.writeComQuit() } -func testSSLConnectionBasics(t *testing.T, params *ConnParams) { +func testSSLConnectionBasics(t *testing.T, ctx context.Context, params *ConnParams) { // Create a client connection, connect. - ctx := context.Background() conn, err := Connect(ctx, params) require.NoError(t, err, "Connect failed: %v", err) diff --git a/go/mysql/replication_test.go b/go/mysql/replication_test.go index c397bc71b45..a5310a03764 100644 --- a/go/mysql/replication_test.go +++ b/go/mysql/replication_test.go @@ -23,6 +23,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "vitess.io/vitess/go/test/utils" binlogdatapb "vitess.io/vitess/go/vt/proto/binlogdata" ) @@ -72,6 +73,7 @@ func TestComBinlogDump(t *testing.T) { } func TestComBinlogDumpGTID(t *testing.T) { + _ = utils.LeakCheckContext(t) listener, sConn, cConn := createSocketPair(t) defer func() { listener.Close() @@ -161,6 +163,7 @@ func TestComBinlogDumpGTID(t *testing.T) { } func TestSendSemiSyncAck(t *testing.T) { + _ = utils.LeakCheckContext(t) listener, sConn, cConn := createSocketPair(t) defer func() { listener.Close() diff --git a/go/mysql/server_test.go b/go/mysql/server_test.go index 72b6f25d0c8..0b0356bbbd1 100644 --- a/go/mysql/server_test.go +++ b/go/mysql/server_test.go @@ -267,6 +267,7 @@ func getHostPort(t *testing.T, a net.Addr) (string, int) { } func TestConnectionFromListener(t *testing.T) { + ctx := utils.LeakCheckContext(t) th := &testHandler{} authServer := NewAuthServerStatic("", "", 0) @@ -282,9 +283,6 @@ func TestConnectionFromListener(t *testing.T) { l, err := NewFromListener(listener, authServer, th, 0, 0, false, 0, 0) require.NoError(t, err, "NewListener failed") - defer l.Close() - go l.Accept() - host, port := getHostPort(t, l.Addr()) fmt.Printf("host: %s, port: %d\n", host, port) // Setup the right parameters. @@ -294,13 +292,21 @@ func TestConnectionFromListener(t *testing.T) { Uname: "user1", Pass: "password1", } + go l.Accept() + defer func() { + l.Close() + // The accept loop actually only ends on a connection error, which will + // occur when trying to connect after the listener has been closed. + _, _ = Connect(ctx, params) + }() - c, err := Connect(context.Background(), params) + c, err := Connect(ctx, params) require.NoError(t, err, "Should be able to connect to server") c.Close() } func TestConnectionWithoutSourceHost(t *testing.T) { + ctx := utils.LeakCheckContext(t) th := &testHandler{} authServer := NewAuthServerStatic("", "", 0) @@ -309,13 +315,10 @@ func TestConnectionWithoutSourceHost(t *testing.T) { UserData: "userData1", }} defer authServer.close() + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0) require.NoError(t, err, "NewListener failed") - defer l.Close() - go l.Accept() - host, port := getHostPort(t, l.Addr()) - // Setup the right parameters. params := &ConnParams{ Host: host, @@ -323,13 +326,21 @@ func TestConnectionWithoutSourceHost(t *testing.T) { Uname: "user1", Pass: "password1", } + go l.Accept() + defer func() { + l.Close() + // The accept loop actually only ends on a connection error, which will + // occur when trying to connect after the listener has been closed. + _, _ = Connect(ctx, params) + }() - c, err := Connect(context.Background(), params) + c, err := Connect(ctx, params) require.NoError(t, err, "Should be able to connect to server") c.Close() } func TestConnectionWithSourceHost(t *testing.T) { + ctx := utils.LeakCheckContext(t) th := &testHandler{} authServer := NewAuthServerStatic("", "", 0) @@ -344,11 +355,7 @@ func TestConnectionWithSourceHost(t *testing.T) { l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0) require.NoError(t, err, "NewListener failed") - defer l.Close() - go l.Accept() - host, port := getHostPort(t, l.Addr()) - // Setup the right parameters. params := &ConnParams{ Host: host, @@ -356,13 +363,21 @@ func TestConnectionWithSourceHost(t *testing.T) { Uname: "user1", Pass: "password1", } + go l.Accept() + defer func() { + l.Close() + // The accept loop actually only ends on a connection error, which will + // occur when trying to connect after the listener has been closed. + _, _ = Connect(ctx, params) + }() - _, err = Connect(context.Background(), params) + _, err = Connect(ctx, params) // target is localhost, should not work from tcp connection require.EqualError(t, err, "Access denied for user 'user1' (errno 1045) (sqlstate 28000)", "Should not be able to connect to server") } func TestConnectionUseMysqlNativePasswordWithSourceHost(t *testing.T) { + ctx := utils.LeakCheckContext(t) th := &testHandler{} authServer := NewAuthServerStatic("", "", 0) @@ -377,11 +392,7 @@ func TestConnectionUseMysqlNativePasswordWithSourceHost(t *testing.T) { l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0) require.NoError(t, err, "NewListener failed") - defer l.Close() - go l.Accept() - host, port := getHostPort(t, l.Addr()) - // Setup the right parameters. params := &ConnParams{ Host: host, @@ -389,13 +400,21 @@ func TestConnectionUseMysqlNativePasswordWithSourceHost(t *testing.T) { Uname: "user1", Pass: "mysql_password", } + go l.Accept() + defer func() { + l.Close() + // The accept loop actually only ends on a connection error, which will + // occur when trying to connect after the listener has been closed. + _, _ = Connect(ctx, params) + }() - _, err = Connect(context.Background(), params) + _, err = Connect(ctx, params) // target is localhost, should not work from tcp connection require.EqualError(t, err, "Access denied for user 'user1' (errno 1045) (sqlstate 28000)", "Should not be able to connect to server") } func TestConnectionUnixSocket(t *testing.T) { + ctx := utils.LeakCheckContext(t) th := &testHandler{} authServer := NewAuthServerStatic("", "", 0) @@ -415,22 +434,27 @@ func TestConnectionUnixSocket(t *testing.T) { l, err := NewListener("unix", unixSocket.Name(), authServer, th, 0, 0, false, false, 0, 0) require.NoError(t, err, "NewListener failed") - defer l.Close() - go l.Accept() - // Setup the right parameters. params := &ConnParams{ UnixSocket: unixSocket.Name(), Uname: "user1", Pass: "password1", } + go l.Accept() + defer func() { + l.Close() + // The accept loop actually only ends on a connection error, which will + // occur when trying to connect after the listener has been closed. + _, _ = Connect(ctx, params) + }() - c, err := Connect(context.Background(), params) + c, err := Connect(ctx, params) require.NoError(t, err, "Should be able to connect to server") c.Close() } func TestClientFoundRows(t *testing.T) { + ctx := utils.LeakCheckContext(t) th := &testHandler{} authServer := NewAuthServerStatic("", "", 0) @@ -439,13 +463,10 @@ func TestClientFoundRows(t *testing.T) { UserData: "userData1", }} defer authServer.close() + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0) require.NoError(t, err, "NewListener failed") - defer l.Close() - go l.Accept() - host, port := getHostPort(t, l.Addr()) - // Setup the right parameters. params := &ConnParams{ Host: host, @@ -453,9 +474,16 @@ func TestClientFoundRows(t *testing.T) { Uname: "user1", Pass: "password1", } + go l.Accept() + defer func() { + l.Close() + // The accept loop actually only ends on a connection error, which will + // occur when trying to connect after the listener has been closed. + _, _ = Connect(ctx, params) + }() // Test without flag. - c, err := Connect(context.Background(), params) + c, err := Connect(ctx, params) require.NoError(t, err, "Connect failed") foundRows := th.LastConn().Capabilities & CapabilityClientFoundRows assert.Equal(t, uint32(0), foundRows, "FoundRows flag: %x, second bit must be 0", th.LastConn().Capabilities) @@ -464,7 +492,7 @@ func TestClientFoundRows(t *testing.T) { // Test with flag. params.Flags |= CapabilityClientFoundRows - c, err = Connect(context.Background(), params) + c, err = Connect(ctx, params) require.NoError(t, err, "Connect failed") foundRows = th.LastConn().Capabilities & CapabilityClientFoundRows assert.NotZero(t, foundRows, "FoundRows flag: %x, second bit must be set", th.LastConn().Capabilities) @@ -472,6 +500,7 @@ func TestClientFoundRows(t *testing.T) { } func TestConnCounts(t *testing.T) { + ctx := utils.LeakCheckContext(t) th := &testHandler{} user := "anotherNotYetConnectedUser1" @@ -483,13 +512,10 @@ func TestConnCounts(t *testing.T) { UserData: "userData1", }} defer authServer.close() + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0) require.NoError(t, err, "NewListener failed") - defer l.Close() - go l.Accept() - host, port := getHostPort(t, l.Addr()) - // Test with one new connection. params := &ConnParams{ Host: host, @@ -497,14 +523,21 @@ func TestConnCounts(t *testing.T) { Uname: user, Pass: passwd, } + go l.Accept() + defer func() { + l.Close() + // The accept loop actually only ends on a connection error, which will + // occur when trying to connect after the listener has been closed. + _, _ = Connect(ctx, params) + }() - c, err := Connect(context.Background(), params) + c, err := Connect(ctx, params) require.NoError(t, err, "Connect failed") checkCountsForUser(t, user, 1) // Test with a second new connection. - c2, err := Connect(context.Background(), params) + c2, err := Connect(ctx, params) require.NoError(t, err) checkCountsForUser(t, user, 2) @@ -529,6 +562,7 @@ func checkCountsForUser(t assert.TestingT, user string, expected int64) { } func TestServer(t *testing.T) { + ctx := utils.LeakCheckContext(t) th := &testHandler{} authServer := NewAuthServerStatic("", "", 0) @@ -537,14 +571,10 @@ func TestServer(t *testing.T) { UserData: "userData1", }} defer authServer.close() + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0) require.NoError(t, err) - l.SlowConnectWarnThreshold.Store(time.Nanosecond.Nanoseconds()) - defer l.Close() - go l.Accept() - host, port := getHostPort(t, l.Addr()) - // Setup the right parameters. params := &ConnParams{ Host: host, @@ -552,6 +582,14 @@ func TestServer(t *testing.T) { Uname: "user1", Pass: "password1", } + l.SlowConnectWarnThreshold.Store(time.Nanosecond.Nanoseconds()) + go l.Accept() + defer func() { + l.Close() + // The accept loop actually only ends on a connection error, which will + // occur when trying to connect after the listener has been closed. + _, _ = Connect(ctx, params) + }() // Run a 'select rows' command with results. output, err := runMysqlWithErr(t, params, "select rows") @@ -629,6 +667,7 @@ func TestServer(t *testing.T) { } func TestServerStats(t *testing.T) { + ctx := utils.LeakCheckContext(t) th := &testHandler{} authServer := NewAuthServerStatic("", "", 0) @@ -637,14 +676,10 @@ func TestServerStats(t *testing.T) { UserData: "userData1", }} defer authServer.close() + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0) require.NoError(t, err) - l.SlowConnectWarnThreshold.Store(time.Nanosecond.Nanoseconds()) - defer l.Close() - go l.Accept() - host, port := getHostPort(t, l.Addr()) - // Setup the right parameters. params := &ConnParams{ Host: host, @@ -652,6 +687,14 @@ func TestServerStats(t *testing.T) { Uname: "user1", Pass: "password1", } + l.SlowConnectWarnThreshold.Store(time.Nanosecond.Nanoseconds()) + go l.Accept() + defer func() { + l.Close() + // The accept loop actually only ends on a connection error, which will + // occur when trying to connect after the listener has been closed. + _, _ = Connect(ctx, params) + }() timings.Reset() connAccept.Reset() @@ -703,6 +746,7 @@ func TestServerStats(t *testing.T) { // TestClearTextServer creates a Server that needs clear text // passwords from the client. func TestClearTextServer(t *testing.T) { + ctx := utils.LeakCheckContext(t) th := &testHandler{} authServer := NewAuthServerStaticWithAuthMethodDescription("", "", 0, MysqlClearPassword) @@ -711,16 +755,10 @@ func TestClearTextServer(t *testing.T) { UserData: "userData1", }} defer authServer.close() + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0) require.NoError(t, err) - defer l.Close() - go l.Accept() - host, port := getHostPort(t, l.Addr()) - - version, _ := runMysql(t, nil, "--version") - isMariaDB := strings.Contains(version, "MariaDB") - // Setup the right parameters. params := &ConnParams{ Host: host, @@ -728,6 +766,16 @@ func TestClearTextServer(t *testing.T) { Uname: "user1", Pass: "password1", } + go l.Accept() + defer func() { + l.Close() + // The accept loop actually only ends on a connection error, which will + // occur when trying to connect after the listener has been closed. + _, _ = Connect(ctx, params) + }() + + version, _ := runMysql(t, nil, "--version") + isMariaDB := strings.Contains(version, "MariaDB") // Run a 'select rows' command with results. This should fail // as clear text is not enabled by default on the client @@ -776,6 +824,7 @@ func TestClearTextServer(t *testing.T) { // TestDialogServer creates a Server that uses the dialog plugin on the client. func TestDialogServer(t *testing.T) { + ctx := utils.LeakCheckContext(t) th := &testHandler{} authServer := NewAuthServerStaticWithAuthMethodDescription("", "", 0, MysqlDialog) @@ -784,14 +833,11 @@ func TestDialogServer(t *testing.T) { UserData: "userData1", }} defer authServer.close() + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0) require.NoError(t, err) l.AllowClearTextWithoutTLS.Store(true) - defer l.Close() - go l.Accept() - host, port := getHostPort(t, l.Addr()) - // Setup the right parameters. params := &ConnParams{ Host: host, @@ -800,6 +846,14 @@ func TestDialogServer(t *testing.T) { Pass: "password1", SslMode: vttls.Disabled, } + go l.Accept() + defer func() { + l.Close() + // The accept loop actually only ends on a connection error, which will + // occur when trying to connect after the listener has been closed. + _, _ = Connect(ctx, params) + }() + sql := "select rows" output, ok := runMysql(t, params, sql) if strings.Contains(output, "No such file or directory") || strings.Contains(output, "Authentication plugin 'dialog' cannot be loaded") { @@ -815,6 +869,7 @@ func TestDialogServer(t *testing.T) { // TestTLSServer creates a Server with TLS support, then uses mysql // client to connect to it. func TestTLSServer(t *testing.T) { + ctx := utils.LeakCheckContext(t) th := &testHandler{} authServer := NewAuthServerStatic("", "", 0) @@ -823,46 +878,20 @@ func TestTLSServer(t *testing.T) { }} defer authServer.close() + // Create the certs. + root := t.TempDir() + tlstest.CreateCA(root) + tlstest.CreateSignedCert(root, tlstest.CA, "01", "server", "server.example.com") + tlstest.CreateSignedCert(root, tlstest.CA, "02", "client", "Client Cert") + // Create the listener, so we can get its host. // Below, we are enabling --ssl-verify-server-cert, which adds // a check that the common name of the certificate matches the // server host name we connect to. l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0) require.NoError(t, err) - defer l.Close() - host := l.Addr().(*net.TCPAddr).IP.String() port := l.Addr().(*net.TCPAddr).Port - - // Create the certs. - root := t.TempDir() - tlstest.CreateCA(root) - tlstest.CreateSignedCert(root, tlstest.CA, "01", "server", "server.example.com") - tlstest.CreateSignedCert(root, tlstest.CA, "02", "client", "Client Cert") - - // Create the server with TLS config. - serverConfig, err := vttls.ServerConfig( - path.Join(root, "server-cert.pem"), - path.Join(root, "server-key.pem"), - path.Join(root, "ca-cert.pem"), - "", - "", - tls.VersionTLS12) - require.NoError(t, err) - l.TLSConfig.Store(serverConfig) - - var wg sync.WaitGroup - wg.Add(1) - go func(l *Listener) { - wg.Done() - l.Accept() - }(l) - // This is ensure the listener is called - wg.Wait() - // Sleep so that the Accept function is called as well.' - time.Sleep(3 * time.Second) - - connCountByTLSVer.ResetAll() // Setup the right parameters. params := &ConnParams{ Host: host, @@ -876,9 +905,28 @@ func TestTLSServer(t *testing.T) { SslKey: path.Join(root, "client-key.pem"), ServerName: "server.example.com", } + // Create the server with TLS config. + serverConfig, err := vttls.ServerConfig( + path.Join(root, "server-cert.pem"), + path.Join(root, "server-key.pem"), + path.Join(root, "ca-cert.pem"), + "", + "", + tls.VersionTLS12) + require.NoError(t, err) + l.TLSConfig.Store(serverConfig) + go l.Accept() + defer func() { + l.Close() + // The accept loop actually only ends on a connection error, which will + // occur when trying to connect after the listener has been closed. + _, _ = Connect(ctx, params) + }() + + connCountByTLSVer.ResetAll() // Run a 'select rows' command with results. - conn, err := Connect(context.Background(), params) + conn, err := Connect(ctx, params) // output, ok := runMysql(t, params, "select rows") require.NoError(t, err) results, err := conn.ExecuteFetch("select rows", 1000, true) @@ -913,25 +961,9 @@ func TestTLSServer(t *testing.T) { // TestTLSRequired creates a Server with TLS required, then tests that an insecure mysql // client is rejected func TestTLSRequired(t *testing.T) { + ctx := utils.LeakCheckContext(t) th := &testHandler{} - authServer := NewAuthServerStatic("", "", 0) - authServer.entries["user1"] = []*AuthServerStaticEntry{{ - Password: "password1", - }} - defer authServer.close() - - // Create the listener, so we can get its host. - // Below, we are enabling --ssl-verify-server-cert, which adds - // a check that the common name of the certificate matches the - // server host name we connect to. - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0) - require.NoError(t, err) - defer l.Close() - - host := l.Addr().(*net.TCPAddr).IP.String() - port := l.Addr().(*net.TCPAddr).Port - // Create the certs. root := t.TempDir() tlstest.CreateCA(root) @@ -940,6 +972,13 @@ func TestTLSRequired(t *testing.T) { tlstest.CreateSignedCert(root, tlstest.CA, "03", "revoked-client", "Revoked Client Cert") tlstest.RevokeCertAndRegenerateCRL(root, tlstest.CA, "revoked-client") + params := &ConnParams{ + Uname: "user1", + Pass: "password1", + SslMode: vttls.Disabled, // TLS is disabled at first + ServerName: "server.example.com", + } + // Create the server with TLS config. serverConfig, err := vttls.ServerConfig( path.Join(root, "server-cert.pem"), @@ -949,34 +988,55 @@ func TestTLSRequired(t *testing.T) { "", tls.VersionTLS12) require.NoError(t, err) - l.TLSConfig.Store(serverConfig) - l.RequireSecureTransport = true - - var wg sync.WaitGroup - wg.Add(1) - go func(l *Listener) { - wg.Done() - l.Accept() - }(l) - // This is ensure the listener is called - wg.Wait() - // Sleep so that the Accept function is called as well.' - time.Sleep(3 * time.Second) - - // Setup conn params without SSL. - params := &ConnParams{ - Host: host, - Port: port, - Uname: "user1", - Pass: "password1", - SslMode: vttls.Disabled, - ServerName: "server.example.com", + + authServer := NewAuthServerStatic("", "", 0) + authServer.entries["user1"] = []*AuthServerStaticEntry{{ + Password: "password1", + }} + defer authServer.close() + + var l *Listener + setupServer := func() { + // Create the listener, so we can get its host. + // Below, we are enabling --ssl-verify-server-cert, which adds + // a check that the common name of the certificate matches the + // server host name we connect to. + l, err = NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0) + require.NoError(t, err) + host := l.Addr().(*net.TCPAddr).IP.String() + port := l.Addr().(*net.TCPAddr).Port + l.TLSConfig.Store(serverConfig) + l.RequireSecureTransport = true + go l.Accept() + params.Host = host + params.Port = port + } + setupServer() + + cleanup := func() { + l.Close() + // The accept loop actually only ends on a connection error, which will + // occur when trying to connect after the listener has been closed. + _, _ = Connect(ctx, params) } - conn, err := Connect(context.Background(), params) - require.NotNil(t, err) - require.Contains(t, err.Error(), "Code: UNAVAILABLE") - require.Contains(t, err.Error(), "server does not allow insecure connections, client must use SSL/TLS") - require.Contains(t, err.Error(), "(errno 1105) (sqlstate HY000)") + defer cleanup() + + // This test calls Connect multiple times so we add handling for when the + // listener goes away for any reason. + connectWithGoneServerHandling := func() (*Conn, error) { + conn, err := Connect(ctx, params) + if sqlErr, ok := sqlerror.NewSQLErrorFromError(err).(*sqlerror.SQLError); ok && sqlErr.Num == sqlerror.CRConnHostError { + cleanup() + setupServer() + conn, err = Connect(ctx, params) + } + return conn, err + } + + conn, err := connectWithGoneServerHandling() + require.ErrorContains(t, err, "Code: UNAVAILABLE") + require.ErrorContains(t, err, "server does not allow insecure connections, client must use SSL/TLS") + require.ErrorContains(t, err, "(errno 1105) (sqlstate HY000)") if conn != nil { conn.Close() } @@ -987,7 +1047,7 @@ func TestTLSRequired(t *testing.T) { params.SslCert = path.Join(root, "client-cert.pem") params.SslKey = path.Join(root, "client-key.pem") - conn, err = Connect(context.Background(), params) + conn, err = connectWithGoneServerHandling() require.NoError(t, err) if conn != nil { conn.Close() @@ -996,15 +1056,15 @@ func TestTLSRequired(t *testing.T) { // setup conn params with TLS, but with a revoked client certificate params.SslCert = path.Join(root, "revoked-client-cert.pem") params.SslKey = path.Join(root, "revoked-client-key.pem") - conn, err = Connect(context.Background(), params) - require.NotNil(t, err) - require.Contains(t, err.Error(), "remote error: tls: bad certificate") + conn, err = connectWithGoneServerHandling() + require.ErrorContains(t, err, "remote error: tls: bad certificate") if conn != nil { conn.Close() } } func TestCachingSha2PasswordAuthWithTLS(t *testing.T) { + ctx := utils.LeakCheckContext(t) th := &testHandler{} authServer := NewAuthServerStaticWithAuthMethodDescription("", "", 0, CachingSha2Password) @@ -1013,19 +1073,17 @@ func TestCachingSha2PasswordAuthWithTLS(t *testing.T) { } defer authServer.close() - // Create the listener, so we can get its host. - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0) - require.NoError(t, err, "NewListener failed: %v", err) - defer l.Close() - host := l.Addr().(*net.TCPAddr).IP.String() - port := l.Addr().(*net.TCPAddr).Port - // Create the certs. root := t.TempDir() tlstest.CreateCA(root) tlstest.CreateSignedCert(root, tlstest.CA, "01", "server", "server.example.com") tlstest.CreateSignedCert(root, tlstest.CA, "02", "client", "Client Cert") + // Create the listener, so we can get its host. + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0) + require.NoError(t, err, "NewListener failed: %v", err) + host := l.Addr().(*net.TCPAddr).IP.String() + port := l.Addr().(*net.TCPAddr).Port // Create the server with TLS config. serverConfig, err := vttls.ServerConfig( path.Join(root, "server-cert.pem"), @@ -1035,12 +1093,6 @@ func TestCachingSha2PasswordAuthWithTLS(t *testing.T) { "", tls.VersionTLS12) require.NoError(t, err, "TLSServerConfig failed: %v", err) - - l.TLSConfig.Store(serverConfig) - go func() { - l.Accept() - }() - // Setup the right parameters. params := &ConnParams{ Host: host, @@ -1054,10 +1106,16 @@ func TestCachingSha2PasswordAuthWithTLS(t *testing.T) { SslKey: path.Join(root, "client-key.pem"), ServerName: "server.example.com", } + l.TLSConfig.Store(serverConfig) + go l.Accept() + defer func() { + l.Close() + // The accept loop actually only ends on a connection error, which will + // occur when trying to connect after the listener has been closed. + _, _ = Connect(ctx, params) + }() // Connection should fail, as server requires SSL for caching_sha2_password. - ctx := context.Background() - conn, err := Connect(ctx, params) require.NoError(t, err, "unexpected connection error: %v", err) @@ -1093,12 +1151,11 @@ func newAuthServerAlwaysFallback(file, jsonConfig string, reloadInterval time.Du authMethod := NewSha2CachingAuthMethod(&alwaysFallbackAuth{}, a, a) a.methods = []AuthMethod{authMethod} - a.reload() - a.installSignalHandlers() return a } func TestCachingSha2PasswordAuthWithMoreData(t *testing.T) { + ctx := utils.LeakCheckContext(t) th := &testHandler{} authServer := newAuthServerAlwaysFallback("", "", 0) @@ -1107,19 +1164,17 @@ func TestCachingSha2PasswordAuthWithMoreData(t *testing.T) { } defer authServer.close() - // Create the listener, so we can get its host. - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0) - require.NoError(t, err, "NewListener failed: %v", err) - defer l.Close() - host := l.Addr().(*net.TCPAddr).IP.String() - port := l.Addr().(*net.TCPAddr).Port - // Create the certs. root := t.TempDir() tlstest.CreateCA(root) tlstest.CreateSignedCert(root, tlstest.CA, "01", "server", "server.example.com") tlstest.CreateSignedCert(root, tlstest.CA, "02", "client", "Client Cert") + // Create the listener, so we can get its host. + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0) + require.NoError(t, err, "NewListener failed: %v", err) + host := l.Addr().(*net.TCPAddr).IP.String() + port := l.Addr().(*net.TCPAddr).Port // Create the server with TLS config. serverConfig, err := vttls.ServerConfig( path.Join(root, "server-cert.pem"), @@ -1129,12 +1184,6 @@ func TestCachingSha2PasswordAuthWithMoreData(t *testing.T) { "", tls.VersionTLS12) require.NoError(t, err, "TLSServerConfig failed: %v", err) - - l.TLSConfig.Store(serverConfig) - go func() { - l.Accept() - }() - // Setup the right parameters. params := &ConnParams{ Host: host, @@ -1148,10 +1197,16 @@ func TestCachingSha2PasswordAuthWithMoreData(t *testing.T) { SslKey: path.Join(root, "client-key.pem"), ServerName: "server.example.com", } + l.TLSConfig.Store(serverConfig) + go l.Accept() + defer func() { + l.Close() + // The accept loop actually only ends on a connection error, which will + // occur when trying to connect after the listener has been closed. + _, _ = Connect(ctx, params) + }() // Connection should fail, as server requires SSL for caching_sha2_password. - ctx := context.Background() - conn, err := Connect(ctx, params) require.NoError(t, err, "unexpected connection error: %v", err) @@ -1168,6 +1223,7 @@ func TestCachingSha2PasswordAuthWithMoreData(t *testing.T) { } func TestCachingSha2PasswordAuthWithoutTLS(t *testing.T) { + ctx := utils.LeakCheckContext(t) th := &testHandler{} authServer := NewAuthServerStaticWithAuthMethodDescription("", "", 0, CachingSha2Password) @@ -1179,13 +1235,8 @@ func TestCachingSha2PasswordAuthWithoutTLS(t *testing.T) { // Create the listener. l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0) require.NoError(t, err, "NewListener failed: %v", err) - defer l.Close() host := l.Addr().(*net.TCPAddr).IP.String() port := l.Addr().(*net.TCPAddr).Port - go func() { - l.Accept() - }() - // Setup the right parameters. params := &ConnParams{ Host: host, @@ -1194,9 +1245,15 @@ func TestCachingSha2PasswordAuthWithoutTLS(t *testing.T) { Pass: "password1", SslMode: vttls.Disabled, } + go l.Accept() + defer func() { + l.Close() + // The accept loop actually only ends on a connection error, which will + // occur when trying to connect after the listener has been closed. + _, _ = Connect(ctx, params) + }() // Connection should fail, as server requires SSL for caching_sha2_password. - ctx := context.Background() _, err = Connect(ctx, params) if err == nil || !strings.Contains(err.Error(), "No authentication methods available for authentication") { t.Fatalf("unexpected connection error: %v", err) @@ -1211,6 +1268,7 @@ func checkCountForTLSVer(t *testing.T, version string, expected int64) { } func TestErrorCodes(t *testing.T) { + ctx := utils.LeakCheckContext(t) th := &testHandler{} authServer := NewAuthServerStatic("", "", 0) @@ -1219,13 +1277,10 @@ func TestErrorCodes(t *testing.T) { UserData: "userData1", }} defer authServer.close() + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0) require.NoError(t, err) - defer l.Close() - go l.Accept() - host, port := getHostPort(t, l.Addr()) - // Setup the right parameters. params := &ConnParams{ Host: host, @@ -1233,8 +1288,14 @@ func TestErrorCodes(t *testing.T) { Uname: "user1", Pass: "password1", } + go l.Accept() + defer func() { + l.Close() + // The accept loop actually only ends on a connection error, which will + // occur when trying to connect after the listener has been closed. + _, _ = Connect(ctx, params) + }() - ctx := context.Background() client, err := Connect(ctx, params) require.NoError(t, err) @@ -1390,6 +1451,7 @@ func binaryPath(root, binary string) (string, error) { } func TestListenerShutdown(t *testing.T) { + ctx := utils.LeakCheckContext(t) th := &testHandler{} authServer := NewAuthServerStatic("", "", 0) authServer.entries["user1"] = []*AuthServerStaticEntry{{ @@ -1397,13 +1459,10 @@ func TestListenerShutdown(t *testing.T) { UserData: "userData1", }} defer authServer.close() + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0) require.NoError(t, err) - defer l.Close() - go l.Accept() - host, port := getHostPort(t, l.Addr()) - // Setup the right parameters. params := &ConnParams{ Host: host, @@ -1411,9 +1470,17 @@ func TestListenerShutdown(t *testing.T) { Uname: "user1", Pass: "password1", } + go l.Accept() + defer func() { + l.Close() + // The accept loop actually only ends on a connection error, which will + // occur when trying to connect after the listener has been closed. + _, _ = Connect(ctx, params) + }() + connRefuse.Reset() - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(ctx) defer cancel() conn, err := Connect(ctx, params) @@ -1424,7 +1491,7 @@ func TestListenerShutdown(t *testing.T) { l.Shutdown() - waitForConnRefuse(t, 1) + waitForConnRefuse(t, ctx, 1) err = conn.Ping() require.EqualError(t, err, "Server shutdown in progress (errno 1053) (sqlstate 08S01)") @@ -1436,8 +1503,8 @@ func TestListenerShutdown(t *testing.T) { require.Equal(t, "Server shutdown in progress", sqlErr.Message) } -func waitForConnRefuse(t *testing.T, valWanted int64) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) +func waitForConnRefuse(t *testing.T, ctx context.Context, valWanted int64) { + ctx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() tick := time.NewTicker(100 * time.Millisecond) defer tick.Stop() @@ -1483,21 +1550,26 @@ func TestParseConnAttrs(t *testing.T) { } func TestServerFlush(t *testing.T) { + ctx := utils.LeakCheckContext(t) mysqlServerFlushDelay := 10 * time.Millisecond th := &testHandler{} l, err := NewListener("tcp", "127.0.0.1:", NewAuthServerNone(), th, 0, 0, false, false, 0, mysqlServerFlushDelay) require.NoError(t, err) - defer l.Close() - go l.Accept() - host, port := getHostPort(t, l.Addr()) params := &ConnParams{ Host: host, Port: port, } + go l.Accept() + defer func() { + l.Close() + // The accept loop actually only ends on a connection error, which will + // occur when trying to connect after the listener has been closed. + _, _ = Connect(ctx, params) + }() - c, err := Connect(context.Background(), params) + c, err := Connect(ctx, params) require.NoError(t, err) defer c.Close() @@ -1531,20 +1603,26 @@ func TestServerFlush(t *testing.T) { } func TestTcpKeepAlive(t *testing.T) { + ctx := utils.LeakCheckContext(t) th := &testHandler{} + l, err := NewListener("tcp", "127.0.0.1:", NewAuthServerNone(), th, 0, 0, false, false, 0, 0) require.NoError(t, err) - defer l.Close() - go l.Accept() - host, port := getHostPort(t, l.Addr()) params := &ConnParams{ Host: host, Port: port, } + go l.Accept() + defer func() { + l.Close() + // The accept loop actually only ends on a connection error, which will + // occur when trying to connect after the listener has been closed. + _, _ = Connect(ctx, params) + }() // on connect, the tcp method should be called. - c, err := Connect(context.Background(), params) + c, err := Connect(ctx, params) require.NoError(t, err) defer c.Close() require.True(t, th.lastConn.keepAliveOn, "tcp property method not called") From 0d05d3e00ee32ef4a5292de46949c80b24689340 Mon Sep 17 00:00:00 2001 From: Matt Lord Date: Mon, 21 Oct 2024 14:23:43 -0400 Subject: [PATCH 2/5] Leverage test helper everywhere Signed-off-by: Matt Lord --- go/mysql/auth_server_clientcert_test.go | 25 +++-- go/mysql/handshake_test.go | 14 +-- go/mysql/server_test.go | 143 ++++-------------------- 3 files changed, 36 insertions(+), 146 deletions(-) diff --git a/go/mysql/auth_server_clientcert_test.go b/go/mysql/auth_server_clientcert_test.go index 95f7125a6b3..201c4d8dabd 100644 --- a/go/mysql/auth_server_clientcert_test.go +++ b/go/mysql/auth_server_clientcert_test.go @@ -17,6 +17,7 @@ limitations under the License. package mysql import ( + "context" "crypto/tls" "net" "path" @@ -33,6 +34,16 @@ import ( const clientCertUsername = "Client Cert" +// The listener's accept loop actually only ends on a connection +// error, which will occur when trying to connect after the listener +// has been closed. So this function closes the listener and then +// calls Connect to trigger the error which ends that conneciton +// request handler goroutine. +var cleanupListener = func(ctx context.Context, l *Listener, params *ConnParams) { + l.Close() + _, _ = Connect(ctx, params) +} + func TestValidCert(t *testing.T) { ctx := utils.LeakCheckContext(t) th := &testHandler{} @@ -79,12 +90,7 @@ func TestValidCert(t *testing.T) { l.TLSConfig.Store(serverConfig) go l.Accept() - defer func() { - l.Close() - // The accept loop actually only ends on a connection error, which will - // occur when trying to connect after the listener has been closed. - _, _ = Connect(ctx, params) - }() + defer cleanupListener(ctx, l, params) conn, err := Connect(ctx, params) require.NoError(t, err, "Connect failed: %v", err) @@ -148,12 +154,7 @@ func TestNoCert(t *testing.T) { l.TLSConfig.Store(serverConfig) go l.Accept() - defer func() { - l.Close() - // The accept loop actually only ends on a connection error, which will - // occur when trying to connect after the listener has been closed. - _, _ = Connect(ctx, params) - }() + defer cleanupListener(ctx, l, params) conn, err := Connect(ctx, params) assert.Error(t, err, "Connect() should have errored due to no client cert") diff --git a/go/mysql/handshake_test.go b/go/mysql/handshake_test.go index b71976e3064..13ed1099e58 100644 --- a/go/mysql/handshake_test.go +++ b/go/mysql/handshake_test.go @@ -61,12 +61,7 @@ func TestClearTextClientAuth(t *testing.T) { SslMode: vttls.Disabled, } go l.Accept() - defer func() { - l.Close() - // The accept loop actually only ends on a connection error, which will - // occur when trying to connect after the listener has been closed. - _, _ = Connect(ctx, params) - }() + defer cleanupListener(ctx, l, params) // Connection should fail, as server requires SSL for clear text auth. _, err = Connect(ctx, params) @@ -140,12 +135,7 @@ func TestSSLConnection(t *testing.T) { } l.TLSConfig.Store(serverConfig) go l.Accept() - defer func() { - l.Close() - // The accept loop actually only ends on a connection error, which will - // occur when trying to connect after the listener has been closed. - _, _ = Connect(ctx, params) - }() + defer cleanupListener(ctx, l, params) t.Run("Basics", func(t *testing.T) { testSSLConnectionBasics(t, ctx, params) diff --git a/go/mysql/server_test.go b/go/mysql/server_test.go index 0b0356bbbd1..d4a619d1bc0 100644 --- a/go/mysql/server_test.go +++ b/go/mysql/server_test.go @@ -293,12 +293,7 @@ func TestConnectionFromListener(t *testing.T) { Pass: "password1", } go l.Accept() - defer func() { - l.Close() - // The accept loop actually only ends on a connection error, which will - // occur when trying to connect after the listener has been closed. - _, _ = Connect(ctx, params) - }() + defer cleanupListener(ctx, l, params) c, err := Connect(ctx, params) require.NoError(t, err, "Should be able to connect to server") @@ -327,12 +322,7 @@ func TestConnectionWithoutSourceHost(t *testing.T) { Pass: "password1", } go l.Accept() - defer func() { - l.Close() - // The accept loop actually only ends on a connection error, which will - // occur when trying to connect after the listener has been closed. - _, _ = Connect(ctx, params) - }() + defer cleanupListener(ctx, l, params) c, err := Connect(ctx, params) require.NoError(t, err, "Should be able to connect to server") @@ -364,12 +354,7 @@ func TestConnectionWithSourceHost(t *testing.T) { Pass: "password1", } go l.Accept() - defer func() { - l.Close() - // The accept loop actually only ends on a connection error, which will - // occur when trying to connect after the listener has been closed. - _, _ = Connect(ctx, params) - }() + defer cleanupListener(ctx, l, params) _, err = Connect(ctx, params) // target is localhost, should not work from tcp connection @@ -401,12 +386,7 @@ func TestConnectionUseMysqlNativePasswordWithSourceHost(t *testing.T) { Pass: "mysql_password", } go l.Accept() - defer func() { - l.Close() - // The accept loop actually only ends on a connection error, which will - // occur when trying to connect after the listener has been closed. - _, _ = Connect(ctx, params) - }() + defer cleanupListener(ctx, l, params) _, err = Connect(ctx, params) // target is localhost, should not work from tcp connection @@ -441,12 +421,7 @@ func TestConnectionUnixSocket(t *testing.T) { Pass: "password1", } go l.Accept() - defer func() { - l.Close() - // The accept loop actually only ends on a connection error, which will - // occur when trying to connect after the listener has been closed. - _, _ = Connect(ctx, params) - }() + defer cleanupListener(ctx, l, params) c, err := Connect(ctx, params) require.NoError(t, err, "Should be able to connect to server") @@ -475,12 +450,7 @@ func TestClientFoundRows(t *testing.T) { Pass: "password1", } go l.Accept() - defer func() { - l.Close() - // The accept loop actually only ends on a connection error, which will - // occur when trying to connect after the listener has been closed. - _, _ = Connect(ctx, params) - }() + defer cleanupListener(ctx, l, params) // Test without flag. c, err := Connect(ctx, params) @@ -524,12 +494,7 @@ func TestConnCounts(t *testing.T) { Pass: passwd, } go l.Accept() - defer func() { - l.Close() - // The accept loop actually only ends on a connection error, which will - // occur when trying to connect after the listener has been closed. - _, _ = Connect(ctx, params) - }() + defer cleanupListener(ctx, l, params) c, err := Connect(ctx, params) require.NoError(t, err, "Connect failed") @@ -584,12 +549,7 @@ func TestServer(t *testing.T) { } l.SlowConnectWarnThreshold.Store(time.Nanosecond.Nanoseconds()) go l.Accept() - defer func() { - l.Close() - // The accept loop actually only ends on a connection error, which will - // occur when trying to connect after the listener has been closed. - _, _ = Connect(ctx, params) - }() + defer cleanupListener(ctx, l, params) // Run a 'select rows' command with results. output, err := runMysqlWithErr(t, params, "select rows") @@ -689,12 +649,7 @@ func TestServerStats(t *testing.T) { } l.SlowConnectWarnThreshold.Store(time.Nanosecond.Nanoseconds()) go l.Accept() - defer func() { - l.Close() - // The accept loop actually only ends on a connection error, which will - // occur when trying to connect after the listener has been closed. - _, _ = Connect(ctx, params) - }() + defer cleanupListener(ctx, l, params) timings.Reset() connAccept.Reset() @@ -767,12 +722,7 @@ func TestClearTextServer(t *testing.T) { Pass: "password1", } go l.Accept() - defer func() { - l.Close() - // The accept loop actually only ends on a connection error, which will - // occur when trying to connect after the listener has been closed. - _, _ = Connect(ctx, params) - }() + defer cleanupListener(ctx, l, params) version, _ := runMysql(t, nil, "--version") isMariaDB := strings.Contains(version, "MariaDB") @@ -847,12 +797,7 @@ func TestDialogServer(t *testing.T) { SslMode: vttls.Disabled, } go l.Accept() - defer func() { - l.Close() - // The accept loop actually only ends on a connection error, which will - // occur when trying to connect after the listener has been closed. - _, _ = Connect(ctx, params) - }() + defer cleanupListener(ctx, l, params) sql := "select rows" output, ok := runMysql(t, params, sql) @@ -916,12 +861,7 @@ func TestTLSServer(t *testing.T) { require.NoError(t, err) l.TLSConfig.Store(serverConfig) go l.Accept() - defer func() { - l.Close() - // The accept loop actually only ends on a connection error, which will - // occur when trying to connect after the listener has been closed. - _, _ = Connect(ctx, params) - }() + defer cleanupListener(ctx, l, params) connCountByTLSVer.ResetAll() @@ -1013,20 +953,14 @@ func TestTLSRequired(t *testing.T) { } setupServer() - cleanup := func() { - l.Close() - // The accept loop actually only ends on a connection error, which will - // occur when trying to connect after the listener has been closed. - _, _ = Connect(ctx, params) - } - defer cleanup() + defer cleanupListener(ctx, l, params) // This test calls Connect multiple times so we add handling for when the // listener goes away for any reason. connectWithGoneServerHandling := func() (*Conn, error) { conn, err := Connect(ctx, params) if sqlErr, ok := sqlerror.NewSQLErrorFromError(err).(*sqlerror.SQLError); ok && sqlErr.Num == sqlerror.CRConnHostError { - cleanup() + cleanupListener(ctx, l, params) setupServer() conn, err = Connect(ctx, params) } @@ -1108,12 +1042,7 @@ func TestCachingSha2PasswordAuthWithTLS(t *testing.T) { } l.TLSConfig.Store(serverConfig) go l.Accept() - defer func() { - l.Close() - // The accept loop actually only ends on a connection error, which will - // occur when trying to connect after the listener has been closed. - _, _ = Connect(ctx, params) - }() + defer cleanupListener(ctx, l, params) // Connection should fail, as server requires SSL for caching_sha2_password. conn, err := Connect(ctx, params) @@ -1199,12 +1128,7 @@ func TestCachingSha2PasswordAuthWithMoreData(t *testing.T) { } l.TLSConfig.Store(serverConfig) go l.Accept() - defer func() { - l.Close() - // The accept loop actually only ends on a connection error, which will - // occur when trying to connect after the listener has been closed. - _, _ = Connect(ctx, params) - }() + defer cleanupListener(ctx, l, params) // Connection should fail, as server requires SSL for caching_sha2_password. conn, err := Connect(ctx, params) @@ -1246,12 +1170,7 @@ func TestCachingSha2PasswordAuthWithoutTLS(t *testing.T) { SslMode: vttls.Disabled, } go l.Accept() - defer func() { - l.Close() - // The accept loop actually only ends on a connection error, which will - // occur when trying to connect after the listener has been closed. - _, _ = Connect(ctx, params) - }() + defer cleanupListener(ctx, l, params) // Connection should fail, as server requires SSL for caching_sha2_password. _, err = Connect(ctx, params) @@ -1289,12 +1208,7 @@ func TestErrorCodes(t *testing.T) { Pass: "password1", } go l.Accept() - defer func() { - l.Close() - // The accept loop actually only ends on a connection error, which will - // occur when trying to connect after the listener has been closed. - _, _ = Connect(ctx, params) - }() + defer cleanupListener(ctx, l, params) client, err := Connect(ctx, params) require.NoError(t, err) @@ -1471,12 +1385,7 @@ func TestListenerShutdown(t *testing.T) { Pass: "password1", } go l.Accept() - defer func() { - l.Close() - // The accept loop actually only ends on a connection error, which will - // occur when trying to connect after the listener has been closed. - _, _ = Connect(ctx, params) - }() + defer cleanupListener(ctx, l, params) connRefuse.Reset() @@ -1562,12 +1471,7 @@ func TestServerFlush(t *testing.T) { Port: port, } go l.Accept() - defer func() { - l.Close() - // The accept loop actually only ends on a connection error, which will - // occur when trying to connect after the listener has been closed. - _, _ = Connect(ctx, params) - }() + defer cleanupListener(ctx, l, params) c, err := Connect(ctx, params) require.NoError(t, err) @@ -1614,12 +1518,7 @@ func TestTcpKeepAlive(t *testing.T) { Port: port, } go l.Accept() - defer func() { - l.Close() - // The accept loop actually only ends on a connection error, which will - // occur when trying to connect after the listener has been closed. - _, _ = Connect(ctx, params) - }() + defer cleanupListener(ctx, l, params) // on connect, the tcp method should be called. c, err := Connect(ctx, params) From bf8a9d52e3f2b0daea50d90b7b775b9c74239c44 Mon Sep 17 00:00:00 2001 From: Matt Lord Date: Mon, 21 Oct 2024 14:47:02 -0400 Subject: [PATCH 3/5] Plug remaining leaks Signed-off-by: Matt Lord --- go/mysql/auth_server_static.go | 24 ++++++++++++++++++++---- go/mysql/auth_server_static_test.go | 1 + 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/go/mysql/auth_server_static.go b/go/mysql/auth_server_static.go index d9e6decf5e5..46302bcabe1 100644 --- a/go/mysql/auth_server_static.go +++ b/go/mysql/auth_server_static.go @@ -50,8 +50,10 @@ type AuthServerStatic struct { // entries contains the users, passwords and user data. entries map[string][]*AuthServerStaticEntry + // Signal handling related fields. sigChan chan os.Signal ticker *time.Ticker + done chan struct{} // Tell the signal related goroutines to stop } // AuthServerStaticEntry stores the values for a given user. @@ -267,11 +269,17 @@ func (a *AuthServerStatic) installSignalHandlers() { return } + a.done = make(chan struct{}) a.sigChan = make(chan os.Signal, 1) signal.Notify(a.sigChan, syscall.SIGHUP) go func() { - for range a.sigChan { - a.reload() + for { + select { + case <-a.done: + return + case <-a.sigChan: + a.reload() + } } }() @@ -279,14 +287,22 @@ func (a *AuthServerStatic) installSignalHandlers() { if a.reloadInterval > 0 { a.ticker = time.NewTicker(a.reloadInterval) go func() { - for range a.ticker.C { - a.sigChan <- syscall.SIGHUP + for { + select { + case <-a.done: + return + case <-a.ticker.C: + a.sigChan <- syscall.SIGHUP + } } }() } } func (a *AuthServerStatic) close() { + if a.done != nil { + close(a.done) + } if a.ticker != nil { a.ticker.Stop() } diff --git a/go/mysql/auth_server_static_test.go b/go/mysql/auth_server_static_test.go index a5957cccfed..20b85817988 100644 --- a/go/mysql/auth_server_static_test.go +++ b/go/mysql/auth_server_static_test.go @@ -137,6 +137,7 @@ func TestStaticConfigHUP(t *testing.T) { } func TestStaticConfigHUPWithRotation(t *testing.T) { + _ = utils.LeakCheckContext(t) tmpFile, err := os.CreateTemp("", "mysql_auth_server_static_file.json") require.NoError(t, err, "couldn't create temp file: %v", err) defer os.Remove(tmpFile.Name()) From b306d8c5e5b7720525df9e33d0c7c3fd33677271 Mon Sep 17 00:00:00 2001 From: Matt Lord Date: Mon, 21 Oct 2024 16:43:48 -0400 Subject: [PATCH 4/5] Minor change from self review Signed-off-by: Matt Lord --- go/mysql/auth_server_clientcert_test.go | 5 ++--- go/mysql/replication_test.go | 1 + 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go/mysql/auth_server_clientcert_test.go b/go/mysql/auth_server_clientcert_test.go index 201c4d8dabd..72a1ecce87c 100644 --- a/go/mysql/auth_server_clientcert_test.go +++ b/go/mysql/auth_server_clientcert_test.go @@ -34,11 +34,10 @@ import ( const clientCertUsername = "Client Cert" -// The listener's accept loop actually only ends on a connection +// The listener's Accept() loop actually only ends on a connection // error, which will occur when trying to connect after the listener // has been closed. So this function closes the listener and then -// calls Connect to trigger the error which ends that conneciton -// request handler goroutine. +// calls Connect to trigger the error which ends that work. var cleanupListener = func(ctx context.Context, l *Listener, params *ConnParams) { l.Close() _, _ = Connect(ctx, params) diff --git a/go/mysql/replication_test.go b/go/mysql/replication_test.go index a5310a03764..933c40a5349 100644 --- a/go/mysql/replication_test.go +++ b/go/mysql/replication_test.go @@ -24,6 +24,7 @@ import ( "github.com/stretchr/testify/require" "vitess.io/vitess/go/test/utils" + binlogdatapb "vitess.io/vitess/go/vt/proto/binlogdata" ) From 1084086bca2d206eb6f7afcab3add10d68e11be4 Mon Sep 17 00:00:00 2001 From: Matt Lord Date: Mon, 21 Oct 2024 23:01:38 -0400 Subject: [PATCH 5/5] Add the leak checker in all of the tests Signed-off-by: Matt Lord --- go/mysql/auth_server_static_test.go | 4 ++++ go/mysql/replication_test.go | 1 + 2 files changed, 5 insertions(+) diff --git a/go/mysql/auth_server_static_test.go b/go/mysql/auth_server_static_test.go index 20b85817988..a808ce9b66b 100644 --- a/go/mysql/auth_server_static_test.go +++ b/go/mysql/auth_server_static_test.go @@ -37,6 +37,7 @@ func (a *AuthServerStatic) getEntries() map[string][]*AuthServerStaticEntry { } func TestJsonConfigParser(t *testing.T) { + _ = utils.LeakCheckContext(t) // works with legacy format config := make(map[string][]*AuthServerStaticEntry) jsonConfig := "{\"mysql_user\":{\"Password\":\"123\", \"UserData\":\"dummy\"}, \"mysql_user_2\": {\"Password\": \"123\", \"UserData\": \"mysql_user_2\"}}" @@ -69,6 +70,7 @@ func TestJsonConfigParser(t *testing.T) { } func TestValidateHashGetter(t *testing.T) { + _ = utils.LeakCheckContext(t) jsonConfig := `{"mysql_user": [{"Password": "password", "UserData": "user.name", "Groups": ["user_group"]}]}` auth := NewAuthServerStatic("", jsonConfig, 0) @@ -92,6 +94,7 @@ func TestValidateHashGetter(t *testing.T) { } func TestHostMatcher(t *testing.T) { + _ = utils.LeakCheckContext(t) ip := net.ParseIP("192.168.0.1") addr := &net.TCPAddr{IP: ip, Port: 9999} match := MatchSourceHost(net.Addr(addr), "") @@ -195,6 +198,7 @@ func hupTestWithRotation(t *testing.T, aStatic *AuthServerStatic, tmpFile *os.Fi } func TestStaticPasswords(t *testing.T) { + _ = utils.LeakCheckContext(t) jsonConfig := ` { "user01": [{ "Password": "user01" }], diff --git a/go/mysql/replication_test.go b/go/mysql/replication_test.go index 933c40a5349..c9a54485497 100644 --- a/go/mysql/replication_test.go +++ b/go/mysql/replication_test.go @@ -29,6 +29,7 @@ import ( ) func TestComBinlogDump(t *testing.T) { + _ = utils.LeakCheckContext(t) listener, sConn, cConn := createSocketPair(t) defer func() { listener.Close()