Skip to content

Commit

Permalink
cb/sm: add option to disable csrf; encode state parameter with redire…
Browse files Browse the repository at this point in the history
…ct URL in json
  • Loading branch information
Limess authored and dosyara committed Jan 19, 2021
1 parent c1e1afe commit 32d424a
Show file tree
Hide file tree
Showing 9 changed files with 77 additions and 16 deletions.
1 change: 1 addition & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ func newDefaultConfig() *Config {
EnableDefaultDeny: true,
EnableSessionCookies: true,
EnableTokenHeader: true,
EnableCSRFCheck: false,
HTTPOnlyCookie: true,
Headers: make(map[string]string),
LetsEncryptCacheDir: "./cache/",
Expand Down
15 changes: 13 additions & 2 deletions cookies.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package main

import (
"encoding/base64"
"encoding/json"
"net/http"
"strconv"
"strings"
Expand Down Expand Up @@ -119,16 +120,26 @@ func (r *oauthProxy) dropRefreshTokenCookie(req *http.Request, w http.ResponseWr
r.dropCookieWithChunks(req, w, r.config.CookieRefreshName, value, duration)
}

type StateParameter struct {
Token string `json:"token"`
Url string `json:"url"`
}

// writeStateParameterCookie sets a state parameter cookie into the response
func (r *oauthProxy) writeStateParameterCookie(req *http.Request, w http.ResponseWriter) string {
func (r *oauthProxy) writeStateParameterCookie(req *http.Request, w http.ResponseWriter) (string, error) {
uuid, err := uuid.NewV4()
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
}
requestURI := base64.StdEncoding.EncodeToString([]byte(req.URL.RequestURI()))
r.dropCookie(w, req.Host, requestURICookie, requestURI, 0)
r.dropCookie(w, req.Host, requestStateCookie, uuid.String(), 0)
return uuid.String()

stateParam := StateParameter{Token: uuid.String(),
Url: req.URL.RequestURI()}
output, err := json.Marshal(stateParam)

return string(output), err
}

// clearAllCookies is just a helper function for the below
Expand Down
2 changes: 2 additions & 0 deletions doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,8 @@ type Config struct {
LocalhostMetrics bool `json:"localhost-metrics" yaml:"localhost-metrics" usage:"enforces the metrics page can only been requested from 127.0.0.1"`
// EnableCompression enables gzip compression for response
EnableCompression bool `json:"enable-compression" yaml:"enable-compression" usage:"enable gzip compression for response"`
// EnableCSRFCheck enables CSRF protection between authorise/callback using cookies and state parameter
EnableCSRFCheck bool `json:"enable-csrf-check" yaml:"enable-csrf-check" usage:"enable crsf check between authorise and callback"`

// AccessTokenDuration is default duration applied to the access token cookie
AccessTokenDuration time.Duration `json:"access-token-duration" yaml:"access-token-duration" usage:"fallback cookie duration for the access token when using refresh tokens"`
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ require (
github.com/PuerkitoBio/purell v1.1.0
github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578 // indirect
github.com/armon/go-proxyproto v0.0.0-20180202201750-5b7edb60ff5f
github.com/client9/misspell v0.3.4 // indirect
github.com/codegangsta/negroni v1.0.0 // indirect
github.com/coreos/go-oidc v0.0.0-20171020180921-e860bd55bfa7
github.com/coreos/pkg v0.0.0-20180928190104-399ea9e2e55f // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ github.com/armon/go-proxyproto v0.0.0-20180202201750-5b7edb60ff5f h1:SaJ6yqg936T
github.com/armon/go-proxyproto v0.0.0-20180202201750-5b7edb60ff5f/go.mod h1:QmP9hvJ91BbJmGVGSbutW19IC0Q9phDCLGaomwTJbgU=
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973 h1:xJ4a3vCFaGF/jqvzLMYoU8P317H5OQ+Via4RmuPwCS0=
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q=
github.com/client9/misspell v0.3.4 h1:ta993UF76GwbvJcIo3Y68y/M3WxlpEHPWIGDkJYwzJI=
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
github.com/codegangsta/negroni v1.0.0 h1:+aYywywx4bnKXWvoWtRfJ91vC59NbEhEY03sZjQhbVY=
github.com/codegangsta/negroni v1.0.0/go.mod h1:v0y3T5G7Y1UlFfyxFn/QLRU4a2EuNau2iZY63YTKWo0=
github.com/coreos/go-oidc v0.0.0-20171020180921-e860bd55bfa7 h1:UeXD8Kli+SWhDlj1ikNXs9NKHsm2SR9dVnGiKq86DJ4=
Expand Down
35 changes: 28 additions & 7 deletions handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,23 @@ func (r *oauthProxy) getRedirectionURL(w http.ResponseWriter, req *http.Request)
redirect = r.config.RedirectionURL
}

state, _ := req.Cookie(requestStateCookie)
if state != nil && req.URL.Query().Get("state") != state.Value {
r.log.Error("state parameter mismatch")
w.WriteHeader(http.StatusForbidden)
return ""
if r.config.EnableCSRFCheck {
state, _ := req.Cookie(requestStateCookie)

stateParameter := req.URL.Query().Get("state")
stateParam := StateParameter{}
if stateParameter != "" {
err := json.Unmarshal([]byte(stateParameter), &stateParam)
if err != nil {
r.log.Warn("failed to deserialise state parameter from json")
}
}

if state != nil && stateParam.Token != state.Value {
r.log.Error("state parameter mismatch")
w.WriteHeader(http.StatusForbidden)
return ""
}
}
return fmt.Sprintf("%s%s", redirect, r.config.WithOAuthURI("callback"))
}
Expand Down Expand Up @@ -209,8 +221,17 @@ func (r *oauthProxy) oauthCallbackHandler(w http.ResponseWriter, req *http.Reque

// step: decode the request variable
redirectURI := "/"
if req.URL.Query().Get("state") != "" {
if encodedRequestURI, _ := req.Cookie(requestURICookie); encodedRequestURI != nil {
stateParameter := req.URL.Query().Get("state")
if stateParameter != "" {
stateParam := StateParameter{}
err := json.Unmarshal([]byte(stateParameter), &stateParam)
if err != nil {
r.log.Warn("failed to deserialise state parameter from json")
}

if stateParam.Url != "" {
redirectURI = stateParam.Url
} else if encodedRequestURI, _ := req.Cookie(requestURICookie); encodedRequestURI != nil {
decoded, _ := base64.StdEncoding.DecodeString(encodedRequestURI.Value)
redirectURI = string(decoded)
}
Expand Down
24 changes: 20 additions & 4 deletions middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
package main

import (
"encoding/json"
"fmt"
"io/ioutil"
"log"
Expand Down Expand Up @@ -73,6 +74,8 @@ type fakeRequest struct {

// advanced test cases
ExpectedCookiesValidator map[string]func(string) bool

ExpectedStateUrl string
}

type fakeProxy struct {
Expand Down Expand Up @@ -223,11 +226,24 @@ func (f *fakeProxy) RunTests(t *testing.T, requests []fakeRequest) {
if c.ExpectedCode != 0 {
assert.Equal(t, c.ExpectedCode, status, "case %d, expected status code: %d, got: %d", i, c.ExpectedCode, status)
}
if c.ExpectedLocation != "" {
if c.ExpectedLocation != "" || c.ExpectedStateUrl != "" {
l, _ := url.Parse(resp.Header().Get("Location"))
assert.True(t, strings.Contains(l.String(), c.ExpectedLocation), "expected location to contain %s", l.String())
if l.Query().Get("state") != "" {
state, err := uuid.FromString(l.Query().Get("state"))
if c.ExpectedLocation != "" {
assert.True(t, strings.Contains(l.String(), c.ExpectedLocation), "expected location to contain %s", l.String())
}
stateStr := l.Query().Get("state")
if stateStr != "" {
stateParam := StateParameter{}
err := json.Unmarshal([]byte(stateStr), &stateParam)
if err != nil {
assert.Fail(t, "failed to deserialise state parameter from json, got: %s with error %s", stateStr, err)
}

if c.ExpectedStateUrl != "" {
assert.Equal(t, c.ExpectedStateUrl, stateParam.Url, "expected state url to contain %s", stateParam.Url)
}

state, err := uuid.FromString(stateParam.Token)
if err != nil {
assert.Fail(t, "expected state parameter with valid UUID, got: %s with error %s", state.String(), err)
}
Expand Down
10 changes: 8 additions & 2 deletions misc.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"context"
"fmt"
"net/http"
"net/url"
"path"
"strings"
"time"
Expand Down Expand Up @@ -97,8 +98,13 @@ func (r *oauthProxy) redirectToAuthorization(w http.ResponseWriter, req *http.Re
}

// step: add a state referrer to the authorization page
uuid := r.writeStateParameterCookie(req, w)
authQuery := fmt.Sprintf("?state=%s", uuid)
state, err := r.writeStateParameterCookie(req, w)
if err != nil {
r.log.Error("failed to create state parameter")
w.WriteHeader(http.StatusInternalServerError)
return r.revokeProxy(w, req)
}
authQuery := fmt.Sprintf("?state=%s", url.QueryEscape(state))

// step: if verification is switched off, we can't authorization
if r.config.SkipTokenVerification {
Expand Down
3 changes: 2 additions & 1 deletion misc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,10 @@ func TestRedirectToAuthorizationUnauthorized(t *testing.T) {
func TestRedirectToAuthorization(t *testing.T) {
requests := []fakeRequest{
{
URI: "/admin",
URI: "/admin?blah=1",
Redirects: true,
ExpectedLocation: "/oauth/authorize?state",
ExpectedStateUrl: "/admin?blah=1",
ExpectedCode: http.StatusSeeOther,
},
}
Expand Down

0 comments on commit 32d424a

Please sign in to comment.