Skip to content

Commit

Permalink
Refactor (#366)
Browse files Browse the repository at this point in the history
  • Loading branch information
raviqqe authored Feb 28, 2024
1 parent df4845c commit 8ac32ae
Show file tree
Hide file tree
Showing 9 changed files with 98 additions and 126 deletions.
4 changes: 2 additions & 2 deletions arguments.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ type arguments struct {
Help bool `short:"h" long:"help" description:"Show this help"`
Version bool `long:"version" description:"Show version"`
URL string
AcceptedStatusCodes statusCodeCollection
AcceptedStatusCodes statusCodeSet
ExcludedPatterns []*regexp.Regexp
IncludePatterns []*regexp.Regexp
Header http.Header
Expand Down Expand Up @@ -78,7 +78,7 @@ func getArguments(ss []string) (*arguments, error) {
return nil, err
}

args.AcceptedStatusCodes, err = parseStatusCodeCollection(args.RawAcceptedStatusCodes)
args.AcceptedStatusCodes, err = parseStatusCodeSet(args.RawAcceptedStatusCodes)
if err != nil {
return nil, err
}
Expand Down
6 changes: 3 additions & 3 deletions redirect_http_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ import (
type redirectHttpClient struct {
client httpClient
maxRedirections int
acceptedStatusCodes statusCodeCollection
acceptedStatusCodes statusCodeSet
}

func newRedirectHttpClient(c httpClient, maxRedirections int, acceptedStatusCodes statusCodeCollection) httpClient {
func newRedirectHttpClient(c httpClient, maxRedirections int, acceptedStatusCodes statusCodeSet) httpClient {
return &redirectHttpClient{c, maxRedirections, acceptedStatusCodes}
}

Expand Down Expand Up @@ -43,7 +43,7 @@ func (c *redirectHttpClient) Get(u *url.URL, header http.Header) (httpResponse,

code := r.StatusCode()

if c.acceptedStatusCodes.isInCollection(code) {
if c.acceptedStatusCodes.isInSet(code) {
return r, nil
} else if code >= 300 && code <= 399 {
i++
Expand Down
2 changes: 1 addition & 1 deletion redirect_http_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (

const testUrl = "http://foo.com"

var acceptedStatusCodes = statusCodeCollection{[]statusCodeRange{{200, 300}}}
var acceptedStatusCodes = statusCodeSet{{200, 300}: struct{}{}}

func TestNewRedirectHttpClient(t *testing.T) {
newRedirectHttpClient(newFakeHttpClient(nil), 42, acceptedStatusCodes)
Expand Down
41 changes: 0 additions & 41 deletions status_code_collection.go

This file was deleted.

47 changes: 0 additions & 47 deletions status_code_collection_test.go

This file was deleted.

38 changes: 21 additions & 17 deletions status_code_range.go
Original file line number Diff line number Diff line change
@@ -1,36 +1,40 @@
package main

import (
"errors"
"regexp"
"fmt"
"strconv"
"strings"
)

var fixedCodePattern = regexp.MustCompile(`^\s*(\d{3})\s*$`)
var rangeCodePattern = regexp.MustCompile(`^\s*(\d{3})\s*\.\.\s*(\d{3})\s*$`)

type statusCodeRange struct {
start int
end int
}

func parseStatusCodeRange(value string) (*statusCodeRange, error) {
fixedMatch := fixedCodePattern.FindAllStringSubmatch(value, -1)
if len(fixedMatch) > 0 {
code, _ := strconv.Atoi(fixedMatch[0][1])
return &statusCodeRange{code, code + 1}, nil
func parseStatusCodeRange(s string) (*statusCodeRange, error) {
if c, err := strconv.Atoi(s); err == nil {
return &statusCodeRange{c, c + 1}, nil
}

ss := strings.Split(s, "..")
if len(ss) != 2 {
return nil, fmt.Errorf("invalid status code range: %v", s)
}

rangeMatch := rangeCodePattern.FindAllStringSubmatch(value, -1)
if len(rangeMatch) > 0 {
start, _ := strconv.Atoi(rangeMatch[0][1])
end, _ := strconv.Atoi(rangeMatch[0][2])
return &statusCodeRange{start, end}, nil
cs := []int{0, 0}

for i, s := range ss {
c, err := strconv.Atoi(s)
if err != nil {
return nil, fmt.Errorf("invalid status code: %v", s)
}

cs[i] = c
}

return nil, errors.New("invalid HTTP response status code value")
return &statusCodeRange{cs[0], cs[1]}, nil
}

func (r *statusCodeRange) isInRange(code int) bool {
func (r statusCodeRange) isInRange(code int) bool {
return code >= r.start && code < r.end
}
29 changes: 14 additions & 15 deletions status_code_range_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,36 +6,35 @@ import (
)

func TestParsingFixedStatusCode(t *testing.T) {
code, err := parseStatusCodeRange("403")
r, err := parseStatusCodeRange("403")

assert.Nil(t, err)
assert.Equal(t, 403, code.start)
assert.Equal(t, 404, code.end)
assert.Equal(t, 403, r.start)
assert.Equal(t, 404, r.end)
}

func TestParsingStatusCodeRange(t *testing.T) {
code, err := parseStatusCodeRange("200..300")
r, err := parseStatusCodeRange("200..300")

assert.Nil(t, err)
assert.Equal(t, 200, code.start)
assert.Equal(t, 300, code.end)
assert.Equal(t, 200, r.start)
assert.Equal(t, 300, r.end)
}

func TestParsingInvalidStatusCode(t *testing.T) {
code, err := parseStatusCodeRange("foo")
_, err := parseStatusCodeRange("foo")

assert.NotNil(t, err)
assert.Nil(t, code)
}

func TestInRangeOfStatusCode(t *testing.T) {
code := statusCodeRange{200, 300}
r := statusCodeRange{200, 300}

assert.False(t, code.isInRange(199))
assert.True(t, code.isInRange(200))
assert.True(t, code.isInRange(201))
assert.False(t, r.isInRange(199))
assert.True(t, r.isInRange(200))
assert.True(t, r.isInRange(201))

assert.True(t, code.isInRange(298))
assert.True(t, code.isInRange(299))
assert.False(t, code.isInRange(300))
assert.True(t, r.isInRange(298))
assert.True(t, r.isInRange(299))
assert.False(t, r.isInRange(300))
}
30 changes: 30 additions & 0 deletions status_code_set.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package main

import "strings"

type statusCodeSet map[statusCodeRange]struct{}

func parseStatusCodeSet(value string) (statusCodeSet, error) {
rs := statusCodeSet{}

for _, r := range strings.Split(value, ",") {
r, err := parseStatusCodeRange(r)
if err != nil {
return nil, err
}

rs[*r] = struct{}{}
}

return rs, nil
}

func (s statusCodeSet) isInSet(code int) bool {
for r := range s {
if r.isInRange(code) {
return true
}
}

return false
}
27 changes: 27 additions & 0 deletions status_code_set_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package main

import (
"github.com/stretchr/testify/assert"
"testing"
)

func TestParsingValidStatusCodeSet(t *testing.T) {
for s, r := range map[string]statusCodeSet{
"200": {{200, 201}: {}},
"200..300": {{200, 300}: {}},
"200..207,403": {{200, 207}: {}, {403, 404}: {}},
} {
t.Run(s, func(t *testing.T) {
s, err := parseStatusCodeSet(s)

assert.Nil(t, err)
assert.Equal(t, s, r)
})
}
}

func TestParsingInvalidStatusCodeSet(t *testing.T) {
_, err := parseStatusCodeSet("200,foo")

assert.NotNil(t, err)
}

0 comments on commit 8ac32ae

Please sign in to comment.