Skip to content

Commit

Permalink
Properly support mTLS when scraping metrics
Browse files Browse the repository at this point in the history
Verify server signatures. Also move the timeout out of the HTTP client
into the Scrape method's context.

Signed-off-by: Tom Wieczorek <[email protected]>
  • Loading branch information
twz123 committed Aug 1, 2024
1 parent 0279b4b commit 5064d5e
Showing 1 changed file with 104 additions and 40 deletions.
144 changes: 104 additions & 40 deletions pkg/component/controller/metrics/component.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@ package metrics
import (
"context"
"crypto/tls"
"crypto/x509"
_ "embed"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"os"
Expand All @@ -31,6 +34,7 @@ import (
"time"

"github.com/k0sproject/k0s/internal/pkg/dir"
internalnet "github.com/k0sproject/k0s/internal/pkg/net"
"github.com/k0sproject/k0s/internal/pkg/templatewriter"
"github.com/k0sproject/k0s/pkg/apis/k0s/v1beta1"
"github.com/k0sproject/k0s/pkg/component/manager"
Expand Down Expand Up @@ -87,19 +91,27 @@ func NewComponent(k0sVars *config.CfgVars, clientCF kubernetes.ClientFactoryInte
}

// Init does nothing
func (c *Component) Init(_ context.Context) error {
func (c *Component) Init(ctx context.Context) error {
if err := dir.Init(filepath.Join(c.K0sVars.ManifestsDir, "metrics"), constant.ManifestsDirMode); err != nil {
return err
}

loopbackIP, err := internalnet.LookupLoopbackIP(ctx)
if err != nil {
if errors.Is(err, ctx.Err()) {
return err
}
c.log.WithError(err).Errorf("Falling back to %s as bind address", loopbackIP)
}

var j *job
j, err := c.newJob("https://localhost:10259/metrics")
j, err = c.newKubernetesJob(fmt.Sprintf("https://%s/metrics", net.JoinHostPort(loopbackIP.String(), "10259")))
if err != nil {
return err
}
c.jobs["kube-scheduler"] = j

j, err = c.newJob("https://localhost:10257/metrics")
j, err = c.newKubernetesJob(fmt.Sprintf("https://%s/metrics", net.JoinHostPort(loopbackIP.String(), "10257")))
if err != nil {
return err
}
Expand Down Expand Up @@ -168,50 +180,79 @@ func (c *Component) Reconcile(_ context.Context, clusterConfig *v1beta1.ClusterC
}

type job struct {
scrapeURL string
scrapeClient *http.Client
scrapeURL string
scrapeClient *http.Client
scrapeTimeout time.Duration
}

func (c *Component) newEtcdJob() (*job, error) {
certFile := path.Join(c.K0sVars.CertRootDir, "apiserver-etcd-client.crt")
keyFile := path.Join(c.K0sVars.CertRootDir, "apiserver-etcd-client.key")

httpClient, err := getClient(certFile, keyFile)
rootCAs, err := c.loadRootCAs()
if err != nil {
return nil, err
}
clientCerts, err := c.loadClientCerts("apiserver-etcd-client")
if err != nil {
return nil, err
}

return &job{
scrapeURL: "https://localhost:2379/metrics",
scrapeClient: httpClient,
scrapeURL: "https://localhost:2379/metrics",
scrapeClient: newHttpClient(&tls.Config{
RootCAs: rootCAs,
Certificates: clientCerts,
}),
scrapeTimeout: 1 * time.Minute,
}, nil
}

func (c *Component) newKineJob() (*job, error) {
httpClient, err := getClient("", "")
return &job{
scrapeURL: "http://localhost:2380/metrics",
scrapeClient: newHttpClient(nil),
scrapeTimeout: 1 * time.Minute,
}, nil
}

func (c *Component) newKubernetesJob(scrapeURL string) (*job, error) {
rootCAs, err := c.loadRootCAs()
if err != nil {
return nil, err
}
clientCerts, err := c.loadClientCerts("admin")
if err != nil {
return nil, err
}

return &job{
scrapeURL: "http://localhost:2380/metrics",
scrapeClient: httpClient,
scrapeURL: scrapeURL,
scrapeClient: newHttpClient(&tls.Config{
RootCAs: rootCAs,
Certificates: clientCerts,
}),
scrapeTimeout: 1 * time.Minute,
}, nil
}

func (c *Component) newJob(scrapeURL string) (*job, error) {
certFile := path.Join(c.K0sVars.CertRootDir, "admin.crt")
keyFile := path.Join(c.K0sVars.CertRootDir, "admin.key")
func (c *Component) loadRootCAs() (*x509.CertPool, error) {
rootCAs := x509.NewCertPool()
if rootCA, err := os.ReadFile(filepath.Join(c.K0sVars.CertRootDir, "ca.crt")); err != nil {
return nil, fmt.Errorf("failed to load root TLS certificates: %w", err)
} else if ok := rootCAs.AppendCertsFromPEM(rootCA); !ok {
return nil, fmt.Errorf("failed to append root TLS certificates to pool")
}

return rootCAs, nil
}

httpClient, err := getClient(certFile, keyFile)
func (c *Component) loadClientCerts(name string) ([]tls.Certificate, error) {
certFile := path.Join(c.K0sVars.CertRootDir, name+".crt")
keyFile := path.Join(c.K0sVars.CertRootDir, name+".key")
clientCert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to load key pair: %w", err)
}

return &job{
scrapeURL: scrapeURL,
scrapeClient: httpClient,
}, nil
return []tls.Certificate{clientCert}, nil
}

func (c *Component) run(ctx context.Context, jobName string, s Scraper) {
Expand All @@ -231,6 +272,9 @@ func (c *Component) run(ctx context.Context, jobName string, s Scraper) {
}

func (c *Component) collectAndPush(ctx context.Context, jobName string, s Scraper) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()

metrics, err := s.Scrape(ctx)
if err != nil {
return err
Expand All @@ -248,7 +292,14 @@ func (c *Component) collectAndPush(ctx context.Context, jobName string, s Scrape
return nil
}

func (j *job) Scrape(ctx context.Context) (io.ReadCloser, error) {
func (j *job) Scrape(ctx context.Context) (_ io.ReadCloser, err error) {
ctx, cancel := context.WithTimeout(ctx, j.scrapeTimeout)
defer func() {
if err != nil {
cancel()
}
}()

req, err := http.NewRequestWithContext(ctx, http.MethodGet, j.scrapeURL, nil)
if err != nil {
return nil, fmt.Errorf("error creating %s request for %s: %w", http.MethodGet, j.scrapeURL, err)
Expand All @@ -257,7 +308,7 @@ func (j *job) Scrape(ctx context.Context) (io.ReadCloser, error) {
if resp, err := j.scrapeClient.Do(req); err != nil {
return nil, err
} else if resp.StatusCode >= 200 && resp.StatusCode < 300 {
return resp.Body, nil
return &cancelingReadCloser{resp.Body, cancel}, nil
} else {
resp.Body.Close()
return nil, &url.Error{
Expand All @@ -268,21 +319,34 @@ func (j *job) Scrape(ctx context.Context) (io.ReadCloser, error) {
}
}

func getClient(certFile, keyFile string) (*http.Client, error) {
transport := http.DefaultTransport.(*http.Transport).Clone()
transport.ResponseHeaderTimeout = time.Minute
tlsConfig := &tls.Config{InsecureSkipVerify: true}
transport.TLSClientConfig = tlsConfig

if certFile != "" && keyFile != "" {
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return nil, err
}
tlsConfig.Certificates = []tls.Certificate{cert}
func newHttpClient(tlsConfig *tls.Config) *http.Client {
if tlsConfig == nil {
tlsConfig = new(tls.Config)
}
tlsConfig.MinVersion = tls.VersionTLS12
tlsConfig.CipherSuites = constant.AllowedTLS12CipherSuiteIDs

return &http.Client{
Transport: transport,
Timeout: time.Minute,
}, nil
Transport: &http.Transport{
TLSClientConfig: tlsConfig,
DisableCompression: true, // This is to be used on loopback connections.
MaxIdleConns: 1, // There won't be any concurrent connections.
IdleConnTimeout: 1 * time.Minute, // The metrics scraper interval is 30 secs by default.
},
CheckRedirect: disallowRedirects,
}
}

func disallowRedirects(req *http.Request, via []*http.Request) error {
return fmt.Errorf("no redirects allowed: %s", req.URL)
}

type cancelingReadCloser struct {
io.ReadCloser
cancel context.CancelFunc
}

func (c *cancelingReadCloser) Close() error {
defer c.cancel()
return c.ReadCloser.Close()
}

0 comments on commit 5064d5e

Please sign in to comment.