From 9bb7ed0fe8cc285d083421bbb0e1ac2f77a1ae30 Mon Sep 17 00:00:00 2001 From: oleiade Date: Wed, 1 Nov 2023 11:59:58 +0100 Subject: [PATCH] Ensure the redis TLS connection uses k6's netext.Dialer under the hood --- redis/client.go | 99 ++++++++++++++++++++++++++++++-------------- redis/client_test.go | 51 +++++++++++++++++++++++ 2 files changed, 118 insertions(+), 32 deletions(-) diff --git a/redis/client.go b/redis/client.go index b22848d..ebc1d68 100644 --- a/redis/client.go +++ b/redis/client.go @@ -1,15 +1,17 @@ package redis import ( + "context" "crypto/tls" "fmt" + "net" "time" "github.com/dop251/goja" "github.com/redis/go-redis/v9" "go.k6.io/k6/js/common" "go.k6.io/k6/js/modules" - "go.k6.io/k6/lib/netext" + "go.k6.io/k6/lib" ) // Client represents the Client constructor (i.e. `new redis.Client()`) and @@ -1080,37 +1082,34 @@ func (c *Client) connect() error { } tlsCfg := c.redisOptions.TLSConfig - if tlsCfg != nil { - if vuState.TLSConfig != nil { - // Merge k6 TLS configuration with the one we received from the - // Client constructor. This will need adjusting depending on which - // options we want to expose in the Redis module, and how we want - // the override to work. - tlsCfg.InsecureSkipVerify = vuState.TLSConfig.InsecureSkipVerify - tlsCfg.CipherSuites = vuState.TLSConfig.CipherSuites - tlsCfg.MinVersion = vuState.TLSConfig.MinVersion - tlsCfg.MaxVersion = vuState.TLSConfig.MaxVersion - tlsCfg.Renegotiation = vuState.TLSConfig.Renegotiation - tlsCfg.KeyLogWriter = vuState.TLSConfig.KeyLogWriter - - tlsCfg.Certificates = append(tlsCfg.Certificates, vuState.TLSConfig.Certificates...) - - // TODO: Merge vuState.TLSConfig.RootCAs with - // c.redisOptions.TLSConfig. k6 currently doesn't allow setting - // this, so it doesn't matter right now, but these should be merged. - // I couldn't find a way to do this with the x509.CertPool API - // though... - } - - k6dialer, ok := vuState.Dialer.(*netext.Dialer) - if !ok { - panic(fmt.Sprintf("expected *netext.Dialer, got: %T", vuState.Dialer)) - } - tlsDialer := &tls.Dialer{ - NetDialer: &k6dialer.Dialer, - Config: tlsCfg, - } - c.redisOptions.Dialer = tlsDialer.DialContext + if tlsCfg != nil && vuState.TLSConfig != nil { + // Merge k6 TLS configuration with the one we received from the + // Client constructor. This will need adjusting depending on which + // options we want to expose in the Redis module, and how we want + // the override to work. + tlsCfg.InsecureSkipVerify = vuState.TLSConfig.InsecureSkipVerify + tlsCfg.CipherSuites = vuState.TLSConfig.CipherSuites + tlsCfg.MinVersion = vuState.TLSConfig.MinVersion + tlsCfg.MaxVersion = vuState.TLSConfig.MaxVersion + tlsCfg.Renegotiation = vuState.TLSConfig.Renegotiation + tlsCfg.KeyLogWriter = vuState.TLSConfig.KeyLogWriter + tlsCfg.Certificates = append(tlsCfg.Certificates, vuState.TLSConfig.Certificates...) + + // TODO: Merge vuState.TLSConfig.RootCAs with + // c.redisOptions.TLSConfig. k6 currently doesn't allow setting + // this, so it doesn't matter right now, but these should be merged. + // I couldn't find a way to do this with the x509.CertPool API + // though... + + // In order to preserve the underlying effects of the [netext.Dialer], such + // as handling blocked hostnames, or handling hostname resolution, we override + // the redis client's dialer with our own function which uses the VU's [netext.Dialer] + // and manually upgrades the connection to TLS. + // + // See Pull Request's #17 [discussion] for more details. + // + // [discussion]: https://github.com/grafana/xk6-redis/pull/17#discussion_r1369707388 + c.redisOptions.Dialer = c.upgradeDialerToTLS(vuState.Dialer, tlsCfg) } else { c.redisOptions.Dialer = vuState.Dialer.DialContext } @@ -1154,3 +1153,39 @@ func (c *Client) isSupportedType(offset int, args ...interface{}) error { return nil } + +// DialContextFunc is a function that can be used to dial a connection to a redis server. +type DialContextFunc func(ctx context.Context, network, addr string) (net.Conn, error) + +// upgradeDialerToTLS returns a DialContextFunc that uses the provided dialer to +// establish a connection, and then upgrades it to TLS using the provided config. +// +// We use this function to make sure the k6 [netext.Dialer], our redis module uses to establish +// the connection and handle network-related options such as blocked hostnames, +// or hostname resolution, but we also want to use the TLS configuration provided +// by the user. +func (c *Client) upgradeDialerToTLS(dialer lib.DialContexter, config *tls.Config) DialContextFunc { + return func(ctx context.Context, network string, addr string) (net.Conn, error) { + // Use netext.Dialer to establish the connection + rawConn, err := dialer.DialContext(ctx, network, addr) + if err != nil { + return nil, err + } + + // Upgrade the connection to TLS if needed + tlsConn := tls.Client(rawConn, config) + err = tlsConn.Handshake() + if err != nil { + if closeErr := rawConn.Close(); closeErr != nil { + return nil, fmt.Errorf("failed to close connection after TLS handshake error: %w", closeErr) + } + + return nil, err + } + + // Overwrite rawConn with the TLS connection + rawConn = tlsConn + + return rawConn, nil + } +} diff --git a/redis/client_test.go b/redis/client_test.go index 0e6d312..de63fc2 100644 --- a/redis/client_test.go +++ b/redis/client_test.go @@ -2560,6 +2560,7 @@ type testSetup struct { state *lib.State samples chan metrics.SampleContainer ev *eventloop.EventLoop + tb *httpmultibin.HTTPMultiBin } // newTestSetup initializes a new test setup. @@ -2618,6 +2619,7 @@ func newTestSetup(t testing.TB) testSetup { state: state, samples: samples, ev: ev, + tb: tb, } } @@ -2732,3 +2734,52 @@ func TestClientTLSAuth(t *testing.T) { {"PING"}, }, rs.GotCommands()) } + +func TestClientTLSRespectsNetworkOPtions(t *testing.T) { + t.Parallel() + + clientCert, clientPKey, err := generateTLSCert() + require.NoError(t, err) + + ts := newTestSetup(t) + rs := RunTSecure(t, clientCert) + + err = ts.rt.Set("caCert", string(rs.TLSCertificate())) + require.NoError(t, err) + err = ts.rt.Set("clientCert", string(clientCert)) + require.NoError(t, err) + err = ts.rt.Set("clientPKey", string(clientPKey)) + require.NoError(t, err) + + // Set the redis server's IP to be blacklisted. + net, err := lib.ParseCIDR(rs.Addr().IP.String() + "/32") + require.NoError(t, err) + ts.tb.Dialer.Blacklist = []*lib.IPNet{net} + + gotScriptErr := ts.ev.Start(func() error { + _, err := ts.rt.RunString(fmt.Sprintf(` + const redis = new Client({ + socket: { + host: '%s', + port: %d, + tls: { + ca: [caCert], + cert: clientCert, + key: clientPKey + } + } + }); + + // This operation triggers a connection to the redis + // server under the hood, and should therefore fail, since + // the server's IP is blacklisted by k6. + redis.sendCommand("PING") + `, rs.Addr().IP.String(), rs.Addr().Port)) + + return err + }) + + assert.Error(t, gotScriptErr) + assert.ErrorContains(t, gotScriptErr, "IP ("+rs.Addr().IP.String()+") is in a blacklisted range") + assert.Equal(t, 0, rs.HandledCommandsCount()) +}