diff --git a/client.go b/client.go index ce7ef0c..c977dc1 100644 --- a/client.go +++ b/client.go @@ -101,6 +101,10 @@ func isSchemeValid(parsed *urllib.URL, config *Config, debugLogFunc func(string) func isHostValid(parsed *urllib.URL, config *Config, debugLogFunc func(string)) error { host := parsed.Hostname() + if host == "" { + debugLogFunc("empty host received") + return &InvalidHostError{host: ""} + } if config.AllowedHosts != nil && !isAllowedHost(host, config.AllowedHosts) { debugLogFunc(fmt.Sprintf("disallowed host: %s", host)) @@ -254,6 +258,14 @@ func (e *AllowedSchemeError) Error() string { return fmt.Sprintf("scheme: %v not found in allowlist", e.scheme) } +type InvalidHostError struct { + host string +} + +func (e *InvalidHostError) Error() string { + return fmt.Sprintf("host: %v is not valid", e.host) +} + type AllowedHostError struct { host string } diff --git a/client_test.go b/client_test.go index 581db6e..601110f 100644 --- a/client_test.go +++ b/client_test.go @@ -437,3 +437,24 @@ func TestInternalIPAreAlwaysBlocked(t *testing.T) { } } + +func TestInvalidHostValidation(t *testing.T) { + cfg := GetConfigBuilder().Build() + client := Client(cfg) + + urls := []string{"http://[]", "http://[]:123", "http://:123"} + + for _, url := range urls { + _, err := client.Get(url) + if err == nil { + t.Errorf("invalid host from url => %v was accepted. client didn't not return an error", err) + } + + err = unwrap(err) + _, ok := err.(*InvalidHostError) + if !ok { + t.Errorf("client returned incorrect error: %v", err) + } + } + +}