From 32d424a5e9ed7c00ab28f722eb6fbaae28282aae Mon Sep 17 00:00:00 2001 From: Charlie Briggs Date: Tue, 19 Jan 2021 16:00:15 +0000 Subject: [PATCH] cb/sm: add option to disable csrf; encode state parameter with redirect URL in json --- config.go | 1 + cookies.go | 15 +++++++++++++-- doc.go | 2 ++ go.mod | 1 + go.sum | 2 ++ handlers.go | 35 ++++++++++++++++++++++++++++------- middleware_test.go | 24 ++++++++++++++++++++---- misc.go | 10 ++++++++-- misc_test.go | 3 ++- 9 files changed, 77 insertions(+), 16 deletions(-) diff --git a/config.go b/config.go index 080af6ad..4280b86f 100644 --- a/config.go +++ b/config.go @@ -43,6 +43,7 @@ func newDefaultConfig() *Config { EnableDefaultDeny: true, EnableSessionCookies: true, EnableTokenHeader: true, + EnableCSRFCheck: false, HTTPOnlyCookie: true, Headers: make(map[string]string), LetsEncryptCacheDir: "./cache/", diff --git a/cookies.go b/cookies.go index bcd7dd24..9e2ef721 100644 --- a/cookies.go +++ b/cookies.go @@ -17,6 +17,7 @@ package main import ( "encoding/base64" + "encoding/json" "net/http" "strconv" "strings" @@ -119,8 +120,13 @@ func (r *oauthProxy) dropRefreshTokenCookie(req *http.Request, w http.ResponseWr r.dropCookieWithChunks(req, w, r.config.CookieRefreshName, value, duration) } +type StateParameter struct { + Token string `json:"token"` + Url string `json:"url"` +} + // writeStateParameterCookie sets a state parameter cookie into the response -func (r *oauthProxy) writeStateParameterCookie(req *http.Request, w http.ResponseWriter) string { +func (r *oauthProxy) writeStateParameterCookie(req *http.Request, w http.ResponseWriter) (string, error) { uuid, err := uuid.NewV4() if err != nil { w.WriteHeader(http.StatusInternalServerError) @@ -128,7 +134,12 @@ func (r *oauthProxy) writeStateParameterCookie(req *http.Request, w http.Respons requestURI := base64.StdEncoding.EncodeToString([]byte(req.URL.RequestURI())) r.dropCookie(w, req.Host, requestURICookie, requestURI, 0) r.dropCookie(w, req.Host, requestStateCookie, uuid.String(), 0) - return uuid.String() + + stateParam := StateParameter{Token: uuid.String(), + Url: req.URL.RequestURI()} + output, err := json.Marshal(stateParam) + + return string(output), err } // clearAllCookies is just a helper function for the below diff --git a/doc.go b/doc.go index c42c53f0..1aad5f2a 100644 --- a/doc.go +++ b/doc.go @@ -252,6 +252,8 @@ type Config struct { LocalhostMetrics bool `json:"localhost-metrics" yaml:"localhost-metrics" usage:"enforces the metrics page can only been requested from 127.0.0.1"` // EnableCompression enables gzip compression for response EnableCompression bool `json:"enable-compression" yaml:"enable-compression" usage:"enable gzip compression for response"` + // EnableCSRFCheck enables CSRF protection between authorise/callback using cookies and state parameter + EnableCSRFCheck bool `json:"enable-csrf-check" yaml:"enable-csrf-check" usage:"enable crsf check between authorise and callback"` // AccessTokenDuration is default duration applied to the access token cookie AccessTokenDuration time.Duration `json:"access-token-duration" yaml:"access-token-duration" usage:"fallback cookie duration for the access token when using refresh tokens"` diff --git a/go.mod b/go.mod index 9e0b2a3b..d4fa9c21 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ require ( github.com/PuerkitoBio/purell v1.1.0 github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578 // indirect github.com/armon/go-proxyproto v0.0.0-20180202201750-5b7edb60ff5f + github.com/client9/misspell v0.3.4 // indirect github.com/codegangsta/negroni v1.0.0 // indirect github.com/coreos/go-oidc v0.0.0-20171020180921-e860bd55bfa7 github.com/coreos/pkg v0.0.0-20180928190104-399ea9e2e55f // indirect diff --git a/go.sum b/go.sum index 437e413c..77dbd71d 100644 --- a/go.sum +++ b/go.sum @@ -7,6 +7,8 @@ github.com/armon/go-proxyproto v0.0.0-20180202201750-5b7edb60ff5f h1:SaJ6yqg936T github.com/armon/go-proxyproto v0.0.0-20180202201750-5b7edb60ff5f/go.mod h1:QmP9hvJ91BbJmGVGSbutW19IC0Q9phDCLGaomwTJbgU= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973 h1:xJ4a3vCFaGF/jqvzLMYoU8P317H5OQ+Via4RmuPwCS0= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= +github.com/client9/misspell v0.3.4 h1:ta993UF76GwbvJcIo3Y68y/M3WxlpEHPWIGDkJYwzJI= +github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/codegangsta/negroni v1.0.0 h1:+aYywywx4bnKXWvoWtRfJ91vC59NbEhEY03sZjQhbVY= github.com/codegangsta/negroni v1.0.0/go.mod h1:v0y3T5G7Y1UlFfyxFn/QLRU4a2EuNau2iZY63YTKWo0= github.com/coreos/go-oidc v0.0.0-20171020180921-e860bd55bfa7 h1:UeXD8Kli+SWhDlj1ikNXs9NKHsm2SR9dVnGiKq86DJ4= diff --git a/handlers.go b/handlers.go index 87b55b73..9340bc52 100644 --- a/handlers.go +++ b/handlers.go @@ -58,11 +58,23 @@ func (r *oauthProxy) getRedirectionURL(w http.ResponseWriter, req *http.Request) redirect = r.config.RedirectionURL } - state, _ := req.Cookie(requestStateCookie) - if state != nil && req.URL.Query().Get("state") != state.Value { - r.log.Error("state parameter mismatch") - w.WriteHeader(http.StatusForbidden) - return "" + if r.config.EnableCSRFCheck { + state, _ := req.Cookie(requestStateCookie) + + stateParameter := req.URL.Query().Get("state") + stateParam := StateParameter{} + if stateParameter != "" { + err := json.Unmarshal([]byte(stateParameter), &stateParam) + if err != nil { + r.log.Warn("failed to deserialise state parameter from json") + } + } + + if state != nil && stateParam.Token != state.Value { + r.log.Error("state parameter mismatch") + w.WriteHeader(http.StatusForbidden) + return "" + } } return fmt.Sprintf("%s%s", redirect, r.config.WithOAuthURI("callback")) } @@ -209,8 +221,17 @@ func (r *oauthProxy) oauthCallbackHandler(w http.ResponseWriter, req *http.Reque // step: decode the request variable redirectURI := "/" - if req.URL.Query().Get("state") != "" { - if encodedRequestURI, _ := req.Cookie(requestURICookie); encodedRequestURI != nil { + stateParameter := req.URL.Query().Get("state") + if stateParameter != "" { + stateParam := StateParameter{} + err := json.Unmarshal([]byte(stateParameter), &stateParam) + if err != nil { + r.log.Warn("failed to deserialise state parameter from json") + } + + if stateParam.Url != "" { + redirectURI = stateParam.Url + } else if encodedRequestURI, _ := req.Cookie(requestURICookie); encodedRequestURI != nil { decoded, _ := base64.StdEncoding.DecodeString(encodedRequestURI.Value) redirectURI = string(decoded) } diff --git a/middleware_test.go b/middleware_test.go index 774bc629..0b4cbb0b 100644 --- a/middleware_test.go +++ b/middleware_test.go @@ -16,6 +16,7 @@ limitations under the License. package main import ( + "encoding/json" "fmt" "io/ioutil" "log" @@ -73,6 +74,8 @@ type fakeRequest struct { // advanced test cases ExpectedCookiesValidator map[string]func(string) bool + + ExpectedStateUrl string } type fakeProxy struct { @@ -223,11 +226,24 @@ func (f *fakeProxy) RunTests(t *testing.T, requests []fakeRequest) { if c.ExpectedCode != 0 { assert.Equal(t, c.ExpectedCode, status, "case %d, expected status code: %d, got: %d", i, c.ExpectedCode, status) } - if c.ExpectedLocation != "" { + if c.ExpectedLocation != "" || c.ExpectedStateUrl != "" { l, _ := url.Parse(resp.Header().Get("Location")) - assert.True(t, strings.Contains(l.String(), c.ExpectedLocation), "expected location to contain %s", l.String()) - if l.Query().Get("state") != "" { - state, err := uuid.FromString(l.Query().Get("state")) + if c.ExpectedLocation != "" { + assert.True(t, strings.Contains(l.String(), c.ExpectedLocation), "expected location to contain %s", l.String()) + } + stateStr := l.Query().Get("state") + if stateStr != "" { + stateParam := StateParameter{} + err := json.Unmarshal([]byte(stateStr), &stateParam) + if err != nil { + assert.Fail(t, "failed to deserialise state parameter from json, got: %s with error %s", stateStr, err) + } + + if c.ExpectedStateUrl != "" { + assert.Equal(t, c.ExpectedStateUrl, stateParam.Url, "expected state url to contain %s", stateParam.Url) + } + + state, err := uuid.FromString(stateParam.Token) if err != nil { assert.Fail(t, "expected state parameter with valid UUID, got: %s with error %s", state.String(), err) } diff --git a/misc.go b/misc.go index a0311356..2cc6619f 100644 --- a/misc.go +++ b/misc.go @@ -19,6 +19,7 @@ import ( "context" "fmt" "net/http" + "net/url" "path" "strings" "time" @@ -97,8 +98,13 @@ func (r *oauthProxy) redirectToAuthorization(w http.ResponseWriter, req *http.Re } // step: add a state referrer to the authorization page - uuid := r.writeStateParameterCookie(req, w) - authQuery := fmt.Sprintf("?state=%s", uuid) + state, err := r.writeStateParameterCookie(req, w) + if err != nil { + r.log.Error("failed to create state parameter") + w.WriteHeader(http.StatusInternalServerError) + return r.revokeProxy(w, req) + } + authQuery := fmt.Sprintf("?state=%s", url.QueryEscape(state)) // step: if verification is switched off, we can't authorization if r.config.SkipTokenVerification { diff --git a/misc_test.go b/misc_test.go index c416bd27..15daac9c 100644 --- a/misc_test.go +++ b/misc_test.go @@ -33,9 +33,10 @@ func TestRedirectToAuthorizationUnauthorized(t *testing.T) { func TestRedirectToAuthorization(t *testing.T) { requests := []fakeRequest{ { - URI: "/admin", + URI: "/admin?blah=1", Redirects: true, ExpectedLocation: "/oauth/authorize?state", + ExpectedStateUrl: "/admin?blah=1", ExpectedCode: http.StatusSeeOther, }, }