Skip to content

Commit

Permalink
[#7272] Add custom bearer token feature and other fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Carlos Cruz authored and brondsem committed May 23, 2024
1 parent 5630217 commit 8a69cd0
Show file tree
Hide file tree
Showing 13 changed files with 358 additions and 236 deletions.
78 changes: 75 additions & 3 deletions Allura/allura/controllers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from __future__ import annotations

import logging
import json
import os
from base64 import b32encode
from datetime import datetime
Expand Down Expand Up @@ -63,6 +64,8 @@
from allura.controllers import BaseController
from allura.tasks.mail_tasks import send_system_mail_to_user
import six
import oauthlib.oauth2


log = logging.getLogger(__name__)

Expand Down Expand Up @@ -116,6 +119,9 @@ def __init__(self):
self.subscriptions = SubscriptionsController()
self.oauth = OAuthController()

if asbool(config.get('auth.oauth2.enabled', False)):
self.oauth2 = OAuth2AuthorizationController()

if asbool(config.get('auth.allow_user_to_disable_account', False)):
self.disable = DisableAccountController()

Expand Down Expand Up @@ -1346,12 +1352,12 @@ def index(self, **kw):
access_tokens.append({
'type': 2,
'app': M.OAuth2ClientApp.query.get(client_id=oauth2_tok.client_id),
# TODO personal bearer tokens:
'is_bearer': False,
'api_key': None,
'is_bearer': oauth2_tok.is_bearer,
'api_key': oauth2_tok.access_token if oauth2_tok.is_bearer else None,
'last_access': oauth2_tok.last_access,
'_id': oauth2_tok._id,
})

# include auth codes too, but only if they're not already listed via an access/refresh token
for oauth2_auth in M.OAuth2AuthorizationCode.query.find({'user_id': c.user._id}):
client_app = M.OAuth2ClientApp.query.get(client_id=oauth2_auth.client_id)
Expand Down Expand Up @@ -1489,6 +1495,22 @@ def _check_revoke_perm(self, access_token):
flash('Invalid token ID', 'error')
redirect('.')

@expose()
@require_post()
def generate_bearer_token(self, client_id):
"""
Manually generates an OAuth2 access token without needing to go through the OAuth2 flow.
"""
M.OAuth2AccessToken(
client_id=client_id,
user_id=c.user._id,
access_token=h.nonce(40),
is_bearer=True,
expires_at=datetime.max
)

redirect('.')

@expose()
@require_post()
def revoke_access_token(self, _id):
Expand Down Expand Up @@ -1526,6 +1548,56 @@ def revoke_access_token2authcode(self, _id):
flash('Authorization revoked')
redirect('.')

class OAuth2AuthorizationController(BaseController):
def _check_security(self):
require_authenticated()

@property
def server(self):
from allura.controllers.rest import Oauth2Validator
return oauthlib.oauth2.WebApplicationServer(Oauth2Validator())

@expose('jinja:allura:templates/oauth2_authorize.html')
@without_trailing_slash
def authorize(self, **kwargs):
json_body = None
if request.body:
# We need to decode the request body and convert it to a dict because Turbogears creates it as bytes
# and oauthlib will treat it as x-www-form-urlencoded format.
decoded_body = str(request.body, 'utf-8')
json_body = json.loads(decoded_body)

scopes, credentials = self.server.validate_authorization_request(uri=request.url, http_method=request.method, headers=request.headers, body=json_body)

client_id = request.params.get('client_id')
client = M.OAuth2ClientApp.query.get(client_id=client_id)

# The credentials object has a request object that it's too big to be serialized,
# so we remove it because we don't need it for the rest of the authorization workflow
del credentials['request']

return dict(client=client, credentials=json.dumps(credentials))

@expose()
@require_post()
def do_authorize(self, yes=None, no=None):
client_id = request.params['client_id']
client = M.OAuth2ClientApp.query.get(client_id=client_id)

if no:
flash(f'{client.name} NOT AUTHORIZED', 'error')
redirect('/auth/oauth/')

credentials = json.loads(request.params['credentials'])
headers, body, status = self.server.create_authorization_response(
uri=request.url, http_method=request.method, body=request.body, headers=request.headers, scopes=[], credentials=credentials
)

response.status_int = status
response.headers.update(headers)
return body


class DisableAccountController(BaseController):

def _check_security(self):
Expand Down
57 changes: 7 additions & 50 deletions Allura/allura/controllers/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def dummy_access_token(self) -> str:


class Oauth2Validator(oauthlib.oauth2.RequestValidator):
def validate_client_id(self, client_id: str, request: oauthlib.common.Request) -> bool:
def validate_client_id(self, client_id: str, request: oauthlib.common.Request, *args, **kwargs) -> bool:
return M.OAuth2ClientApp.query.get(client_id=client_id) is not None

def validate_redirect_uri(self, client_id, redirect_uri, request, *args, **kwargs):
Expand All @@ -270,7 +270,7 @@ def validate_scopes(self, client_id: str, scopes, client: oauthlib.oauth2.Client
return True

def validate_grant_type(self, client_id: str, grant_type: str, client: oauthlib.oauth2.Client, request: oauthlib.common.Request, *args, **kwargs) -> bool:
return grant_type in ['authorization_code', 'refresh_token', 'client_credentials']
return grant_type in ['authorization_code', 'refresh_token']

def get_default_scopes(self, client_id: str, request: oauthlib.common.Request, *args, **kwargs):
return []
Expand Down Expand Up @@ -313,7 +313,7 @@ def validate_bearer_token(self, token: str, scopes: list[str], request: oauthlib
return False

def validate_refresh_token(self, refresh_token: str, client: oauthlib.oauth2.Client, request: oauthlib.common.Request, *args, **kwargs) -> bool:
return M.OAuth2AccessToken.query.get(refresh_token=refresh_token) is not None
return M.OAuth2AccessToken.query.get(refresh_token=refresh_token, client_id=client.client_id) is not None

def confirm_redirect_uri(self, client_id: str, code: str, redirect_uri: str, client: oauthlib.oauth2.Client, request: oauthlib.common.Request, *args, **kwargs) -> bool:
# This method is called when the client is exchanging the authorization code for an access token.
Expand All @@ -322,11 +322,11 @@ def confirm_redirect_uri(self, client_id: str, code: str, redirect_uri: str, cli
return authorization.redirect_uri == redirect_uri

def save_authorization_code(self, client_id: str, code, request: oauthlib.common.Request, *args, **kwargs) -> None:
authorization = M.OAuth2AuthorizationCode.query.get(client_id=client_id, user_id=c.user._id, authorization_code=code['code'])
authorization = M.OAuth2AuthorizationCode.query.get(client_id=client_id, user_id=c.user._id)

# Remove the existing authorization code if it exists and create a new record
if authorization:
M.OAuth2AuthorizationCode.query.remove({'client_id': client_id, 'user_id': c.user._id, 'authorization_code': code['code']})
M.OAuth2AuthorizationCode.query.remove({'client_id': client_id, 'user_id': c.user._id})

log.info('Saving authorization code for client: %s', client_id)
auth_code = M.OAuth2AuthorizationCode(
Expand All @@ -347,10 +347,10 @@ def save_bearer_token(self, token, request: oauthlib.common.Request, *args, **kw
elif request.grant_type == 'refresh_token':
user_id = M.OAuth2AccessToken.query.get(client_id=request.client_id, refresh_token=request.refresh_token).user_id

current_token = M.OAuth2AccessToken.query.get(client_id=request.client_id, user_id=user_id)
current_token = M.OAuth2AccessToken.query.get(client_id=request.client_id, user_id=user_id, is_bearer=False)

if current_token:
M.OAuth2AccessToken.query.remove({'client_id': request.client_id, 'user_id': user_id})
M.OAuth2AccessToken.query.remove({'client_id': request.client_id, 'user_id': user_id, 'is_bearer': False})

bearer_token = M.OAuth2AccessToken(
client_id=request.client_id,
Expand Down Expand Up @@ -520,49 +520,6 @@ def _authenticate(self):
token.last_access = datetime.utcnow()
return token

@expose('jinja:allura:templates/oauth2_authorize.html')
@without_trailing_slash
def authorize(self, **kwargs):
security.require_authenticated()
json_body = None
if request.body:
# We need to decode the request body and convert it to a dict because Turbogears creates it as bytes
# and oauthlib will treat it as x-www-form-urlencoded format.
decoded_body = str(request.body, 'utf-8')
json_body = json.loads(decoded_body)

scopes, credentials = self.server.validate_authorization_request(uri=request.url, http_method=request.method, headers=request.headers, body=json_body)

client_id = request.params.get('client_id')
client = M.OAuth2ClientApp.query.get(client_id=client_id)

# The credentials object has a request object that it's too big to be serialized,
# so we remove it because we don't need it for the rest of the authorization workflow
del credentials['request']

return dict(client=client, credentials=json.dumps(credentials))

@expose('jinja:allura:templates/oauth2_authorize_ok.html')
@require_post()
def do_authorize(self, yes=None, no=None):
security.require_authenticated()

client_id = request.params['client_id']
client = M.OAuth2ClientApp.query.get(client_id=client_id)

if no:
flash(f'{client.name} NOT AUTHORIZED', 'error')
redirect('/auth/oauth/')

credentials = json.loads(request.params['credentials'])
headers, body, status = self.server.create_authorization_response(
uri=request.url, http_method=request.method, body=request.body, headers=request.headers, scopes=[], credentials=credentials
)

response.status_int = status
response.headers.update(headers)
return body

@expose('json:')
@require_post()
def token(self, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion Allura/allura/lib/custom_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ def __call__(self, environ, start_response):
srcs += ' ' + ' '.join(environ['csp_form_actions'])

oauth_endpoints = (
'/rest/oauth2/authorize', '/rest/oauth2/do_authorize', '/rest/oauth/authorize', '/rest/oauth/do_authorize')
'/auth/oauth2/authorize', '/auth/oauth2/do_authorize', '/rest/oauth/authorize', '/rest/oauth/do_authorize')
if not req.path.startswith(oauth_endpoints): # Do not enforce CSP for OAuth1 and OAuth2 authorization
if asbool(self.config.get('csp.form_actions_enforce', False)):
rules.add(f"form-action {srcs}")
Expand Down
19 changes: 16 additions & 3 deletions Allura/allura/model/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,10 @@ class OAuth2ClientApp(MappedClass):
class __mongometa__:
session = main_orm_session
name = 'oauth2_client_app'
unique_indexes = [('client_id', 'user_id')]
unique_indexes = [
('client_id', 'user_id'),
('client_id')
]
indexes = [
('user_id'),
]
Expand Down Expand Up @@ -194,7 +197,10 @@ class OAuth2AuthorizationCode(MappedClass):
class __mongometa__:
session = main_orm_session
name = 'oauth2_authorization_code'
unique_indexes = [('authorization_code', 'client_id', 'user_id')]
unique_indexes = [
('authorization_code', 'client_id', 'user_id'),
('authorization_code')
]
indexes = [
('user_id'),
]
Expand All @@ -219,7 +225,13 @@ class OAuth2AccessToken(MappedClass):
class __mongometa__:
session = main_orm_session
name = 'oauth2_access_token'
unique_indexes = [('access_token', 'client_id', 'user_id')]
unique_indexes = [
('access_token', 'client_id', 'user_id'),
('access_token')
]
custom_indexes = [
dict(fields=('refresh_token',), partialFilterExpression={'refresh_token': {'$gt': None}}, unique=True),
]
indexes = [
('user_id'),
]
Expand All @@ -232,6 +244,7 @@ class __mongometa__:
scopes = FieldProperty([str])
access_token = FieldProperty(str)
refresh_token = FieldProperty(str)
is_bearer = FieldProperty(bool, if_missing=False)
expires_at = FieldProperty(S.DateTime)
last_access = FieldProperty(datetime)

Expand Down
Loading

0 comments on commit 8a69cd0

Please sign in to comment.