diff --git a/pkg/camo/proxy.go b/pkg/camo/proxy.go index 783d5ed..246aa90 100644 --- a/pkg/camo/proxy.go +++ b/pkg/camo/proxy.go @@ -281,7 +281,7 @@ func (p *Proxy) checkURL(reqURL *url.URL) error { // evaluate filters. first false value "fails" for i := 0; i < p.filtersLen; i++ { - if !p.filters[0](reqURL) { + if !p.filters[i](reqURL) { return errors.New("Rejected due to filter-ruleset") } } diff --git a/pkg/camo/proxy_test.go b/pkg/camo/proxy_test.go index 8834274..e72efcc 100644 --- a/pkg/camo/proxy_test.go +++ b/pkg/camo/proxy_test.go @@ -9,6 +9,7 @@ import ( "fmt" "net/http" "net/http/httptest" + "net/url" "os" "testing" "time" @@ -41,10 +42,23 @@ func makeReq(testURL string) (*http.Request, error) { return req, nil } -func processRequest(req *http.Request, status int, camoConfig Config) (*httptest.ResponseRecorder, error) { - camoServer, err := New(camoConfig) - if err != nil { - return nil, fmt.Errorf("Error building Camo: %s", err.Error()) +func processRequest(req *http.Request, status int, camoConfig Config, filters []FilterFunc) (*httptest.ResponseRecorder, error) { + + var ( + camoServer *Proxy + err error + ) + + if filters == nil || len(filters) == 0 { + camoServer, err = New(camoConfig) + if err != nil { + return nil, fmt.Errorf("Error building Camo: %s", err.Error()) + } + } else { + camoServer, err = NewWithFilters(camoConfig, filters) + if err != nil { + return nil, fmt.Errorf("Error building Camo: %s", err.Error()) + } } router := &router.DumbRouter{ @@ -66,7 +80,7 @@ func makeTestReq(testURL string, status int, config Config) (*httptest.ResponseR if err != nil { return nil, err } - record, err := processRequest(req, status, config) + record, err := processRequest(req, status, config, nil) if err != nil { return record, err } @@ -78,7 +92,7 @@ func TestNotFound(t *testing.T) { req, err := http.NewRequest("GET", "http://example.com/favicon.ico", nil) assert.Nil(t, err) - record, err := processRequest(req, 404, camoConfig) + record, err := processRequest(req, 404, camoConfig, nil) if assert.Nil(t, err) { assert.Equal(t, 404, record.Code, "Expected 404 but got '%d' instead", record.Code) assert.Equal(t, "404 Not Found\n", record.Body.String(), "Expected 404 response body but got '%s' instead", record.Body.String()) @@ -173,12 +187,12 @@ func TestXForwardedFor(t *testing.T) { req.Header.Set("X-Forwarded-For", "2.2.2.2, 1.1.1.1") - record, err := processRequest(req, 200, camoConfigWithoutFwd4) + record, err := processRequest(req, 200, camoConfigWithoutFwd4, nil) assert.Nil(t, err) assert.EqualValues(t, record.Body.String(), "2.2.2.2, 1.1.1.1") camoConfigWithoutFwd4.EnableXFwdFor = false - record, err = processRequest(req, 200, camoConfigWithoutFwd4) + record, err = processRequest(req, 200, camoConfigWithoutFwd4, nil) assert.Nil(t, err) assert.Empty(t, record.Body.String()) } @@ -203,7 +217,7 @@ func TestVideoContentTypeAllowed(t *testing.T) { req, err := makeReq(testURL) assert.Nil(t, err) req.Header.Add("Range", "bytes=0-10") - record, err := processRequest(req, 206, camoConfigWithVideo) + record, err := processRequest(req, 206, camoConfigWithVideo, nil) resp := record.Result() assert.Equal(t, resp.Header.Get("Content-Range"), "bytes 0-10/179698") assert.Nil(t, err) @@ -374,8 +388,96 @@ func TestSupplyAcceptIfNoneGiven(t *testing.T) { req, err := makeReq(testURL) req.Header.Del("Accept") assert.Nil(t, err) - _, err = processRequest(req, 200, camoConfig) + _, err = processRequest(req, 200, camoConfig, nil) + assert.Nil(t, err) +} + +func TestFilterListAcceptSimple(t *testing.T) { + t.Parallel() + + called := false + filters := []FilterFunc{ + func(*url.URL) bool { + called = true + return true + }, + } + testURL := "http://www.google.com/images/srpr/logo11w.png" + req, err := makeReq(testURL) + _, err = processRequest(req, 200, camoConfig, filters) + assert.Nil(t, err) + assert.True(t, called, "filter func wasn't called") +} + +func TestFilterListMatrixMultiples(t *testing.T) { + t.Parallel() + + testURL := "http://www.google.com/images/srpr/logo11w.png" + req, err := makeReq(testURL) assert.Nil(t, err) + + var mixtests = []struct { + filterRuleAnswers []bool + expectedCallMatrix []bool + respcode int + }{ + // all rules return true, so all rules should have been called + // so pass: http200 + { + []bool{true, true, true}, + []bool{true, true, true}, + 200, + }, + // 3rd rule should not be called, because 2nd returned false + // so no pass: http404 + { + []bool{true, false, true}, + []bool{true, true, false}, + 404, + }, + // 2nd, 3rd rules should not be called, because 1st returned false + // so no pass: http404 + { + []bool{false, false, true}, + []bool{true, false, false}, + 404, + }, + // last rule returns false, but all rules should be called. + // so no pass: http404 + { + []bool{true, true, false}, + []bool{true, true, true}, + 404, + }, + } + + for _, tt := range mixtests { + callMatrix := []bool{false, false, false} + filters := make([]FilterFunc, 0) + for i := 0; i < 3; i++ { + filters = append( + filters, func(x int) func(*url.URL) bool { + return func(*url.URL) bool { + callMatrix[x] = true + return tt.filterRuleAnswers[x] + } + }(i), + ) + } + + _, err = processRequest(req, tt.respcode, camoConfig, filters) + assert.Nil(t, err) + for i := range callMatrix { + assert.Equal(t, + callMatrix[i], + tt.expectedCallMatrix[i], + fmt.Sprintf( + "filter func called='%t' wanted '%t'", + callMatrix[i], tt.expectedCallMatrix[i], + ), + ) + } + } } func TestTimeout(t *testing.T) { @@ -405,7 +507,7 @@ func TestTimeout(t *testing.T) { errc := make(chan error, 1) go func() { code := 504 - _, err := processRequest(req, code, c) + _, err := processRequest(req, code, c, nil) errc <- err }()