Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flakes: Setup new fake server if it has gone away #17023

Merged
merged 5 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 39 additions & 31 deletions go/mysql/auth_server_clientcert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,24 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"vitess.io/vitess/go/test/utils"
"vitess.io/vitess/go/vt/tlstest"
"vitess.io/vitess/go/vt/vttls"
)

const clientCertUsername = "Client Cert"

// The listener's Accept() loop actually only ends on a connection
// error, which will occur when trying to connect after the listener
// has been closed. So this function closes the listener and then
// calls Connect to trigger the error which ends that work.
var cleanupListener = func(ctx context.Context, l *Listener, params *ConnParams) {
l.Close()
_, _ = Connect(ctx, params)
}

func TestValidCert(t *testing.T) {
ctx := utils.LeakCheckContext(t)
th := &testHandler{}

authServer := newAuthServerClientCert(string(MysqlClearPassword))
Expand All @@ -52,21 +63,6 @@ func TestValidCert(t *testing.T) {
tlstest.CreateSignedCert(root, tlstest.CA, "02", "client", clientCertUsername)
tlstest.CreateCRL(root, tlstest.CA)

// Create the server with TLS config.
serverConfig, err := vttls.ServerConfig(
path.Join(root, "server-cert.pem"),
path.Join(root, "server-key.pem"),
path.Join(root, "ca-cert.pem"),
path.Join(root, "ca-crl.pem"),
"",
tls.VersionTLS12)
require.NoError(t, err, "TLSServerConfig failed: %v", err)

l.TLSConfig.Store(serverConfig)
go func() {
l.Accept()
}()

// Setup the right parameters.
params := &ConnParams{
Host: host,
Expand All @@ -81,7 +77,20 @@ func TestValidCert(t *testing.T) {
ServerName: "server.example.com",
}

ctx := context.Background()
// Create the server with TLS config.
serverConfig, err := vttls.ServerConfig(
path.Join(root, "server-cert.pem"),
path.Join(root, "server-key.pem"),
path.Join(root, "ca-cert.pem"),
path.Join(root, "ca-crl.pem"),
"",
tls.VersionTLS12)
require.NoError(t, err, "TLSServerConfig failed: %v", err)

l.TLSConfig.Store(serverConfig)
go l.Accept()
defer cleanupListener(ctx, l, params)

conn, err := Connect(ctx, params)
require.NoError(t, err, "Connect failed: %v", err)

Expand All @@ -103,6 +112,7 @@ func TestValidCert(t *testing.T) {
}

func TestNoCert(t *testing.T) {
ctx := utils.LeakCheckContext(t)
th := &testHandler{}

authServer := newAuthServerClientCert(string(MysqlClearPassword))
Expand All @@ -120,6 +130,17 @@ func TestNoCert(t *testing.T) {
tlstest.CreateSignedCert(root, tlstest.CA, "01", "server", "server.example.com")
tlstest.CreateCRL(root, tlstest.CA)

// Setup the right parameters.
params := &ConnParams{
Host: host,
Port: port,
Uname: "user1",
Pass: "",
SslMode: vttls.VerifyIdentity,
SslCa: path.Join(root, "ca-cert.pem"),
ServerName: "server.example.com",
}

// Create the server with TLS config.
serverConfig, err := vttls.ServerConfig(
path.Join(root, "server-cert.pem"),
Expand All @@ -131,22 +152,9 @@ func TestNoCert(t *testing.T) {
require.NoError(t, err, "TLSServerConfig failed: %v", err)

l.TLSConfig.Store(serverConfig)
go func() {
l.Accept()
}()

// Setup the right parameters.
params := &ConnParams{
Host: host,
Port: port,
Uname: "user1",
Pass: "",
SslMode: vttls.VerifyIdentity,
SslCa: path.Join(root, "ca-cert.pem"),
ServerName: "server.example.com",
}
go l.Accept()
defer cleanupListener(ctx, l, params)

ctx := context.Background()
conn, err := Connect(ctx, params)
assert.Error(t, err, "Connect() should have errored due to no client cert")

Expand Down
24 changes: 20 additions & 4 deletions go/mysql/auth_server_static.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,10 @@ type AuthServerStatic struct {
// entries contains the users, passwords and user data.
entries map[string][]*AuthServerStaticEntry

// Signal handling related fields.
sigChan chan os.Signal
ticker *time.Ticker
done chan struct{} // Tell the signal related goroutines to stop
}

// AuthServerStaticEntry stores the values for a given user.
Expand Down Expand Up @@ -267,26 +269,40 @@ func (a *AuthServerStatic) installSignalHandlers() {
return
}

a.done = make(chan struct{})
a.sigChan = make(chan os.Signal, 1)
signal.Notify(a.sigChan, syscall.SIGHUP)
go func() {
for range a.sigChan {
a.reload()
for {
select {
case <-a.done:
return
case <-a.sigChan:
a.reload()
}
}
}()

// If duration is set, it will reload configuration every interval
if a.reloadInterval > 0 {
a.ticker = time.NewTicker(a.reloadInterval)
go func() {
for range a.ticker.C {
a.sigChan <- syscall.SIGHUP
for {
select {
case <-a.done:
return
case <-a.ticker.C:
a.sigChan <- syscall.SIGHUP
}
}
}()
}
}

func (a *AuthServerStatic) close() {
if a.done != nil {
close(a.done)
}
if a.ticker != nil {
a.ticker.Stop()
}
Expand Down
29 changes: 25 additions & 4 deletions go/mysql/auth_server_static_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ import (
"time"

"github.com/stretchr/testify/require"

"vitess.io/vitess/go/test/utils"
)

// getEntries is a test-only method for AuthServerStatic.
Expand All @@ -35,6 +37,7 @@ func (a *AuthServerStatic) getEntries() map[string][]*AuthServerStaticEntry {
}

func TestJsonConfigParser(t *testing.T) {
_ = utils.LeakCheckContext(t)
// works with legacy format
config := make(map[string][]*AuthServerStaticEntry)
jsonConfig := "{\"mysql_user\":{\"Password\":\"123\", \"UserData\":\"dummy\"}, \"mysql_user_2\": {\"Password\": \"123\", \"UserData\": \"mysql_user_2\"}}"
Expand Down Expand Up @@ -67,6 +70,7 @@ func TestJsonConfigParser(t *testing.T) {
}

func TestValidateHashGetter(t *testing.T) {
_ = utils.LeakCheckContext(t)
jsonConfig := `{"mysql_user": [{"Password": "password", "UserData": "user.name", "Groups": ["user_group"]}]}`

auth := NewAuthServerStatic("", jsonConfig, 0)
Expand All @@ -90,6 +94,7 @@ func TestValidateHashGetter(t *testing.T) {
}

func TestHostMatcher(t *testing.T) {
_ = utils.LeakCheckContext(t)
ip := net.ParseIP("192.168.0.1")
addr := &net.TCPAddr{IP: ip, Port: 9999}
match := MatchSourceHost(net.Addr(addr), "")
Expand All @@ -105,9 +110,9 @@ func TestHostMatcher(t *testing.T) {
}

func TestStaticConfigHUP(t *testing.T) {
_ = utils.LeakCheckContext(t)
tmpFile, err := os.CreateTemp("", "mysql_auth_server_static_file.json")
require.NoError(t, err, "couldn't create temp file: %v", err)

defer os.Remove(tmpFile.Name())

oldStr := "str5"
Expand All @@ -125,14 +130,19 @@ func TestStaticConfigHUP(t *testing.T) {

mu.Lock()
defer mu.Unlock()
// delete registered Auth server
clear(authServers)
// Delete registered Auth servers.
for k, v := range authServers {
if s, ok := v.(*AuthServerStatic); ok {
s.close()
}
delete(authServers, k)
}
}

func TestStaticConfigHUPWithRotation(t *testing.T) {
_ = utils.LeakCheckContext(t)
tmpFile, err := os.CreateTemp("", "mysql_auth_server_static_file.json")
require.NoError(t, err, "couldn't create temp file: %v", err)

defer os.Remove(tmpFile.Name())

oldStr := "str1"
Expand All @@ -147,6 +157,16 @@ func TestStaticConfigHUPWithRotation(t *testing.T) {

hupTestWithRotation(t, aStatic, tmpFile, oldStr, "str4")
hupTestWithRotation(t, aStatic, tmpFile, "str4", "str5")

mu.Lock()
defer mu.Unlock()
// Delete registered Auth servers.
for k, v := range authServers {
if s, ok := v.(*AuthServerStatic); ok {
s.close()
}
delete(authServers, k)
}
}

func hupTest(t *testing.T, aStatic *AuthServerStatic, tmpFile *os.File, oldStr, newStr string) {
Expand Down Expand Up @@ -178,6 +198,7 @@ func hupTestWithRotation(t *testing.T, aStatic *AuthServerStatic, tmpFile *os.Fi
}

func TestStaticPasswords(t *testing.T) {
_ = utils.LeakCheckContext(t)
jsonConfig := `
{
"user01": [{ "Password": "user01" }],
Expand Down
29 changes: 11 additions & 18 deletions go/mysql/handshake_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import (
// This file tests the handshake scenarios between our client and our server.

func TestClearTextClientAuth(t *testing.T) {
ctx := utils.LeakCheckContext(t)
th := &testHandler{}

authServer := NewAuthServerStaticWithAuthMethodDescription("", "", 0, MysqlClearPassword)
Expand All @@ -51,10 +52,6 @@ func TestClearTextClientAuth(t *testing.T) {
defer l.Close()
host := l.Addr().(*net.TCPAddr).IP.String()
port := l.Addr().(*net.TCPAddr).Port
go func() {
l.Accept()
}()

// Setup the right parameters.
params := &ConnParams{
Host: host,
Expand All @@ -63,9 +60,10 @@ func TestClearTextClientAuth(t *testing.T) {
Pass: "password1",
SslMode: vttls.Disabled,
}
go l.Accept()
defer cleanupListener(ctx, l, params)

// Connection should fail, as server requires SSL for clear text auth.
ctx := context.Background()
_, err = Connect(ctx, params)
if err == nil || !strings.Contains(err.Error(), "Cannot use clear text authentication over non-SSL connections") {
t.Fatalf("unexpected connection error: %v", err)
Expand All @@ -92,6 +90,7 @@ func TestClearTextClientAuth(t *testing.T) {
// TestSSLConnection creates a server with TLS support, a client that
// also has SSL support, and connects them.
func TestSSLConnection(t *testing.T) {
ctx := utils.LeakCheckContext(t)
th := &testHandler{}

authServer := NewAuthServerStaticWithAuthMethodDescription("", "", 0, MysqlClearPassword)
Expand All @@ -103,7 +102,6 @@ func TestSSLConnection(t *testing.T) {
// Create the listener, so we can get its host.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0)
require.NoError(t, err, "NewListener failed: %v", err)
defer l.Close()
host := l.Addr().(*net.TCPAddr).IP.String()
port := l.Addr().(*net.TCPAddr).Port

Expand All @@ -122,12 +120,6 @@ func TestSSLConnection(t *testing.T) {
"",
tls.VersionTLS12)
require.NoError(t, err, "TLSServerConfig failed: %v", err)

l.TLSConfig.Store(serverConfig)
go func() {
l.Accept()
}()

// Setup the right parameters.
params := &ConnParams{
Host: host,
Expand All @@ -141,20 +133,22 @@ func TestSSLConnection(t *testing.T) {
SslKey: path.Join(root, "client-key.pem"),
ServerName: "server.example.com",
}
l.TLSConfig.Store(serverConfig)
go l.Accept()
defer cleanupListener(ctx, l, params)

t.Run("Basics", func(t *testing.T) {
testSSLConnectionBasics(t, params)
testSSLConnectionBasics(t, ctx, params)
})

// Make sure clear text auth works over SSL.
t.Run("ClearText", func(t *testing.T) {
testSSLConnectionClearText(t, params)
testSSLConnectionClearText(t, ctx, params)
})
}

func testSSLConnectionClearText(t *testing.T, params *ConnParams) {
func testSSLConnectionClearText(t *testing.T, ctx context.Context, params *ConnParams) {
// Create a client connection, connect.
ctx := context.Background()
conn, err := Connect(ctx, params)
require.NoError(t, err, "Connect failed: %v", err)

Expand All @@ -170,9 +164,8 @@ func testSSLConnectionClearText(t *testing.T, params *ConnParams) {
conn.writeComQuit()
}

func testSSLConnectionBasics(t *testing.T, params *ConnParams) {
func testSSLConnectionBasics(t *testing.T, ctx context.Context, params *ConnParams) {
// Create a client connection, connect.
ctx := context.Background()
conn, err := Connect(ctx, params)
require.NoError(t, err, "Connect failed: %v", err)

Expand Down
5 changes: 5 additions & 0 deletions go/mysql/replication_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,13 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"vitess.io/vitess/go/test/utils"

binlogdatapb "vitess.io/vitess/go/vt/proto/binlogdata"
)

func TestComBinlogDump(t *testing.T) {
_ = utils.LeakCheckContext(t)
listener, sConn, cConn := createSocketPair(t)
defer func() {
listener.Close()
Expand Down Expand Up @@ -72,6 +75,7 @@ func TestComBinlogDump(t *testing.T) {
}

func TestComBinlogDumpGTID(t *testing.T) {
_ = utils.LeakCheckContext(t)
listener, sConn, cConn := createSocketPair(t)
defer func() {
listener.Close()
Expand Down Expand Up @@ -161,6 +165,7 @@ func TestComBinlogDumpGTID(t *testing.T) {
}

func TestSendSemiSyncAck(t *testing.T) {
_ = utils.LeakCheckContext(t)
listener, sConn, cConn := createSocketPair(t)
defer func() {
listener.Close()
Expand Down
Loading
Loading