diff --git a/example/flask_op/private/cookie_jwks.json b/example/flask_op/private/cookie_jwks.json index 528741ff..ab562780 100644 --- a/example/flask_op/private/cookie_jwks.json +++ b/example/flask_op/private/cookie_jwks.json @@ -1 +1 @@ -{"keys": [{"kty": "oct", "use": "enc", "kid": "enc", "k": "GCizp3ewVRV0VZEef3VQwFve7n2QwAFI"}, {"kty": "oct", "use": "sig", "kid": "sig", "k": "QC2JxpVJXPDMpYv_h76jIrt_lA1P4KSu"}]} \ No newline at end of file +{"keys": [{"kty": "oct", "use": "enc", "kid": "enc", "k": "kB4Z6SmpYe0URDyigVNHi25PeZ1MPG_B"}, {"kty": "oct", "use": "sig", "kid": "sig", "k": "1HMtbIxBQISnRkLSnyf_wJOJ0SzYp6pC"}]} \ No newline at end of file diff --git a/src/idpyoidc/client/oauth2/add_on/dpop.py b/src/idpyoidc/client/oauth2/add_on/dpop.py index a4ad740d..6c311539 100644 --- a/src/idpyoidc/client/oauth2/add_on/dpop.py +++ b/src/idpyoidc/client/oauth2/add_on/dpop.py @@ -8,6 +8,8 @@ from cryptojwt.jws.jws import factory from cryptojwt.key_bundle import key_by_alg +from idpyoidc.client.client_auth import BearerHeader +from idpyoidc.client.client_auth import find_token_info from idpyoidc.client.service_context import ServiceContext from idpyoidc.message import SINGLE_OPTIONAL_STRING from idpyoidc.message import SINGLE_REQUIRED_INT @@ -159,7 +161,7 @@ def dpop_header( return headers -def add_support(services, dpop_signing_alg_values_supported): +def add_support(services, dpop_signing_alg_values_supported, with_dpop_header=None): """ Add the necessary pieces to make pushed authorization happen. @@ -181,7 +183,52 @@ def add_support(services, dpop_signing_alg_values_supported): _service.construct_extra_headers.append(dpop_header) - # The same for userinfo requests - _userinfo_service = services.get("userinfo") - if _userinfo_service: - _userinfo_service.construct_extra_headers.append(dpop_header) + # To be backward compatible + if with_dpop_header is None: + with_dpop_header = ["userinfo"] + + # Add dpop HTTP header to these + for _srv in with_dpop_header: + if _srv == "accesstoken": + continue + _service = services.get(_srv) + if _service: + _service.construct_extra_headers.append(dpop_header) + +class DPoPClientAuth(BearerHeader): + tag = "dpop_client_auth" + + def construct(self, request=None, service=None, http_args=None, **kwargs): + """ + Constructing the Authorization header. The value of + the Authorization header is "Bearer ". + + :param request: Request class instance + :param service: The service this authentication method applies to. + :param http_args: HTTP header arguments + :param kwargs: extra keyword arguments + :return: + """ + + _token_type = "access_token" + + _token_info = find_token_info(request, _token_type, service, **kwargs) + + if not _token_info: + raise KeyError("No bearer token available") + + # The authorization value starts with the token_type + # if _token_info["token_type"].to_lower() != "bearer": + _bearer = f"DPoP {_token_info[_token_type]}" + + # Add 'Authorization' to the headers + if http_args is None: + http_args = {"headers": {}} + http_args["headers"]["Authorization"] = _bearer + else: + try: + http_args["headers"]["Authorization"] = _bearer + except KeyError: + http_args["headers"] = {"Authorization": _bearer} + + return http_args diff --git a/src/idpyoidc/client/oauth2/add_on/par.py b/src/idpyoidc/client/oauth2/add_on/par.py index db358e8f..d2216780 100644 --- a/src/idpyoidc/client/oauth2/add_on/par.py +++ b/src/idpyoidc/client/oauth2/add_on/par.py @@ -1,19 +1,20 @@ import logging +from typing import Union from cryptojwt import JWT from cryptojwt.utils import importer +from requests import request from idpyoidc.client.client_auth import CLIENT_AUTHN_METHOD from idpyoidc.message import Message from idpyoidc.message.oauth2 import JWTSecuredAuthorizationRequest from idpyoidc.server.util import execute from idpyoidc.util import instantiate -from requests import request logger = logging.getLogger(__name__) -def push_authorization(request_args, service, **kwargs): +def get_request_parameters(request_args: Union[dict, Message], service, **kwargs) -> dict: """ :param request_args: All the request arguments as a AuthorizationRequest instance :param service: The service to which this post construct method is applied. @@ -26,20 +27,21 @@ def push_authorization(request_args, service, **kwargs): logger.debug(f"PAR kwargs: {kwargs}") if method_args["apply"] is False: - return request_args + return {"request_args": request_args} _http_method = method_args["http_client"] _httpc_params = service.upstream_get("unit").httpc_params + logger.debug(f"httpc_params: {_httpc_params}") # Add client authentication if needed _headers = {} authn_method = method_args["authn_method"] if authn_method: + _name = "" if isinstance(authn_method, str): if authn_method not in _context.client_authn_methods: _context.client_authn_methods[authn_method] = CLIENT_AUTHN_METHOD[authn_method]() else: - _name = "" for _name, spec in authn_method.items(): if _name not in _context.client_authn_methods: _context.client_authn_methods[_name] = execute(spec) @@ -48,10 +50,10 @@ def push_authorization(request_args, service, **kwargs): _args = {} if _context.issuer: _args["iss"] = _context.issuer - if _name == "client_attestation": - _wia = kwargs.get("client_attestation") + if _name == 'client_authentication_attestation': + _wia = kwargs.get('wallet_instance_attestation') if _wia: - _args["client_attestation"] = _wia + _args["attestation"] = _wia _headers = service.get_headers( request_args, http_method=_http_method, authn_method=authn_method, **_args @@ -60,6 +62,8 @@ def push_authorization(request_args, service, **kwargs): # construct the message body if method_args["body_format"] == "urlencoded": + if isinstance(request_args, dict): + request_args = Message(**request_args) _body = request_args.to_urlencoded() else: _jwt = JWT( @@ -74,13 +78,29 @@ def push_authorization(request_args, service, **kwargs): _body = _msg.to_urlencoded() + return { + "http_method": _http_method, + "body": _body, + "headers": _headers, + "httpc_params": _httpc_params + } + + +def push_authorization(request_args, service, **kwargs): + _req_info = get_request_parameters(request_args=request_args, service=service, **kwargs) + if "request_args" in _req_info: + return _req_info["request_args"] + + _context = service.upstream_get("context") + # Send it to the Pushed Authorization Request Endpoint using POST - resp = _http_method( + kwargs = _req_info.get("httpc_params", {}) + resp = _req_info["http_method"]( method="POST", url=_context.provider_info["pushed_authorization_request_endpoint"], - data=_body, - headers=_headers, - **_httpc_params + data=_req_info["body"], + headers=_req_info["headers"], + **kwargs ) if resp.status_code == 200: @@ -99,12 +119,12 @@ def push_authorization(request_args, service, **kwargs): def add_support( - services, - body_format="jws", - signing_algorithm="RS256", - http_client=None, - merge_rule="strict", - authn_method="", + services, + body_format="jws", + signing_algorithm="RS256", + http_client=None, + merge_rule="strict", + authn_method="", ): """ Add the necessary pieces to support Pushed authorization. diff --git a/src/idpyoidc/client/oauth2/authorization.py b/src/idpyoidc/client/oauth2/authorization.py index 9d85f1fd..05300d8a 100644 --- a/src/idpyoidc/client/oauth2/authorization.py +++ b/src/idpyoidc/client/oauth2/authorization.py @@ -30,8 +30,8 @@ class Authorization(Service): response_body_type = "urlencoded" _supports = { - "response_types_supported": ["code"], - "response_modes_supported": ["query", "fragment"], + "response_types": ["code"], + "response_modes": ["query", "fragment"], } _callback_path = { diff --git a/src/idpyoidc/client/oauth2/utils.py b/src/idpyoidc/client/oauth2/utils.py index 254e1bd2..41421ea8 100644 --- a/src/idpyoidc/client/oauth2/utils.py +++ b/src/idpyoidc/client/oauth2/utils.py @@ -4,7 +4,6 @@ from idpyoidc.client.defaults import DEFAULT_RESPONSE_MODE from idpyoidc.client.service import Service -from idpyoidc.exception import MissingParameter from idpyoidc.exception import MissingRequiredAttribute from idpyoidc.message import Message @@ -13,22 +12,14 @@ def get_state_parameter(request_args, kwargs): """Find a state value from a set of possible places.""" - try: - _state = kwargs["state"] - except KeyError: - try: - _state = request_args["state"] - except KeyError: - raise MissingParameter("state") - - return _state + return kwargs.get("state", request_args.get("state", None)) def pick_redirect_uri( - context, - request_args: Optional[Union[Message, dict]] = None, - response_type: Optional[str] = "", - response_mode: Optional[str] = "", + context, + request_args: Optional[Union[Message, dict]] = None, + response_type: Optional[str] = "", + response_mode: Optional[str] = "", ): if request_args is None: request_args = {} @@ -87,7 +78,7 @@ def pick_redirect_uri( def pre_construct_pick_redirect_uri( - request_args: Optional[Union[Message, dict]] = None, service: Optional[Service] = None, **kwargs + request_args: Optional[Union[Message, dict]] = None, service: Optional[Service] = None, **kwargs ): request_args["redirect_uri"] = pick_redirect_uri( service.upstream_get("context"), request_args=request_args @@ -97,5 +88,9 @@ def pre_construct_pick_redirect_uri( def set_state_parameter(request_args=None, **kwargs): """Assigned a state value.""" - request_args["state"] = get_state_parameter(request_args, kwargs) - return request_args, {"state": request_args["state"]} + _state = get_state_parameter(request_args, kwargs) + if _state: + request_args["state"] = _state + return request_args, {"state": request_args["state"]} + else: + return request_args, {} diff --git a/src/idpyoidc/client/oidc/userinfo.py b/src/idpyoidc/client/oidc/userinfo.py index 05fce76b..c2bc2ba8 100644 --- a/src/idpyoidc/client/oidc/userinfo.py +++ b/src/idpyoidc/client/oidc/userinfo.py @@ -110,7 +110,9 @@ def post_parse_response(self, response, **kwargs): return response def gather_verify_arguments( - self, response: Optional[Union[dict, Message]] = None, behaviour_args: Optional[dict] = None + self, + response: Optional[Union[dict, Message]] = None, + behaviour_args: Optional[dict] = None ): """ Need to add some information before running verify() diff --git a/src/idpyoidc/client/service.py b/src/idpyoidc/client/service.py index 231d56b0..674de23a 100644 --- a/src/idpyoidc/client/service.py +++ b/src/idpyoidc/client/service.py @@ -88,6 +88,7 @@ def __init__( self.client_authn_methods = {} if conf: + LOGGER.debug(f"Service config: {conf}") self.conf = conf for param in [ "msg_type", @@ -110,8 +111,8 @@ def __init__( if _client_authn_methods: self.client_authn_methods = client_auth_setup(method_to_item(_client_authn_methods)) - if self.default_authn_method: - if self.default_authn_method not in self.client_authn_methods: + if not self.client_authn_methods: + if self.default_authn_method: self.client_authn_methods[self.default_authn_method] = single_authn_setup( self.default_authn_method, None ) diff --git a/src/idpyoidc/configure.py b/src/idpyoidc/configure.py index 5258b576..1418b65f 100644 --- a/src/idpyoidc/configure.py +++ b/src/idpyoidc/configure.py @@ -116,10 +116,17 @@ def __init__( add_base_path(conf, base_path, self._dir_attributes, "dir") # entity info - self.domain = domain or conf.get("domain", "127.0.0.1") - self.port = port or conf.get("port", 80) + if domain != "": + self.domain = domain + else: + self.domain = conf.get("domain", "127.0.0.1") + if port != 0: + self.port = port + else: + self.port = conf.get("port", 80) - self.conf = set_domain_and_port(conf, self.domain, self.port) + if self.domain: + self.conf = set_domain_and_port(conf, self.domain, self.port) def __getattr__(self, item, default=None): if item in self: diff --git a/src/idpyoidc/impexp.py b/src/idpyoidc/impexp.py index ff8938ea..4f1fe0c8 100644 --- a/src/idpyoidc/impexp.py +++ b/src/idpyoidc/impexp.py @@ -65,6 +65,8 @@ def dump_attr(self, cls, item, exclude_attributes: Optional[List[str]] = None) - val = qualified_name(item) elif isinstance(cls, list): val = [self.dump_attr(cls[0], v, exclude_attributes) for v in item] + elif isinstance(item, dict): + val = {k: self.dump_attr(type2cls(v), v, exclude_attributes) for k, v in item.items()} else: val = item.dump(exclude_attributes=exclude_attributes) diff --git a/src/idpyoidc/server/__init__.py b/src/idpyoidc/server/__init__.py index 6e93aea0..0d9c0006 100644 --- a/src/idpyoidc/server/__init__.py +++ b/src/idpyoidc/server/__init__.py @@ -46,8 +46,16 @@ def __init__( entity_id: Optional[str] = "", key_conf: Optional[dict] = None, ): - self.entity_id = entity_id or conf.get("entity_id") - self.issuer = conf.get("issuer", self.entity_id) + # issuer == entity_id + # entity_id as parameter has higher precedence then in conf + _iss = entity_id or conf.get("entity_id", "") + if _iss: + self.entity_id = self.issuer = _iss + else: + _iss = conf.get("issuer", "") + if _iss: + self.entity_id = self.issuer = _iss + self.persistence = None if upstream_get is None: @@ -81,6 +89,7 @@ def __init__( cwd=cwd, cookie_handler=cookie_handler, keyjar=self.keyjar, + entity_id=self.entity_id ) # Need to have context in place before doing this diff --git a/src/idpyoidc/server/client_authn.py b/src/idpyoidc/server/client_authn.py index f250c31b..755a941c 100755 --- a/src/idpyoidc/server/client_authn.py +++ b/src/idpyoidc/server/client_authn.py @@ -360,7 +360,7 @@ def _verify( res = super()._verify( request=request, key_type="client_secret", endpoint=endpoint, **kwargs ) - # Verify that a HS alg was used + # Verify that an HS alg was used return res @@ -487,9 +487,12 @@ def verify_client( if not allowed_methods: allowed_methods = list(methods.keys()) # If not specific for this endpoint then all + logger.debug(f"Allowed client authentication methods: {allowed_methods}") _method = None + _cinfo = {} for _method in (methods[meth] for meth in allowed_methods): if not _method.is_usable(request=request, authorization_token=authorization_token): + logger.debug(f"{_method} not usable") continue try: logger.info(f"Verifying client authentication using {_method.tag}") diff --git a/src/idpyoidc/server/configure.py b/src/idpyoidc/server/configure.py index 3f304e7c..bc32c942 100755 --- a/src/idpyoidc/server/configure.py +++ b/src/idpyoidc/server/configure.py @@ -54,7 +54,7 @@ "authorization_code": { "supports_minting": ["access_token", "refresh_token"], "max_usage": 1, - "expires_in": 120, # 2 minutes + "expires_in": 300, # 5 minutes }, "access_token": {"expires_in": 3600}, # An hour "refresh_token": { @@ -123,7 +123,7 @@ "id_token", ], "max_usage": 1, - "expires_in": 120, # 2 minutes + "expires_in": 300, # 5 minutes }, "access_token": {"expires_in": 3600}, # An hour "refresh_token": { @@ -256,6 +256,10 @@ def __init__( _val = conf.get(key) if not _val: if key in self.default_config: + if key == "issuer" and self.default_config[key] == 'https://{domain}:{port}': + self.issuer = "" + continue + _val = self.format( copy.deepcopy(self.default_config[key]), base_path=base_path, @@ -637,3 +641,26 @@ def __init__( "kwargs": {}, }, } + +DEFAULT_OAUTH2_ENDPOINTS = { + "server_metadata": { + "path": ".well-known/oauth-authorization-server", + "class": "idpyoidc.server.oauth2.server_metadata.ServerMetadata", + "kwargs": {}, + }, + "register": { + "path": "registration", + "class": "idpyoidc.server.oauth2.registration.Registration", + "kwargs": {}, + }, + "authorization": { + "path": "authorization", + "class": "idpyoidc.server.oauth2.authorization.Authorization", + "kwargs": {}, + }, + "token": { + "path": "token", + "class": "idpyoidc.server.oauth2.token.Token", + "kwargs": {}, + } +} diff --git a/src/idpyoidc/server/endpoint.py b/src/idpyoidc/server/endpoint.py index 256f0029..6bc94b17 100755 --- a/src/idpyoidc/server/endpoint.py +++ b/src/idpyoidc/server/endpoint.py @@ -226,6 +226,7 @@ def parse_request( else: req = self.request_cls() + LOGGER.debug(f"Parsed request: {req}") # Verify that the client is allowed to do this auth_info = self.client_authentication(req, http_info, endpoint=self, **kwargs) LOGGER.debug(f"parse_request:auth_info:{auth_info}") diff --git a/src/idpyoidc/server/endpoint_context.py b/src/idpyoidc/server/endpoint_context.py index 3317a5fc..c25fc071 100755 --- a/src/idpyoidc/server/endpoint_context.py +++ b/src/idpyoidc/server/endpoint_context.py @@ -21,6 +21,7 @@ from idpyoidc.server.session.manager import SessionManager from idpyoidc.server.session.manager import create_session_manager from idpyoidc.server.template_handler import Jinja2TemplateHandler +from idpyoidc.server.user_authn.authn_context import AuthnBroker from idpyoidc.server.user_authn.authn_context import populate_authn_broker from idpyoidc.server.util import get_http_params from idpyoidc.util import importer @@ -116,7 +117,8 @@ def __init__( keyjar: Optional[KeyJar] = None, claims_class: Optional[Claims] = None, ): - _id = entity_id or conf.get("issuer", "") + _id = entity_id or conf.get("entity_id", conf.get("issuer", "")) + conf["issuer"] = _id OidcContext.__init__(self, conf, entity_id=_id) self.conf = conf self.upstream_get = upstream_get @@ -157,7 +159,7 @@ def __init__( self.endpoint_to_authn_method = {} self.httpc = httpc or request self.idtoken = None - self.issuer = "" + self.issuer = _id # self.jwks_uri = None self.login_hint_lookup = None self.login_hint2acrs = None @@ -174,7 +176,6 @@ def __init__( self.client_authn_method = {} for param in [ - "issuer", "sso_ttl", "symkey", "client_authn", @@ -443,7 +444,7 @@ def setup_authentication(self): _conf, self.upstream_get, self.template_handler ) else: - self.authn_broker = {} + self.authn_broker = AuthnBroker() self.endpoint_to_authn_method = {} for method in self.authn_broker: diff --git a/src/idpyoidc/server/oauth2/authorization.py b/src/idpyoidc/server/oauth2/authorization.py index a90cbcfd..20c9d609 100755 --- a/src/idpyoidc/server/oauth2/authorization.py +++ b/src/idpyoidc/server/oauth2/authorization.py @@ -387,13 +387,23 @@ def authentication_error_response(self, request, error, error_description, **kwa def verify_response_type(self, request: Union[Message, dict], cinfo: dict) -> bool: # Checking response types - _registered = [set(rt.split(" ")) for rt in cinfo.get("response_types_supported", [])] + _rts = cinfo.get("response_types", cinfo.get("response_types_supported",[])) + _registered = [set(rt.split(" ")) for rt in _rts] if not _registered: # If no response_type is registered by the client then we'll use code. _registered = [{"code"}] + if isinstance(request["response_type"], list): + _asked_for = set(request["response_type"]) + else: + _asked_for = set(request["response_type"].split(" ")) + # Is the asked for response_type among those that are permitted - return set(request["response_type"]) in _registered + if _asked_for in _registered: + return True + else: + logger.debug(f"Asked for response_type: {_asked_for} not among registered: {_registered}") + return False def mint_token(self, token_class, grant, session_id, based_on=None, **kwargs): usage_rules = grant.usage_rules.get(token_class, {}) @@ -663,6 +673,8 @@ def setup_auth( res = self.pick_authn_method(request, redirect_uri, acr, **kwargs) + logger.debug(f"pick_authn_method response: {res}") + authn = res["method"] authn_class_ref = res["acr"] @@ -867,7 +879,12 @@ def create_authn_response(self, request: Union[dict, Message], sid: str) -> dict list(set(scope + resource_scopes)), _sinfo["client_id"] ) - rtype = set(request["response_type"][:]) + if isinstance(request["response_type"], str): + rtype = set(request["response_type"].split(" ")) + else: + rtype = set(request["response_type"][:]) + logger.debug(f"Response type: {rtype}") + handled_response_type = [] fragment_enc = True @@ -876,6 +893,15 @@ def create_authn_response(self, request: Union[dict, Message], sid: str) -> dict grant = _sinfo["grant"] + _aud = request.get("audience", None) + if _aud: + if isinstance(_aud, list): + _aud_arg = {"aud": _aud} + else: + _aud_arg = {"aud": [_aud]} + else: + _aud_arg = {} + if "code" in rtype: _code = self.mint_token( token_class="authorization_code", @@ -892,6 +918,7 @@ def create_authn_response(self, request: Union[dict, Message], sid: str) -> dict token_class="access_token", grant=grant, session_id=_sinfo["branch_id"], + **_aud_arg ) aresp["access_token"] = _access_token.value aresp["token_type"] = "Bearer" @@ -910,6 +937,7 @@ def create_authn_response(self, request: Union[dict, Message], sid: str) -> dict elif {"id_token", "token"}.issubset(rtype): kwargs = {"access_token": _access_token.value} + kwargs.update(_aud_arg) if rtype == {"id_token"}: kwargs["as_if"] = "userinfo" @@ -944,6 +972,7 @@ def create_authn_response(self, request: Union[dict, Message], sid: str) -> dict return {"response_args": resp, "fragment_enc": fragment_enc} aresp = self.extra_response_args(aresp) + logger.debug(f"Authn response: {aresp}") return {"response_args": aresp, "fragment_enc": fragment_enc} diff --git a/src/idpyoidc/server/oauth2/pushed_authorization.py b/src/idpyoidc/server/oauth2/pushed_authorization.py index 693b073f..4c8405f0 100644 --- a/src/idpyoidc/server/oauth2/pushed_authorization.py +++ b/src/idpyoidc/server/oauth2/pushed_authorization.py @@ -1,3 +1,4 @@ +import json import uuid from typing import Optional from typing import Union @@ -34,17 +35,26 @@ def process_request(self, request: Optional[Union[Message, str]] = None, **kwarg # create URN if isinstance(request, str): - _request = AuthorizationRequest().from_urlencoded(request) + _request = self.request_cls().from_urlencoded(request) else: - _request = AuthorizationRequest(**request) + _request = self.request_cls(**request) _request.verify(keyjar=self.upstream_get("attribute", "keyjar")) _urn = "urn:uuid:{}".format(uuid.uuid4()) # Store the parsed and verified request self.upstream_get("context").par_db[_urn] = _request + _response = {"request_uri": _urn, "expires_in": self.ttl} + # + _msg = AuthorizationRequest() + for param in _msg.required_parameters(): + _response[param] = _request.get(param) + + _redirect_uri = _request.get("redirect_uri") + if _redirect_uri: + _response["redirect_uri"] = _redirect_uri return { - "http_response": {"request_uri": _urn, "expires_in": self.ttl}, + "http_response": json.dumps(_response), "return_uri": _request["redirect_uri"], } diff --git a/src/idpyoidc/server/oauth2/token_helper/access_token.py b/src/idpyoidc/server/oauth2/token_helper/access_token.py index 3262cb71..afb74337 100755 --- a/src/idpyoidc/server/oauth2/token_helper/access_token.py +++ b/src/idpyoidc/server/oauth2/token_helper/access_token.py @@ -116,7 +116,11 @@ def process_request(self, req: Union[Message, dict], **kwargs): if resources: token_args = {"resources": resources} else: - token_args = None + token_args = {} + + _aud = grant.authorization_request.get("audience") + if _aud: + token_args["aud"] = _aud try: token = self._mint_token( diff --git a/src/idpyoidc/server/session/grant.py b/src/idpyoidc/server/session/grant.py index 0d1c85a2..286827b2 100644 --- a/src/idpyoidc/server/session/grant.py +++ b/src/idpyoidc/server/session/grant.py @@ -314,7 +314,11 @@ def mint_token( for param in ["audience", "aud"]: _val = class_args.get(param) if _val: - _aud = _aud.union(set(_val)) + if isinstance(_val, list): + _aud = _aud.union(set(_val)) + else: + _val = [_val] + _aud = _aud.union(set(_val)) del class_args[param] if _aud != set(): diff --git a/src/idpyoidc/server/token/id_token.py b/src/idpyoidc/server/token/id_token.py index 181a000c..7aa74752 100755 --- a/src/idpyoidc/server/token/id_token.py +++ b/src/idpyoidc/server/token/id_token.py @@ -6,6 +6,7 @@ from cryptojwt.jws.jws import factory from cryptojwt.jws.utils import left_hash from cryptojwt.jwt import JWT +from idpyoidc.message import Message from idpyoidc.server.construct import construct_provider_info from idpyoidc.server.exception import ToOld @@ -174,17 +175,16 @@ def payload( raise ValueError("Could not match expected 'acr'") if user_info: - try: + if isinstance(user_info, Message): user_info = user_info.to_dict() - except AttributeError: - pass # Make sure that there are no name clashes for key in ["iss", "sub", "aud", "exp", "acr", "nonce", "auth_time"]: - try: - del user_info[key] - except KeyError: - pass + if key in _args: + try: + del user_info[key] + except KeyError: + pass _args.update(user_info) @@ -242,6 +242,13 @@ def sign_encrypt( _context, client_info, "id_token", sign=sign, encrypt=encrypt ) + pack_args = {} + if user_info: + _aud = user_info.get("aud") + if _aud: + del user_info["aud"] + pack_args = {"aud": _aud} + _payload = self.payload( session_id=session_id, alg=alg_dict["sign_alg"], @@ -261,7 +268,7 @@ def sign_encrypt( **alg_dict, ) - return _jwt.pack(_payload, recv=client_id) + return _jwt.pack(_payload, recv=client_id, **pack_args) def __call__( self, @@ -275,9 +282,14 @@ def __call__( ) -> str: _context = self.upstream_get("context") + try: + del kwargs["client_id"] + except KeyError: + pass + user_id, client_id, grant_id = _context.session_manager.decrypt_session_id(session_id) - # Should I add session ID. This is about Single Logout. + # Should I add session ID ? This is about Single Logout. if include_session_id(_context, client_id, "back") or include_session_id( _context, client_id, "front" ): diff --git a/src/idpyoidc/server/user_authn/user.py b/src/idpyoidc/server/user_authn/user.py index b623d6a0..e4c94e74 100755 --- a/src/idpyoidc/server/user_authn/user.py +++ b/src/idpyoidc/server/user_authn/user.py @@ -45,7 +45,7 @@ class UserAuthnMethod(object): # override in subclass specifying suitable url endpoint to POST user input url_endpoint = "/verify" - FAILED_AUTHN = (None, True) + FAILED_AUTHN = (None, 0) def __init__(self, upstream_get=None, **kwargs): self.query_param = "upm_answer" @@ -314,7 +314,7 @@ class NoAuthn(UserAuthnMethod): # Just for testing allows anyone it without authentication - def __init__(self, user, upstream_get=None): + def __init__(self, user, upstream_get=None, **kwarg): UserAuthnMethod.__init__(self, upstream_get=upstream_get) self.user = user self.fail = None diff --git a/src/idpyoidc/server/user_info/__init__.py b/src/idpyoidc/server/user_info/__init__.py index f8206017..8aeb61d0 100755 --- a/src/idpyoidc/server/user_info/__init__.py +++ b/src/idpyoidc/server/user_info/__init__.py @@ -3,6 +3,9 @@ __author__ = "rolandh" +import logging + +logger = logging.getLogger(__name__) def dict_subset(a, b): for attr, values in a.items(): @@ -35,6 +38,7 @@ def __init__(self, db=None, db_file=""): self.db = db elif db_file: self.db = json.loads(open(db_file).read()) + logger.debug(f"Loaded user info file: {db_file}") else: self.db = {} diff --git a/src/idpyoidc/storage/abfile.py b/src/idpyoidc/storage/abfile.py index 6257fe21..33c116bc 100644 --- a/src/idpyoidc/storage/abfile.py +++ b/src/idpyoidc/storage/abfile.py @@ -89,6 +89,9 @@ def __getitem__(self, item): logger.info(f"File content change in {item}") fname = os.path.join(self.fdir, item) self.storage[item] = self._read_info(fname) + elif not self.storage[item]: + fname = os.path.join(self.fdir, item) + self.storage[item] = self._read_info(fname) _msg = f'Read from "{item}"' logger.debug(_msg) diff --git a/tests/afs/client_1.lock.lock b/tests/afs/client_1.lock.lock deleted file mode 100755 index e69de29b..00000000 diff --git a/tests/afs/client_2.lock.lock b/tests/afs/client_2.lock.lock deleted file mode 100755 index e69de29b..00000000 diff --git a/tests/private/token_jwks.json b/tests/private/token_jwks.json index 96348945..3ef68090 100644 --- a/tests/private/token_jwks.json +++ b/tests/private/token_jwks.json @@ -1 +1 @@ -{"keys": [{"kty": "oct", "use": "enc", "kid": "code", "k": "vSHDkLBHhDStkR0NWu8519rmV5zmnm5_"}, {"kty": "oct", "use": "enc", "kid": "refresh", "k": "XB2_T04TbhR_hmpm439FntWuuidEDy-H"}]} \ No newline at end of file +{"keys": [{"kty": "oct", "use": "enc", "kid": "code", "k": "vSHDkLBHhDStkR0NWu8519rmV5zmnm5_"}]} \ No newline at end of file diff --git a/tests/pub_client.jwks b/tests/pub_client.jwks index 84a27042..d5ce25ed 100644 --- a/tests/pub_client.jwks +++ b/tests/pub_client.jwks @@ -1 +1 @@ -{"keys": [{"kty": "EC", "use": "sig", "kid": "azZQQ2FEQjh3QnVZWVdrbHJkMEZSaWR6aVJ0LTBjeUFfeWRlbTRrRFZ5VQ", "crv": "P-256", "x": "2ADe18caWWGp6hpRbfa9HqQHDFNpid9xUmR56Wzm_wc", "y": "HnD_8QBanz4Y-UF8mKQFZXfqkGkXUSm34mLsdDKtSyk"}, {"kty": "RSA", "use": "sig", "kid": "SHEyYWcwNVk0LTdROTZzZ2FUWndIVXdack0xWUM5SEpwcS03dVUxWU4zRQ", "e": "AQAB", "n": "rRz52ddyP9Y2ezSlRsnkt-sjXfV_Ii7vOFX-cStLE3IUlVeSJGEe_kAASLr2r3BE2unjntaxj67NP8D95h_rzG1SpCklTEn-aTe3FOwNyTzUH_oiDVeRoEcf04Y43ciRGYRB5PhI6ii-2lYuig6hyUr776Qxiu6-0zw-M_ay2MgGSy5CEj55dDSvcUyxStUObxGpPWnEvybO1vnE7iJEWGNe0L5uPe5nLidOiR-JwjxSWEx1xZYtIjxaf2Ulu-qu4hwgwBUQdx4bNZyBfljKj55skWuHqPMG3xMjnedQC6Ms5bR3rIkbBpvmgI3kJK-4CZikM6ruyLo94-Lk19aYQw"}]} \ No newline at end of file +{"keys": [{"kty": "EC", "use": "sig", "kid": "azZQQ2FEQjh3QnVZWVdrbHJkMEZSaWR6aVJ0LTBjeUFfeWRlbTRrRFZ5VQ", "crv": "P-256", "x": "2ADe18caWWGp6hpRbfa9HqQHDFNpid9xUmR56Wzm_wc", "y": "HnD_8QBanz4Y-UF8mKQFZXfqkGkXUSm34mLsdDKtSyk"}, {"kty": "RSA", "use": "sig", "kid": "SHEyYWcwNVk0LTdROTZzZ2FUWndIVXdack0xWUM5SEpwcS03dVUxWU4zRQ", "n": "rRz52ddyP9Y2ezSlRsnkt-sjXfV_Ii7vOFX-cStLE3IUlVeSJGEe_kAASLr2r3BE2unjntaxj67NP8D95h_rzG1SpCklTEn-aTe3FOwNyTzUH_oiDVeRoEcf04Y43ciRGYRB5PhI6ii-2lYuig6hyUr776Qxiu6-0zw-M_ay2MgGSy5CEj55dDSvcUyxStUObxGpPWnEvybO1vnE7iJEWGNe0L5uPe5nLidOiR-JwjxSWEx1xZYtIjxaf2Ulu-qu4hwgwBUQdx4bNZyBfljKj55skWuHqPMG3xMjnedQC6Ms5bR3rIkbBpvmgI3kJK-4CZikM6ruyLo94-Lk19aYQw", "e": "AQAB"}]} \ No newline at end of file diff --git a/tests/pub_iss.jwks b/tests/pub_iss.jwks index 9b062907..77081f40 100644 --- a/tests/pub_iss.jwks +++ b/tests/pub_iss.jwks @@ -1 +1 @@ -{"keys": [{"kty": "EC", "use": "sig", "kid": "SmdKMlVGcG1zMnprdDdXZGpGWEczdHhlZVpGbkx1THpPdUY4d0w4bnZkSQ", "crv": "P-256", "x": "tRHJYm0fsOi0icpGEb33qiDVgt68ltMoYSWdLGhDGz4", "y": "fRpX0i6p5Jigf5I0qwW34PyStosMShwWAWS8x_w5o7E"}, {"kty": "RSA", "use": "sig", "kid": "R0FsaFdqREFaUFp1c0MwbUpsbHVSZ200blBJZWJVMTUtNGsyVlBmdHk5UQ", "e": "AQAB", "n": "2ilgsKVqF92KfhwmosSVeZOaDgb3RF1mbg-pqkmLO6YpOO06LF4V4angF-GhP-ysAm2E75aSIU4tnHVThFlcxTgKFqjYKJQXyVzTVK2r-L2IbvFPaDtvoU6WteybpMlIUVk2po3cFDGObCWYKCm7CUOLlwH0uOpui66P9VSCqdKVKbJRAQBvTSbP10KWPxulfqjWGJtHO5fY7-JVWwOBkG-eHSJIT_uaoPjyvKCZjknq04bLUV9qP78KRQpRyYijBN60w2v8F79baN9CN10TIEjjWKGz0uX0M_YYQzTUoSY5l5ka9RkL3wT4o2iQ1t5nHphX6aA-gqwgCQmi-nvjaw"}]} \ No newline at end of file +{"keys": [{"kty": "EC", "use": "sig", "kid": "SmdKMlVGcG1zMnprdDdXZGpGWEczdHhlZVpGbkx1THpPdUY4d0w4bnZkSQ", "crv": "P-256", "x": "tRHJYm0fsOi0icpGEb33qiDVgt68ltMoYSWdLGhDGz4", "y": "fRpX0i6p5Jigf5I0qwW34PyStosMShwWAWS8x_w5o7E"}, {"kty": "RSA", "use": "sig", "kid": "R0FsaFdqREFaUFp1c0MwbUpsbHVSZ200blBJZWJVMTUtNGsyVlBmdHk5UQ", "n": "2ilgsKVqF92KfhwmosSVeZOaDgb3RF1mbg-pqkmLO6YpOO06LF4V4angF-GhP-ysAm2E75aSIU4tnHVThFlcxTgKFqjYKJQXyVzTVK2r-L2IbvFPaDtvoU6WteybpMlIUVk2po3cFDGObCWYKCm7CUOLlwH0uOpui66P9VSCqdKVKbJRAQBvTSbP10KWPxulfqjWGJtHO5fY7-JVWwOBkG-eHSJIT_uaoPjyvKCZjknq04bLUV9qP78KRQpRyYijBN60w2v8F79baN9CN10TIEjjWKGz0uX0M_YYQzTUoSY5l5ka9RkL3wT4o2iQ1t5nHphX6aA-gqwgCQmi-nvjaw", "e": "AQAB"}]} \ No newline at end of file diff --git a/tests/test_21_abfile.py b/tests/test_21_abfile.py index b29b779b..97b5ab72 100644 --- a/tests/test_21_abfile.py +++ b/tests/test_21_abfile.py @@ -137,3 +137,13 @@ def test_get(self): abf.clear() assert set(abf.keys()) == set() + + def test_read_multiple_times(self): + abf = AbstractFileSystem(fdir=full_path("afs"), value_conv="idpyoidc.util.JSON") + abf["client_1"] = CLIENT_1 + val = abf.get("client_1") + assert val == CLIENT_1 + val = abf.get("client_1") + assert val == CLIENT_1 + val = abf.get("client_1") + assert val == CLIENT_1 diff --git a/tests/test_server_08_id_token.py b/tests/test_server_08_id_token.py index 0229b26d..001d7ff7 100644 --- a/tests/test_server_08_id_token.py +++ b/tests/test_server_08_id_token.py @@ -243,7 +243,6 @@ def test_id_token_payload_0(self): "email_verified", "jti", "scope", - "client_id", "iss", "sid", } @@ -266,7 +265,6 @@ def test_id_token_payload_with_code(self): "email_verified", "jti", "scope", - "client_id", "c_hash", "iss", "iat", @@ -297,7 +295,6 @@ def test_id_token_payload_with_access_token(self): "email_verified", "jti", "scope", - "client_id", "iss", "iat", "nonce", @@ -330,7 +327,6 @@ def test_id_token_payload_with_code_and_access_token(self): "email_verified", "jti", "scope", - "client_id", "iss", "iat", "nonce", @@ -357,7 +353,6 @@ def test_id_token_payload_with_userinfo(self): "email_verified", "jti", "scope", - "client_id", "given_name", "aud", "exp", @@ -392,7 +387,6 @@ def test_id_token_payload_many_0(self): "email_verified", "jti", "scope", - "client_id", "sub", "auth_time", "given_name", diff --git a/tests/test_server_10_session_manager.py b/tests/test_server_10_session_manager.py index 8812a68f..4d0b6476 100644 --- a/tests/test_server_10_session_manager.py +++ b/tests/test_server_10_session_manager.py @@ -714,3 +714,19 @@ def test_find_latest_idtoken(self): idt = grant.last_issued_token_of_type("id_token") assert idt.session_id == id_token_3.session_id + + def test_dump_load(self): + session_id = self.session_manager.create_session( + authn_event=self.authn_event, + auth_req=AUTH_REQ, + user_id="diana", + client_id="client_1", + scopes=["openid", "phoe"], + ) + + _session_state = self.session_manager.dump() + _keys = set(self.session_manager.db.keys()) + self.session_manager.db = {} + + self.session_manager.load(_session_state) + assert set(self.session_manager.db.keys()) == _keys diff --git a/tests/test_server_24_oauth2_authorization_endpoint.py b/tests/test_server_24_oauth2_authorization_endpoint.py index 627c4ebd..ece4d3b8 100755 --- a/tests/test_server_24_oauth2_authorization_endpoint.py +++ b/tests/test_server_24_oauth2_authorization_endpoint.py @@ -5,6 +5,7 @@ from urllib.parse import parse_qs from urllib.parse import urlparse +from cryptojwt.jws.jws import factory import pytest import yaml from cryptojwt import KeyJar @@ -773,6 +774,49 @@ def test_unwrap_identity(self): # # assert set(res.keys()) == {"authn_event", "identity", "user"} + def test_audience_id_token(self): + request = AuthorizationRequest( + client_id="client_1", + redirect_uri="https://example.com/cb", + response_type=["id_token"], + state="state", + nonce="nonce", + scope="openid", + audience="https://aud.exmple.org" + ) + _context = self.endpoint.upstream_get("context") + _context.cdb["client_1"]["response_types"] = ["code", "token", "id_token"] + _pr_resp = self.endpoint.parse_request(request) + _resp = self.endpoint.process_request(_pr_resp) + _jws = factory(_resp["response_args"]["id_token"]) + _payload = _jws.jwt.payload() + assert 'aud' in _payload + + + # def test_audience(self): + # request = AuthorizationRequest( + # client_id="client_id", + # redirect_uri="https://rp.example.com/cb", + # response_type=["id_token"], + # state="state", + # nonce="nonce", + # scope="openid", + # audience="https://aud.exmple.org" + # ) + # redirect_uri = request["redirect_uri"] + # cinfo = { + # "client_id": "client_id", + # "redirect_uris": [("https://rp.example.com/cb", {})], + # "id_token_signed_response_alg": "RS256", + # } + # + # session_id = self._create_session(request) + # + # item = self.endpoint.upstream_get("context").authn_broker.db["anon"] + # item["method"].user = b64e(as_bytes(json.dumps({"uid": "krall", "sid": session_id}))) + # + # res = self.endpoint.setup_auth(request, redirect_uri, cinfo, None) + # assert set(res.keys()) == {"session_id", "identity", "user"} def test_inputs(): elems = inputs(dict(foo="bar", home="stead")) diff --git a/tests/test_server_24_oauth2_token_endpoint.py b/tests/test_server_24_oauth2_token_endpoint.py index d4789959..17c41033 100644 --- a/tests/test_server_24_oauth2_token_endpoint.py +++ b/tests/test_server_24_oauth2_token_endpoint.py @@ -1,18 +1,18 @@ import json import os -import pytest from cryptojwt import JWT from cryptojwt import KeyJar from cryptojwt.jws.jws import factory from cryptojwt.key_jar import build_keyjar +import pytest from idpyoidc.context import OidcContext from idpyoidc.defaults import JWT_BEARER +from idpyoidc.message import Message from idpyoidc.message import REQUIRED_LIST_OF_STRINGS from idpyoidc.message import SINGLE_REQUIRED_INT from idpyoidc.message import SINGLE_REQUIRED_STRING -from idpyoidc.message import Message from idpyoidc.message.oauth2 import CCAccessTokenRequest from idpyoidc.message.oauth2 import JWTAccessToken from idpyoidc.message.oauth2 import ROPCAccessTokenRequest @@ -790,6 +790,25 @@ def test_refresh_token_request_other_client(self): assert isinstance(_resp, TokenErrorResponse) assert _resp.to_dict() == {"error": "invalid_grant", "error_description": "Wrong client"} + def test_audience(self): + auth_req = AUTH_REQ.copy() + auth_req["audience"] = "https://foobar.example.org" + + session_id = self._create_session(auth_req) + grant = self.session_manager[session_id] + code = self._mint_code(grant, auth_req["client_id"]) + + _token_request = TOKEN_REQ_DICT.copy() + _context = self.context + _token_request["code"] = code.value + _req = self.token_endpoint.parse_request(_token_request) + _resp = self.token_endpoint.process_request(request=_req) + + _jws = factory(_resp["response_args"]["access_token"]) + _payload = _jws.jwt.payload() + assert "aud" in _payload + assert _payload["aud"] == ['https://foobar.example.org'] + DEFAULT_TOKEN_HANDLER_ARGS = { "jwks_file": "private/token_jwks.json", diff --git a/tests/test_server_24_oidc_authorization_endpoint.py b/tests/test_server_24_oidc_authorization_endpoint.py index 836b9a81..fac7d13e 100755 --- a/tests/test_server_24_oidc_authorization_endpoint.py +++ b/tests/test_server_24_oidc_authorization_endpoint.py @@ -113,40 +113,26 @@ def full_path(local_file): "client_secret": 'hemligtkodord' "redirect_uris": - ['https://example.com/cb', ''] - "client_salt": "salted" - 'token_endpoint_auth_method': 'client_secret_post' - response_types_supported: + client_salt: salted + token_endpoint_auth_method: client_secret_post + response_types: - 'code' - 'code id_token' - 'id_token' - allowed_scopes: - - 'openid' - - 'profile' - - 'email' - - 'address' - - 'phone' - - 'offline_access' client2: client_secret: "spraket_sr.se" redirect_uris: - ['https://app1.example.net/foo', ''] - ['https://app2.example.net/bar', ''] - response_types_supported: + response_types: - code client3: client_secret: '2222222222222222222222222222222222222222' redirect_uris: - ['https://127.0.0.1:8090/authz_cb/bobcat', ''] post_logout_redirect_uri: ['https://openidconnect.net/', ''] - response_types_supported: + response_types: - code - allowed_scopes: - - 'openid' - - 'profile' - - 'email' - - 'address' - - 'phone' - - 'offline_access' """ @@ -880,7 +866,7 @@ def test_verify_response_type(self): assert self.endpoint.verify_response_type(request, client_info) is False - client_info["response_types_supported"] = [ + client_info["response_types"] = [ "code", "code id_token", "id_token", diff --git a/tests/test_server_30_oidc_end_session.py b/tests/test_server_30_oidc_end_session.py index 95255a70..52471fb4 100644 --- a/tests/test_server_30_oidc_end_session.py +++ b/tests/test_server_30_oidc_end_session.py @@ -198,23 +198,15 @@ def create_endpoint(self): "redirect_uris": [("{}cb".format(CLI1), None)], "client_salt": "salted", "token_endpoint_auth_method": "client_secret_post", - "response_types_supported": ["code", "code id_token", "id_token"], + "response_types": ["code", "code id_token", "id_token"], "post_logout_redirect_uri": [f"{CLI1}logout_cb", ""], - "allowed_scopes": [ - "openid", - "profile", - "email", - "address", - "phone", - "offline_access", - ], }, "client_2": { "client_secret": "hemligare", "redirect_uris": [("{}cb".format(CLI2), None)], "client_salt": "saltare", "token_endpoint_auth_method": "client_secret_post", - "response_types_supported": ["code", "code id_token", "id_token"], + "response_types": ["code", "code id_token", "id_token"], "post_logout_redirect_uri": [f"{CLI2}logout_cb", ""], "allowed_scopes": [ "openid", diff --git a/tests/test_server_40_oauth2_pushed_authorization.py b/tests/test_server_40_oauth2_pushed_authorization.py index 323dd6d6..5e79e9fc 100644 --- a/tests/test_server_40_oauth2_pushed_authorization.py +++ b/tests/test_server_40_oauth2_pushed_authorization.py @@ -1,4 +1,5 @@ import io +import json import os import pytest @@ -251,7 +252,8 @@ def test_pushed_auth_urlencoded_process(self): # And now for the authorization request with the OP provided request_uri - _msg["request_uri"] = _resp["http_response"]["request_uri"] + _resp = json.loads(_resp["http_response"]) + _msg["request_uri"] = _resp["request_uri"] for parameter in ["code_challenge", "code_challenge_method"]: del _msg[parameter] diff --git a/tests/test_server_41_oauth2_def_conf.py b/tests/test_server_41_oauth2_def_conf.py new file mode 100644 index 00000000..d528851f --- /dev/null +++ b/tests/test_server_41_oauth2_def_conf.py @@ -0,0 +1,140 @@ +import os + +import pytest +from cryptojwt.key_jar import build_keyjar + +from idpyoidc.client.defaults import DEFAULT_KEY_DEFS +from idpyoidc.client.oauth2 import Client +from idpyoidc.message.oauth2 import is_error_message +from idpyoidc.server import ASConfiguration +from idpyoidc.server import Server +from idpyoidc.server.user_authn.authn_context import INTERNETPROTOCOLPASSWORD +from idpyoidc.util import rndstr + +BASEDIR = os.path.abspath(os.path.dirname(__file__)) +AS_ENTITY_ID = "https://as.example.com" + + +class TestDefConf(object): + + @pytest.fixture(autouse=True) + def setup(self): + self.server = Server( + ASConfiguration( + conf={"authentication": + { + "anon": { + "acr": INTERNETPROTOCOLPASSWORD, + "class": "idpyoidc.server.user_authn.user.NoAuthn", + "kwargs": {"user": "diana"}, + } + } + }, + base_path=BASEDIR), + entity_id=AS_ENTITY_ID, + cwd=BASEDIR) + + conf = { + "redirect_uris": ["https://example.com/cli/authz_cb"], + "client_id": "client_1", + "client_secret": "abcdefghijklmnop", + "response_types_supported": ["code"], + "issuer": AS_ENTITY_ID + } + self.client = Client( + client_type="oauth2", + config=conf, + keyjar=build_keyjar(DEFAULT_KEY_DEFS) + ) + + self.context = self.server.context + self.context.cdb["client_1"] = conf + self.context.keyjar.import_jwks(self.client.keyjar.export_jwks(), "client_1") + + def test_init(self): + assert self.server + assert set(self.server.endpoint.keys()) == {'token', 'authorization', 'server_metadata'} + assert self.server.entity_id == self.server.issuer + assert self.server.entity_id == AS_ENTITY_ID + assert self.server.context.entity_id == AS_ENTITY_ID + + def do_query(self, service_type, endpoint_type, request_args, state): + _client_service = self.client.get_service(service_type) + req_info = _client_service.get_request_parameters(request_args=request_args) + + areq = req_info.get("request") + headers = req_info.get("headers") + + _server_endpoint = self.server.get_endpoint(endpoint_type) + if areq: + if headers: + argv = {"http_info": {"headers": headers}} + else: + argv = {} + areq.lax = True + _req = areq.serialize(_server_endpoint.request_format) + _pr_resp = _server_endpoint.parse_request(_req, **argv) + else: + _pr_resp = _server_endpoint.parse_request(areq) + + if is_error_message(_pr_resp): + return areq, _pr_resp + + _resp = _server_endpoint.process_request(_pr_resp) + if is_error_message(_resp): + return areq, _resp + + _response = _server_endpoint.do_response(**_resp) + + resp = _client_service.parse_response(_response["response"]) + _client_service.update_service_context(_resp["response_args"], key=state) + return areq, resp + + def process_setup(self, token=None, scope=None): + # ***** Discovery ********* + + _req, _resp = self.do_query("server_metadata", "server_metadata", {}, "") + + # ***** Authorization Request ********** + _nonce = rndstr(24) + _context = self.client.get_service_context() + # Need a new state for a new authorization request + _state = _context.cstate.create_state(iss=_context.get("issuer")) + _context.cstate.bind_key(_nonce, _state) + + req_args = {"response_type": ["code"], "nonce": _nonce, "state": _state} + + if scope: + _scope = scope + else: + _scope = ["openid"] + + if token and list(token.keys())[0] == "refresh_token": + _scope = ["openid", "offline_access"] + + req_args["scope"] = _scope + + areq, auth_response = self.do_query("authorization", "authorization", req_args, _state) + + # ***** Token Request ********** + + req_args = { + "code": auth_response["code"], + "state": auth_response["state"], + "redirect_uri": areq["redirect_uri"], + "grant_type": "authorization_code", + "client_id": self.client.get_client_id(), + "client_secret": _context.get_usage("client_secret"), + } + + _token_request, resp = self.do_query("accesstoken", "token", req_args, _state) + + return resp, _state, _scope + + def test_flow(self): + """ + Test that token exchange requests work correctly + """ + + resp, _state, _scope = self.process_setup(token="access_token", scope=["foobar"]) + assert resp \ No newline at end of file diff --git a/tests/test_tandem_oauth2_code.py b/tests/test_tandem_oauth2_code.py index 091a0046..a3d94ed0 100644 --- a/tests/test_tandem_oauth2_code.py +++ b/tests/test_tandem_oauth2_code.py @@ -214,7 +214,7 @@ def process_setup(self, token=None, scope=None): _req, _resp = self.do_query("server_metadata", "server_metadata", {}, "") # ***** Authorization Request ********** - _nonce = (rndstr(24),) + _nonce = rndstr(24) _context = self.client.get_service_context() # Need a new state for a new authorization request _state = _context.cstate.create_state(iss=_context.get("issuer"))