From a32a6ed13515f2bca148386dcab3fcb231f484f1 Mon Sep 17 00:00:00 2001 From: Nikolas Sepos Date: Sun, 15 Nov 2020 19:19:37 +0200 Subject: [PATCH 1/3] implement PKCE for AuthorizationCode grant --- const.go | 36 +++++++++++++ const_test.go | 21 ++++++++ errors/error.go | 3 ++ errors/response.go | 69 +++++++++++++----------- example/client/client.go | 13 ++++- manage.go | 21 ++++---- manage/manager.go | 29 +++++++++++ model.go | 4 ++ models/token.go | 48 ++++++++++++----- server/config.go | 32 +++++++----- server/server.go | 68 +++++++++++++++++++----- server/server_test.go | 110 +++++++++++++++++++++++++++++++++++++++ 12 files changed, 375 insertions(+), 79 deletions(-) create mode 100644 const_test.go diff --git a/const.go b/const.go index 3f81a7f..3309bc4 100644 --- a/const.go +++ b/const.go @@ -1,5 +1,10 @@ package oauth2 +import ( + "crypto/sha256" + "encoding/base64" +) + // ResponseType the type of authorization request type ResponseType string @@ -34,3 +39,34 @@ func (gt GrantType) String() string { } return "" } + +// CodeChallengeMethod PCKE method +type CodeChallengeMethod string + +const ( + // CodeChallengePlain PCKE Method + CodeChallengePlain CodeChallengeMethod = "plain" + // CodeChallengeS256 PCKE Method + CodeChallengeS256 CodeChallengeMethod = "S256" +) + +func (ccm CodeChallengeMethod) String() string { + if ccm == CodeChallengePlain || + ccm == CodeChallengeS256 { + return string(ccm) + } + return "" +} + +// Validate code challenge +func (ccm CodeChallengeMethod) Validate(cc, ver string) bool { + switch ccm { + case CodeChallengePlain: + return cc == ver + case CodeChallengeS256: + s256 := sha256.Sum256([]byte(ver)) + return base64.URLEncoding.EncodeToString(s256[:]) == cc + default: + return false + } +} diff --git a/const_test.go b/const_test.go new file mode 100644 index 0000000..7dbdcea --- /dev/null +++ b/const_test.go @@ -0,0 +1,21 @@ +package oauth2_test + +import ( + "testing" + + "github.com/go-oauth2/oauth2/v4" +) + +func TestValidatePlain(t *testing.T) { + cc := oauth2.CodeChallengePlain + if !cc.Validate("plaintest", "plaintest") { + t.Fatal("not valid") + } +} + +func TestValidateS256(t *testing.T) { + cc := oauth2.CodeChallengeS256 + if !cc.Validate("W6YWc_4yHwYN-cGDgGmOMHF3l7KDy7VcRjf7q2FVF-o=", "s256test") { + t.Fatal("not valid") + } +} diff --git a/errors/error.go b/errors/error.go index c106563..71ae046 100644 --- a/errors/error.go +++ b/errors/error.go @@ -13,4 +13,7 @@ var ( ErrInvalidRefreshToken = errors.New("invalid refresh token") ErrExpiredAccessToken = errors.New("expired access token") ErrExpiredRefreshToken = errors.New("expired refresh token") + ErrMissingCodeVerifier = errors.New("missing code verifier") + ErrMissingCodeChallenge = errors.New("missing code challenge") + ErrInvalidCodeChallenge = errors.New("invalid code challenge") ) diff --git a/errors/response.go b/errors/response.go index 5ba1199..c8d5902 100644 --- a/errors/response.go +++ b/errors/response.go @@ -34,42 +34,51 @@ func (r *Response) SetHeader(key, value string) { // https://tools.ietf.org/html/rfc6749#section-5.2 var ( - ErrInvalidRequest = errors.New("invalid_request") - ErrUnauthorizedClient = errors.New("unauthorized_client") - ErrAccessDenied = errors.New("access_denied") - ErrUnsupportedResponseType = errors.New("unsupported_response_type") - ErrInvalidScope = errors.New("invalid_scope") - ErrServerError = errors.New("server_error") - ErrTemporarilyUnavailable = errors.New("temporarily_unavailable") - ErrInvalidClient = errors.New("invalid_client") - ErrInvalidGrant = errors.New("invalid_grant") - ErrUnsupportedGrantType = errors.New("unsupported_grant_type") + ErrInvalidRequest = errors.New("invalid_request") + ErrUnauthorizedClient = errors.New("unauthorized_client") + ErrAccessDenied = errors.New("access_denied") + ErrUnsupportedResponseType = errors.New("unsupported_response_type") + ErrInvalidScope = errors.New("invalid_scope") + ErrServerError = errors.New("server_error") + ErrTemporarilyUnavailable = errors.New("temporarily_unavailable") + ErrInvalidClient = errors.New("invalid_client") + ErrInvalidGrant = errors.New("invalid_grant") + ErrUnsupportedGrantType = errors.New("unsupported_grant_type") + ErrCodeChallengeRquired = errors.New("invalid_request") + ErrUnsupportedCodeChallengeMethod = errors.New("invalid_request") + ErrInvalidCodeChallengeLen = errors.New("invalid_request") ) // Descriptions error description var Descriptions = map[error]string{ - ErrInvalidRequest: "The request is missing a required parameter, includes an invalid parameter value, includes a parameter more than once, or is otherwise malformed", - ErrUnauthorizedClient: "The client is not authorized to request an authorization code using this method", - ErrAccessDenied: "The resource owner or authorization server denied the request", - ErrUnsupportedResponseType: "The authorization server does not support obtaining an authorization code using this method", - ErrInvalidScope: "The requested scope is invalid, unknown, or malformed", - ErrServerError: "The authorization server encountered an unexpected condition that prevented it from fulfilling the request", - ErrTemporarilyUnavailable: "The authorization server is currently unable to handle the request due to a temporary overloading or maintenance of the server", - ErrInvalidClient: "Client authentication failed", - ErrInvalidGrant: "The provided authorization grant (e.g., authorization code, resource owner credentials) or refresh token is invalid, expired, revoked, does not match the redirection URI used in the authorization request, or was issued to another client", - ErrUnsupportedGrantType: "The authorization grant type is not supported by the authorization server", + ErrInvalidRequest: "The request is missing a required parameter, includes an invalid parameter value, includes a parameter more than once, or is otherwise malformed", + ErrUnauthorizedClient: "The client is not authorized to request an authorization code using this method", + ErrAccessDenied: "The resource owner or authorization server denied the request", + ErrUnsupportedResponseType: "The authorization server does not support obtaining an authorization code using this method", + ErrInvalidScope: "The requested scope is invalid, unknown, or malformed", + ErrServerError: "The authorization server encountered an unexpected condition that prevented it from fulfilling the request", + ErrTemporarilyUnavailable: "The authorization server is currently unable to handle the request due to a temporary overloading or maintenance of the server", + ErrInvalidClient: "Client authentication failed", + ErrInvalidGrant: "The provided authorization grant (e.g., authorization code, resource owner credentials) or refresh token is invalid, expired, revoked, does not match the redirection URI used in the authorization request, or was issued to another client", + ErrUnsupportedGrantType: "The authorization grant type is not supported by the authorization server", + ErrCodeChallengeRquired: "PKCE is required. code_challenge is missing", + ErrUnsupportedCodeChallengeMethod: "Selected code_challenge_method not supported", + ErrInvalidCodeChallengeLen: "Code challenge length must be between 43 and 128 charachters long", } // StatusCodes response error HTTP status code var StatusCodes = map[error]int{ - ErrInvalidRequest: 400, - ErrUnauthorizedClient: 401, - ErrAccessDenied: 403, - ErrUnsupportedResponseType: 401, - ErrInvalidScope: 400, - ErrServerError: 500, - ErrTemporarilyUnavailable: 503, - ErrInvalidClient: 401, - ErrInvalidGrant: 401, - ErrUnsupportedGrantType: 401, + ErrInvalidRequest: 400, + ErrUnauthorizedClient: 401, + ErrAccessDenied: 403, + ErrUnsupportedResponseType: 401, + ErrInvalidScope: 400, + ErrServerError: 500, + ErrTemporarilyUnavailable: 503, + ErrInvalidClient: 401, + ErrInvalidGrant: 401, + ErrUnsupportedGrantType: 401, + ErrCodeChallengeRquired: 400, + ErrUnsupportedCodeChallengeMethod: 400, + ErrInvalidCodeChallengeLen: 400, } diff --git a/example/client/client.go b/example/client/client.go index 1a21676..2999330 100644 --- a/example/client/client.go +++ b/example/client/client.go @@ -2,6 +2,8 @@ package main import ( "context" + "crypto/sha256" + "encoding/base64" "encoding/json" "fmt" "io" @@ -33,7 +35,9 @@ var ( func main() { http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - u := config.AuthCodeURL("xyz") + u := config.AuthCodeURL("xyz", + oauth2.SetAuthURLParam("code_challenge", genCodeChallengeS256("s256example")), + oauth2.SetAuthURLParam("code_challenge_method", "S256")) http.Redirect(w, r, u, http.StatusFound) }) @@ -49,7 +53,7 @@ func main() { http.Error(w, "Code not found", http.StatusBadRequest) return } - token, err := config.Exchange(context.Background(), code) + token, err := config.Exchange(context.Background(), code, oauth2.SetAuthURLParam("code_verifier", "s256example")) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return @@ -130,3 +134,8 @@ func main() { log.Println("Client is running at 9094 port.Please open http://localhost:9094") log.Fatal(http.ListenAndServe(":9094", nil)) } + +func genCodeChallengeS256(s string) string { + s256 := sha256.Sum256([]byte(s)) + return base64.URLEncoding.EncodeToString(s256[:]) +} diff --git a/manage.go b/manage.go index 3b408e1..5c0bdf8 100644 --- a/manage.go +++ b/manage.go @@ -8,15 +8,18 @@ import ( // TokenGenerateRequest provide to generate the token request parameters type TokenGenerateRequest struct { - ClientID string - ClientSecret string - UserID string - RedirectURI string - Scope string - Code string - Refresh string - AccessTokenExp time.Duration - Request *http.Request + ClientID string + ClientSecret string + UserID string + RedirectURI string + Scope string + Code string + CodeChallenge string + CodeChallengeMethod CodeChallengeMethod + Refresh string + CodeVerifier string + AccessTokenExp time.Duration + Request *http.Request } // Manager authorization management interface diff --git a/manage/manager.go b/manage/manager.go index b8fb01e..1d48f77 100755 --- a/manage/manager.go +++ b/manage/manager.go @@ -176,6 +176,10 @@ func (m *Manager) GenerateAuthToken(ctx context.Context, rt oauth2.ResponseType, if exp := tgr.AccessTokenExp; exp > 0 { ti.SetAccessExpiresIn(exp) } + if tgr.CodeChallenge != "" { + ti.SetCodeChallenge(tgr.CodeChallenge) + ti.SetCodeChallengeMethod(tgr.CodeChallengeMethod) + } tv, err := m.authorizeGenerate.Token(ctx, td) if err != nil { @@ -251,6 +255,28 @@ func (m *Manager) getAndDelAuthorizationCode(ctx context.Context, tgr *oauth2.To return ti, nil } +func (m *Manager) validateCodeChallenge(ti oauth2.TokenInfo, ver string) error { + cc := ti.GetCodeChallenge() + // early return + if cc == "" && ver == "" { + return nil + } + if cc == "" { + return errors.ErrMissingCodeVerifier + } + if ver == "" { + return errors.New("missing code verifier") + } + ccm := ti.GetCodeChallengeMethod() + if ccm.String() == "" { + ccm = oauth2.CodeChallengePlain + } + if !ccm.Validate(cc, ver) { + return errors.ErrInvalidCodeChallenge + } + return nil +} + // GenerateAccessToken generate the access token func (m *Manager) GenerateAccessToken(ctx context.Context, gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { cli, err := m.GetClient(ctx, tgr.ClientID) @@ -275,6 +301,9 @@ func (m *Manager) GenerateAccessToken(ctx context.Context, gt oauth2.GrantType, if err != nil { return nil, err } + if err := m.validateCodeChallenge(ti, tgr.CodeVerifier); err != nil { + return nil, err + } tgr.UserID = ti.GetUserID() tgr.Scope = ti.GetScope() if exp := ti.GetAccessExpiresIn(); exp > 0 { diff --git a/model.go b/model.go index 26e5441..121a42d 100644 --- a/model.go +++ b/model.go @@ -37,6 +37,10 @@ type ( SetCodeCreateAt(time.Time) GetCodeExpiresIn() time.Duration SetCodeExpiresIn(time.Duration) + GetCodeChallenge() string + SetCodeChallenge(string) + GetCodeChallengeMethod() CodeChallengeMethod + SetCodeChallengeMethod(CodeChallengeMethod) GetAccess() string SetAccess(string) diff --git a/models/token.go b/models/token.go index f841823..882d823 100644 --- a/models/token.go +++ b/models/token.go @@ -13,19 +13,21 @@ func NewToken() *Token { // Token token model type Token struct { - ClientID string `bson:"ClientID"` - UserID string `bson:"UserID"` - RedirectURI string `bson:"RedirectURI"` - Scope string `bson:"Scope"` - Code string `bson:"Code"` - CodeCreateAt time.Time `bson:"CodeCreateAt"` - CodeExpiresIn time.Duration `bson:"CodeExpiresIn"` - Access string `bson:"Access"` - AccessCreateAt time.Time `bson:"AccessCreateAt"` - AccessExpiresIn time.Duration `bson:"AccessExpiresIn"` - Refresh string `bson:"Refresh"` - RefreshCreateAt time.Time `bson:"RefreshCreateAt"` - RefreshExpiresIn time.Duration `bson:"RefreshExpiresIn"` + ClientID string `bson:"ClientID"` + UserID string `bson:"UserID"` + RedirectURI string `bson:"RedirectURI"` + Scope string `bson:"Scope"` + Code string `bson:"Code"` + CodeChallenge string `bson:"CodeChallenge"` + CodeChallengeMethod string `bson:"CodeChallengeMethod"` + CodeCreateAt time.Time `bson:"CodeCreateAt"` + CodeExpiresIn time.Duration `bson:"CodeExpiresIn"` + Access string `bson:"Access"` + AccessCreateAt time.Time `bson:"AccessCreateAt"` + AccessExpiresIn time.Duration `bson:"AccessExpiresIn"` + Refresh string `bson:"Refresh"` + RefreshCreateAt time.Time `bson:"RefreshCreateAt"` + RefreshExpiresIn time.Duration `bson:"RefreshExpiresIn"` } // New create to token model instance @@ -103,6 +105,26 @@ func (t *Token) SetCodeExpiresIn(exp time.Duration) { t.CodeExpiresIn = exp } +// GetCodeChallenge challenge code +func (t *Token) GetCodeChallenge() string { + return t.CodeChallenge +} + +// SetCodeChallenge challenge code +func (t *Token) SetCodeChallenge(code string) { + t.CodeChallenge = code +} + +// GetCodeChallengeMethod challenge method +func (t *Token) GetCodeChallengeMethod() oauth2.CodeChallengeMethod { + return oauth2.CodeChallengeMethod(t.CodeChallengeMethod) +} + +// SetCodeChallengeMethod challenge method +func (t *Token) SetCodeChallengeMethod(method oauth2.CodeChallengeMethod) { + t.CodeChallengeMethod = string(method) +} + // GetAccess access Token func (t *Token) GetAccess() string { return t.Access diff --git a/server/config.go b/server/config.go index f5fc5d2..3bbb884 100644 --- a/server/config.go +++ b/server/config.go @@ -9,10 +9,12 @@ import ( // Config configuration parameters type Config struct { - TokenType string // token type - AllowGetAccessRequest bool // to allow GET requests for the token - AllowedResponseTypes []oauth2.ResponseType // allow the authorization type - AllowedGrantTypes []oauth2.GrantType // allow the grant type + TokenType string // token type + AllowGetAccessRequest bool // to allow GET requests for the token + AllowedResponseTypes []oauth2.ResponseType // allow the authorization type + AllowedGrantTypes []oauth2.GrantType // allow the grant type + AllowedCodeChallengeMethods []oauth2.CodeChallengeMethod + ForcePKCE bool } // NewConfig create to configuration instance @@ -26,17 +28,23 @@ func NewConfig() *Config { oauth2.ClientCredentials, oauth2.Refreshing, }, + AllowedCodeChallengeMethods: []oauth2.CodeChallengeMethod{ + oauth2.CodeChallengePlain, + oauth2.CodeChallengeS256, + }, } } // AuthorizeRequest authorization request type AuthorizeRequest struct { - ResponseType oauth2.ResponseType - ClientID string - Scope string - RedirectURI string - State string - UserID string - AccessTokenExp time.Duration - Request *http.Request + ResponseType oauth2.ResponseType + ClientID string + Scope string + RedirectURI string + State string + UserID string + CodeChallenge string + CodeChallengeMethod oauth2.CodeChallengeMethod + AccessTokenExp time.Duration + Request *http.Request } diff --git a/server/server.go b/server/server.go index ca1cd94..380f956 100755 --- a/server/server.go +++ b/server/server.go @@ -138,6 +138,16 @@ func (s *Server) CheckResponseType(rt oauth2.ResponseType) bool { return false } +// CheckCodeChallengeMethod checks for allowed code challenge method +func (s *Server) CheckCodeChallengeMethod(ccm oauth2.CodeChallengeMethod) bool { + for _, c := range s.Config.AllowedCodeChallengeMethods { + if c == ccm { + return true + } + } + return false +} + // ValidationAuthorizeRequest the authorization request validation func (s *Server) ValidationAuthorizeRequest(r *http.Request) (*AuthorizeRequest, error) { redirectURI := r.FormValue("redirect_uri") @@ -154,13 +164,34 @@ func (s *Server) ValidationAuthorizeRequest(r *http.Request) (*AuthorizeRequest, return nil, errors.ErrUnauthorizedClient } + cc := r.FormValue("code_challenge") + if cc == "" && s.Config.ForcePKCE { + return nil, errors.ErrCodeChallengeRquired + } + if cc != "" { + if len(cc) < 43 || len(cc) > 128 { + return nil, errors.ErrInvalidCodeChallengeLen + } + } + + ccm := oauth2.CodeChallengeMethod(r.FormValue("code_challenge_method")) + // set default + if ccm == "" { + ccm = oauth2.CodeChallengePlain + } + if ccm.String() != "" && !s.CheckCodeChallengeMethod(ccm) { + return nil, errors.ErrUnsupportedCodeChallengeMethod + } + req := &AuthorizeRequest{ - RedirectURI: redirectURI, - ResponseType: resType, - ClientID: clientID, - State: r.FormValue("state"), - Scope: r.FormValue("scope"), - Request: r, + RedirectURI: redirectURI, + ResponseType: resType, + ClientID: clientID, + State: r.FormValue("state"), + Scope: r.FormValue("scope"), + Request: r, + CodeChallenge: cc, + CodeChallengeMethod: ccm, } return req, nil } @@ -193,12 +224,14 @@ func (s *Server) GetAuthorizeToken(ctx context.Context, req *AuthorizeRequest) ( } tgr := &oauth2.TokenGenerateRequest{ - ClientID: req.ClientID, - UserID: req.UserID, - RedirectURI: req.RedirectURI, - Scope: req.Scope, - AccessTokenExp: req.AccessTokenExp, - Request: req.Request, + ClientID: req.ClientID, + UserID: req.UserID, + RedirectURI: req.RedirectURI, + Scope: req.Scope, + AccessTokenExp: req.AccessTokenExp, + Request: req.Request, + CodeChallenge: req.CodeChallenge, + CodeChallengeMethod: req.CodeChallengeMethod, } return s.Manager.GenerateAuthToken(ctx, req.ResponseType, tgr) } @@ -279,6 +312,13 @@ func (s *Server) ValidationTokenRequest(r *http.Request) (oauth2.GrantType, *oau return "", nil, errors.ErrUnsupportedGrantType } + codeVer := r.FormValue("code_verifier") + if s.Config.ForcePKCE { + if codeVer == "" { + return "", nil, errors.ErrInvalidRequest + } + } + clientID, clientSecret, err := s.ClientInfoHandler(r) if err != nil { return "", nil, err @@ -298,6 +338,7 @@ func (s *Server) ValidationTokenRequest(r *http.Request) (oauth2.GrantType, *oau tgr.Code == "" { return "", nil, errors.ErrInvalidRequest } + tgr.CodeVerifier = codeVer case oauth2.PasswordCredentials: tgr.Scope = r.FormValue("scope") username, password := r.FormValue("username"), r.FormValue("password") @@ -354,7 +395,8 @@ func (s *Server) GetAccessToken(ctx context.Context, gt oauth2.GrantType, tgr *o ti, err := s.Manager.GenerateAccessToken(ctx, gt, tgr) if err != nil { switch err { - case errors.ErrInvalidAuthorizeCode: + case errors.ErrInvalidAuthorizeCode, errors.ErrInvalidCodeChallenge, + errors.ErrMissingCodeChallenge, errors.ErrMissingCodeChallenge: return nil, errors.ErrInvalidGrant case errors.ErrInvalidClient: return nil, errors.ErrInvalidClient diff --git a/server/server_test.go b/server/server_test.go index f1c8ed1..d438996 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -23,6 +23,11 @@ var ( csrv *httptest.Server clientID = "111111" clientSecret = "11111111" + + plainChallenge = "plaintest" + s256Challenge = "s256test" + // echo s256test | sha256 | base64 | tr '/+' '_-' + s256ChallengeHash = "W6YWc_4yHwYN-cGDgGmOMHF3l7KDy7VcRjf7q2FVF-o=" ) func init() { @@ -105,6 +110,111 @@ func TestAuthorizeCode(t *testing.T) { Expect().Status(http.StatusOK) } +func TestAuthorizeCodeWithChallengePlain(t *testing.T) { + tsrv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + testServer(t, w, r) + })) + defer tsrv.Close() + + e := httpexpect.New(t, tsrv.URL) + + csrv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/oauth2": + r.ParseForm() + code, state := r.Form.Get("code"), r.Form.Get("state") + if state != "123" { + t.Error("unrecognized state:", state) + return + } + resObj := e.POST("/token"). + WithFormField("redirect_uri", csrv.URL+"/oauth2"). + WithFormField("code", code). + WithFormField("grant_type", "authorization_code"). + WithFormField("client_id", clientID). + WithFormField("code", code). + WithBasicAuth("code_verifier", "testchallenge"). + Expect(). + Status(http.StatusOK). + JSON().Object() + + t.Logf("%#v\n", resObj.Raw()) + + validationAccessToken(t, resObj.Value("access_token").String().Raw()) + } + })) + defer csrv.Close() + + manager.MapClientStorage(clientStore(csrv.URL)) + srv = server.NewDefaultServer(manager) + srv.SetUserAuthorizationHandler(func(w http.ResponseWriter, r *http.Request) (userID string, err error) { + userID = "000000" + return + }) + + e.GET("/authorize"). + WithQuery("response_type", "code"). + WithQuery("client_id", clientID). + WithQuery("scope", "all"). + WithQuery("state", "123"). + WithQuery("redirect_uri", url.QueryEscape(csrv.URL+"/oauth2")). + WithQuery("code_challenge", plainChallenge). + Expect().Status(http.StatusOK) +} + +func TestAuthorizeCodeWithChallengeS256(t *testing.T) { + tsrv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + testServer(t, w, r) + })) + defer tsrv.Close() + + e := httpexpect.New(t, tsrv.URL) + + csrv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/oauth2": + r.ParseForm() + code, state := r.Form.Get("code"), r.Form.Get("state") + if state != "123" { + t.Error("unrecognized state:", state) + return + } + resObj := e.POST("/token"). + WithFormField("redirect_uri", csrv.URL+"/oauth2"). + WithFormField("code", code). + WithFormField("grant_type", "authorization_code"). + WithFormField("client_id", clientID). + WithFormField("code", code). + WithBasicAuth("code_verifier", s256Challenge). + Expect(). + Status(http.StatusOK). + JSON().Object() + + t.Logf("%#v\n", resObj.Raw()) + + validationAccessToken(t, resObj.Value("access_token").String().Raw()) + } + })) + defer csrv.Close() + + manager.MapClientStorage(clientStore(csrv.URL)) + srv = server.NewDefaultServer(manager) + srv.SetUserAuthorizationHandler(func(w http.ResponseWriter, r *http.Request) (userID string, err error) { + userID = "000000" + return + }) + + e.GET("/authorize"). + WithQuery("response_type", "code"). + WithQuery("client_id", clientID). + WithQuery("scope", "all"). + WithQuery("state", "123"). + WithQuery("redirect_uri", url.QueryEscape(csrv.URL+"/oauth2")). + WithQuery("code_challenge", s256ChallengeHash). + WithQuery("code_challenge_method", "S256"). + Expect().Status(http.StatusOK) +} + func TestImplicit(t *testing.T) { tsrv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { testServer(t, w, r) From e543c8eb4534b165a9fbbee647b7f55576f1f721 Mon Sep 17 00:00:00 2001 From: Nikolas Sepos Date: Mon, 16 Nov 2020 11:12:24 +0200 Subject: [PATCH 2/3] fix minor issues --- manage/manager.go | 2 +- server/server.go | 12 ++++-------- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/manage/manager.go b/manage/manager.go index 1d48f77..74adf0e 100755 --- a/manage/manager.go +++ b/manage/manager.go @@ -265,7 +265,7 @@ func (m *Manager) validateCodeChallenge(ti oauth2.TokenInfo, ver string) error { return errors.ErrMissingCodeVerifier } if ver == "" { - return errors.New("missing code verifier") + return errors.ErrMissingCodeVerifier } ccm := ti.GetCodeChallengeMethod() if ccm.String() == "" { diff --git a/server/server.go b/server/server.go index 380f956..cbd9c84 100755 --- a/server/server.go +++ b/server/server.go @@ -168,10 +168,8 @@ func (s *Server) ValidationAuthorizeRequest(r *http.Request) (*AuthorizeRequest, if cc == "" && s.Config.ForcePKCE { return nil, errors.ErrCodeChallengeRquired } - if cc != "" { - if len(cc) < 43 || len(cc) > 128 { - return nil, errors.ErrInvalidCodeChallengeLen - } + if cc != "" && (len(cc) < 43 || len(cc) > 128) { + return nil, errors.ErrInvalidCodeChallengeLen } ccm := oauth2.CodeChallengeMethod(r.FormValue("code_challenge_method")) @@ -313,10 +311,8 @@ func (s *Server) ValidationTokenRequest(r *http.Request) (oauth2.GrantType, *oau } codeVer := r.FormValue("code_verifier") - if s.Config.ForcePKCE { - if codeVer == "" { - return "", nil, errors.ErrInvalidRequest - } + if s.Config.ForcePKCE && codeVer == "" { + return "", nil, errors.ErrInvalidRequest } clientID, clientSecret, err := s.ClientInfoHandler(r) From 7b9faad50f52e027875e94526fef55473d6c6cbe Mon Sep 17 00:00:00 2001 From: Nikolas Sepos Date: Sun, 20 Dec 2020 12:31:21 +0200 Subject: [PATCH 3/3] better challenge validation - removing padding before comparing base64(sha256) results - plaintest string to 43 chars --- const.go | 6 +++++- const_test.go | 7 +++++++ server/server_test.go | 2 +- 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/const.go b/const.go index 3309bc4..193e839 100644 --- a/const.go +++ b/const.go @@ -3,6 +3,7 @@ package oauth2 import ( "crypto/sha256" "encoding/base64" + "strings" ) // ResponseType the type of authorization request @@ -65,7 +66,10 @@ func (ccm CodeChallengeMethod) Validate(cc, ver string) bool { return cc == ver case CodeChallengeS256: s256 := sha256.Sum256([]byte(ver)) - return base64.URLEncoding.EncodeToString(s256[:]) == cc + // trim padding + a := strings.TrimRight(base64.URLEncoding.EncodeToString(s256[:]), "=") + b := strings.TrimRight(cc, "=") + return a == b default: return false } diff --git a/const_test.go b/const_test.go index 7dbdcea..5037b9a 100644 --- a/const_test.go +++ b/const_test.go @@ -19,3 +19,10 @@ func TestValidateS256(t *testing.T) { t.Fatal("not valid") } } + +func TestValidateS256NoPadding(t *testing.T) { + cc := oauth2.CodeChallengeS256 + if !cc.Validate("W6YWc_4yHwYN-cGDgGmOMHF3l7KDy7VcRjf7q2FVF-o", "s256test") { + t.Fatal("not valid") + } +} diff --git a/server/server_test.go b/server/server_test.go index d438996..b365f18 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -24,7 +24,7 @@ var ( clientID = "111111" clientSecret = "11111111" - plainChallenge = "plaintest" + plainChallenge = "ThisIsAFourtyThreeCharactersLongStringThing" s256Challenge = "s256test" // echo s256test | sha256 | base64 | tr '/+' '_-' s256ChallengeHash = "W6YWc_4yHwYN-cGDgGmOMHF3l7KDy7VcRjf7q2FVF-o="