Skip to content

Commit

Permalink
Merge pull request #163 from thegrumpylion/pkce
Browse files Browse the repository at this point in the history
implement PKCE for AuthorizationCode grant
  • Loading branch information
LyricTian authored Dec 24, 2020
2 parents 07c72de + 7b9faad commit f3419dd
Show file tree
Hide file tree
Showing 12 changed files with 382 additions and 79 deletions.
40 changes: 40 additions & 0 deletions const.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
package oauth2

import (
"crypto/sha256"
"encoding/base64"
"strings"
)

// ResponseType the type of authorization request
type ResponseType string

Expand Down Expand Up @@ -34,3 +40,37 @@ func (gt GrantType) String() string {
}
return ""
}

// CodeChallengeMethod PCKE method
type CodeChallengeMethod string

const (
// CodeChallengePlain PCKE Method
CodeChallengePlain CodeChallengeMethod = "plain"
// CodeChallengeS256 PCKE Method
CodeChallengeS256 CodeChallengeMethod = "S256"
)

func (ccm CodeChallengeMethod) String() string {
if ccm == CodeChallengePlain ||
ccm == CodeChallengeS256 {
return string(ccm)
}
return ""
}

// Validate code challenge
func (ccm CodeChallengeMethod) Validate(cc, ver string) bool {
switch ccm {
case CodeChallengePlain:
return cc == ver
case CodeChallengeS256:
s256 := sha256.Sum256([]byte(ver))
// trim padding
a := strings.TrimRight(base64.URLEncoding.EncodeToString(s256[:]), "=")
b := strings.TrimRight(cc, "=")
return a == b
default:
return false
}
}
28 changes: 28 additions & 0 deletions const_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package oauth2_test

import (
"testing"

"github.com/go-oauth2/oauth2/v4"
)

func TestValidatePlain(t *testing.T) {
cc := oauth2.CodeChallengePlain
if !cc.Validate("plaintest", "plaintest") {
t.Fatal("not valid")
}
}

func TestValidateS256(t *testing.T) {
cc := oauth2.CodeChallengeS256
if !cc.Validate("W6YWc_4yHwYN-cGDgGmOMHF3l7KDy7VcRjf7q2FVF-o=", "s256test") {
t.Fatal("not valid")
}
}

func TestValidateS256NoPadding(t *testing.T) {
cc := oauth2.CodeChallengeS256
if !cc.Validate("W6YWc_4yHwYN-cGDgGmOMHF3l7KDy7VcRjf7q2FVF-o", "s256test") {
t.Fatal("not valid")
}
}
3 changes: 3 additions & 0 deletions errors/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,7 @@ var (
ErrInvalidRefreshToken = errors.New("invalid refresh token")
ErrExpiredAccessToken = errors.New("expired access token")
ErrExpiredRefreshToken = errors.New("expired refresh token")
ErrMissingCodeVerifier = errors.New("missing code verifier")
ErrMissingCodeChallenge = errors.New("missing code challenge")
ErrInvalidCodeChallenge = errors.New("invalid code challenge")
)
69 changes: 39 additions & 30 deletions errors/response.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,42 +34,51 @@ func (r *Response) SetHeader(key, value string) {

// https://tools.ietf.org/html/rfc6749#section-5.2
var (
ErrInvalidRequest = errors.New("invalid_request")
ErrUnauthorizedClient = errors.New("unauthorized_client")
ErrAccessDenied = errors.New("access_denied")
ErrUnsupportedResponseType = errors.New("unsupported_response_type")
ErrInvalidScope = errors.New("invalid_scope")
ErrServerError = errors.New("server_error")
ErrTemporarilyUnavailable = errors.New("temporarily_unavailable")
ErrInvalidClient = errors.New("invalid_client")
ErrInvalidGrant = errors.New("invalid_grant")
ErrUnsupportedGrantType = errors.New("unsupported_grant_type")
ErrInvalidRequest = errors.New("invalid_request")
ErrUnauthorizedClient = errors.New("unauthorized_client")
ErrAccessDenied = errors.New("access_denied")
ErrUnsupportedResponseType = errors.New("unsupported_response_type")
ErrInvalidScope = errors.New("invalid_scope")
ErrServerError = errors.New("server_error")
ErrTemporarilyUnavailable = errors.New("temporarily_unavailable")
ErrInvalidClient = errors.New("invalid_client")
ErrInvalidGrant = errors.New("invalid_grant")
ErrUnsupportedGrantType = errors.New("unsupported_grant_type")
ErrCodeChallengeRquired = errors.New("invalid_request")
ErrUnsupportedCodeChallengeMethod = errors.New("invalid_request")
ErrInvalidCodeChallengeLen = errors.New("invalid_request")
)

// Descriptions error description
var Descriptions = map[error]string{
ErrInvalidRequest: "The request is missing a required parameter, includes an invalid parameter value, includes a parameter more than once, or is otherwise malformed",
ErrUnauthorizedClient: "The client is not authorized to request an authorization code using this method",
ErrAccessDenied: "The resource owner or authorization server denied the request",
ErrUnsupportedResponseType: "The authorization server does not support obtaining an authorization code using this method",
ErrInvalidScope: "The requested scope is invalid, unknown, or malformed",
ErrServerError: "The authorization server encountered an unexpected condition that prevented it from fulfilling the request",
ErrTemporarilyUnavailable: "The authorization server is currently unable to handle the request due to a temporary overloading or maintenance of the server",
ErrInvalidClient: "Client authentication failed",
ErrInvalidGrant: "The provided authorization grant (e.g., authorization code, resource owner credentials) or refresh token is invalid, expired, revoked, does not match the redirection URI used in the authorization request, or was issued to another client",
ErrUnsupportedGrantType: "The authorization grant type is not supported by the authorization server",
ErrInvalidRequest: "The request is missing a required parameter, includes an invalid parameter value, includes a parameter more than once, or is otherwise malformed",
ErrUnauthorizedClient: "The client is not authorized to request an authorization code using this method",
ErrAccessDenied: "The resource owner or authorization server denied the request",
ErrUnsupportedResponseType: "The authorization server does not support obtaining an authorization code using this method",
ErrInvalidScope: "The requested scope is invalid, unknown, or malformed",
ErrServerError: "The authorization server encountered an unexpected condition that prevented it from fulfilling the request",
ErrTemporarilyUnavailable: "The authorization server is currently unable to handle the request due to a temporary overloading or maintenance of the server",
ErrInvalidClient: "Client authentication failed",
ErrInvalidGrant: "The provided authorization grant (e.g., authorization code, resource owner credentials) or refresh token is invalid, expired, revoked, does not match the redirection URI used in the authorization request, or was issued to another client",
ErrUnsupportedGrantType: "The authorization grant type is not supported by the authorization server",
ErrCodeChallengeRquired: "PKCE is required. code_challenge is missing",
ErrUnsupportedCodeChallengeMethod: "Selected code_challenge_method not supported",
ErrInvalidCodeChallengeLen: "Code challenge length must be between 43 and 128 charachters long",
}

// StatusCodes response error HTTP status code
var StatusCodes = map[error]int{
ErrInvalidRequest: 400,
ErrUnauthorizedClient: 401,
ErrAccessDenied: 403,
ErrUnsupportedResponseType: 401,
ErrInvalidScope: 400,
ErrServerError: 500,
ErrTemporarilyUnavailable: 503,
ErrInvalidClient: 401,
ErrInvalidGrant: 401,
ErrUnsupportedGrantType: 401,
ErrInvalidRequest: 400,
ErrUnauthorizedClient: 401,
ErrAccessDenied: 403,
ErrUnsupportedResponseType: 401,
ErrInvalidScope: 400,
ErrServerError: 500,
ErrTemporarilyUnavailable: 503,
ErrInvalidClient: 401,
ErrInvalidGrant: 401,
ErrUnsupportedGrantType: 401,
ErrCodeChallengeRquired: 400,
ErrUnsupportedCodeChallengeMethod: 400,
ErrInvalidCodeChallengeLen: 400,
}
13 changes: 11 additions & 2 deletions example/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package main

import (
"context"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"io"
Expand Down Expand Up @@ -33,7 +35,9 @@ var (

func main() {
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
u := config.AuthCodeURL("xyz")
u := config.AuthCodeURL("xyz",
oauth2.SetAuthURLParam("code_challenge", genCodeChallengeS256("s256example")),
oauth2.SetAuthURLParam("code_challenge_method", "S256"))
http.Redirect(w, r, u, http.StatusFound)
})

Expand All @@ -49,7 +53,7 @@ func main() {
http.Error(w, "Code not found", http.StatusBadRequest)
return
}
token, err := config.Exchange(context.Background(), code)
token, err := config.Exchange(context.Background(), code, oauth2.SetAuthURLParam("code_verifier", "s256example"))
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
Expand Down Expand Up @@ -130,3 +134,8 @@ func main() {
log.Println("Client is running at 9094 port.Please open http://localhost:9094")
log.Fatal(http.ListenAndServe(":9094", nil))
}

func genCodeChallengeS256(s string) string {
s256 := sha256.Sum256([]byte(s))
return base64.URLEncoding.EncodeToString(s256[:])
}
21 changes: 12 additions & 9 deletions manage.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,18 @@ import (

// TokenGenerateRequest provide to generate the token request parameters
type TokenGenerateRequest struct {
ClientID string
ClientSecret string
UserID string
RedirectURI string
Scope string
Code string
Refresh string
AccessTokenExp time.Duration
Request *http.Request
ClientID string
ClientSecret string
UserID string
RedirectURI string
Scope string
Code string
CodeChallenge string
CodeChallengeMethod CodeChallengeMethod
Refresh string
CodeVerifier string
AccessTokenExp time.Duration
Request *http.Request
}

// Manager authorization management interface
Expand Down
29 changes: 29 additions & 0 deletions manage/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,10 @@ func (m *Manager) GenerateAuthToken(ctx context.Context, rt oauth2.ResponseType,
if exp := tgr.AccessTokenExp; exp > 0 {
ti.SetAccessExpiresIn(exp)
}
if tgr.CodeChallenge != "" {
ti.SetCodeChallenge(tgr.CodeChallenge)
ti.SetCodeChallengeMethod(tgr.CodeChallengeMethod)
}

tv, err := m.authorizeGenerate.Token(ctx, td)
if err != nil {
Expand Down Expand Up @@ -251,6 +255,28 @@ func (m *Manager) getAndDelAuthorizationCode(ctx context.Context, tgr *oauth2.To
return ti, nil
}

func (m *Manager) validateCodeChallenge(ti oauth2.TokenInfo, ver string) error {
cc := ti.GetCodeChallenge()
// early return
if cc == "" && ver == "" {
return nil
}
if cc == "" {
return errors.ErrMissingCodeVerifier
}
if ver == "" {
return errors.ErrMissingCodeVerifier
}
ccm := ti.GetCodeChallengeMethod()
if ccm.String() == "" {
ccm = oauth2.CodeChallengePlain
}
if !ccm.Validate(cc, ver) {
return errors.ErrInvalidCodeChallenge
}
return nil
}

// GenerateAccessToken generate the access token
func (m *Manager) GenerateAccessToken(ctx context.Context, gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) {
cli, err := m.GetClient(ctx, tgr.ClientID)
Expand All @@ -275,6 +301,9 @@ func (m *Manager) GenerateAccessToken(ctx context.Context, gt oauth2.GrantType,
if err != nil {
return nil, err
}
if err := m.validateCodeChallenge(ti, tgr.CodeVerifier); err != nil {
return nil, err
}
tgr.UserID = ti.GetUserID()
tgr.Scope = ti.GetScope()
if exp := ti.GetAccessExpiresIn(); exp > 0 {
Expand Down
4 changes: 4 additions & 0 deletions model.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ type (
SetCodeCreateAt(time.Time)
GetCodeExpiresIn() time.Duration
SetCodeExpiresIn(time.Duration)
GetCodeChallenge() string
SetCodeChallenge(string)
GetCodeChallengeMethod() CodeChallengeMethod
SetCodeChallengeMethod(CodeChallengeMethod)

GetAccess() string
SetAccess(string)
Expand Down
48 changes: 35 additions & 13 deletions models/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,21 @@ func NewToken() *Token {

// Token token model
type Token struct {
ClientID string `bson:"ClientID"`
UserID string `bson:"UserID"`
RedirectURI string `bson:"RedirectURI"`
Scope string `bson:"Scope"`
Code string `bson:"Code"`
CodeCreateAt time.Time `bson:"CodeCreateAt"`
CodeExpiresIn time.Duration `bson:"CodeExpiresIn"`
Access string `bson:"Access"`
AccessCreateAt time.Time `bson:"AccessCreateAt"`
AccessExpiresIn time.Duration `bson:"AccessExpiresIn"`
Refresh string `bson:"Refresh"`
RefreshCreateAt time.Time `bson:"RefreshCreateAt"`
RefreshExpiresIn time.Duration `bson:"RefreshExpiresIn"`
ClientID string `bson:"ClientID"`
UserID string `bson:"UserID"`
RedirectURI string `bson:"RedirectURI"`
Scope string `bson:"Scope"`
Code string `bson:"Code"`
CodeChallenge string `bson:"CodeChallenge"`
CodeChallengeMethod string `bson:"CodeChallengeMethod"`
CodeCreateAt time.Time `bson:"CodeCreateAt"`
CodeExpiresIn time.Duration `bson:"CodeExpiresIn"`
Access string `bson:"Access"`
AccessCreateAt time.Time `bson:"AccessCreateAt"`
AccessExpiresIn time.Duration `bson:"AccessExpiresIn"`
Refresh string `bson:"Refresh"`
RefreshCreateAt time.Time `bson:"RefreshCreateAt"`
RefreshExpiresIn time.Duration `bson:"RefreshExpiresIn"`
}

// New create to token model instance
Expand Down Expand Up @@ -103,6 +105,26 @@ func (t *Token) SetCodeExpiresIn(exp time.Duration) {
t.CodeExpiresIn = exp
}

// GetCodeChallenge challenge code
func (t *Token) GetCodeChallenge() string {
return t.CodeChallenge
}

// SetCodeChallenge challenge code
func (t *Token) SetCodeChallenge(code string) {
t.CodeChallenge = code
}

// GetCodeChallengeMethod challenge method
func (t *Token) GetCodeChallengeMethod() oauth2.CodeChallengeMethod {
return oauth2.CodeChallengeMethod(t.CodeChallengeMethod)
}

// SetCodeChallengeMethod challenge method
func (t *Token) SetCodeChallengeMethod(method oauth2.CodeChallengeMethod) {
t.CodeChallengeMethod = string(method)
}

// GetAccess access Token
func (t *Token) GetAccess() string {
return t.Access
Expand Down
Loading

0 comments on commit f3419dd

Please sign in to comment.