Skip to content

Commit

Permalink
mysql: Refactor out usage of servenv
Browse files Browse the repository at this point in the history
The servenv package is something that is use for server environments
like vtgate, vttablet etc. But the go/mysql package should really be
independent code for things like the MySQL protocol bits.

This change refactors things so that go/mysql doesn't depend anymore on
servenv. This makes it easier to be used as a library for example.

The remaining bits here are in collations, which is something to tackle
separately as it's very invasive.

Signed-off-by: Dirkjan Bussink <[email protected]>
  • Loading branch information
dbussink committed Dec 8, 2023
1 parent 87b047b commit 1e8b3b3
Show file tree
Hide file tree
Showing 23 changed files with 157 additions and 148 deletions.
6 changes: 5 additions & 1 deletion go/cmd/vtgate/cli/plugin_auth_clientcert.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ import (
"vitess.io/vitess/go/vt/vtgate"
)

var clientcertAuthMethod string

func init() {
vtgate.RegisterPluginInitializer(func() { mysql.InitAuthServerClientCert() })
Main.Flags().StringVar(&clientcertAuthMethod, "mysql_clientcert_auth_method", string(mysql.MysqlClearPassword), "client-side authentication method to use. Supported values: mysql_clear_password, dialog.")

vtgate.RegisterPluginInitializer(func() { mysql.InitAuthServerClientCert(clientcertAuthMethod) })
}
13 changes: 12 additions & 1 deletion go/cmd/vtgate/cli/plugin_auth_ldap.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,21 @@ package cli
// This plugin imports ldapauthserver to register the LDAP implementation of AuthServer.

import (
"vitess.io/vitess/go/mysql"
"vitess.io/vitess/go/mysql/ldapauthserver"
"vitess.io/vitess/go/vt/vtgate"
)

var (
ldapAuthConfigFile string
ldapAuthConfigString string
ldapAuthMethod string
)

func init() {
vtgate.RegisterPluginInitializer(func() { ldapauthserver.Init() })
Main.Flags().StringVar(&ldapAuthConfigFile, "mysql_ldap_auth_config_file", "", "JSON File from which to read LDAP server config.")
Main.Flags().StringVar(&ldapAuthConfigString, "mysql_ldap_auth_config_string", "", "JSON representation of LDAP server config.")
Main.Flags().StringVar(&ldapAuthMethod, "mysql_ldap_auth_method", string(mysql.MysqlClearPassword), "client-side authentication method to use. Supported values: mysql_clear_password, dialog.")

vtgate.RegisterPluginInitializer(func() { ldapauthserver.Init(ldapAuthConfigFile, ldapAuthConfigString, ldapAuthMethod) })
}
16 changes: 15 additions & 1 deletion go/cmd/vtgate/cli/plugin_auth_static.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,24 @@ package cli
// This plugin imports staticauthserver to register the flat-file implementation of AuthServer.

import (
"time"

"vitess.io/vitess/go/mysql"
"vitess.io/vitess/go/vt/vtgate"
)

var (
mysqlAuthServerStaticFile string
mysqlAuthServerStaticString string
mysqlAuthServerStaticReloadInterval time.Duration
)

func init() {
vtgate.RegisterPluginInitializer(func() { mysql.InitAuthServerStatic() })
Main.Flags().StringVar(&mysqlAuthServerStaticFile, "mysql_auth_server_static_file", "", "JSON File to read the users/passwords from.")
Main.Flags().StringVar(&mysqlAuthServerStaticString, "mysql_auth_server_static_string", "", "JSON representation of the users/passwords config.")
Main.Flags().DurationVar(&mysqlAuthServerStaticReloadInterval, "mysql_auth_static_reload_interval", 0, "Ticker to reload credentials")

vtgate.RegisterPluginInitializer(func() {
mysql.InitAuthServerStatic(mysqlAuthServerStaticFile, mysqlAuthServerStaticString, mysqlAuthServerStaticReloadInterval)
})
}
28 changes: 27 additions & 1 deletion go/cmd/vtgate/cli/plugin_auth_vault.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,36 @@ package cli
// This plugin imports InitAuthServerVault to register the HashiCorp Vault implementation of AuthServer.

import (
"time"

"vitess.io/vitess/go/mysql/vault"
"vitess.io/vitess/go/vt/vtgate"
)

var (
vaultAddr string
vaultTimeout time.Duration
vaultCACert string
vaultPath string
vaultCacheTTL time.Duration
vaultTokenFile string
vaultRoleID string
vaultRoleSecretIDFile string
vaultRoleMountPoint string
)

func init() {
vtgate.RegisterPluginInitializer(func() { vault.InitAuthServerVault() })
Main.Flags().StringVar(&vaultAddr, "mysql_auth_vault_addr", "", "URL to Vault server")
Main.Flags().DurationVar(&vaultTimeout, "mysql_auth_vault_timeout", 10*time.Second, "Timeout for vault API operations")
Main.Flags().StringVar(&vaultCACert, "mysql_auth_vault_tls_ca", "", "Path to CA PEM for validating Vault server certificate")
Main.Flags().StringVar(&vaultPath, "mysql_auth_vault_path", "", "Vault path to vtgate credentials JSON blob, e.g.: secret/data/prod/vtgatecreds")
Main.Flags().DurationVar(&vaultCacheTTL, "mysql_auth_vault_ttl", 30*time.Minute, "How long to cache vtgate credentials from the Vault server")
Main.Flags().StringVar(&vaultTokenFile, "mysql_auth_vault_tokenfile", "", "Path to file containing Vault auth token; token can also be passed using VAULT_TOKEN environment variable")
Main.Flags().StringVar(&vaultRoleID, "mysql_auth_vault_roleid", "", "Vault AppRole id; can also be passed using VAULT_ROLEID environment variable")
Main.Flags().StringVar(&vaultRoleSecretIDFile, "mysql_auth_vault_role_secretidfile", "", "Path to file containing Vault AppRole secret_id; can also be passed using VAULT_SECRETID environment variable")
Main.Flags().StringVar(&vaultRoleMountPoint, "mysql_auth_vault_role_mountpoint", "approle", "Vault AppRole mountpoint; can also be passed using VAULT_MOUNTPOINT environment variable")

vtgate.RegisterPluginInitializer(func() {
vault.InitAuthServerVault(vaultAddr, vaultTimeout, vaultCACert, vaultPath, vaultCacheTTL, vaultTokenFile, vaultRoleID, vaultRoleSecretIDFile, vaultRoleMountPoint)
})
}
1 change: 1 addition & 0 deletions go/flags/endtoend/vtcombo.txt
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ Flags:
--mysql_default_workload string Default session workload (OLTP, OLAP, DBA) (default "OLTP")
--mysql_port int mysql port (default 3306)
--mysql_server_bind_address string Binds on this address when listening to MySQL binary protocol. Useful to restrict listening to 'localhost' only for instance.
--mysql_server_flush_delay duration Delay after which buffered response will be flushed to the client. (default 100ms)
--mysql_server_port int If set, also listen for MySQL binary protocol connections on this port. (default -1)
--mysql_server_query_timeout duration mysql query timeout
--mysql_server_read_timeout duration connection read timeout
Expand Down
15 changes: 3 additions & 12 deletions go/mysql/auth_server_clientcert.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,25 +23,16 @@ import (
"github.com/spf13/pflag"

"vitess.io/vitess/go/vt/log"
"vitess.io/vitess/go/vt/servenv"
)

var clientcertAuthMethod string

func init() {
servenv.OnParseFor("vtgate", func(fs *pflag.FlagSet) {
fs.StringVar(&clientcertAuthMethod, "mysql_clientcert_auth_method", string(MysqlClearPassword), "client-side authentication method to use. Supported values: mysql_clear_password, dialog.")
})
}

// AuthServerClientCert implements AuthServer which enforces client side certificates
type AuthServerClientCert struct {
methods []AuthMethod
Method AuthMethodDescription
}

// InitAuthServerClientCert is public so it can be called from plugin_auth_clientcert.go (go/cmd/vtgate)
func InitAuthServerClientCert() {
func InitAuthServerClientCert(clientcertAuthMethod string) {
if pflag.CommandLine.Lookup("mysql_server_ssl_ca").Value.String() == "" {
log.Info("Not configuring AuthServerClientCert because mysql_server_ssl_ca is empty")
return
Expand All @@ -50,11 +41,11 @@ func InitAuthServerClientCert() {
log.Exitf("Invalid mysql_clientcert_auth_method value: only support mysql_clear_password or dialog")
}

ascc := newAuthServerClientCert()
ascc := newAuthServerClientCert(clientcertAuthMethod)
RegisterAuthServer("clientcert", ascc)
}

func newAuthServerClientCert() *AuthServerClientCert {
func newAuthServerClientCert(clientcertAuthMethod string) *AuthServerClientCert {
ascc := &AuthServerClientCert{
Method: AuthMethodDescription(clientcertAuthMethod),
}
Expand Down
14 changes: 4 additions & 10 deletions go/mysql/auth_server_clientcert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,13 @@ import (

const clientCertUsername = "Client Cert"

func init() {
// These tests do not invoke the servenv.Parse codepaths, so this default
// does not get set by the OnParseFor hook.
clientcertAuthMethod = string(MysqlClearPassword)
}

func TestValidCert(t *testing.T) {
th := &testHandler{}

authServer := newAuthServerClientCert()
authServer := newAuthServerClientCert(string(MysqlClearPassword))

// Create the listener, so we can get its host.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, mysqlVersion)
require.NoError(t, err, "NewListener failed: %v", err)
defer l.Close()
host := l.Addr().(*net.TCPAddr).IP.String()
Expand Down Expand Up @@ -111,10 +105,10 @@ func TestValidCert(t *testing.T) {
func TestNoCert(t *testing.T) {
th := &testHandler{}

authServer := newAuthServerClientCert()
authServer := newAuthServerClientCert(string(MysqlClearPassword))

// Create the listener, so we can get its host.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, mysqlVersion)
require.NoError(t, err, "NewListener failed: %v", err)
defer l.Close()
host := l.Addr().(*net.TCPAddr).IP.String()
Expand Down
21 changes: 1 addition & 20 deletions go/mysql/auth_server_static.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,34 +27,15 @@ import (
"syscall"
"time"

"github.com/spf13/pflag"

"vitess.io/vitess/go/mysql/sqlerror"

"vitess.io/vitess/go/vt/log"
"vitess.io/vitess/go/vt/servenv"
"vitess.io/vitess/go/vt/vterrors"

querypb "vitess.io/vitess/go/vt/proto/query"
"vitess.io/vitess/go/vt/proto/vtrpc"
)

var (
mysqlAuthServerStaticFile string
mysqlAuthServerStaticString string
mysqlAuthServerStaticReloadInterval time.Duration
mysqlServerFlushDelay = 100 * time.Millisecond
)

func init() {
servenv.OnParseFor("vtgate", func(fs *pflag.FlagSet) {
fs.StringVar(&mysqlAuthServerStaticFile, "mysql_auth_server_static_file", "", "JSON File to read the users/passwords from.")
fs.StringVar(&mysqlAuthServerStaticString, "mysql_auth_server_static_string", "", "JSON representation of the users/passwords config.")
fs.DurationVar(&mysqlAuthServerStaticReloadInterval, "mysql_auth_static_reload_interval", 0, "Ticker to reload credentials")
fs.DurationVar(&mysqlServerFlushDelay, "mysql_server_flush_delay", mysqlServerFlushDelay, "Delay after which buffered response will be flushed to the client.")
})
}

const (
localhostName = "localhost"
)
Expand Down Expand Up @@ -94,7 +75,7 @@ type AuthServerStaticEntry struct {
}

// InitAuthServerStatic Handles initializing the AuthServerStatic if necessary.
func InitAuthServerStatic() {
func InitAuthServerStatic(mysqlAuthServerStaticFile, mysqlAuthServerStaticString string, mysqlAuthServerStaticReloadInterval time.Duration) {
// Check parameters.
if mysqlAuthServerStaticFile == "" && mysqlAuthServerStaticString == "" {
// Not configured, nothing to do.
Expand Down
2 changes: 1 addition & 1 deletion go/mysql/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ func Connect(ctx context.Context, params *ConnParams) (*Conn, error) {
}

// Send the connection back, so the other side can close it.
c := newConn(conn)
c := newConn(conn, params.FlushDelay)
status <- connectResult{
c: c,
}
Expand Down
10 changes: 5 additions & 5 deletions go/mysql/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ func TestTLSClientDisabled(t *testing.T) {
// Below, we are enabling --ssl-verify-server-cert, which adds
// a check that the common name of the certificate matches the
// server host name we connect to.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, mysqlVersion)
require.NoError(t, err)
defer l.Close()

Expand Down Expand Up @@ -223,7 +223,7 @@ func TestTLSClientPreferredDefault(t *testing.T) {
// Below, we are enabling --ssl-verify-server-cert, which adds
// a check that the common name of the certificate matches the
// server host name we connect to.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, mysqlVersion)
require.NoError(t, err)
defer l.Close()

Expand Down Expand Up @@ -296,7 +296,7 @@ func TestTLSClientRequired(t *testing.T) {
// Below, we are enabling --ssl-verify-server-cert, which adds
// a check that the common name of the certificate matches the
// server host name we connect to.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, mysqlVersion)
require.NoError(t, err)
defer l.Close()

Expand Down Expand Up @@ -343,7 +343,7 @@ func TestTLSClientVerifyCA(t *testing.T) {
// Below, we are enabling --ssl-verify-server-cert, which adds
// a check that the common name of the certificate matches the
// server host name we connect to.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, mysqlVersion)
require.NoError(t, err)
defer l.Close()

Expand Down Expand Up @@ -426,7 +426,7 @@ func TestTLSClientVerifyIdentity(t *testing.T) {
// Below, we are enabling --ssl-verify-server-cert, which adds
// a check that the common name of the certificate matches the
// server host name we connect to.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0, mysqlVersion)
require.NoError(t, err)
defer l.Close()

Expand Down
14 changes: 11 additions & 3 deletions go/mysql/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ import (
)

const (
DefaultFlushDelay = 100 * time.Millisecond

// connBufferSize is how much we buffer for reading and
// writing. It is also how much we allocate for ephemeral buffers.
connBufferSize = 16 * 1024
Expand Down Expand Up @@ -129,6 +131,7 @@ type Conn struct {

bufferedReader *bufio.Reader
flushTimer *time.Timer
flushDelay time.Duration
header [packetHeaderSize]byte

// Keep track of how and of the buffer we allocated for an
Expand Down Expand Up @@ -247,10 +250,14 @@ var readersPool = sync.Pool{New: func() any { return bufio.NewReaderSize(nil, co

// newConn is an internal method to create a Conn. Used by client and server
// side for common creation code.
func newConn(conn net.Conn) *Conn {
func newConn(conn net.Conn, flushDelay time.Duration) *Conn {
if flushDelay == 0 {
flushDelay = DefaultFlushDelay
}
return &Conn{
conn: conn,
bufferedReader: bufio.NewReaderSize(conn, connBufferSize),
flushDelay: flushDelay,
}
}

Expand All @@ -275,6 +282,7 @@ func newServerConn(conn net.Conn, listener *Listener) *Conn {
listener: listener,
PrepareData: make(map[uint32]*PrepareData),
keepAliveOn: enabledKeepAlive,
flushDelay: listener.flushDelay,
}

if listener.connReadBufferSize > 0 {
Expand Down Expand Up @@ -348,7 +356,7 @@ func (c *Conn) returnReader() {
// startFlushTimer must be called while holding lock on bufMu.
func (c *Conn) startFlushTimer() {
if c.flushTimer == nil {
c.flushTimer = time.AfterFunc(mysqlServerFlushDelay, func() {
c.flushTimer = time.AfterFunc(c.flushDelay, func() {
c.bufMu.Lock()
defer c.bufMu.Unlock()

Expand All @@ -358,7 +366,7 @@ func (c *Conn) startFlushTimer() {
c.bufferedWriter.Flush()
})
} else {
c.flushTimer.Reset(mysqlServerFlushDelay)
c.flushTimer.Reset(c.flushDelay)
}
}

Expand Down
2 changes: 1 addition & 1 deletion go/mysql/conn_fake.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ var _ net.Addr = (*mockAddress)(nil)

// GetTestConn returns a conn for testing purpose only.
func GetTestConn() *Conn {
return newConn(testConn{})
return newConn(testConn{}, DefaultFlushDelay)
}

// GetTestServerConn is only meant to be used for testing.
Expand Down
12 changes: 6 additions & 6 deletions go/mysql/conn_flaky_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ func createSocketPair(t *testing.T) (net.Listener, *Conn, *Conn) {
require.Nil(t, serverErr, "Accept failed: %v", serverErr)

// Create a Conn on both sides.
cConn := newConn(clientConn)
sConn := newConn(serverConn)
cConn := newConn(clientConn, DefaultFlushDelay)
sConn := newConn(serverConn, DefaultFlushDelay)
sConn.PrepareData = map[uint32]*PrepareData{}

return listener, sConn, cConn
Expand Down Expand Up @@ -942,7 +942,7 @@ func TestConnectionErrorWhileWritingComQuery(t *testing.T) {
pos: -1,
queryPacket: []byte{0x21, 0x00, 0x00, 0x00, ComQuery, 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x20, 0x40, 0x40, 0x76, 0x65, 0x72, 0x73,
0x69, 0x6f, 0x6e, 0x5f, 0x63, 0x6f, 0x6d, 0x6d, 0x65, 0x6e, 0x74, 0x20, 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x20, 0x31},
})
}, DefaultFlushDelay)

// this handler will return an error on the first run, and fail the test if it's run more times
errorString := make([]byte, 17000)
Expand All @@ -958,7 +958,7 @@ func TestConnectionErrorWhileWritingComStmtSendLongData(t *testing.T) {
pos: -1,
queryPacket: []byte{0x21, 0x00, 0x00, 0x00, ComStmtSendLongData, 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x20, 0x40, 0x40, 0x76, 0x65, 0x72, 0x73,
0x69, 0x6f, 0x6e, 0x5f, 0x63, 0x6f, 0x6d, 0x6d, 0x65, 0x6e, 0x74, 0x20, 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x20, 0x31},
})
}, DefaultFlushDelay)

// this handler will return an error on the first run, and fail the test if it's run more times
handler := &testRun{t: t, err: fmt.Errorf("not used")}
Expand All @@ -972,7 +972,7 @@ func TestConnectionErrorWhileWritingComPrepare(t *testing.T) {
writeToPass: []bool{false},
pos: -1,
queryPacket: []byte{0x01, 0x00, 0x00, 0x00, ComPrepare},
})
}, DefaultFlushDelay)
sConn.Capabilities = sConn.Capabilities | CapabilityClientMultiStatements
// this handler will return an error on the first run, and fail the test if it's run more times
handler := &testRun{t: t, err: fmt.Errorf("not used")}
Expand All @@ -987,7 +987,7 @@ func TestConnectionErrorWhileWritingComStmtExecute(t *testing.T) {
pos: -1,
queryPacket: []byte{0x21, 0x00, 0x00, 0x00, ComStmtExecute, 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x20, 0x40, 0x40, 0x76, 0x65, 0x72, 0x73,
0x69, 0x6f, 0x6e, 0x5f, 0x63, 0x6f, 0x6d, 0x6d, 0x65, 0x6e, 0x74, 0x20, 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x20, 0x31},
})
}, DefaultFlushDelay)
// this handler will return an error on the first run, and fail the test if it's run more times
handler := &testRun{t: t, err: fmt.Errorf("not used")}
res := sConn.handleNextCommand(handler)
Expand Down
Loading

0 comments on commit 1e8b3b3

Please sign in to comment.