Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix support for resource indicator #102

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/idpyoidc/message/oauth2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ class CCAccessTokenRequest(Message):
"client_secret": SINGLE_OPTIONAL_STRING,
"grant_type": SINGLE_REQUIRED_STRING,
"scope": OPTIONAL_LIST_OF_SP_SEP_STRINGS,
"resource": OPTIONAL_LIST_OF_STRINGS,
}

def verify(self, **kwargs):
Expand Down
4 changes: 2 additions & 2 deletions src/idpyoidc/server/authz/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ def __call__(
else:
setattr(grant, key, val)

if resources is None:
if resources is None and (grant.resources is None or len(grant.resources) == 0):
grant.resources = [_client_id]
else:
elif resources is not None:
grant.resources = resources

# Scope handling. If allowed scopes are defined for the client filter using that
Expand Down
49 changes: 48 additions & 1 deletion src/idpyoidc/server/oauth2/token_helper/client_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@
from typing import Optional
from typing import Union

from idpyoidc.exception import ImproperlyConfigured
from idpyoidc.message import Message
from idpyoidc.message.oauth2 import TokenErrorResponse
from idpyoidc.message.oauth2 import CCAccessTokenRequest
from idpyoidc.time_util import utc_time_sans_frac
from idpyoidc.util import importer
from idpyoidc.util import sanitize

from . import TokenEndpointHelper
from . import validate_resource_indicators_policy

logger = logging.getLogger(__name__)

Expand All @@ -22,7 +26,6 @@ def process_request(self, req: Union[Message, dict], **kwargs):
logger.debug("Client credentials flow")

# verify the client and the user

client_id = req["client_id"]
_authenticated = req.get("authenticated", False)
if not _authenticated:
Expand All @@ -45,11 +48,33 @@ def process_request(self, req: Union[Message, dict], **kwargs):
branch_id = _mngr.add_grant(["client_credentials", client_id])
_session_info = _mngr.get_session_info(branch_id)

_cinfo = _context.cdb.get(client_id)

if "resource_indicators" in _cinfo and "client_credentials" in _cinfo["resource_indicators"]:
resource_indicators_config = _cinfo["resource_indicators"]["client_credentials"]
else:
resource_indicators_config = self.endpoint.kwargs.get("resource_indicators", None)

if resource_indicators_config is not None:
if "policy" not in resource_indicators_config:
policy = {"policy": {"function": validate_resource_indicators_policy}}
resource_indicators_config.update(policy)

req = self._enforce_resource_indicators_policy(req, resource_indicators_config)

if isinstance(req, TokenErrorResponse):
return req

_grant = _session_info["grant"]

token_type = "Bearer"

_allowed = _context.cdb[client_id].get("allowed_scopes", [])
resources = req.get("resource", None)
if resources:
token_args = {"resources": resources}
else:
token_args = None
access_token = self._mint_token(
token_class="access_token",
grant=_grant,
Expand All @@ -58,6 +83,7 @@ def process_request(self, req: Union[Message, dict], **kwargs):
based_on=None,
scope=_allowed,
token_type=token_type,
token_args=token_args,
)

_resp = {
Expand All @@ -77,3 +103,24 @@ def post_parse_request(
request = CCAccessTokenRequest(**request.to_dict())
logger.debug("%s: %s" % (request.__class__.__name__, sanitize(request)))
return request

def _enforce_resource_indicators_policy(self, request, config):
_context = self.endpoint.upstream_get("context")

policy = config["policy"]
function = policy["function"]
kwargs = policy.get("kwargs", {})

if isinstance(function, str):
try:
fn = importer(function)
except Exception:
raise ImproperlyConfigured(f"Error importing {function} policy function")
else:
fn = function
try:
return fn(request, context=_context, **kwargs)
except Exception as e:
logger.error(f"Error while executing the {fn} policy function: {e}")
return self.error_cls(error="server_error", error_description="Internal server error")

Loading