Skip to content

Commit

Permalink
Handle case of empty string CAs in TLS config (#2063)
Browse files Browse the repository at this point in the history
* Don't set callback in case of disabled TLS

* Treat empty string CAs as not configured

* Add validation of TLS config to NewTLSConfigProviderFromConfig()

* Revert "Treat empty string CAs as not configured"

This reverts commit 52e7bc5.

* Revert "Don't set callback in case of disabled TLS"

This reverts commit b71651e.

* Add tests for TLS config validation
  • Loading branch information
sergeybykov authored Oct 18, 2021
1 parent 87a89e6 commit 92081a5
Show file tree
Hide file tree
Showing 2 changed files with 258 additions and 0 deletions.
91 changes: 91 additions & 0 deletions common/rpc/encryption/tlsFactory.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ package encryption
import (
"crypto/tls"
"crypto/x509"
"fmt"
"strings"
"time"

"github.com/uber-go/tally"
Expand Down Expand Up @@ -89,8 +91,97 @@ func NewTLSConfigProviderFromConfig(
certProviderFactory CertProviderFactory,
) (TLSConfigProvider, error) {

if err := validateRootTLS(&encryptionSettings); err != nil {
return nil, err
}
if certProviderFactory == nil {
certProviderFactory = NewLocalStoreCertProvider
}
return NewLocalStoreTlsProvider(&encryptionSettings, scope, logger, certProviderFactory)
}

func validateRootTLS(cfg *config.RootTLS) error {
if err := validateGroupTLS(&cfg.Internode); err != nil {
return err
}
if err := validateGroupTLS(&cfg.Frontend); err != nil {
return err
}
if err := validateWorkerTLS(&cfg.SystemWorker); err != nil {
return err
}
return nil
}

func validateGroupTLS(cfg *config.GroupTLS) error {
if err := validateServerTLS(&cfg.Server); err != nil {
return err
}
if err := validateClientTLS(&cfg.Client); err != nil {
return err
}
for host, hostConfig := range cfg.PerHostOverrides {

if strings.TrimSpace(host) == "" {
return fmt.Errorf("host name cannot be empty string")
}
if err := validateServerTLS(&hostConfig); err != nil {
return err
}
}
return nil
}

func validateWorkerTLS(cfg *config.WorkerTLS) error {
if cfg.CertFile != "" && cfg.CertData != "" {
return fmt.Errorf("cannot specify CertFile and CertData at the same time")
}
if cfg.KeyFile != "" && cfg.KeyData != "" {
return fmt.Errorf("cannot specify KeyFile and KeyData at the same time")
}
if err := validateClientTLS(&cfg.Client); err != nil {
return err
}
return nil
}

func validateServerTLS(cfg *config.ServerTLS) error {
if cfg.CertFile != "" && cfg.CertData != "" {
return fmt.Errorf("cannot specify CertFile and CertData at the same time")
}
if cfg.KeyFile != "" && cfg.KeyData != "" {
return fmt.Errorf("cannot specify KeyFile and KeyData at the same time")
}
if err := validateCAs(cfg.ClientCAData); err != nil {
return fmt.Errorf("invalid ServerTLS.ClientCAData: %w", err)
}
if err := validateCAs(cfg.ClientCAFiles); err != nil {
return fmt.Errorf("invalid ServerTLS.ClientCAFiles: %w", err)
}
if len(cfg.ClientCAFiles) > 0 && len(cfg.ClientCAData) > 0 {
return fmt.Errorf("cannot specify ClientCAFiles and ClientCAData at the same time")
}
return nil
}

func validateClientTLS(cfg *config.ClientTLS) error {
if err := validateCAs(cfg.RootCAData); err != nil {
return fmt.Errorf("invalid ClientTLS.RootCAData: %w", err)
}
if err := validateCAs(cfg.RootCAFiles); err != nil {
return fmt.Errorf("invalid ClientTLS.RootCAFiles: %w", err)
}
if len(cfg.RootCAData) > 0 && len(cfg.RootCAFiles) > 0 {
return fmt.Errorf("cannot specify RootCAFiles and RootCAData at the same time")
}
return nil
}

func validateCAs(cas []string) error {
for _, ca := range cas {
if strings.TrimSpace(ca) == "" {
return fmt.Errorf("CA cannot be empty string")
}
}
return nil
}
167 changes: 167 additions & 0 deletions common/rpc/encryption/tls_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,170 @@ func (s *tlsConfigTest) TestIsSystemWorker() {
cfg = &config.RootTLS{SystemWorker: config.WorkerTLS{Client: config.ClientTLS{ForceTLS: false}}}
s.False(isSystemWorker(cfg))
}

func (s *tlsConfigTest) TestCertFileAndData() {
s.testGroupTLS(s.testCertFileAndData)
}

func (s *tlsConfigTest) TestKeyFileAndData() {
s.testGroupTLS(s.testKeyFileAndData)
}

func (s *tlsConfigTest) TestClientCAData() {
s.testGroupTLS(s.testClientCAData)
}

func (s *tlsConfigTest) TestClientCAFiles() {
s.testGroupTLS(s.testClientCAFiles)
}

func (s *tlsConfigTest) TestRootCAData() {
s.testGroupTLS(s.testRootCAData)
}

func (s *tlsConfigTest) TestRootCAFiles() {
s.testGroupTLS(s.testRootCAFiles)
}

func (s *tlsConfigTest) testGroupTLS(f func(*config.RootTLS, *config.GroupTLS)) {

cfg := &config.RootTLS{Internode: config.GroupTLS{}}
f(cfg, &cfg.Internode)
cfg = &config.RootTLS{Frontend: config.GroupTLS{}}
f(cfg, &cfg.Frontend)
}

func (s *tlsConfigTest) testClientTLS(f func(*config.RootTLS, *config.ClientTLS)) {

cfg := &config.RootTLS{Internode: config.GroupTLS{}}
f(cfg, &cfg.Internode.Client)
cfg = &config.RootTLS{Frontend: config.GroupTLS{}}
f(cfg, &cfg.Frontend.Client)
}

func (s *tlsConfigTest) testServerTLS(f func(*config.RootTLS, *config.ServerTLS)) {

cfg := &config.RootTLS{Internode: config.GroupTLS{}}
f(cfg, &cfg.Internode.Server)
cfg = &config.RootTLS{Frontend: config.GroupTLS{}}
f(cfg, &cfg.Frontend.Server)
}

func (s *tlsConfigTest) testCertFileAndData(cfg *config.RootTLS, group *config.GroupTLS) {

group.Server = config.ServerTLS{}
s.Nil(validateRootTLS(cfg))
group.Server = config.ServerTLS{CertFile: "foo"}
s.Nil(validateRootTLS(cfg))
group.Server = config.ServerTLS{CertData: "bar"}
s.Nil(validateRootTLS(cfg))
group.Server = config.ServerTLS{CertFile: "foo", CertData: "bar"}
s.Error(validateRootTLS(cfg))
}

func (s *tlsConfigTest) testKeyFileAndData(cfg *config.RootTLS, group *config.GroupTLS) {

group.Server = config.ServerTLS{}
s.Nil(validateRootTLS(cfg))
group.Server = config.ServerTLS{KeyFile: "foo"}
s.Nil(validateRootTLS(cfg))
group.Server = config.ServerTLS{KeyData: "bar"}
s.Nil(validateRootTLS(cfg))
group.Server = config.ServerTLS{KeyFile: "foo", KeyData: "bar"}
s.Error(validateRootTLS(cfg))
}

func (s *tlsConfigTest) testClientCAData(cfg *config.RootTLS, group *config.GroupTLS) {

group.Server = config.ServerTLS{}
s.Nil(validateRootTLS(cfg))
group.Server = config.ServerTLS{ClientCAData: []string{}}
s.Nil(validateRootTLS(cfg))
group.Server = config.ServerTLS{ClientCAData: []string{"foo"}}
s.Nil(validateRootTLS(cfg))
group.Server = config.ServerTLS{ClientCAData: []string{"foo", "bar"}}
s.Nil(validateRootTLS(cfg))
group.Server = config.ServerTLS{ClientCAData: []string{"foo", " "}}
s.Error(validateRootTLS(cfg))
group.Server = config.ServerTLS{ClientCAData: []string{""}}
s.Error(validateRootTLS(cfg))
}

func (s *tlsConfigTest) testClientCAFiles(cfg *config.RootTLS, group *config.GroupTLS) {

group.Server = config.ServerTLS{}
s.Nil(validateRootTLS(cfg))
group.Server = config.ServerTLS{ClientCAFiles: []string{}}
s.Nil(validateRootTLS(cfg))
group.Server = config.ServerTLS{ClientCAFiles: []string{"foo"}}
s.Nil(validateRootTLS(cfg))
group.Server = config.ServerTLS{ClientCAFiles: []string{"foo", "bar"}}
s.Nil(validateRootTLS(cfg))
group.Server = config.ServerTLS{ClientCAFiles: []string{"foo", " "}}
s.Error(validateRootTLS(cfg))
group.Server = config.ServerTLS{ClientCAFiles: []string{""}}
s.Error(validateRootTLS(cfg))
}

func (s *tlsConfigTest) testRootCAData(cfg *config.RootTLS, group *config.GroupTLS) {

group.Client = config.ClientTLS{}
s.Nil(validateRootTLS(cfg))
group.Client = config.ClientTLS{RootCAData: []string{}}
s.Nil(validateRootTLS(cfg))
group.Client = config.ClientTLS{RootCAData: []string{"foo"}}
s.Nil(validateRootTLS(cfg))
group.Client = config.ClientTLS{RootCAData: []string{"foo", "bar"}}
s.Nil(validateRootTLS(cfg))
group.Client = config.ClientTLS{RootCAData: []string{"foo", " "}}
s.Error(validateRootTLS(cfg))
group.Client = config.ClientTLS{RootCAData: []string{""}}
s.Error(validateRootTLS(cfg))
}

func (s *tlsConfigTest) testRootCAFiles(cfg *config.RootTLS, group *config.GroupTLS) {

group.Client = config.ClientTLS{}
s.Nil(validateRootTLS(cfg))
group.Client = config.ClientTLS{RootCAFiles: []string{}}
s.Nil(validateRootTLS(cfg))
group.Client = config.ClientTLS{RootCAFiles: []string{"foo"}}
s.Nil(validateRootTLS(cfg))
group.Client = config.ClientTLS{RootCAFiles: []string{"foo", "bar"}}
s.Nil(validateRootTLS(cfg))
group.Client = config.ClientTLS{RootCAFiles: []string{"foo", " "}}
s.Error(validateRootTLS(cfg))
group.Client = config.ClientTLS{RootCAFiles: []string{""}}
s.Error(validateRootTLS(cfg))
}

func (s *tlsConfigTest) TestSystemWorkerTLSConfig() {
cfg := &config.RootTLS{}
cfg.SystemWorker = config.WorkerTLS{}
s.Nil(validateRootTLS(cfg))
cfg.SystemWorker = config.WorkerTLS{CertFile: "foo"}
s.Nil(validateRootTLS(cfg))
cfg.SystemWorker = config.WorkerTLS{CertData: "bar"}
s.Nil(validateRootTLS(cfg))
cfg.SystemWorker = config.WorkerTLS{CertFile: "foo", CertData: "bar"}
s.Error(validateRootTLS(cfg))
cfg.SystemWorker = config.WorkerTLS{KeyFile: "foo"}
s.Nil(validateRootTLS(cfg))
cfg.SystemWorker = config.WorkerTLS{KeyData: "bar"}
s.Nil(validateRootTLS(cfg))
cfg.SystemWorker = config.WorkerTLS{KeyFile: "foo", KeyData: "bar"}
s.Error(validateRootTLS(cfg))

cfg.SystemWorker = config.WorkerTLS{Client: config.ClientTLS{}}
client := &cfg.SystemWorker.Client
client.RootCAData = []string{}
s.Nil(validateRootTLS(cfg))
client.RootCAData = []string{"foo"}
s.Nil(validateRootTLS(cfg))
client.RootCAData = []string{"foo", "bar"}
s.Nil(validateRootTLS(cfg))
client.RootCAData = []string{"foo", " "}
s.Error(validateRootTLS(cfg))
client.RootCAData = []string{""}
s.Error(validateRootTLS(cfg))
}

0 comments on commit 92081a5

Please sign in to comment.