Skip to content

Commit

Permalink
cb/sm: add option to disable csrf; encode state parameter with redire…
Browse files Browse the repository at this point in the history
…ct URL in json
  • Loading branch information
Limess committed Jan 19, 2021
1 parent c1e1afe commit 9543369
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 11 deletions.
1 change: 1 addition & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ func newDefaultConfig() *Config {
EnableDefaultDeny: true,
EnableSessionCookies: true,
EnableTokenHeader: true,
EnableCSRFCheck: false,
HTTPOnlyCookie: true,
Headers: make(map[string]string),
LetsEncryptCacheDir: "./cache/",
Expand Down
15 changes: 13 additions & 2 deletions cookies.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package main

import (
"encoding/base64"
"encoding/json"
"net/http"
"strconv"
"strings"
Expand Down Expand Up @@ -119,16 +120,26 @@ func (r *oauthProxy) dropRefreshTokenCookie(req *http.Request, w http.ResponseWr
r.dropCookieWithChunks(req, w, r.config.CookieRefreshName, value, duration)
}

type StateParameter struct {
Token string `json:"token"`
Url string `json:"url"`
}

// writeStateParameterCookie sets a state parameter cookie into the response
func (r *oauthProxy) writeStateParameterCookie(req *http.Request, w http.ResponseWriter) string {
func (r *oauthProxy) writeStateParameterCookie(req *http.Request, w http.ResponseWriter) (string, error) {
uuid, err := uuid.NewV4()
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
}
requestURI := base64.StdEncoding.EncodeToString([]byte(req.URL.RequestURI()))
r.dropCookie(w, req.Host, requestURICookie, requestURI, 0)
r.dropCookie(w, req.Host, requestStateCookie, uuid.String(), 0)
return uuid.String()

stateParam := StateParameter{Token: uuid.String(),
Url: req.URL.RequestURI()}
output, err := json.Marshal(stateParam)

return string(output), err
}

// clearAllCookies is just a helper function for the below
Expand Down
2 changes: 2 additions & 0 deletions doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,8 @@ type Config struct {
LocalhostMetrics bool `json:"localhost-metrics" yaml:"localhost-metrics" usage:"enforces the metrics page can only been requested from 127.0.0.1"`
// EnableCompression enables gzip compression for response
EnableCompression bool `json:"enable-compression" yaml:"enable-compression" usage:"enable gzip compression for response"`
// EnableCSRFCheck enables CSRF protection between authorise/callback using cookies and state parameter
EnableCSRFCheck bool `json:"enable-csrf-check" yaml:"enable-csrf-check" usage:"enable crsf check between authorise and callback"`

// AccessTokenDuration is default duration applied to the access token cookie
AccessTokenDuration time.Duration `json:"access-token-duration" yaml:"access-token-duration" usage:"fallback cookie duration for the access token when using refresh tokens"`
Expand Down
35 changes: 28 additions & 7 deletions handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,23 @@ func (r *oauthProxy) getRedirectionURL(w http.ResponseWriter, req *http.Request)
redirect = r.config.RedirectionURL
}

state, _ := req.Cookie(requestStateCookie)
if state != nil && req.URL.Query().Get("state") != state.Value {
r.log.Error("state parameter mismatch")
w.WriteHeader(http.StatusForbidden)
return ""
if r.config.EnableCSRFCheck {
state, _ := req.Cookie(requestStateCookie)

stateParameter := req.URL.Query().Get("state")
stateParam := StateParameter{}
if stateParameter != "" {
err := json.Unmarshal([]byte(stateParameter), &stateParam)
if err != nil {
r.log.Warn("failed to deserialise state parameter from json")
}
}

if state != nil && stateParam.Token != state.Value {
r.log.Error("state parameter mismatch")
w.WriteHeader(http.StatusForbidden)
return ""
}
}
return fmt.Sprintf("%s%s", redirect, r.config.WithOAuthURI("callback"))
}
Expand Down Expand Up @@ -209,8 +221,17 @@ func (r *oauthProxy) oauthCallbackHandler(w http.ResponseWriter, req *http.Reque

// step: decode the request variable
redirectURI := "/"
if req.URL.Query().Get("state") != "" {
if encodedRequestURI, _ := req.Cookie(requestURICookie); encodedRequestURI != nil {
stateParameter := req.URL.Query().Get("state")
if stateParameter != "" {
stateParam := StateParameter{}
err := json.Unmarshal([]byte(stateParameter), &stateParam)
if err != nil {
r.log.Warn("failed to deserialise state parameter from json")
}

if stateParam.Url != "" {
redirectURI = stateParam.Url
} else if encodedRequestURI, _ := req.Cookie(requestURICookie); encodedRequestURI != nil {
decoded, _ := base64.StdEncoding.DecodeString(encodedRequestURI.Value)
redirectURI = string(decoded)
}
Expand Down
10 changes: 8 additions & 2 deletions misc.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"context"
"fmt"
"net/http"
"net/url"
"path"
"strings"
"time"
Expand Down Expand Up @@ -97,8 +98,13 @@ func (r *oauthProxy) redirectToAuthorization(w http.ResponseWriter, req *http.Re
}

// step: add a state referrer to the authorization page
uuid := r.writeStateParameterCookie(req, w)
authQuery := fmt.Sprintf("?state=%s", uuid)
state, err := r.writeStateParameterCookie(req, w)
if err != nil {
r.log.Error("failed to create state parameter")
w.WriteHeader(http.StatusInternalServerError)
return r.revokeProxy(w, req)
}
authQuery := fmt.Sprintf("?state=%s", url.QueryEscape(state))

// step: if verification is switched off, we can't authorization
if r.config.SkipTokenVerification {
Expand Down

0 comments on commit 9543369

Please sign in to comment.