diff --git a/arguments.go b/arguments.go index 6b32a07..dd962c0 100644 --- a/arguments.go +++ b/arguments.go @@ -41,7 +41,7 @@ type arguments struct { Help bool `short:"h" long:"help" description:"Show this help"` Version bool `long:"version" description:"Show version"` URL string - AcceptedStatusCodes statusCodeCollection + AcceptedStatusCodes statusCodeSet ExcludedPatterns []*regexp.Regexp IncludePatterns []*regexp.Regexp Header http.Header @@ -78,7 +78,7 @@ func getArguments(ss []string) (*arguments, error) { return nil, err } - args.AcceptedStatusCodes, err = parseStatusCodeCollection(args.RawAcceptedStatusCodes) + args.AcceptedStatusCodes, err = parseStatusCodeSet(args.RawAcceptedStatusCodes) if err != nil { return nil, err } diff --git a/redirect_http_client.go b/redirect_http_client.go index 511a197..15d6e78 100644 --- a/redirect_http_client.go +++ b/redirect_http_client.go @@ -11,10 +11,10 @@ import ( type redirectHttpClient struct { client httpClient maxRedirections int - acceptedStatusCodes statusCodeCollection + acceptedStatusCodes statusCodeSet } -func newRedirectHttpClient(c httpClient, maxRedirections int, acceptedStatusCodes statusCodeCollection) httpClient { +func newRedirectHttpClient(c httpClient, maxRedirections int, acceptedStatusCodes statusCodeSet) httpClient { return &redirectHttpClient{c, maxRedirections, acceptedStatusCodes} } @@ -43,7 +43,7 @@ func (c *redirectHttpClient) Get(u *url.URL, header http.Header) (httpResponse, code := r.StatusCode() - if c.acceptedStatusCodes.isInCollection(code) { + if c.acceptedStatusCodes.isInSet(code) { return r, nil } else if code >= 300 && code <= 399 { i++ diff --git a/redirect_http_client_test.go b/redirect_http_client_test.go index 0ed8c37..2676759 100644 --- a/redirect_http_client_test.go +++ b/redirect_http_client_test.go @@ -10,7 +10,7 @@ import ( const testUrl = "http://foo.com" -var acceptedStatusCodes = statusCodeCollection{[]statusCodeRange{{200, 300}}} +var acceptedStatusCodes = statusCodeSet{{200, 300}: struct{}{}} func TestNewRedirectHttpClient(t *testing.T) { newRedirectHttpClient(newFakeHttpClient(nil), 42, acceptedStatusCodes) diff --git a/status_code_collection.go b/status_code_collection.go deleted file mode 100644 index e60535a..0000000 --- a/status_code_collection.go +++ /dev/null @@ -1,41 +0,0 @@ -package main - -import "strings" - -type statusCodeCollection struct { - elements []statusCodeRange -} - -func parseStatusCodeCollection(value string) (statusCodeCollection, error) { - statusCodeRanges := []statusCodeRange{} - - for _, partial := range strings.Split(value, ",") { - if len(value) == 0 { - continue - } - - statusCodeRange, err := parseStatusCodeRange(partial) - - if err != nil { - return statusCodeCollection{}, err - } - - statusCodeRanges = append(statusCodeRanges, *statusCodeRange) - } - - if len(statusCodeRanges) == 0 { - statusCodeRanges = append(statusCodeRanges, statusCodeRange{200, 300}) - } - - return statusCodeCollection{statusCodeRanges}, nil -} - -func (c *statusCodeCollection) isInCollection(code int) bool { - for _, element := range c.elements { - if element.isInRange(code) { - return true - } - } - - return false -} diff --git a/status_code_collection_test.go b/status_code_collection_test.go deleted file mode 100644 index 2e89f98..0000000 --- a/status_code_collection_test.go +++ /dev/null @@ -1,47 +0,0 @@ -package main - -import ( - "github.com/stretchr/testify/assert" - "testing" -) - -func TestParsingEmptyStatusCodeCollection(t *testing.T) { - collection, err := parseStatusCodeCollection("") - - assert.Nil(t, err) - - assert.False(t, collection.isInCollection(199)) - assert.True(t, collection.isInCollection(200)) - assert.True(t, collection.isInCollection(201)) - - assert.True(t, collection.isInCollection(298)) - assert.True(t, collection.isInCollection(299)) - assert.False(t, collection.isInCollection(300)) -} - -func TestParsingValidStatusCodeCollection(t *testing.T) { - collection, err := parseStatusCodeCollection("200..207,403") - - assert.Nil(t, err) - - assert.False(t, collection.isInCollection(199)) - assert.True(t, collection.isInCollection(200)) - assert.True(t, collection.isInCollection(201)) - - assert.True(t, collection.isInCollection(205)) - assert.True(t, collection.isInCollection(206)) - assert.False(t, collection.isInCollection(207)) - - assert.False(t, collection.isInCollection(402)) - assert.True(t, collection.isInCollection(403)) - assert.False(t, collection.isInCollection(404)) -} - -func TestParsingInvalidStatusCodeCollection(t *testing.T) { - collection, err := parseStatusCodeCollection("200,foo") - - assert.NotNil(t, err) - - assert.NotNil(t, collection) - assert.NotNil(t, collection.isInCollection(200)) -} diff --git a/status_code_range.go b/status_code_range.go index a8ae3a7..3b0168d 100644 --- a/status_code_range.go +++ b/status_code_range.go @@ -1,36 +1,40 @@ package main import ( - "errors" - "regexp" + "fmt" "strconv" + "strings" ) -var fixedCodePattern = regexp.MustCompile(`^\s*(\d{3})\s*$`) -var rangeCodePattern = regexp.MustCompile(`^\s*(\d{3})\s*\.\.\s*(\d{3})\s*$`) - type statusCodeRange struct { start int end int } -func parseStatusCodeRange(value string) (*statusCodeRange, error) { - fixedMatch := fixedCodePattern.FindAllStringSubmatch(value, -1) - if len(fixedMatch) > 0 { - code, _ := strconv.Atoi(fixedMatch[0][1]) - return &statusCodeRange{code, code + 1}, nil +func parseStatusCodeRange(s string) (*statusCodeRange, error) { + if c, err := strconv.Atoi(s); err == nil { + return &statusCodeRange{c, c + 1}, nil + } + + ss := strings.Split(s, "..") + if len(ss) != 2 { + return nil, fmt.Errorf("invalid status code range: %v", s) } - rangeMatch := rangeCodePattern.FindAllStringSubmatch(value, -1) - if len(rangeMatch) > 0 { - start, _ := strconv.Atoi(rangeMatch[0][1]) - end, _ := strconv.Atoi(rangeMatch[0][2]) - return &statusCodeRange{start, end}, nil + cs := []int{0, 0} + + for i, s := range ss { + c, err := strconv.Atoi(s) + if err != nil { + return nil, fmt.Errorf("invalid status code: %v", s) + } + + cs[i] = c } - return nil, errors.New("invalid HTTP response status code value") + return &statusCodeRange{cs[0], cs[1]}, nil } -func (r *statusCodeRange) isInRange(code int) bool { +func (r statusCodeRange) isInRange(code int) bool { return code >= r.start && code < r.end } diff --git a/status_code_range_test.go b/status_code_range_test.go index a14a30a..47b1d51 100644 --- a/status_code_range_test.go +++ b/status_code_range_test.go @@ -6,36 +6,35 @@ import ( ) func TestParsingFixedStatusCode(t *testing.T) { - code, err := parseStatusCodeRange("403") + r, err := parseStatusCodeRange("403") assert.Nil(t, err) - assert.Equal(t, 403, code.start) - assert.Equal(t, 404, code.end) + assert.Equal(t, 403, r.start) + assert.Equal(t, 404, r.end) } func TestParsingStatusCodeRange(t *testing.T) { - code, err := parseStatusCodeRange("200..300") + r, err := parseStatusCodeRange("200..300") assert.Nil(t, err) - assert.Equal(t, 200, code.start) - assert.Equal(t, 300, code.end) + assert.Equal(t, 200, r.start) + assert.Equal(t, 300, r.end) } func TestParsingInvalidStatusCode(t *testing.T) { - code, err := parseStatusCodeRange("foo") + _, err := parseStatusCodeRange("foo") assert.NotNil(t, err) - assert.Nil(t, code) } func TestInRangeOfStatusCode(t *testing.T) { - code := statusCodeRange{200, 300} + r := statusCodeRange{200, 300} - assert.False(t, code.isInRange(199)) - assert.True(t, code.isInRange(200)) - assert.True(t, code.isInRange(201)) + assert.False(t, r.isInRange(199)) + assert.True(t, r.isInRange(200)) + assert.True(t, r.isInRange(201)) - assert.True(t, code.isInRange(298)) - assert.True(t, code.isInRange(299)) - assert.False(t, code.isInRange(300)) + assert.True(t, r.isInRange(298)) + assert.True(t, r.isInRange(299)) + assert.False(t, r.isInRange(300)) } diff --git a/status_code_set.go b/status_code_set.go new file mode 100644 index 0000000..790d434 --- /dev/null +++ b/status_code_set.go @@ -0,0 +1,30 @@ +package main + +import "strings" + +type statusCodeSet map[statusCodeRange]struct{} + +func parseStatusCodeSet(value string) (statusCodeSet, error) { + rs := statusCodeSet{} + + for _, r := range strings.Split(value, ",") { + r, err := parseStatusCodeRange(r) + if err != nil { + return nil, err + } + + rs[*r] = struct{}{} + } + + return rs, nil +} + +func (s statusCodeSet) isInSet(code int) bool { + for r := range s { + if r.isInRange(code) { + return true + } + } + + return false +} diff --git a/status_code_set_test.go b/status_code_set_test.go new file mode 100644 index 0000000..3bf5e1f --- /dev/null +++ b/status_code_set_test.go @@ -0,0 +1,27 @@ +package main + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestParsingValidStatusCodeSet(t *testing.T) { + for s, r := range map[string]statusCodeSet{ + "200": {{200, 201}: {}}, + "200..300": {{200, 300}: {}}, + "200..207,403": {{200, 207}: {}, {403, 404}: {}}, + } { + t.Run(s, func(t *testing.T) { + s, err := parseStatusCodeSet(s) + + assert.Nil(t, err) + assert.Equal(t, s, r) + }) + } +} + +func TestParsingInvalidStatusCodeSet(t *testing.T) { + _, err := parseStatusCodeSet("200,foo") + + assert.NotNil(t, err) +}