diff --git a/src/idpyoidc/message/oauth2/__init__.py b/src/idpyoidc/message/oauth2/__init__.py index 788fe8c5..066b33ef 100644 --- a/src/idpyoidc/message/oauth2/__init__.py +++ b/src/idpyoidc/message/oauth2/__init__.py @@ -303,6 +303,7 @@ class CCAccessTokenRequest(Message): "client_secret": SINGLE_OPTIONAL_STRING, "grant_type": SINGLE_REQUIRED_STRING, "scope": OPTIONAL_LIST_OF_SP_SEP_STRINGS, + "resource": OPTIONAL_LIST_OF_STRINGS, } def verify(self, **kwargs): diff --git a/src/idpyoidc/server/authz/__init__.py b/src/idpyoidc/server/authz/__init__.py index 7752e531..5dfb5991 100755 --- a/src/idpyoidc/server/authz/__init__.py +++ b/src/idpyoidc/server/authz/__init__.py @@ -76,9 +76,9 @@ def __call__( else: setattr(grant, key, val) - if resources is None: + if resources is None and (grant.resources is None or len(grant.resources) == 0): grant.resources = [_client_id] - else: + elif resources is not None: grant.resources = resources # Scope handling. If allowed scopes are defined for the client filter using that diff --git a/src/idpyoidc/server/oauth2/token_helper/client_credentials.py b/src/idpyoidc/server/oauth2/token_helper/client_credentials.py index fa3db7cd..c23e327c 100755 --- a/src/idpyoidc/server/oauth2/token_helper/client_credentials.py +++ b/src/idpyoidc/server/oauth2/token_helper/client_credentials.py @@ -2,12 +2,16 @@ from typing import Optional from typing import Union +from idpyoidc.exception import ImproperlyConfigured from idpyoidc.message import Message +from idpyoidc.message.oauth2 import TokenErrorResponse from idpyoidc.message.oauth2 import CCAccessTokenRequest from idpyoidc.time_util import utc_time_sans_frac +from idpyoidc.util import importer from idpyoidc.util import sanitize from . import TokenEndpointHelper +from . import validate_resource_indicators_policy logger = logging.getLogger(__name__) @@ -22,7 +26,6 @@ def process_request(self, req: Union[Message, dict], **kwargs): logger.debug("Client credentials flow") # verify the client and the user - client_id = req["client_id"] _authenticated = req.get("authenticated", False) if not _authenticated: @@ -45,11 +48,33 @@ def process_request(self, req: Union[Message, dict], **kwargs): branch_id = _mngr.add_grant(["client_credentials", client_id]) _session_info = _mngr.get_session_info(branch_id) + _cinfo = _context.cdb.get(client_id) + + if "resource_indicators" in _cinfo and "client_credentials" in _cinfo["resource_indicators"]: + resource_indicators_config = _cinfo["resource_indicators"]["client_credentials"] + else: + resource_indicators_config = self.endpoint.kwargs.get("resource_indicators", None) + + if resource_indicators_config is not None: + if "policy" not in resource_indicators_config: + policy = {"policy": {"function": validate_resource_indicators_policy}} + resource_indicators_config.update(policy) + + req = self._enforce_resource_indicators_policy(req, resource_indicators_config) + + if isinstance(req, TokenErrorResponse): + return req + _grant = _session_info["grant"] token_type = "Bearer" _allowed = _context.cdb[client_id].get("allowed_scopes", []) + resources = req.get("resource", None) + if resources: + token_args = {"resources": resources} + else: + token_args = None access_token = self._mint_token( token_class="access_token", grant=_grant, @@ -58,6 +83,7 @@ def process_request(self, req: Union[Message, dict], **kwargs): based_on=None, scope=_allowed, token_type=token_type, + token_args=token_args, ) _resp = { @@ -77,3 +103,24 @@ def post_parse_request( request = CCAccessTokenRequest(**request.to_dict()) logger.debug("%s: %s" % (request.__class__.__name__, sanitize(request))) return request + + def _enforce_resource_indicators_policy(self, request, config): + _context = self.endpoint.upstream_get("context") + + policy = config["policy"] + function = policy["function"] + kwargs = policy.get("kwargs", {}) + + if isinstance(function, str): + try: + fn = importer(function) + except Exception: + raise ImproperlyConfigured(f"Error importing {function} policy function") + else: + fn = function + try: + return fn(request, context=_context, **kwargs) + except Exception as e: + logger.error(f"Error while executing the {fn} policy function: {e}") + return self.error_cls(error="server_error", error_description="Internal server error") +