diff --git a/server/handler.go b/server/handler.go index e0d5d32..a5fd139 100755 --- a/server/handler.go +++ b/server/handler.go @@ -27,6 +27,9 @@ type ( // RefreshingScopeHandler check the scope of the refreshing token RefreshingScopeHandler func(newScope, oldScope string) (allowed bool, err error) + //RefreshingValidationHandler check if refresh_token is still valid. eg no revocation or other + RefreshingValidationHandler func(ti oauth2.TokenInfo) (allowed bool, err error) + // ResponseErrorHandler response error handing ResponseErrorHandler func(re *errors.Response) diff --git a/server/server.go b/server/server.go index ca1cd94..9aec8f2 100755 --- a/server/server.go +++ b/server/server.go @@ -47,6 +47,7 @@ type Server struct { ClientScopeHandler ClientScopeHandler UserAuthorizationHandler UserAuthorizationHandler PasswordAuthorizationHandler PasswordAuthorizationHandler + RefreshingValidationHandler RefreshingValidationHandler RefreshingScopeHandler RefreshingScopeHandler ResponseErrorHandler ResponseErrorHandler InternalErrorHandler InternalErrorHandler @@ -392,6 +393,22 @@ func (s *Server) GetAccessToken(ctx context.Context, gt oauth2.GrantType, tgr *o } } + if validationFn := s.RefreshingValidationHandler; validationFn != nil { + rti, err := s.Manager.LoadRefreshToken(ctx, tgr.Refresh) + if err != nil { + if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken { + return nil, errors.ErrInvalidGrant + } + return nil, err + } + allowed, err := validationFn(rti) + if err != nil { + return nil, err + } else if !allowed { + return nil, errors.ErrInvalidScope + } + } + ti, err := s.Manager.RefreshAccessToken(ctx, tgr) if err != nil { if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken { diff --git a/server/server_config.go b/server/server_config.go index d9b740d..e0d7d1d 100644 --- a/server/server_config.go +++ b/server/server_config.go @@ -54,6 +54,12 @@ func (s *Server) SetRefreshingScopeHandler(handler RefreshingScopeHandler) { s.RefreshingScopeHandler = handler } +// SetRefreshingValidationHandler check if refresh_token is still valid. eg no revocation or other +func (s *Server) SetRefreshingValidationHandler(handler RefreshingValidationHandler) { + s.RefreshingValidationHandler = handler +} + + // SetResponseErrorHandler response error handling func (s *Server) SetResponseErrorHandler(handler ResponseErrorHandler) { s.ResponseErrorHandler = handler