diff --git a/pkg/smokescreen/config.go b/pkg/smokescreen/config.go index 36af187b..274d670f 100644 --- a/pkg/smokescreen/config.go +++ b/pkg/smokescreen/config.go @@ -85,9 +85,11 @@ type Config struct { ProxyDialTimeout func(ctx context.Context, network, address string, timeout time.Duration) (net.Conn, error) // Custom handler to allow clients to modify reject responses + // Deprecated: RejectResponseHandler is deprecated.Please use RejectResponseHandlerWithCtx instead. RejectResponseHandler func(*http.Response) // Custom handler to allow clients to modify reject responses + // In case RejectResponseHandler is set, this cannot be used. RejectResponseHandlerWithCtx func(*SmokescreenContext, *http.Response) // Custom handler to allow clients to modify successful CONNECT responses @@ -421,6 +423,13 @@ func (config *Config) SetupTls(certFile, keyFile string, clientCAFiles []string) return nil } +func (config *Config) Validate() error { + if config.RejectResponseHandler != nil && config.RejectResponseHandlerWithCtx != nil { + return errors.New("RejectResponseHandler and RejectResponseHandlerWithCtx cannot be used together") + } + return nil +} + func (config *Config) populateClientCaMap(pemCerts []byte) (ok bool) { for len(pemCerts) > 0 { diff --git a/pkg/smokescreen/smokescreen.go b/pkg/smokescreen/smokescreen.go index 789e91d9..7659dfcf 100644 --- a/pkg/smokescreen/smokescreen.go +++ b/pkg/smokescreen/smokescreen.go @@ -736,9 +736,14 @@ func findListener(ip string, defaultPort uint16) (net.Listener, error) { func StartWithConfig(config *Config, quit <-chan interface{}) { config.Log.Println("starting") + var err error + + if err = config.Validate(); err != nil { + config.Log.Fatal("invalid config", err) + } + proxy := BuildProxy(config) listener := config.Listener - var err error if listener == nil { listener, err = findListener(config.Ip, config.Port) @@ -782,6 +787,10 @@ func StartWithConfig(config *Config, quit <-chan interface{}) { server.IdleTimeout = config.IdleTimeout } + if config.RejectResponseHandler != nil && config.RejectResponseHandlerWithCtx != nil { + config.Log.Fatal("RejectResponseHandler and RejectResponseHandlerWithCtx cannot be set simultaneously") + } + config.MetricsClient.SetStarted() config.ShuttingDown.Store(false) runServer(config, &server, listener, quit) diff --git a/pkg/smokescreen/smokescreen_test.go b/pkg/smokescreen/smokescreen_test.go index fc4aae4f..89718145 100644 --- a/pkg/smokescreen/smokescreen_test.go +++ b/pkg/smokescreen/smokescreen_test.go @@ -1535,6 +1535,38 @@ func TestMitm(t *testing.T) { }) } +func TestConfigValidate(t *testing.T) { + t.Run("Test invalid config", func(t *testing.T) { + conf := NewConfig() + conf.ConnectTimeout = 10 * time.Second + conf.ExitTimeout = 10 * time.Second + conf.AdditionalErrorMessageOnDeny = "Proxy denied" + conf.RejectResponseHandlerWithCtx = func(smokescreenContext *SmokescreenContext, response *http.Response) { + fmt.Println("RejectResponseHandlerWithCtx") + } + conf.RejectResponseHandler = func(response *http.Response) { + fmt.Println("RejectResponseHandler") + } + err := conf.Validate() + require.Error(t, err) + + }) + + t.Run("Test valid config", func(t *testing.T) { + conf := NewConfig() + conf.ConnectTimeout = 10 * time.Second + conf.ExitTimeout = 10 * time.Second + conf.AdditionalErrorMessageOnDeny = "Proxy denied" + + conf.RejectResponseHandler = func(response *http.Response) { + fmt.Println("RejectResponseHandler") + } + err := conf.Validate() + require.NoError(t, err) + + }) +} + func findCanonicalProxyDecision(logs []*logrus.Entry) *logrus.Entry { for _, entry := range logs { if entry.Message == CanonicalProxyDecision {