Skip to content

Commit

Permalink
Merge pull request #131236 from cockroachdb/blathers/backport-release…
Browse files Browse the repository at this point in the history
…-24.2.3-rc-131221

release-24.2.3-rc: util: don't panic on IPv6 entries in cidr mapping
  • Loading branch information
andrewbaptist authored Sep 23, 2024
2 parents 0963c2b + 611bce8 commit b4abaa2
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 69 deletions.
46 changes: 8 additions & 38 deletions pkg/util/cidr/cidr.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ package cidr

import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
io "io"
Expand Down Expand Up @@ -283,7 +282,7 @@ func (c *Lookup) setDestinations(ctx context.Context, contents []byte) error {
if err := json.Unmarshal(contents, &destinations); err != nil {
return err
}
// TODO(baptist): This only handles IPv4. We could change to 128 if we want
// TODO(#130814): This only handles IPv4. We could change to 128 if we want
// to handle IPv6.
byLength := make([]map[string]string, 33)
for i := range 33 {
Expand All @@ -295,6 +294,9 @@ func (c *Lookup) setDestinations(ctx context.Context, contents []byte) error {
return err
}
lenBits, _ := cidr.Mask.Size()
if lenBits > 32 {
return fmt.Errorf("invalid mask size: %d", lenBits)
}
mask := net.CIDRMask(lenBits, 32)
val := hexString(cidr.IP.Mask(mask))
byLength[lenBits][val] = d.Name
Expand Down Expand Up @@ -334,6 +336,10 @@ func (c *Lookup) onChange(ctx context.Context) {
func (c *Lookup) LookupIP(ip net.IP) string {
byLength := *c.byLength.Load()
ip = ip.To4()
// Don't map IPv6 addresses.
if ip == nil {
return ""
}
for i := len(byLength) - 1; i >= 0; i-- {
m := (byLength)[i]
if len(m) == 0 {
Expand Down Expand Up @@ -400,42 +406,6 @@ func (m *NetMetrics) Wrap(dial DialContext, labels ...string) DialContext {
}
}

// WrapTLS is like Wrap, but can be used if the underlying library doesn't
// expose a way to plug in a dialer for TLS connections. This is unfortunately
// pretty ugly... Copied from tls.Dial and kgo.DialTLS because they don't expose
// a dial call with a DialContext. Ideally you don't have to use this if the
// third party API does a sensible thing and exposes the ability to replace the
// "DialContext" directly.
func (m *NetMetrics) WrapTLS(dial DialContext, tlsCfg *tls.Config, labels ...string) DialContext {
return func(ctx context.Context, network, host string) (net.Conn, error) {
c := tlsCfg.Clone()
if c.ServerName == "" {
server, _, err := net.SplitHostPort(host)
if err != nil {
return nil, fmt.Errorf("unable to split host:port for dialing: %w", err)
}
c.ServerName = server
}

rawConn, err := dial(ctx, network, host)
if err != nil {
return nil, err
}
scopedConn := rawConn
// m can be nil in tests.
if m != nil {
scopedConn = m.track(rawConn, labels...)
}

conn := tls.Client(scopedConn, c)
if err := conn.HandshakeContext(ctx); err != nil {
scopedConn.Close()
return nil, err
}
return conn, nil
}
}

type Dialer interface {
Dial(network, addr string) (c net.Conn, err error)
DialContext(ctx context.Context, network, addr string) (c net.Conn, err error)
Expand Down
33 changes: 2 additions & 31 deletions pkg/util/cidr/cidr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ package cidr

import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
Expand Down Expand Up @@ -54,6 +53,7 @@ func TestCIDRLookup(t *testing.T) {
{"10.0.0.2", "CIDR3"},
{"10.0.0.1", "CIDR4"},
{"172.16.0.1", ""},
{"2001:0db8:0a0b:12f0:0000:0000:0000:0001", ""},
}
for _, tc := range testCases {
t.Run(tc.ip, func(t *testing.T) {
Expand Down Expand Up @@ -92,6 +92,7 @@ func TestInvalidCIDR(t *testing.T) {
{"int name ", `[ { "Name": 1, "Ipnet": "192.168.0.0/24" } ]`},
{"missing cidr", `[ { Name: "CIDR1" } ]`},
{"malformed cidr", `[ { "Name": "CIDR1", "Ipnet": "192.168.0.0.1/24" } ]`},
{"ipv6", `[ { "Name": "CIDR1", "Ipnet": "2001:db8::/40" } ]`},
}
c := Lookup{}
for _, tc := range testCases {
Expand Down Expand Up @@ -218,36 +219,6 @@ func TestWrapHTTP(t *testing.T) {
require.Greater(t, m.mu.childMetrics["foo/test"].ReadBytes.Value(), int64(1))
}

// TestWrapHTTP validates the wrapping function for HTTP connections.
func TestWrapHTTPS(t *testing.T) {
s := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
defer s.Close()
// Create a mapping for this server's IP.
mapping := fmt.Sprintf(`[ { "Name": "test", "Ipnet": "%s/32" } ]`, s.Listener.Addr().(*net.TCPAddr).IP.String())
c := Lookup{}
require.NoError(t, c.setDestinations(context.Background(), []byte(mapping)))

// This is the standard way to wrap the transport.
m := c.MakeNetMetrics(writeBytes, readBytes, "label")
transport := http.DefaultTransport.(*http.Transport).Clone()
transport.DialTLSContext = m.WrapTLS(transport.DialContext, &tls.Config{InsecureSkipVerify: true}, "foo")

// Create a simple get request.
client := &http.Client{Transport: transport}
_, err := client.Get(s.URL)
require.NoError(t, err)

// Ideally we could check the actual value, but the header includes the date
// and could be flaky.
require.Greater(t, m.WriteBytes.Count(), int64(1))
require.Greater(t, m.ReadBytes.Count(), int64(1))
// Also check the child metrics by looking up in the map directly.
m.mu.Lock()
defer m.mu.Unlock()
require.Greater(t, m.mu.childMetrics["foo/test"].WriteBytes.Value(), int64(1))
require.Greater(t, m.mu.childMetrics["foo/test"].ReadBytes.Value(), int64(1))
}

func TestWrapDialer(t *testing.T) {
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
defer s.Close()
Expand Down

0 comments on commit b4abaa2

Please sign in to comment.