diff --git a/config.go b/config.go index 080af6ad..4280b86f 100644 --- a/config.go +++ b/config.go @@ -43,6 +43,7 @@ func newDefaultConfig() *Config { EnableDefaultDeny: true, EnableSessionCookies: true, EnableTokenHeader: true, + EnableCSRFCheck: false, HTTPOnlyCookie: true, Headers: make(map[string]string), LetsEncryptCacheDir: "./cache/", diff --git a/cookies.go b/cookies.go index bcd7dd24..9e2ef721 100644 --- a/cookies.go +++ b/cookies.go @@ -17,6 +17,7 @@ package main import ( "encoding/base64" + "encoding/json" "net/http" "strconv" "strings" @@ -119,8 +120,13 @@ func (r *oauthProxy) dropRefreshTokenCookie(req *http.Request, w http.ResponseWr r.dropCookieWithChunks(req, w, r.config.CookieRefreshName, value, duration) } +type StateParameter struct { + Token string `json:"token"` + Url string `json:"url"` +} + // writeStateParameterCookie sets a state parameter cookie into the response -func (r *oauthProxy) writeStateParameterCookie(req *http.Request, w http.ResponseWriter) string { +func (r *oauthProxy) writeStateParameterCookie(req *http.Request, w http.ResponseWriter) (string, error) { uuid, err := uuid.NewV4() if err != nil { w.WriteHeader(http.StatusInternalServerError) @@ -128,7 +134,12 @@ func (r *oauthProxy) writeStateParameterCookie(req *http.Request, w http.Respons requestURI := base64.StdEncoding.EncodeToString([]byte(req.URL.RequestURI())) r.dropCookie(w, req.Host, requestURICookie, requestURI, 0) r.dropCookie(w, req.Host, requestStateCookie, uuid.String(), 0) - return uuid.String() + + stateParam := StateParameter{Token: uuid.String(), + Url: req.URL.RequestURI()} + output, err := json.Marshal(stateParam) + + return string(output), err } // clearAllCookies is just a helper function for the below diff --git a/doc.go b/doc.go index c42c53f0..1aad5f2a 100644 --- a/doc.go +++ b/doc.go @@ -252,6 +252,8 @@ type Config struct { LocalhostMetrics bool `json:"localhost-metrics" yaml:"localhost-metrics" usage:"enforces the metrics page can only been requested from 127.0.0.1"` // EnableCompression enables gzip compression for response EnableCompression bool `json:"enable-compression" yaml:"enable-compression" usage:"enable gzip compression for response"` + // EnableCSRFCheck enables CSRF protection between authorise/callback using cookies and state parameter + EnableCSRFCheck bool `json:"enable-csrf-check" yaml:"enable-csrf-check" usage:"enable crsf check between authorise and callback"` // AccessTokenDuration is default duration applied to the access token cookie AccessTokenDuration time.Duration `json:"access-token-duration" yaml:"access-token-duration" usage:"fallback cookie duration for the access token when using refresh tokens"` diff --git a/handlers.go b/handlers.go index 87b55b73..9340bc52 100644 --- a/handlers.go +++ b/handlers.go @@ -58,11 +58,23 @@ func (r *oauthProxy) getRedirectionURL(w http.ResponseWriter, req *http.Request) redirect = r.config.RedirectionURL } - state, _ := req.Cookie(requestStateCookie) - if state != nil && req.URL.Query().Get("state") != state.Value { - r.log.Error("state parameter mismatch") - w.WriteHeader(http.StatusForbidden) - return "" + if r.config.EnableCSRFCheck { + state, _ := req.Cookie(requestStateCookie) + + stateParameter := req.URL.Query().Get("state") + stateParam := StateParameter{} + if stateParameter != "" { + err := json.Unmarshal([]byte(stateParameter), &stateParam) + if err != nil { + r.log.Warn("failed to deserialise state parameter from json") + } + } + + if state != nil && stateParam.Token != state.Value { + r.log.Error("state parameter mismatch") + w.WriteHeader(http.StatusForbidden) + return "" + } } return fmt.Sprintf("%s%s", redirect, r.config.WithOAuthURI("callback")) } @@ -209,8 +221,17 @@ func (r *oauthProxy) oauthCallbackHandler(w http.ResponseWriter, req *http.Reque // step: decode the request variable redirectURI := "/" - if req.URL.Query().Get("state") != "" { - if encodedRequestURI, _ := req.Cookie(requestURICookie); encodedRequestURI != nil { + stateParameter := req.URL.Query().Get("state") + if stateParameter != "" { + stateParam := StateParameter{} + err := json.Unmarshal([]byte(stateParameter), &stateParam) + if err != nil { + r.log.Warn("failed to deserialise state parameter from json") + } + + if stateParam.Url != "" { + redirectURI = stateParam.Url + } else if encodedRequestURI, _ := req.Cookie(requestURICookie); encodedRequestURI != nil { decoded, _ := base64.StdEncoding.DecodeString(encodedRequestURI.Value) redirectURI = string(decoded) } diff --git a/misc.go b/misc.go index a0311356..2cc6619f 100644 --- a/misc.go +++ b/misc.go @@ -19,6 +19,7 @@ import ( "context" "fmt" "net/http" + "net/url" "path" "strings" "time" @@ -97,8 +98,13 @@ func (r *oauthProxy) redirectToAuthorization(w http.ResponseWriter, req *http.Re } // step: add a state referrer to the authorization page - uuid := r.writeStateParameterCookie(req, w) - authQuery := fmt.Sprintf("?state=%s", uuid) + state, err := r.writeStateParameterCookie(req, w) + if err != nil { + r.log.Error("failed to create state parameter") + w.WriteHeader(http.StatusInternalServerError) + return r.revokeProxy(w, req) + } + authQuery := fmt.Sprintf("?state=%s", url.QueryEscape(state)) // step: if verification is switched off, we can't authorization if r.config.SkipTokenVerification {