diff --git a/pkg/stripe/url.go b/pkg/stripe/url.go index 2cc63344..8d7a9165 100644 --- a/pkg/stripe/url.go +++ b/pkg/stripe/url.go @@ -5,81 +5,60 @@ import ( "regexp" ) -// DefaultAPIBaseURL is the default base URL for API requests -const DefaultAPIBaseURL = "https://api.stripe.com" - -// qaAPIBaseURL is the base URL for API requests in QA -const qaAPIBaseURL = "https://qa-api.stripe.com" - -// devAPIBaseURLRegexp is the base URL for API requests in dev -const devAPIBaseURLRegexp = `http(s)?:\/\/[A-Za-z0-9\-]+api-mydev.dev.stripe.me` - -// DefaultFilesAPIBaseURL is the default base URL for Files API requsts -const DefaultFilesAPIBaseURL = "https://files.stripe.com" - -// DefaultDashboardBaseURL is the default base URL for dashboard requests -const DefaultDashboardBaseURL = "https://dashboard.stripe.com" - -// qaDashboardBaseURL is the base URL for dashboard requests in QA -const qaDashboardBaseURL = "https://qa-dashboard.stripe.com" - -// devDashboardBaseURLRegexp is the base URL for dashboard requests in dev -const devDashboardBaseURLRegexp = `http(s)?:\/\/[A-Za-z0-9\-]+manage-mydev\.dev\.stripe\.me` +const ( + // DefaultAPIBaseURL is the default base URL for API requests + DefaultAPIBaseURL = "https://api.stripe.com" + qaAPIBaseURL = "https://qa-api.stripe.com" + devAPIBaseURLRegexp = `http(s)?:\/\/[A-Za-z0-9\-]+api-mydev.dev.stripe.me` + + // DefaultFilesAPIBaseURL is the default base URL for Files API requsts + DefaultFilesAPIBaseURL = "https://files.stripe.com" + + // DefaultDashboardBaseURL is the default base URL for dashboard requests + DefaultDashboardBaseURL = "https://dashboard.stripe.com" + qaDashboardBaseURL = "https://qa-dashboard.stripe.com" + devDashboardBaseURLRegexp = `http(s)?:\/\/[A-Za-z0-9\-]+manage-mydev\.dev\.stripe\.me` + + // localhostURLRegexp is used in tests + localhostURLRegexp = `http:\/\/127\.0\.0\.1(:[0-9]+)?` +) -// localhostURLRegexp is used in tests -const localhostURLRegexp = `http:\/\/127\.0\.0\.1(:[0-9]+)?` +var ( + errInvalidAPIBaseURL = errors.New("invalid API base URL") + errInvalidDashboardBaseURL = errors.New("invalid dashboard base URL") +) -var errInvalidAPIBaseURL = errors.New("invalid API base URL") -var errInvalidDashboardBaseURL = errors.New("invalid dashboard base URL") +func isValid(url string, exactStrings []string, regexpStrings []string) bool { + for _, s := range exactStrings { + if url == s { + return true + } + } + for _, r := range regexpStrings { + matched, err := regexp.Match(r, []byte(url)) + if err == nil && matched { + return true + } + } + return false +} // ValidateAPIBaseURL returns an error if apiBaseURL isn't allowed func ValidateAPIBaseURL(apiBaseURL string) error { - if apiBaseURL == DefaultAPIBaseURL { + exactStrings := []string{DefaultAPIBaseURL, qaAPIBaseURL} + regexpStrings := []string{devAPIBaseURLRegexp, localhostURLRegexp} + if isValid(apiBaseURL, exactStrings, regexpStrings) { return nil } - if apiBaseURL == qaAPIBaseURL { - return nil - } - - matchedDev, err := regexp.Match(devAPIBaseURLRegexp, []byte(apiBaseURL)) - if err != nil { - return errInvalidAPIBaseURL - } - - matchedLocalhost, err := regexp.Match(localhostURLRegexp, []byte(apiBaseURL)) - if err != nil { - return errInvalidAPIBaseURL - } - - if !matchedDev && !matchedLocalhost { - return errInvalidAPIBaseURL - } - - return nil + return errInvalidAPIBaseURL } -// ValidateDashboardBaseURL returns true if dashboardBaseURL is allowed +// ValidateDashboardBaseURL returns an error if dashboardBaseURL isn't allowed func ValidateDashboardBaseURL(dashboardBaseURL string) error { - if dashboardBaseURL == DefaultDashboardBaseURL { - return nil - } - if dashboardBaseURL == qaDashboardBaseURL { + exactStrings := []string{DefaultDashboardBaseURL, qaDashboardBaseURL} + regexpStrings := []string{devDashboardBaseURLRegexp, localhostURLRegexp} + if isValid(dashboardBaseURL, exactStrings, regexpStrings) { return nil } - - matchedDev, err := regexp.Match(devDashboardBaseURLRegexp, []byte(dashboardBaseURL)) - if err != nil { - return errInvalidDashboardBaseURL - } - - matchedLocalhost, err := regexp.Match(localhostURLRegexp, []byte(dashboardBaseURL)) - if err != nil { - return errInvalidAPIBaseURL - } - - if !matchedDev && !matchedLocalhost { - return errInvalidAPIBaseURL - } - - return nil + return errInvalidDashboardBaseURL } diff --git a/pkg/stripe/url_test.go b/pkg/stripe/url_test.go index 1034c974..91df1dcd 100644 --- a/pkg/stripe/url_test.go +++ b/pkg/stripe/url_test.go @@ -14,10 +14,10 @@ func TestValidateAPIBaseURLWorks(t *testing.T) { assert.Nil(t, ValidateAPIBaseURL("http://127.0.0.1")) assert.Nil(t, ValidateAPIBaseURL("http://127.0.0.1:1337")) - assert.Error(t, ValidateAPIBaseURL("https://example.com")) - assert.Error(t, ValidateAPIBaseURL("https://unknowndomain")) - assert.Error(t, ValidateAPIBaseURL("localhost")) - assert.Error(t, ValidateAPIBaseURL("anything_else")) + assert.ErrorIs(t, ValidateAPIBaseURL("https://example.com"), errInvalidAPIBaseURL) + assert.ErrorIs(t, ValidateAPIBaseURL("https://unknowndomain"), errInvalidAPIBaseURL) + assert.ErrorIs(t, ValidateAPIBaseURL("localhost"), errInvalidAPIBaseURL) + assert.ErrorIs(t, ValidateAPIBaseURL("anything_else"), errInvalidAPIBaseURL) } func TestValidateDashboardBaseURLWorks(t *testing.T) { @@ -28,8 +28,8 @@ func TestValidateDashboardBaseURLWorks(t *testing.T) { assert.Nil(t, ValidateDashboardBaseURL("http://127.0.0.1")) assert.Nil(t, ValidateDashboardBaseURL("http://127.0.0.1:1337")) - assert.Error(t, ValidateDashboardBaseURL("https://example.com")) - assert.Error(t, ValidateDashboardBaseURL("https://unknowndomain")) - assert.Error(t, ValidateDashboardBaseURL("localhost")) - assert.Error(t, ValidateDashboardBaseURL("anything_else")) + assert.ErrorIs(t, ValidateDashboardBaseURL("https://example.com"), errInvalidDashboardBaseURL) + assert.ErrorIs(t, ValidateDashboardBaseURL("https://unknowndomain"), errInvalidDashboardBaseURL) + assert.ErrorIs(t, ValidateDashboardBaseURL("localhost"), errInvalidDashboardBaseURL) + assert.ErrorIs(t, ValidateDashboardBaseURL("anything_else"), errInvalidDashboardBaseURL) }