From 015a89aeb5d3447753cb03470f65149ad13f6d01 Mon Sep 17 00:00:00 2001 From: Kostis Triantafyllakis Date: Wed, 18 Jan 2023 15:18:40 +0200 Subject: [PATCH 01/76] Fix SessionManager constructor call --- src/idpyoidc/server/session/manager.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/idpyoidc/server/session/manager.py b/src/idpyoidc/server/session/manager.py index 563e411a..db759cbc 100644 --- a/src/idpyoidc/server/session/manager.py +++ b/src/idpyoidc/server/session/manager.py @@ -94,8 +94,7 @@ def __init__( self.conf = conf or {"session_params": {"encrypter": default_crypt_config()}} session_params = self.conf.get("session_params") or {} - _crypt_config = get_crypt_config(session_params) - super(SessionManager, self).__init__(handler, _crypt_config) + super(SessionManager, self).__init__(handler, self.conf) self.node_type = session_params.get("node_type", ["user", "client", "grant"]) # Make sure node_type is a list and must contain at least one element. From e79b4594f488fc4c6d16d39dd998f4b94c92ab9b Mon Sep 17 00:00:00 2001 From: Kostis Triantafyllakis Date: Thu, 19 Jan 2023 13:28:11 +0200 Subject: [PATCH 02/76] Fix bug related to the control of claims in tokens --- src/idpyoidc/server/session/claims.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/idpyoidc/server/session/claims.py b/src/idpyoidc/server/session/claims.py index 029332cc..35b43ee4 100755 --- a/src/idpyoidc/server/session/claims.py +++ b/src/idpyoidc/server/session/claims.py @@ -68,11 +68,11 @@ def _client_claims( _context = self.server_get("endpoint_context") add_claims_by_scope = _context.cdb[client_id].get("add_claims", {}).get("by_scope", {}) if add_claims_by_scope: - _claims_by_scope = add_claims_by_scope.get(claims_release_point, False) - if not _claims_by_scope and secondary_identifier: + _claims_by_scope = add_claims_by_scope.get(claims_release_point) + if _claims_by_scope is None and secondary_identifier: _claims_by_scope = add_claims_by_scope.get(secondary_identifier, False) - if not _claims_by_scope: + if _claims_by_scope is None: _claims_by_scope = module.kwargs.get("add_claims_by_scope", {}) else: _claims_by_scope = module.kwargs.get("add_claims_by_scope", {}) From 22b3bb54ef1218354a99cbbbf9210325474c349a Mon Sep 17 00:00:00 2001 From: roland Date: Fri, 20 Jan 2023 09:42:08 +0100 Subject: [PATCH 03/76] Make token.JWTToken use RFC9068 as model for payload. --- src/idpyoidc/message/oauth2/__init__.py | 19 +++++++++++++++++++ src/idpyoidc/server/token/jwt_token.py | 6 ++++++ tests/test_server_20e_jwt_token.py | 4 ++-- 3 files changed, 27 insertions(+), 2 deletions(-) diff --git a/src/idpyoidc/message/oauth2/__init__.py b/src/idpyoidc/message/oauth2/__init__.py index b8554f75..86494e68 100644 --- a/src/idpyoidc/message/oauth2/__init__.py +++ b/src/idpyoidc/message/oauth2/__init__.py @@ -404,6 +404,25 @@ class SecurityEventToken(Message): "toe": SINGLE_OPTIONAL_INT, } +class JWTAccessToken(Message): + c_param = { + "iss": SINGLE_REQUIRED_STRING, + "exp": SINGLE_REQUIRED_INT, + "aud": REQUIRED_LIST_OF_STRINGS, + "sub": SINGLE_REQUIRED_STRING, + "client_id": SINGLE_REQUIRED_STRING, + "iat": SINGLE_REQUIRED_INT, + "jti": SINGLE_REQUIRED_STRING, + "auth_time": SINGLE_OPTIONAL_INT, + "acr": SINGLE_OPTIONAL_STRING, + "amr": OPTIONAL_LIST_OF_STRINGS, + 'scope': OPTIONAL_LIST_OF_SP_SEP_STRINGS, + 'groups': OPTIONAL_LIST_OF_STRINGS, + 'roles': OPTIONAL_LIST_OF_STRINGS, + 'entitlements': OPTIONAL_LIST_OF_STRINGS + } + + def factory(msgtype, **kwargs): """ diff --git a/src/idpyoidc/server/token/jwt_token.py b/src/idpyoidc/server/token/jwt_token.py index 9c8ab32a..e08eeff2 100644 --- a/src/idpyoidc/server/token/jwt_token.py +++ b/src/idpyoidc/server/token/jwt_token.py @@ -12,6 +12,8 @@ from . import is_expired from .exception import UnknownToken from .exception import WrongTokenClass +from ...message import Message +from ...message.oauth2 import JWTAccessToken class JWTToken(Token): @@ -81,6 +83,10 @@ def __call__( lifetime=lifetime, sign_alg=self.alg, ) + if isinstance(payload, Message): # don't mess with it. + pass + else: + payload = JWTAccessToken(**payload).to_dict() return signer.pack(payload) diff --git a/tests/test_server_20e_jwt_token.py b/tests/test_server_20e_jwt_token.py index 9bd86599..b363bac2 100644 --- a/tests/test_server_20e_jwt_token.py +++ b/tests/test_server_20e_jwt_token.py @@ -517,7 +517,7 @@ def test_mint_with_scope(self): grant, session_id, code, - scope=["openid"], + scope=["openid", 'foobar'], aud=["https://audience.example.com"], ) @@ -527,7 +527,7 @@ def test_mint_with_scope(self): assert _info["token_class"] == "access_token" # assert _info["eduperson_scoped_affiliation"] == ["staff@example.org"] assert set(_info["aud"]) == {"https://audience.example.com"} - assert _info["scope"] == ["openid"] + assert _info["scope"] == "openid foobar" def test_mint_with_extra(self): _auth_req = AuthorizationRequest( From 90b0405a61291a70aa7e22dcf2620c8147aaf4b4 Mon Sep 17 00:00:00 2001 From: Giuseppe De Marco Date: Fri, 20 Jan 2023 10:50:30 +0100 Subject: [PATCH 04/76] fix: default extended configuration, removed warning --- src/idpyoidc/server/configure.py | 2 +- src/idpyoidc/server/session/manager.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/idpyoidc/server/configure.py b/src/idpyoidc/server/configure.py index cc34f5aa..61538603 100755 --- a/src/idpyoidc/server/configure.py +++ b/src/idpyoidc/server/configure.py @@ -461,7 +461,7 @@ def __init__( }, "httpc_params": {"verify": False, "timeout": 4}, "issuer": "https://{domain}:{port}", - "keys": { + "key_conf": { "private_path": "private/jwks.json", "key_defs": [ {"type": "RSA", "use": ["sig"]}, diff --git a/src/idpyoidc/server/session/manager.py b/src/idpyoidc/server/session/manager.py index 563e411a..729ac3fb 100644 --- a/src/idpyoidc/server/session/manager.py +++ b/src/idpyoidc/server/session/manager.py @@ -495,6 +495,7 @@ def get_session_info_by_token( authorization_request: Optional[bool] = False, handler_key: Optional[str] = "", ) -> dict: + if handler_key: _token_info = self.token_handler.handler[handler_key].info(token_value) else: From 37e711fc836e1efd1d2e0053985f8f39d12d4cc9 Mon Sep 17 00:00:00 2001 From: Kostis Triantafyllakis Date: Mon, 23 Jan 2023 11:51:24 +0200 Subject: [PATCH 05/76] Various client authentication related fixes --- src/idpyoidc/server/client_authn.py | 59 ++++++++++--------- src/idpyoidc/server/endpoint.py | 6 +- src/idpyoidc/server/oidc/session.py | 2 +- src/idpyoidc/server/oidc/userinfo.py | 2 +- tests/test_server_17_client_authn.py | 15 +++-- tests/test_server_20d_client_authn.py | 15 +++-- ...st_server_23_oidc_registration_endpoint.py | 2 +- .../test_server_32_oidc_read_registration.py | 2 +- tests/test_server_60_dpop.py | 6 +- 9 files changed, 58 insertions(+), 51 deletions(-) diff --git a/src/idpyoidc/server/client_authn.py b/src/idpyoidc/server/client_authn.py index 1c62b556..cd3e2c2d 100755 --- a/src/idpyoidc/server/client_authn.py +++ b/src/idpyoidc/server/client_authn.py @@ -262,6 +262,7 @@ def _verify( request: Optional[Union[dict, Message]] = None, authorization_token: Optional[str] = None, endpoint=None, # Optional[Endpoint] + get_client_id_from_token=None, **kwargs, ): _token = request.get("access_token") @@ -269,7 +270,7 @@ def _verify( raise ClientAuthenticationError("No access token") res = {"token": _token} - _client_id = request.get("client_id") + _client_id = get_client_id_from_token(endpoint_context, _token, request) if _client_id: res["client_id"] = _client_id return res @@ -483,6 +484,7 @@ def verify_client( auth_info = {} methods = endpoint_context.client_authn_method + client_id = None allowed_methods = getattr(endpoint, "client_authn_method") if not allowed_methods: allowed_methods = list(methods.keys()) @@ -499,48 +501,47 @@ def verify_client( endpoint=endpoint, get_client_id_from_token=get_client_id_from_token, ) - break except (BearerTokenAuthenticationError, ClientAuthenticationError): raise except Exception as err: logger.info("Verifying auth using {} failed: {}".format(_method.tag, err)) + continue - if auth_info.get("method") == "none": - return auth_info + if auth_info.get("method") == "none" and auth_info.get("client_id") is None: + break - client_id = auth_info.get("client_id") - if client_id is None: - raise ClientAuthenticationError("Failed to verify client") + client_id = auth_info.get("client_id") + if client_id is None: + raise ClientAuthenticationError("Failed to verify client") - if also_known_as: - client_id = also_known_as[client_id] - auth_info["client_id"] = client_id + if also_known_as: + client_id = also_known_as[client_id] + auth_info["client_id"] = client_id - if client_id not in endpoint_context.cdb: - raise UnknownClient("Unknown Client ID") + if client_id not in endpoint_context.cdb: + raise UnknownClient("Unknown Client ID") - _cinfo = endpoint_context.cdb[client_id] + _cinfo = endpoint_context.cdb[client_id] - if not valid_client_info(_cinfo): - logger.warning("Client registration has timed out or " "client secret is expired.") - raise InvalidClient("Not valid client") + if not valid_client_info(_cinfo): + logger.warning("Client registration has timed out or " "client secret is expired.") + raise InvalidClient("Not valid client") - # Validate that the used method is allowed for this client/endpoint - client_allowed_methods = _cinfo.get( - f"{endpoint.endpoint_name}_client_authn_method", _cinfo.get("client_authn_method") - ) - if client_allowed_methods is not None and _method and _method.tag not in client_allowed_methods: - logger.info( - f"Allowed methods for client: {client_id} at endpoint: {endpoint.name} are: " - f"`{', '.join(client_allowed_methods)}`" - ) - raise UnAuthorizedClient( - f"Authentication method: {_method.tag} not allowed for client: {client_id} in " - f"endpoint: {endpoint.name}" + # Validate that the used method is allowed for this client/endpoint + client_allowed_methods = _cinfo.get( + f"{endpoint.endpoint_name}_client_authn_method", _cinfo.get("client_authn_method") ) + if client_allowed_methods is not None and auth_info["method"] not in client_allowed_methods: + logger.info( + f"Allowed methods for client: {client_id} at endpoint: {endpoint.name} are: " + f"`{', '.join(client_allowed_methods)}`" + ) + auth_info = {} + continue + break # store what authn method was used - if auth_info.get("method"): + if "method" in auth_info and client_id: _request_type = request.__class__.__name__ _used_authn_method = _cinfo.get("auth_method") if _used_authn_method: diff --git a/src/idpyoidc/server/endpoint.py b/src/idpyoidc/server/endpoint.py index 24e7b3f6..2167285f 100755 --- a/src/idpyoidc/server/endpoint.py +++ b/src/idpyoidc/server/endpoint.py @@ -132,6 +132,9 @@ def set_client_authn_methods(self, **kwargs): kwargs[self.auth_method_attribute] = _methods elif _methods is not None: # [] or '' or something not None but regarded as nothing. self.client_authn_method = ["none"] # Ignore default value + elif self.default_capabilities: + self.client_authn_method = self.default_capabilities.get("client_authn_method") + self.endpoint_info = construct_provider_info(self.default_capabilities, **kwargs) return kwargs def get_provider_info_attributes(self): @@ -249,7 +252,8 @@ def client_authentication(self, request: Message, http_info: Optional[dict] = No if authn_info == {} and self.client_authn_method and len(self.client_authn_method): LOGGER.debug("client_authn_method: %s", self.client_authn_method) raise UnAuthorizedClient("Authorization failed") - + if "client_id" not in authn_info and authn_info.get("method") != "none": + raise UnAuthorizedClient("Authorization failed") return authn_info def do_post_parse_request( diff --git a/src/idpyoidc/server/oidc/session.py b/src/idpyoidc/server/oidc/session.py index 5768cf5b..716b8277 100644 --- a/src/idpyoidc/server/oidc/session.py +++ b/src/idpyoidc/server/oidc/session.py @@ -361,7 +361,7 @@ def parse_request(self, request, http_info=None, **kwargs): # Verify that the client is allowed to do this auth_info = self.client_authentication(request, http_info, **kwargs) - if not auth_info or auth_info["method"] == "none": + if not auth_info: pass elif isinstance(auth_info, ResponseMessage): return auth_info diff --git a/src/idpyoidc/server/oidc/userinfo.py b/src/idpyoidc/server/oidc/userinfo.py index 6b5473d0..ae6e87b5 100755 --- a/src/idpyoidc/server/oidc/userinfo.py +++ b/src/idpyoidc/server/oidc/userinfo.py @@ -182,7 +182,7 @@ def parse_request(self, request, http_info=None, **kwargs): try: auth_info = self.client_authentication(request, http_info, **kwargs) except ClientAuthenticationError as e: - return self.error_cls(error="invalid_token", error_description=e.args[0]) + return self.error_cls(error="invalid_token", error_description="Invalid token") if isinstance(auth_info, ResponseMessage): return auth_info diff --git a/tests/test_server_17_client_authn.py b/tests/test_server_17_client_authn.py index 4575ecd8..d42a2325 100644 --- a/tests/test_server_17_client_authn.py +++ b/tests/test_server_17_client_authn.py @@ -337,7 +337,7 @@ def create_method(self): def test_bearer_body(self): request = {"access_token": "1234567890"} - assert self.method.verify(request) == {"token": "1234567890", "method": "bearer_body"} + assert self.method.verify(request, get_client_id_from_token=get_client_id_from_token) == {"token": "1234567890", "method": "bearer_body"} def test_bearer_body_no_token(self): request = {} @@ -504,13 +504,12 @@ def test_verify_per_client_per_endpoint(self): ) assert res == {"method": "public", "client_id": client_id} - with pytest.raises(ClientAuthenticationError) as e: - verify_client( - self.endpoint_context, - request, - endpoint=self.server.server_get("endpoint", "endpoint_1"), - ) - assert e.value.args[0] == "Failed to verify client" + res = verify_client( + self.endpoint_context, + request, + endpoint=self.server.server_get("endpoint", "endpoint_1"), + ) + assert res == {} request = {"client_id": client_id, "client_secret": client_secret} res = verify_client( diff --git a/tests/test_server_20d_client_authn.py b/tests/test_server_20d_client_authn.py index e81d26dd..55ab886c 100755 --- a/tests/test_server_20d_client_authn.py +++ b/tests/test_server_20d_client_authn.py @@ -292,7 +292,7 @@ def create_method(self): def test_bearer_body(self): request = {"access_token": "1234567890"} - assert self.method.verify(request) == {"token": "1234567890", "method": "bearer_body"} + assert self.method.verify(request, get_client_id_from_token=get_client_id_from_token) == {"token": "1234567890", "method": "bearer_body"} def test_bearer_body_no_token(self): request = {} @@ -457,13 +457,12 @@ def test_verify_per_client_per_endpoint(self): ) assert res == {"method": "public", "client_id": client_id} - with pytest.raises(ClientAuthenticationError) as e: - verify_client( - self.endpoint_context, - request, - endpoint=self.server.server_get("endpoint", "token"), - ) - assert e.value.args[0] == "Failed to verify client" + res = verify_client( + self.endpoint_context, + request, + endpoint=self.server.server_get("endpoint", "token"), + ) + assert res == {} request = {"client_id": client_id, "client_secret": client_secret} res = verify_client( diff --git a/tests/test_server_23_oidc_registration_endpoint.py b/tests/test_server_23_oidc_registration_endpoint.py index 5b2ef4ae..64cb2a1b 100755 --- a/tests/test_server_23_oidc_registration_endpoint.py +++ b/tests/test_server_23_oidc_registration_endpoint.py @@ -127,7 +127,7 @@ def create_endpoint(self): "registration": { "path": "registration", "class": Registration, - "kwargs": {"client_auth_method": None}, + "kwargs": {"client_authn_method": ["none"]}, }, "authorization": { "path": "authorization", diff --git a/tests/test_server_32_oidc_read_registration.py b/tests/test_server_32_oidc_read_registration.py index 1f7670ad..2e803ba7 100644 --- a/tests/test_server_32_oidc_read_registration.py +++ b/tests/test_server_32_oidc_read_registration.py @@ -95,7 +95,7 @@ def create_endpoint(self): "registration": { "path": "registration", "class": Registration, - "kwargs": {"client_auth_method": None}, + "kwargs": {"client_authn_method": ["none"]}, }, "registration_api": { "path": "registration_api", diff --git a/tests/test_server_60_dpop.py b/tests/test_server_60_dpop.py index cd0301ef..7b74e172 100644 --- a/tests/test_server_60_dpop.py +++ b/tests/test_server_60_dpop.py @@ -164,7 +164,11 @@ def create_endpoint(self): "class": Authorization, "kwargs": {}, }, - "token": {"path": "{}/token", "class": Token, "kwargs": {}}, + "token": { + "path": "{}/token", + "class": Token, + "kwargs": {"client_authn_method": ["none"]}, + }, }, "client_authn": verify_client, "authentication": { From f9d2ab815136d246da49fede497405b232d804a2 Mon Sep 17 00:00:00 2001 From: Kostis Triantafyllakis Date: Mon, 23 Jan 2023 12:51:24 +0200 Subject: [PATCH 06/76] Introduce various token exchange enhancements --- src/idpyoidc/server/oauth2/token_helper.py | 64 +- src/idpyoidc/server/session/grant.py | 85 ++- src/idpyoidc/server/session/manager.py | 7 +- tests/test_server_36_oauth2_token_exchange.py | 680 +++++++++++++++++- tests/test_tandem_10_token_exchange.py | 10 +- 5 files changed, 815 insertions(+), 31 deletions(-) diff --git a/src/idpyoidc/server/oauth2/token_helper.py b/src/idpyoidc/server/oauth2/token_helper.py index 0475abfc..6a8e05aa 100755 --- a/src/idpyoidc/server/oauth2/token_helper.py +++ b/src/idpyoidc/server/oauth2/token_helper.py @@ -75,7 +75,7 @@ def _mint_token( token_args = meth(_context, client_id, token_args) if token_args: - _args = {"token_args": token_args} + _args = token_args else: _args = {} @@ -258,7 +258,6 @@ def process_request(self, req: Union[Message, dict], **kwargs): if ( issue_refresh and "refresh_token" in _supports_minting - and "refresh_token" in grant_types_supported ): try: refresh_token = self._mint_token( @@ -370,7 +369,7 @@ def process_request(self, req: Union[Message, dict], **kwargs): token_type = "DPoP" token = _grant.get_token(token_value) - scope = _grant.find_scope(token.based_on) + scope = _grant.find_scope(token) if "scope" in req: scope = req["scope"] access_token = self._mint_token( @@ -543,6 +542,27 @@ def post_parse_request(self, request, client_id="", **kwargs): ) resp = self._enforce_policy(request, token, config) + if isinstance(resp, TokenErrorResponse): + return resp + + scopes = resp.get("scope", []) + scopes = _context.scopes_handler.filter_scopes(scopes, client_id=resp["client_id"]) + + if not scopes: + logger.error("All requested scopes have been filtered out.") + return self.error_cls( + error="invalid_scope", error_description="Invalid requested scopes" + ) + + _requested_token_type = resp.get( + "requested_token_type", "urn:ietf:params:oauth:token-type:access_token" + ) + _token_class = self.token_types_mapping[_requested_token_type] + if _token_class == "refresh_token" and "offline_access" not in scopes: + return TokenErrorResponse( + error="invalid_request", + error_description="Exchanging this subject token to refresh token forbidden", + ) return resp @@ -572,7 +592,7 @@ def _enforce_policy(self, request, token, config): error_description="Unsupported requested token type", ) - request_info = dict(scope=request.get("scope", [])) + request_info = dict(scope=request.get("scope", token.scope)) try: check_unknown_scopes_policy(request_info, request["client_id"], _context) except UnAuthorizedClientScope: @@ -602,11 +622,11 @@ def _enforce_policy(self, request, token, config): logger.error(f"Error while executing the {fn} policy callable: {e}") return self.error_cls(error="server_error", error_description="Internal server error") - def token_exchange_response(self, token): + def token_exchange_response(self, token, issued_token_type): response_args = {} response_args["access_token"] = token.value response_args["scope"] = token.scope - response_args["issued_token_type"] = token.token_class + response_args["issued_token_type"] = issued_token_type if token.expires_at: response_args["expires_in"] = token.expires_at - utc_time_sans_frac() @@ -636,6 +656,7 @@ def process_request(self, request, **kwargs): error="invalid_request", error_description="Subject token invalid" ) + grant = _session_info["grant"] token = _mngr.find_token(_session_info["branch_id"], request["subject_token"]) _requested_token_type = request.get( "requested_token_type", "urn:ietf:params:oauth:token-type:access_token" @@ -650,16 +671,19 @@ def process_request(self, request, **kwargs): if "dpop_signing_alg_values_supported" in _context.provider_info: if request.get("dpop_jkt"): _token_type = "DPoP" + scopes = request.get("scope", []) if request["client_id"] != _session_info["client_id"]: _token_usage_rules = _context.authz.usage_rules(request["client_id"]) sid = _mngr.create_exchange_session( exchange_request=request, + original_grant=grant, original_session_id=sid, user_id=_session_info["user_id"], client_id=request["client_id"], token_usage_rules=_token_usage_rules, + scopes=scopes, ) try: @@ -676,6 +700,10 @@ def process_request(self, request, **kwargs): else: resources = request.get("audience") + _token_args = None + if resources: + _token_args = {"resources": resources} + try: new_token = self._mint_token( token_class=_token_class, @@ -683,10 +711,11 @@ def process_request(self, request, **kwargs): session_id=sid, client_id=request["client_id"], based_on=token, - scope=request.get("scope"), - token_args={"resources": resources}, + scope=scopes, + token_args=_token_args, token_type=_token_type, ) + new_token.expires_at = token.expires_at except MintingNotAllowed: logger.error(f"Minting not allowed for {_token_class}") return self.error_cls( @@ -694,7 +723,7 @@ def process_request(self, request, **kwargs): error_description="Token Exchange not allowed with that token", ) - return self.token_exchange_response(token=new_token) + return self.token_exchange_response(new_token, _requested_token_type) def _validate_configuration(self, config): if "requested_token_types_supported" not in config: @@ -763,14 +792,13 @@ def validate_token_exchange_policy(request, context, subject_token, **kwargs): f"forbidden", ) - if "scope" in request: - scopes = list(set(request.get("scope")).intersection(kwargs.get("scope"))) - if scopes: - request["scope"] = scopes - else: - return TokenErrorResponse( - error="invalid_request", - error_description="No supported scope requested", - ) + scopes = request.get("scope", subject_token.scope) + scopes = list(set(scopes).intersection(subject_token.scope)) + if kwargs.get("scope"): + scopes = list(set(scopes).intersection(kwargs.get("scope"))) + if scopes: + request["scope"] = scopes + else: + request.pop("scope") return request diff --git a/src/idpyoidc/server/session/grant.py b/src/idpyoidc/server/session/grant.py index de54c4bc..761991fe 100644 --- a/src/idpyoidc/server/session/grant.py +++ b/src/idpyoidc/server/session/grant.py @@ -184,6 +184,7 @@ def payload_arguments( endpoint_context, item: SessionToken, claims_release_point: str, + scope: Optional[dict] = None, extra_payload: Optional[dict] = None, secondary_identifier: str = "", ) -> dict: @@ -211,6 +212,10 @@ def payload_arguments( payload["jti"] = uuid1().hex + if scope is None: + scope = self.scope + payload["scope"] = scope + if extra_payload: payload.update(extra_payload) @@ -359,6 +364,7 @@ def mint_token( endpoint_context, item=item, claims_release_point=claims_release_point, + scope=scope, extra_payload=handler_args, secondary_identifier=_secondary_identifier, ) @@ -474,7 +480,7 @@ def get_usage_rules(token_type, endpoint_context, grant, client_id): class ExchangeGrant(Grant): parameter = Grant.parameter.copy() - parameter.update({"users": []}) + parameter.update({"exchange_request": TokenExchangeRequest, "original_session_id": ""}) type = "exchange_grant" def __init__( @@ -483,6 +489,8 @@ def __init__( claims: Optional[dict] = None, resources: Optional[list] = None, authorization_details: Optional[dict] = None, + authorization_request: Optional[Message] = None, + authentication_event: Optional[AuthnEvent] = None, issued_token: Optional[list] = None, usage_rules: Optional[dict] = None, exchange_request: Optional[TokenExchangeRequest] = None, @@ -501,6 +509,8 @@ def __init__( claims=claims, resources=resources, authorization_details=authorization_details, + authorization_request=authorization_request, + authentication_event=authentication_event, issued_token=issued_token, usage_rules=usage_rules, issued_at=issued_at, @@ -517,3 +527,76 @@ def __init__( } self.exchange_request = exchange_request self.original_branch_id = original_branch_id + + def payload_arguments( + self, + session_id: str, + endpoint_context, + item: SessionToken, + claims_release_point: str, + scope: Optional[dict] = None, + extra_payload: Optional[dict] = None, + secondary_identifier: str = "", + ) -> dict: + """ + :param session_id: Session ID + :param endpoint_context: EndPoint Context + :param claims_release_point: One of "userinfo", "introspection", "id_token", "access_token" + :param scope: scope from the request + :param extra_payload: + :param secondary_identifier: Used if the claims returned are also based on rules for + another release_point + :param item: A SessionToken instance + :type item: SessionToken + :return: dictionary containing information to place in a token value + """ + payload = {} + for _in, _out in [("scope", "scope"), ("resources", "aud")]: + _val = getattr(item, _in) + if _val: + payload[_out] = _val + else: + _val = getattr(self, _in) + if _val: + payload[_out] = _val + + payload["jti"] = uuid1().hex + + if scope is None: + scope = self.scope + + payload = {"scope": scope, "aud": self.resources, "jti": uuid1().hex} + + if extra_payload: + payload.update(extra_payload) + + _jkt = self.extra.get("dpop_jkt") + if _jkt: + payload["cnf"] = {"jkt": _jkt} + + if self.exchange_request: + client_id = self.exchange_request.get("client_id") + if client_id: + payload.update({"client_id": client_id, "sub": self.sub}) + + if item.claims: + _claims_restriction = item.claims + else: + _claims_restriction = endpoint_context.claims_interface.get_claims( + session_id, + scopes=scope, + claims_release_point=claims_release_point, + secondary_identifier=secondary_identifier, + ) + + user_id, _, _ = endpoint_context.session_manager.decrypt_session_id(session_id) + user_info = endpoint_context.claims_interface.get_user_claims(user_id, _claims_restriction) + payload.update(user_info) + + # Should I add the acr value + if self.add_acr_value(claims_release_point): + payload["acr"] = self.authentication_event["authn_info"] + elif self.add_acr_value(secondary_identifier): + payload["acr"] = self.authentication_event["authn_info"] + + return payload \ No newline at end of file diff --git a/src/idpyoidc/server/session/manager.py b/src/idpyoidc/server/session/manager.py index 563e411a..6c10a7a0 100644 --- a/src/idpyoidc/server/session/manager.py +++ b/src/idpyoidc/server/session/manager.py @@ -223,6 +223,7 @@ def create_grant( def create_exchange_grant( self, exchange_request: TokenExchangeRequest, + original_grant: Grant, original_session_id: str, user_id: str, client_id: Optional[str] = "", @@ -241,11 +242,13 @@ def create_exchange_grant( """ return self.add_exchange_grant( + authentication_event=original_grant.authentication_event, + authorization_request=original_grant.authorization_request, exchange_request=exchange_request, original_branch_id=original_session_id, path=self.make_path(user_id=user_id, client_id=client_id), + sub=original_grant.sub, token_usage_rules=token_usage_rules, - sub=self.sub_func[sub_type](user_id, salt=self.get_salt(), sector_identifier=""), scope=scopes ) @@ -286,6 +289,7 @@ def create_session( def create_exchange_session( self, exchange_request: TokenExchangeRequest, + original_grant: Grant, original_session_id: str, user_id: str, client_id: Optional[str] = "", @@ -309,6 +313,7 @@ def create_exchange_session( return self.create_exchange_grant( exchange_request=exchange_request, + original_grant=original_grant, original_session_id=original_session_id, user_id=user_id, client_id=client_id, diff --git a/tests/test_server_36_oauth2_token_exchange.py b/tests/test_server_36_oauth2_token_exchange.py index 8b60d8b3..7c1e70c8 100644 --- a/tests/test_server_36_oauth2_token_exchange.py +++ b/tests/test_server_36_oauth2_token_exchange.py @@ -118,7 +118,10 @@ def create_endpoint(self): "introspection": { "path": "introspection", "class": "idpyoidc.server.oauth2.introspection.Introspection", - "kwargs": {}, + "kwargs": { + "client_authn_method": ["client_secret_post"], + "enable_claims_per_client": False, + }, }, }, "authentication": { @@ -182,6 +185,13 @@ def create_endpoint(self): "redirect_uris": [("https://example.com/cb", None)], "client_salt": "salted", "token_endpoint_auth_method": "client_secret_post", + "grant_types_supported": [ + "authorization_code", + "implicit", + "urn:ietf:params:oauth:grant-type:jwt-bearer", + "refresh_token", + "urn:ietf:params:oauth:grant-type:token-exchange" + ], "response_types": ["code", "token", "code id_token", "id_token"], "allowed_scopes": ["openid", "profile", "offline_access"], } @@ -346,7 +356,7 @@ def test_token_exchange_per_client(self, token): "policy": { "": { "callable": "idpyoidc.server.oauth2.token_helper.validate_token_exchange_policy", - "kwargs": {"scope": ["openid"]}, + "kwargs": {"scope": ["openid", "offline_access"]}, } }, } @@ -355,7 +365,6 @@ def test_token_exchange_per_client(self, token): if list(token.keys())[0] == "refresh_token": areq["scope"] = ["openid", "offline_access"] - session_id = self._create_session(areq) grant = self.endpoint_context.authz(session_id, areq) code = self._mint_code(grant, areq["client_id"]) @@ -386,6 +395,173 @@ def test_token_exchange_per_client(self, token): "issued_token_type", } + def test_token_exchange_scopes_per_client(self): + """ + Test that a client that requests offline_access in a Token Exchange request + only get it if the subject token has it in its scope set, if it is permitted + by the policy and if it is present in the clients allowed scopes. + """ + self.endpoint_context.cdb["client_1"]["token_exchange"] = { + "subject_token_types_supported": [ + "urn:ietf:params:oauth:token-type:access_token", + "urn:ietf:params:oauth:token-type:refresh_token", + ], + "requested_token_types_supported": [ + "urn:ietf:params:oauth:token-type:access_token", + "urn:ietf:params:oauth:token-type:refresh_token", + ], + "default_requested_token_type": "urn:ietf:params:oauth:token-type:access_token", + "policy": { + "": { + "callable": "idpyoidc.server.oauth2.token_helper.validate_token_exchange_policy", + "kwargs": { + "scope": ["openid", "profile", "offline_access"] + }, + } + }, + } + + self.endpoint_context.cdb["client_1"]["allowed_scopes"] = ["openid", "email", "profile", "offline_access"] + + areq = AUTH_REQ.copy() + areq["scope"].append("profile") + + session_id = self._create_session(areq) + grant = self.endpoint_context.authz(session_id, areq) + + code = self._mint_code(grant, areq["client_id"]) + + _token_request = TOKEN_REQ_DICT.copy() + _token_request["code"] = code.value + _req = self.endpoint.parse_request(_token_request) + _resp = self.endpoint.process_request(request=_req) + + _token_value = _resp["response_args"]["access_token"] + + token_exchange_req = TokenExchangeRequest( + grant_type="urn:ietf:params:oauth:grant-type:token-exchange", + subject_token=_token_value, + subject_token_type="urn:ietf:params:oauth:token-type:access_token", + requested_token_type="urn:ietf:params:oauth:token-type:access_token", + scope="openid profile offline_access" + ) + + _req = self.endpoint.parse_request( + token_exchange_req.to_urlencoded(), + {"headers": {"authorization": "Basic {}".format("Y2xpZW50XzE6aGVtbGlndA==")}}, + ) + _resp = self.endpoint.process_request(request=_req) + # Note that offline_access is filtered because subject_token has no offline_access + # in its scope + assert set(_resp["response_args"]["scope"]) == set(["profile", "openid"]) + + def test_token_exchange_unsupported_scopes_per_client(self): + """ + Test that unsupported clients are handled appropriatelly + """ + self.endpoint_context.cdb["client_1"]["token_exchange"] = { + "subject_token_types_supported": [ + "urn:ietf:params:oauth:token-type:access_token", + "urn:ietf:params:oauth:token-type:refresh_token", + ], + "requested_token_types_supported": [ + "urn:ietf:params:oauth:token-type:access_token", + "urn:ietf:params:oauth:token-type:refresh_token", + ], + "default_requested_token_type": "urn:ietf:params:oauth:token-type:access_token", + "policy": { + "": { + "callable": "idpyoidc.server.oauth2.token_helper.validate_token_exchange_policy", + "kwargs": { + "scope": ["openid", "profile", "offline_access"] + }, + } + }, + "allowed_scopes": ["openid", "email", "profile", "offline_access"] + } + + areq = AUTH_REQ.copy() + areq["scope"].append("profile") + + session_id = self._create_session(areq) + grant = self.endpoint_context.authz(session_id, areq) + code = self._mint_code(grant, areq["client_id"]) + + _token_request = TOKEN_REQ_DICT.copy() + _token_request["code"] = code.value + _req = self.endpoint.parse_request(_token_request) + _resp = self.endpoint.process_request(request=_req) + + _token_value = _resp["response_args"]["access_token"] + + token_exchange_req = TokenExchangeRequest( + grant_type="urn:ietf:params:oauth:grant-type:token-exchange", + subject_token=_token_value, + subject_token_type="urn:ietf:params:oauth:token-type:access_token", + requested_token_type="urn:ietf:params:oauth:token-type:access_token", + scope="email" + ) + + _req = self.endpoint.parse_request( + token_exchange_req.to_urlencoded(), + {"headers": {"authorization": "Basic {}".format("Y2xpZW50XzE6aGVtbGlndA==")}}, + ) + _resp = self.endpoint.process_request(request=_req) + assert "scope" not in _resp + + def test_token_exchange_no_scopes_requested(self): + """ + Test that the correct scopes are returned when no scopes requested by the client + """ + self.endpoint_context.cdb["client_1"]["token_exchange"] = { + "subject_token_types_supported": [ + "urn:ietf:params:oauth:token-type:access_token", + "urn:ietf:params:oauth:token-type:refresh_token", + ], + "requested_token_types_supported": [ + "urn:ietf:params:oauth:token-type:access_token", + "urn:ietf:params:oauth:token-type:refresh_token", + ], + "default_requested_token_type": "urn:ietf:params:oauth:token-type:access_token", + "policy": { + "": { + "callable": "idpyoidc.server.oauth2.token_helper.validate_token_exchange_policy", + "kwargs": { + "scope": ["openid", "offline_access"] + }, + } + }, + "allowed_scopes": ["openid", "email", "profile", "offline_access"] + } + + areq = AUTH_REQ.copy() + areq["scope"].append("profile") + + session_id = self._create_session(areq) + grant = self.endpoint_context.authz(session_id, areq) + code = self._mint_code(grant, areq["client_id"]) + + _token_request = TOKEN_REQ_DICT.copy() + _token_request["code"] = code.value + _req = self.endpoint.parse_request(_token_request) + _resp = self.endpoint.process_request(request=_req) + + _token_value = _resp["response_args"]["access_token"] + + token_exchange_req = TokenExchangeRequest( + grant_type="urn:ietf:params:oauth:grant-type:token-exchange", + subject_token=_token_value, + subject_token_type="urn:ietf:params:oauth:token-type:access_token", + requested_token_type="urn:ietf:params:oauth:token-type:access_token" + ) + + _req = self.endpoint.parse_request( + token_exchange_req.to_urlencoded(), + {"headers": {"authorization": "Basic {}".format("Y2xpZW50XzE6aGVtbGlndA==")}}, + ) + _resp = self.endpoint.process_request(request=_req) + assert _resp["response_args"]["scope"] == ["openid"] + def test_additional_parameters(self): """ Test that a token exchange with additional parameters including @@ -438,7 +614,12 @@ def test_token_exchange_fails_if_disabled(self): Test that token exchange fails if it's not included in Token's grant_types_supported (that are set in its helper attribute). """ - del self.endpoint.helper["urn:ietf:params:oauth:grant-type:token-exchange"] + self.endpoint_context.cdb["client_1"]["grant_types_supported"] = [ + 'authorization_code', + 'implicit', + 'urn:ietf:params:oauth:grant-type:jwt-bearer', + 'refresh_token' + ] areq = AUTH_REQ.copy() @@ -467,8 +648,8 @@ def test_token_exchange_fails_if_disabled(self): _resp = self.endpoint.process_request(request=_req) assert _resp["error"] == "invalid_request" assert ( - _resp["error_description"] - == "Unsupported grant_type: urn:ietf:params:oauth:grant-type:token-exchange" + _resp["error_description"] + == "Unsupported grant_type: urn:ietf:params:oauth:grant-type:token-exchange" ) def test_wrong_resource(self): @@ -840,3 +1021,490 @@ def test_invalid_token(self): assert set(_resp.keys()) == {"error", "error_description"} assert _resp["error"] == "invalid_request" assert _resp["error_description"] == "Subject token invalid" + + def test_token_exchange_unsupported_scope_requested_1(self): + """ + Configuration: + - grant_types_supported: [authorization_code, refresh_token, ...:token-exchange] + - allowed_scopes: [profile, offline_access] + - requested_token_type: "...:access_token" + Scenario: + Client1 has an access_token1 (with offline_access, openid and profile scope). + Then, client1 exchanges access_token1 for a new access_token1_13 with scope offline_access + """ + self.endpoint_context.cdb["client_1"]["token_exchange"] = { + "subject_token_types_supported": [ + "urn:ietf:params:oauth:token-type:access_token", + "urn:ietf:params:oauth:token-type:refresh_token", + ], + "requested_token_types_supported": [ + "urn:ietf:params:oauth:token-type:access_token", + "urn:ietf:params:oauth:token-type:refresh_token", + ], + "default_requested_token_type": "urn:ietf:params:oauth:token-type:access_token", + "policy": { + "": { + "callable": "idpyoidc.server.oauth2.token_helper.validate_token_exchange_policy", + "kwargs": { + "scope": ["offline_access", "profile"] + }, + } + }, + } + + areq = AUTH_REQ.copy() + areq["scope"].append("profile") + areq["scope"].append("offline_access") + + self.endpoint_context.cdb["client_1"]["allowed_scopes"] = ["offline_access", "profile"] + + session_id = self._create_session(areq) + grant = self.endpoint_context.authz(session_id, areq) + code = self._mint_code(grant, areq["client_id"]) + + _token_request = TOKEN_REQ_DICT.copy() + _token_request["code"] = code.value + _req = self.endpoint.parse_request(_token_request) + _resp = self.endpoint.process_request(request=_req) + + _token_value = _resp["response_args"]["access_token"] + + token_exchange_req = TokenExchangeRequest( + grant_type="urn:ietf:params:oauth:grant-type:token-exchange", + subject_token=_token_value, + subject_token_type="urn:ietf:params:oauth:token-type:access_token", + requested_token_type="urn:ietf:params:oauth:token-type:access_token", + ) + + _req = self.endpoint.parse_request( + token_exchange_req.to_urlencoded(), + {"headers": {"authorization": "Basic {}".format("Y2xpZW50XzE6aGVtbGlndA==")}}, + ) + _resp = self.endpoint.process_request(request=_req) + assert set(_resp["response_args"]["scope"]) == {"offline_access", "profile"} + + token_exchange_req["scope"] = "profile" + + _req = self.endpoint.parse_request( + token_exchange_req.to_urlencoded(), + {"headers": {"authorization": "Basic {}".format("Y2xpZW50XzE6aGVtbGlndA==")}}, + ) + _resp = self.endpoint.process_request(request=_req) + assert set(_resp["response_args"]["scope"]) == {"profile"} + + token_exchange_req["scope"] = "offline_access" + + _req = self.endpoint.parse_request( + token_exchange_req.to_urlencoded(), + {"headers": {"authorization": "Basic {}".format("Y2xpZW50XzE6aGVtbGlndA==")}}, + ) + _resp = self.endpoint.process_request(request=_req) + assert set(_resp["response_args"]["scope"]) == {"offline_access"} + + token_exchange_req["scope"] = "offline_access profile" + + _req = self.endpoint.parse_request( + token_exchange_req.to_urlencoded(), + {"headers": {"authorization": "Basic {}".format("Y2xpZW50XzE6aGVtbGlndA==")}}, + ) + _resp = self.endpoint.process_request(request=_req) + assert set(_resp["response_args"]["scope"]) == {"offline_access", "profile"} + + def test_token_exchange_unsupported_scope_requested_2(self): + """ + Configuration: + - grant_types_supported: [authorization_code, refresh_token, ...:token-exchange] + - allowed_scopes: [profile] + - requested_token_type: "...:access_token" + Scenario: + Client1 has an access_token1 (with openid and profile scope). + Then, client1 exchanges access_token1 for a new access_token1_13 with scope offline_access + """ + self.endpoint_context.cdb["client_1"]["token_exchange"] = { + "subject_token_types_supported": [ + "urn:ietf:params:oauth:token-type:access_token", + "urn:ietf:params:oauth:token-type:refresh_token", + ], + "requested_token_types_supported": [ + "urn:ietf:params:oauth:token-type:access_token", + "urn:ietf:params:oauth:token-type:refresh_token", + ], + "default_requested_token_type": "urn:ietf:params:oauth:token-type:access_token", + "policy": { + "": { + "callable": "idpyoidc.server.oauth2.token_helper.validate_token_exchange_policy", + "kwargs": { + "scope": ["profile"] + }, + } + }, + } + self.endpoint_context.cdb["client_1"]["allowed_scopes"] = ["openid", "profile"] + + areq = AUTH_REQ.copy() + areq["scope"].append("profile") + areq["scope"].append("offline_access") + + session_id = self._create_session(areq) + grant = self.endpoint_context.authz(session_id, areq) + code = self._mint_code(grant, areq["client_id"]) + + _token_request = TOKEN_REQ_DICT.copy() + _token_request["code"] = code.value + _req = self.endpoint.parse_request(_token_request) + _resp = self.endpoint.process_request(request=_req) + + _token_value = _resp["response_args"]["access_token"] + + token_exchange_req = TokenExchangeRequest( + grant_type="urn:ietf:params:oauth:grant-type:token-exchange", + subject_token=_token_value, + subject_token_type="urn:ietf:params:oauth:token-type:access_token", + requested_token_type="urn:ietf:params:oauth:token-type:access_token", + ) + + _req = self.endpoint.parse_request( + token_exchange_req.to_urlencoded(), + {"headers": {"authorization": "Basic {}".format("Y2xpZW50XzE6aGVtbGlndA==")}}, + ) + _resp = self.endpoint.process_request(request=_req) + assert set(_resp["response_args"]["scope"]) == {"profile"} + + token_exchange_req["scope"] = "profile" + + _req = self.endpoint.parse_request( + token_exchange_req.to_urlencoded(), + {"headers": {"authorization": "Basic {}".format("Y2xpZW50XzE6aGVtbGlndA==")}}, + ) + _resp = self.endpoint.process_request(request=_req) + assert _resp["response_args"]["scope"] == ["profile"] + + token_exchange_req["scope"] = "offline_access" + + _req = self.endpoint.parse_request( + token_exchange_req.to_urlencoded(), + {"headers": {"authorization": "Basic {}".format("Y2xpZW50XzE6aGVtbGlndA==")}}, + ) + _resp = self.endpoint.process_request(request=_req) + assert _resp["error"] == "invalid_scope" + assert ( + _resp["error_description"] + == "Invalid requested scopes" + ) + + token_exchange_req["scope"] = "offline_access profile" + + _req = self.endpoint.parse_request( + token_exchange_req.to_urlencoded(), + {"headers": {"authorization": "Basic {}".format("Y2xpZW50XzE6aGVtbGlndA==")}}, + ) + _resp = self.endpoint.process_request(request=_req) + assert _resp["response_args"]["scope"] == ["profile"] + + def test_token_exchange_unsupported_scope_requested_3(self): + """ + Configuration: + - grant_types_supported: [authorization_code, ...:token-exchange] + - allowed_scopes: [offline_access, profile] + - requested_token_type: "...:access_token" + Scenario: + Client1 has an access_token1 (with openid and profile scope). + Then, client1 exchanges access_token1 for a new access_token1_13 with scope offline_access + """ + self.endpoint_context.cdb["client_1"]["token_exchange"] = { + "subject_token_types_supported": [ + "urn:ietf:params:oauth:token-type:access_token", + "urn:ietf:params:oauth:token-type:refresh_token", + ], + "requested_token_types_supported": [ + "urn:ietf:params:oauth:token-type:access_token", + "urn:ietf:params:oauth:token-type:refresh_token", + ], + "default_requested_token_type": "urn:ietf:params:oauth:token-type:access_token", + "policy": { + "": { + "callable": "idpyoidc.server.oauth2.token_helper.validate_token_exchange_policy", + "kwargs": { + "scope": ["offline_access", "profile"] + }, + } + }, + } + self.endpoint_context.cdb["client_1"]["grant_types_supported"] = [ + 'authorization_code', + 'implicit', + 'urn:ietf:params:oauth:grant-type:jwt-bearer', + 'urn:ietf:params:oauth:grant-type:token-exchange' + ] + + areq = AUTH_REQ.copy() + areq["scope"].append("profile") + areq["scope"].append("offline_access") + + session_id = self._create_session(areq) + grant = self.endpoint_context.authz(session_id, areq) + code = self._mint_code(grant, areq["client_id"]) + + _token_request = TOKEN_REQ_DICT.copy() + _token_request["code"] = code.value + _req = self.endpoint.parse_request(_token_request) + _resp = self.endpoint.process_request(request=_req) + _token_value = _resp["response_args"]["access_token"] + + token_exchange_req = TokenExchangeRequest( + grant_type="urn:ietf:params:oauth:grant-type:token-exchange", + subject_token=_token_value, + subject_token_type="urn:ietf:params:oauth:token-type:access_token", + requested_token_type="urn:ietf:params:oauth:token-type:access_token", + ) + + _req = self.endpoint.parse_request( + token_exchange_req.to_urlencoded(), + {"headers": {"authorization": "Basic {}".format("Y2xpZW50XzE6aGVtbGlndA==")}}, + ) + _resp = self.endpoint.process_request(request=_req) + assert set(_resp["response_args"]["scope"]) == {"profile", "offline_access"} + + token_exchange_req["scope"] = "profile" + + _req = self.endpoint.parse_request( + token_exchange_req.to_urlencoded(), + {"headers": {"authorization": "Basic {}".format("Y2xpZW50XzE6aGVtbGlndA==")}}, + ) + _resp = self.endpoint.process_request(request=_req) + assert _resp["response_args"]["scope"] == ["profile"] + + token_exchange_req["scope"] = "offline_access" + + _req = self.endpoint.parse_request( + token_exchange_req.to_urlencoded(), + {"headers": {"authorization": "Basic {}".format("Y2xpZW50XzE6aGVtbGlndA==")}}, + ) + _resp = self.endpoint.process_request(request=_req) + assert _resp["response_args"]["scope"] == ["offline_access"] + + _c_interface = self.introspection_endpoint.server_get("endpoint_context").claims_interface + grant.claims = { + "introspection": _c_interface.get_claims( + session_id, scopes=AUTH_REQ["scope"], claims_release_point="introspection" + ) + } + _req = self.introspection_endpoint.parse_request( + { + "token": _resp["response_args"]["access_token"], + "client_id": "client_1", + "client_secret": self.endpoint_context.cdb["client_1"]["client_secret"], + } + ) + _resp_intro = self.introspection_endpoint.process_request(_req) + assert _resp_intro["response_args"]["scope"] == "offline_access" + + token_exchange_req["scope"] = "offline_access profile" + + _req = self.endpoint.parse_request( + token_exchange_req.to_urlencoded(), + {"headers": {"authorization": "Basic {}".format("Y2xpZW50XzE6aGVtbGlndA==")}}, + ) + _resp = self.endpoint.process_request(request=_req) + assert set(_resp["response_args"]["scope"]) == {"profile", "offline_access"} + + def test_token_exchange_unsupported_scope_requested_4(self): + """ + Configuration: + - grant_types_supported: [authorization_code, ...:token-exchange] + - allowed_scopes: [offline_access, profile] + - refresh_token removed from grant_types_supported + - requested_token_type: "...:access_token" + Scenario: + Client1 has an access_token1 (with openid and profile scope). + Then, client1 exchanges access_token1 for a new refresh token + """ + self.endpoint_context.cdb["client_1"]["token_exchange"] = { + "subject_token_types_supported": [ + "urn:ietf:params:oauth:token-type:access_token", + "urn:ietf:params:oauth:token-type:refresh_token", + ], + "requested_token_types_supported": [ + "urn:ietf:params:oauth:token-type:access_token", + "urn:ietf:params:oauth:token-type:refresh_token", + ], + "default_requested_token_type": "urn:ietf:params:oauth:token-type:access_token", + "policy": { + "": { + "callable": "idpyoidc.server.oauth2.token_helper.validate_token_exchange_policy", + "kwargs": { + "scope": ["offline_access", "profile"] + }, + } + }, + } + self.endpoint_context.cdb["client_1"]["grant_types_supported"] = [ + 'authorization_code', + 'implicit', + 'urn:ietf:params:oauth:grant-type:jwt-bearer', + 'urn:ietf:params:oauth:grant-type:token-exchange' + ] + + areq = AUTH_REQ.copy() + areq["scope"].append("profile") + areq["scope"].append("offline_access") + + session_id = self._create_session(areq) + grant = self.endpoint_context.authz(session_id, areq) + code = self._mint_code(grant, areq["client_id"]) + + _token_request = TOKEN_REQ_DICT.copy() + _token_request["code"] = code.value + _req = self.endpoint.parse_request(_token_request) + _resp = self.endpoint.process_request(request=_req) + _token_value = _resp["response_args"]["access_token"] + + token_exchange_req = TokenExchangeRequest( + grant_type="urn:ietf:params:oauth:grant-type:token-exchange", + subject_token=_token_value, + subject_token_type="urn:ietf:params:oauth:token-type:access_token", + requested_token_type="urn:ietf:params:oauth:token-type:refresh_token", + ) + + _req = self.endpoint.parse_request( + token_exchange_req.to_urlencoded(), + {"headers": {"authorization": "Basic {}".format("Y2xpZW50XzE6aGVtbGlndA==")}}, + ) + _resp = self.endpoint.process_request(request=_req) + assert set(_resp["response_args"]["scope"]) == {"profile", "offline_access"} + + token_exchange_req["scope"] = "profile" + + _req = self.endpoint.parse_request( + token_exchange_req.to_urlencoded(), + {"headers": {"authorization": "Basic {}".format("Y2xpZW50XzE6aGVtbGlndA==")}}, + ) + _resp = self.endpoint.process_request(request=_req) + assert _resp["error"] == "invalid_request" + assert ( + _resp["error_description"] + == "Exchanging this subject token to refresh token forbidden" + ) + + token_exchange_req["scope"] = "offline_access" + + _req = self.endpoint.parse_request( + token_exchange_req.to_urlencoded(), + {"headers": {"authorization": "Basic {}".format("Y2xpZW50XzE6aGVtbGlndA==")}}, + ) + _resp = self.endpoint.process_request(request=_req) + assert set(_resp["response_args"]["scope"]) == {"offline_access"} + + token_exchange_req["scope"] = "offline_access profile" + + _req = self.endpoint.parse_request( + token_exchange_req.to_urlencoded(), + {"headers": {"authorization": "Basic {}".format("Y2xpZW50XzE6aGVtbGlndA==")}}, + ) + _resp = self.endpoint.process_request(request=_req) + assert set(_resp["response_args"]["scope"]) == {"profile", "offline_access"} + + token = grant.get_token(_resp["response_args"]["access_token"]) + assert token.token_class == "refresh_token" + + def test_token_exchange_unsupported_scope_requested_5(self): + """ + Configuration: + - grant_types_supported: [authorization_code, ...:token-exchange] + - allowed_scopes: [profile] + - requested_token_type: "...:access_token" + Scenario: + Client1 has an access_token1 (with openid and profile scope). + Then, client1 exchanges access_token1 for a new refresh token + """ + self.endpoint_context.cdb["client_1"]["token_exchange"] = { + "subject_token_types_supported": [ + "urn:ietf:params:oauth:token-type:access_token", + "urn:ietf:params:oauth:token-type:refresh_token", + ], + "requested_token_types_supported": [ + "urn:ietf:params:oauth:token-type:access_token", + "urn:ietf:params:oauth:token-type:refresh_token", + ], + "default_requested_token_type": "urn:ietf:params:oauth:token-type:access_token", + "policy": { + "": { + "callable": "idpyoidc.server.oauth2.token_helper.validate_token_exchange_policy", + "kwargs": { + "scope": ["profile"] + }, + } + }, + } + + areq = AUTH_REQ.copy() + areq["scope"].append("profile") + areq["scope"].append("offline_access") + + session_id = self._create_session(areq) + grant = self.endpoint_context.authz(session_id, areq) + code = self._mint_code(grant, areq["client_id"]) + + _token_request = TOKEN_REQ_DICT.copy() + _token_request["code"] = code.value + _req = self.endpoint.parse_request(_token_request) + _resp = self.endpoint.process_request(request=_req) + _token_value = _resp["response_args"]["access_token"] + + token_exchange_req = TokenExchangeRequest( + grant_type="urn:ietf:params:oauth:grant-type:token-exchange", + subject_token=_token_value, + subject_token_type="urn:ietf:params:oauth:token-type:access_token", + requested_token_type="urn:ietf:params:oauth:token-type:refresh_token", + ) + + _req = self.endpoint.parse_request( + token_exchange_req.to_urlencoded(), + {"headers": {"authorization": "Basic {}".format("Y2xpZW50XzE6aGVtbGlndA==")}}, + ) + _resp = self.endpoint.process_request(request=_req) + assert _resp["error"] == "invalid_request" + assert ( + _resp["error_description"] + == "Exchanging this subject token to refresh token forbidden" + ) + + token_exchange_req["scope"] = "profile" + + _req = self.endpoint.parse_request( + token_exchange_req.to_urlencoded(), + {"headers": {"authorization": "Basic {}".format("Y2xpZW50XzE6aGVtbGlndA==")}}, + ) + _resp = self.endpoint.process_request(request=_req) + assert _resp["error"] == "invalid_request" + assert ( + _resp["error_description"] + == "Exchanging this subject token to refresh token forbidden" + ) + + token_exchange_req["scope"] = "offline_access" + + _req = self.endpoint.parse_request( + token_exchange_req.to_urlencoded(), + {"headers": {"authorization": "Basic {}".format("Y2xpZW50XzE6aGVtbGlndA==")}}, + ) + _resp = self.endpoint.process_request(request=_req) + assert _resp["error"] == "invalid_scope" + assert ( + _resp["error_description"] + == "Invalid requested scopes" + ) + + token_exchange_req["scope"] = "offline_access profile" + + _req = self.endpoint.parse_request( + token_exchange_req.to_urlencoded(), + {"headers": {"authorization": "Basic {}".format("Y2xpZW50XzE6aGVtbGlndA==")}}, + ) + _resp = self.endpoint.process_request(request=_req) + assert _resp["error"] == "invalid_request" + assert ( + _resp["error_description"] + == "Exchanging this subject token to refresh token forbidden" + ) + diff --git a/tests/test_tandem_10_token_exchange.py b/tests/test_tandem_10_token_exchange.py index bf2c2649..b462b6c4 100644 --- a/tests/test_tandem_10_token_exchange.py +++ b/tests/test_tandem_10_token_exchange.py @@ -340,8 +340,8 @@ def test_token_exchange(self, token): "issued_token_type", } - assert _te_resp["issued_token_type"] == list(token.keys())[0] - assert _te_resp["scope"] == _scope + assert _te_resp["issued_token_type"] == token[list(token.keys())[0]] + assert set(_te_resp["scope"]) == set(_scope) @pytest.mark.parametrize( "token", @@ -367,7 +367,7 @@ def test_token_exchange_per_client(self, token): "": { "callable": "idpyoidc.server.oauth2.token_helper.validate_token_exchange_policy", - "kwargs": {"scope": ["openid"]}, + "kwargs": {"scope": ["openid", "offline_access"]}, } }, } @@ -395,8 +395,8 @@ def test_token_exchange_per_client(self, token): "issued_token_type", } - assert _te_resp["issued_token_type"] == list(token.keys())[0] - assert _te_resp["scope"] == _scope + assert _te_resp["issued_token_type"] == token[list(token.keys())[0]] + assert set(_te_resp["scope"]) == set(_scope) def test_additional_parameters(self): """ From 6e4a701c0aa44c216d6f703b53cbbe5f062de482 Mon Sep 17 00:00:00 2001 From: Kostis Triantafyllakis Date: Mon, 23 Jan 2023 13:05:04 +0200 Subject: [PATCH 07/76] Properly handle expired tokens on introspection --- src/idpyoidc/server/oauth2/introspection.py | 3 ++- tests/test_server_31_oauth2_introspection.py | 9 +++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/idpyoidc/server/oauth2/introspection.py b/src/idpyoidc/server/oauth2/introspection.py index f36d9503..11b29cca 100644 --- a/src/idpyoidc/server/oauth2/introspection.py +++ b/src/idpyoidc/server/oauth2/introspection.py @@ -6,6 +6,7 @@ from idpyoidc.server.endpoint import Endpoint from idpyoidc.server.token.exception import UnknownToken from idpyoidc.server.token.exception import WrongTokenClass +from idpyoidc.server.exception import ToOld LOGGER = logging.getLogger(__name__) @@ -103,7 +104,7 @@ def process_request(self, request=None, release: Optional[list] = None, **kwargs _session_info = _context.session_manager.get_session_info_by_token( request_token, grant=True ) - except (UnknownToken, WrongTokenClass): + except (UnknownToken, WrongTokenClass, ToOld): return {"response_args": _resp} grant = _session_info["grant"] diff --git a/tests/test_server_31_oauth2_introspection.py b/tests/test_server_31_oauth2_introspection.py index f532db02..04748917 100644 --- a/tests/test_server_31_oauth2_introspection.py +++ b/tests/test_server_31_oauth2_introspection.py @@ -457,9 +457,14 @@ def test_jwt_unknown_key(self): _resp = self.introspection_endpoint.process_request(_req) assert _resp["response_args"]["active"] is False - def test_expired_access_token(self): + def test_expired_access_token(self, monkeypatch): access_token = self._get_access_token(AUTH_REQ) - access_token.expires_at = utc_time_sans_frac() - 1000 + lifetime = self.session_manager.token_handler.handler["access_token"].lifetime + + def mock(): + return utc_time_sans_frac() + lifetime + 1 + + monkeypatch.setattr("idpyoidc.server.token.utc_time_sans_frac", mock) _context = self.introspection_endpoint.server_get("endpoint_context") From d8d1c13b87b1baf9fd067c24f8ad828a3b734bc9 Mon Sep 17 00:00:00 2001 From: roland Date: Mon, 23 Jan 2023 16:18:38 +0100 Subject: [PATCH 08/76] The parameter 'lev' in serialization/deserialization functions/methods is not used - removed. --- src/idpyoidc/message/__init__.py | 62 ++++++++++++--------------- src/idpyoidc/message/oidc/__init__.py | 39 ++++++++--------- tests/test_04_message.py | 4 +- 3 files changed, 47 insertions(+), 58 deletions(-) diff --git a/src/idpyoidc/message/__init__.py b/src/idpyoidc/message/__init__.py index cbe1381c..234cb014 100644 --- a/src/idpyoidc/message/__init__.py +++ b/src/idpyoidc/message/__init__.py @@ -77,7 +77,7 @@ def set_defaults(self): for key, val in self.c_default.items(): self._dict.setdefault(key, val) - def to_urlencoded(self, lev=0): + def to_urlencoded(self): """ Creates a string using the application/x-www-form-urlencoded format @@ -114,13 +114,13 @@ def to_urlencoded(self, lev=0): params.append((key, val.encode("utf-8"))) elif isinstance(val, list): if _ser: - params.append((key, str(_ser(val, sformat="urlencoded", lev=lev)))) + params.append((key, str(_ser(val, sformat="urlencoded")))) else: for item in val: params.append((key, str(item).encode("utf-8"))) elif isinstance(val, Message): try: - _val = json.dumps(_ser(val, sformat="dict", lev=lev + 1)) + _val = json.dumps(_ser(val, sformat="dict")) params.append((key, _val)) except TypeError: params.append((key, val)) @@ -128,7 +128,7 @@ def to_urlencoded(self, lev=0): params.append((key, val)) else: try: - params.append((key, _ser(val, lev=lev))) + params.append((key, _ser(val))) except Exception: params.append((key, str(val))) @@ -143,18 +143,17 @@ def to_urlencoded(self, lev=0): _val.append((k, v)) return urlencode(_val) - def serialize(self, method="urlencoded", lev=0, **kwargs): + def serialize(self, method="urlencoded", **kwargs): """ Convert this instance to another representation. Which representation is given by the choice of serialization method. :param method: A serialization method. Presently 'urlencoded', 'json', 'jwt' and 'dict' is supported. - :param lev: :param kwargs: Extra key word arguments :return: THe content of this message serialized using a chosen method """ - return getattr(self, "to_%s" % method)(lev=lev, **kwargs) + return getattr(self, "to_%s" % method)(**kwargs) def deserialize(self, info, method="urlencoded", **kwargs): """ @@ -231,7 +230,7 @@ def from_urlencoded(self, urlencoded, **kwargs): return self - def to_dict(self, lev=0): + def to_dict(self): """ Return a dictionary representation of the class @@ -241,7 +240,6 @@ def to_dict(self, lev=0): _spec = self.c_param _res = {} - lev += 1 for key, val in self._dict.items(): try: _ser = _spec[str(key)][2] @@ -256,12 +254,12 @@ def to_dict(self, lev=0): _ser = None if _ser: - val = _ser(val, "dict", lev) + val = _ser(val, "dict") if isinstance(val, Message): - _res[key] = val.to_dict(lev + 1) + _res[key] = val.to_dict() elif isinstance(val, list) and isinstance(next(iter(val or []), None), Message): - _res[key] = [v.to_dict(lev) for v in val] + _res[key] = [v.to_dict() for v in val] else: _res[key] = val @@ -418,18 +416,14 @@ def _add_value(self, skey, vtyp, key, val, _deser, null_allowed, sformat="urlenc else: raise ValueError('"{}", wrong type of value for "{}"'.format(val, skey)) - def to_json(self, lev=0, indent=None): + def to_json(self, indent=None): """ Serialize the content of this instance into a JSON string. - :param lev: :param indent: Number of spaces that should be used for indentation :return: """ - if lev: - return self.to_dict(lev + 1) - else: - return json.dumps(self.to_dict(1), indent=indent) + return json.dumps(self.to_dict(), indent=indent) def from_json(self, txt, **kwargs): """ @@ -443,18 +437,17 @@ def from_json(self, txt, **kwargs): _dict = json.loads(txt) return self.from_dict(_dict) - def to_jwt(self, key=None, algorithm="", lev=0, lifetime=0): + def to_jwt(self, key=None, algorithm="", lifetime=0): """ Create a signed JWT representation of the class instance :param key: The signing key :param algorithm: The signature algorithm to use - :param lev: :param lifetime: The lifetime of the JWS :return: A signed JWT """ - _jws = JWS(self.to_json(lev), alg=algorithm) + _jws = JWS(self.to_json(), alg=algorithm) return _jws.sign_compact(key) def _gather_keys(self, keyjar, jwt, header, **kwargs): @@ -763,7 +756,7 @@ def update(self, item, **kwargs): else: raise ValueError("Can't update message using: '%s'" % (item,)) - def to_jwe(self, keys, enc, alg, lev=0): + def to_jwe(self, keys, enc, alg): """ Place the information in this instance in a JSON object. Make that JSON object the body of a JWT. Then encrypt that JWT using the @@ -772,12 +765,11 @@ def to_jwe(self, keys, enc, alg, lev=0): :param keys: list or KeyJar instance :param enc: Content Encryption Algorithm :param alg: Key Management Algorithm - :param lev: Used for JSON construction :return: An encrypted JWT. If encryption failed an exception will be raised. """ - _jwe = JWE(self.to_json(lev), alg=alg, enc=enc) + _jwe = JWE(self.to_json(), alg=alg, enc=enc) return _jwe.encrypt(keys) def from_jwe(self, msg, keys): @@ -861,7 +853,7 @@ def add_non_standard(msg1, msg2): # ============================================================================= -def list_serializer(vals, sformat="urlencoded", lev=0): +def list_serializer(vals, sformat="urlencoded"): if isinstance(vals, str) and sformat == "dict": return [vals] @@ -887,7 +879,7 @@ def list_deserializer(val, sformat="urlencoded"): return val -def sp_sep_list_serializer(vals, sformat="urlencoded", lev=0): +def sp_sep_list_serializer(vals, sformat="urlencoded"): if isinstance(vals, str): return vals else: @@ -903,7 +895,7 @@ def sp_sep_list_deserializer(val, sformat="urlencoded"): return val -def json_serializer(obj, sformat="urlencoded", lev=0): +def json_serializer(obj, sformat="urlencoded"): return json.dumps(obj) @@ -921,7 +913,7 @@ def msg_deser(val, sformat="urlencoded"): return Message().deserialize(val, sformat) -def msg_ser(inst, sformat, lev=0): +def msg_ser(inst, sformat): if sformat in ["urlencoded", "json"]: if isinstance(inst, dict): if sformat == "json": @@ -929,12 +921,12 @@ def msg_ser(inst, sformat, lev=0): else: res = urlencode([(k, v) for k, v in inst.items()]) elif isinstance(inst, Message): - res = inst.serialize(sformat, lev) + res = inst.serialize(sformat) else: res = inst elif sformat == "dict": if isinstance(inst, Message): - res = inst.serialize(sformat, lev) + res = inst.serialize(sformat) elif isinstance(inst, dict): res = inst elif isinstance(inst, str): # Iff ID Token @@ -947,7 +939,7 @@ def msg_ser(inst, sformat, lev=0): return res -def msg_list_deser(val, sformat="urlencoded", lev=0): +def msg_list_deser(val, sformat="urlencoded"): if isinstance(val, dict): return [Message(**val)] @@ -957,7 +949,7 @@ def msg_list_deser(val, sformat="urlencoded", lev=0): return _res -def msg_list_ser(val, sformat="urlencoded", lev=0): +def msg_list_ser(val, sformat="urlencoded"): _res = [] for v in val: _res.append(msg_ser(v, sformat)) @@ -1004,11 +996,11 @@ def msg_list_ser(val, sformat="urlencoded", lev=0): OPTIONAL_LIST_OF_MESSAGES = ([Message], False, msg_list_ser, msg_list_deser, False) -def any_ser(val, sformat="urlencoded", lev=0): +def any_ser(val, sformat="urlencoded"): if isinstance(val, (str, int, bool)): return val elif isinstance(val, Message): - return msg_ser(val, sformat, lev) + return msg_ser(val, sformat) elif isinstance(val, dict): return json.dumps(val) elif isinstance(val, list): @@ -1017,7 +1009,7 @@ def any_ser(val, sformat="urlencoded", lev=0): raise ValueError("Can't serialize this type of data") -def any_deser(val, sformat="urlencoded", lev=0): +def any_deser(val, sformat="urlencoded"): if isinstance(val, dict): return Message(**val) elif isinstance(val, list): diff --git a/src/idpyoidc/message/oidc/__init__.py b/src/idpyoidc/message/oidc/__init__.py index 76a41588..67410091 100644 --- a/src/idpyoidc/message/oidc/__init__.py +++ b/src/idpyoidc/message/oidc/__init__.py @@ -80,11 +80,11 @@ def deserialize_from_one_of(val, msgtype, sformat): raise FormatError("Unexpected format") -def json_ser(val, sformat=None, lev=0): +def json_ser(val, sformat=None): return json.dumps(val) -def json_deser(val, sformat=None, lev=0): +def json_deser(val, sformat=None): return json.loads(val) @@ -108,14 +108,14 @@ def claims_deser(val, sformat="urlencoded"): return deserialize_from_one_of(val, Claims, sformat) -def msg_ser_json(inst, sformat="json", lev=0): +def msg_ser_json(inst, sformat="json"): # sformat = "json" always except when dict - if lev: - sformat = "dict" + # if lev: + # sformat = "dict" if sformat == "dict": if isinstance(inst, Message): - res = inst.serialize(sformat, lev) + res = inst.serialize(sformat) elif isinstance(inst, dict): res = inst else: @@ -125,18 +125,18 @@ def msg_ser_json(inst, sformat="json", lev=0): if isinstance(inst, dict): res = json.dumps(inst) elif isinstance(inst, Message): - res = inst.serialize(sformat, lev) + res = inst.serialize(sformat) else: res = inst return res -def msg_list_ser(insts, sformat, lev=0): - return [msg_ser(inst, sformat, lev) for inst in insts] +def msg_list_ser(insts, sformat): + return [msg_ser(inst, sformat) for inst in insts] -def claims_ser(val, sformat="urlencoded", lev=0): +def claims_ser(val, sformat="urlencoded"): # everything in c_extension if isinstance(val, str): item = val @@ -146,15 +146,12 @@ def claims_ser(val, sformat="urlencoded", lev=0): item = val if isinstance(item, Message): - return item.serialize(method=sformat, lev=lev + 1) + return item.serialize(method=sformat) if sformat == "urlencoded": res = urlencode(item) elif sformat == "json": - if lev: - res = item - else: - res = json.dumps(item) + res = json.dumps(item) elif sformat == "dict": if isinstance(item, dict): res = item @@ -771,9 +768,9 @@ def pack(self, alg="", **kwargs): else: self.pack_init() - def to_jwt(self, key=None, algorithm="", lev=0, lifetime=0): + def to_jwt(self, key=None, algorithm="", lifetime=0): self.pack(alg=algorithm, lifetime=lifetime) - return Message.to_jwt(self, key=key, algorithm=algorithm, lev=lev) + return Message.to_jwt(self, key=key, algorithm=algorithm) def verify(self, **kwargs): super(IdToken, self).verify(**kwargs) @@ -1090,7 +1087,7 @@ def link_deser(val, sformat="urlencoded"): return _l_deser(val, sformat) -def link_ser(inst, sformat, lev=0): +def link_ser(inst, sformat): if sformat in ["urlencoded", "json"]: if isinstance(inst, dict): if sformat == "json": @@ -1098,12 +1095,12 @@ def link_ser(inst, sformat, lev=0): else: res = urlencode([(k, v) for k, v in inst.items()]) elif isinstance(inst, Link): - res = inst.serialize(sformat, lev) + res = inst.serialize(sformat) else: res = inst elif sformat == "dict": if isinstance(inst, Link): - res = inst.serialize(sformat, lev) + res = inst.serialize(sformat) elif isinstance(inst, dict): res = inst elif isinstance(inst, str): # Iff ID Token @@ -1116,7 +1113,7 @@ def link_ser(inst, sformat, lev=0): return res -def link_list_ser(inst, sformat, lev=0): +def link_list_ser(inst, sformat): if isinstance(inst, list): return [link_ser(v, sformat) for v in inst] else: diff --git a/tests/test_04_message.py b/tests/test_04_message.py index 49b6bb0d..7fe18785 100644 --- a/tests/test_04_message.py +++ b/tests/test_04_message.py @@ -336,7 +336,7 @@ def test_to_jwe(keytype, alg, enc): def test_to_dict_with_message_obj(): content = Message(a={"a": {"foo": {"bar": [{"bat": []}]}}}) - _dict = content.to_dict(lev=0) + _dict = content.to_dict() content_fixture = {"a": {"a": {"foo": {"bar": [{"bat": []}]}}}} assert _dict == content_fixture @@ -344,7 +344,7 @@ def test_to_dict_with_message_obj(): def test_to_dict_with_raw_types(): msg = Message(c_default=[]) content_fixture = {"c_default": []} - _dict = msg.to_dict(lev=1) + _dict = msg.to_dict() assert _dict == content_fixture From 079d3d9cebc66e8abc919abac6d39882ff35aef4 Mon Sep 17 00:00:00 2001 From: roland Date: Fri, 27 Jan 2023 10:38:49 +0100 Subject: [PATCH 09/76] Made choice of message class for JWTToken configurable. --- src/idpyoidc/server/token/jwt_token.py | 63 ++++++--- tests/test_server_24_oauth2_token_endpoint.py | 129 +++++++++++++++++- 2 files changed, 170 insertions(+), 22 deletions(-) diff --git a/src/idpyoidc/server/token/jwt_token.py b/src/idpyoidc/server/token/jwt_token.py index e08eeff2..e552115b 100644 --- a/src/idpyoidc/server/token/jwt_token.py +++ b/src/idpyoidc/server/token/jwt_token.py @@ -1,33 +1,36 @@ from typing import Callable from typing import Optional +from typing import Union from cryptojwt import JWT from cryptojwt.jws.exception import JWSException +from cryptojwt.utils import importer -from idpyoidc.encrypter import init_encrypter from idpyoidc.server.exception import ToOld - -from ..constant import DEFAULT_TOKEN_LIFETIME -from . import Token from . import is_expired +from . import Token from .exception import UnknownToken from .exception import WrongTokenClass +from ..constant import DEFAULT_TOKEN_LIFETIME from ...message import Message from ...message.oauth2 import JWTAccessToken class JWTToken(Token): + def __init__( - self, - token_class, - # keyjar: KeyJar = None, - issuer: str = None, - aud: Optional[list] = None, - alg: str = "ES256", - lifetime: int = DEFAULT_TOKEN_LIFETIME, - server_get: Callable = None, - token_type: str = "Bearer", - **kwargs + self, + token_class, + # keyjar: KeyJar = None, + issuer: str = None, + aud: Optional[list] = None, + alg: str = "ES256", + lifetime: int = DEFAULT_TOKEN_LIFETIME, + server_get: Callable = None, + token_type: str = "Bearer", + profile: Optional[Union[Message, str]] = JWTAccessToken, + with_jti: Optional[bool] = False, + **kwargs ): Token.__init__(self, token_class, **kwargs) self.token_type = token_type @@ -42,17 +45,27 @@ def __init__( self.def_aud = aud or [] self.alg = alg + if isinstance(profile, str): + self.profile = importer(profile) + else: + self.profile = profile + self.with_jti = with_jti + + if self.with_jti is False and profile == JWTAccessToken: + self.with_jti = True def load_custom_claims(self, payload: dict = None): # inherit me and do your things here return payload def __call__( - self, - session_id: Optional[str] = "", - token_class: Optional[str] = "", - usage_rules: Optional[dict] = None, - **payload + self, + session_id: Optional[str] = "", + token_class: Optional[str] = "", + usage_rules: Optional[dict] = None, + profile: Optional[Message] = None, + with_jti: Optional[bool] = None, + **payload ) -> str: """ Return a token. @@ -83,10 +96,18 @@ def __call__( lifetime=lifetime, sign_alg=self.alg, ) - if isinstance(payload, Message): # don't mess with it. + if isinstance(payload, Message): # don't mess with it. pass else: - payload = JWTAccessToken(**payload).to_dict() + if profile: + payload = profile(**payload).to_dict() + elif self.profile: + payload = self.profile(**payload).to_dict() + + if with_jti: + signer.with_jti = True + elif with_jti is None: + signer.with_jti = self.with_jti return signer.pack(payload) diff --git a/tests/test_server_24_oauth2_token_endpoint.py b/tests/test_server_24_oauth2_token_endpoint.py index 25c479f4..68b63345 100644 --- a/tests/test_server_24_oauth2_token_endpoint.py +++ b/tests/test_server_24_oauth2_token_endpoint.py @@ -3,9 +3,20 @@ import pytest from cryptojwt import JWT +from cryptojwt import KeyJar +from cryptojwt.jws.jws import factory from cryptojwt.key_jar import build_keyjar +from idpyoidc.message import REQUIRED_LIST_OF_STRINGS +from idpyoidc.message import SINGLE_REQUIRED_INT + +from idpyoidc.message import SINGLE_REQUIRED_STRING + +from idpyoidc.message import Message + +from idpyoidc.context import OidcContext from idpyoidc.defaults import JWT_BEARER +from idpyoidc.message.oauth2 import JWTAccessToken from idpyoidc.message.oidc import AccessTokenRequest from idpyoidc.message.oidc import AuthorizationRequest from idpyoidc.message.oidc import RefreshAccessTokenRequest @@ -18,7 +29,7 @@ from idpyoidc.server.exception import InvalidToken from idpyoidc.server.oauth2.authorization import Authorization from idpyoidc.server.oauth2.token import Token -from idpyoidc.server.session import MintingNotAllowed +from idpyoidc.server.token import handler from idpyoidc.server.user_authn.authn_context import INTERNETPROTOCOLPASSWORD from idpyoidc.server.user_info import UserInfo from idpyoidc.time_util import utc_time_sans_frac @@ -162,6 +173,7 @@ def conf(): class TestEndpoint(object): + @pytest.fixture(autouse=True) def create_endpoint(self, conf): server = Server(ASConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) @@ -777,3 +789,118 @@ def test_refresh_token_request_other_client(self): ) assert isinstance(_resp, TokenErrorResponse) assert _resp.to_dict() == {"error": "invalid_grant", "error_description": "Wrong client"} + + +DEFAULT_TOKEN_HANDLER_ARGS = { + "jwks_file": "private/token_jwks.json", + "code": {"lifetime": 600, "kwargs": {"crypt_conf": CRYPT_CONFIG}}, + "token": { + "class": "idpyoidc.server.token.jwt_token.JWTToken", + "kwargs": { + "lifetime": 3600, + "add_claims_by_scope": True, + "aud": ["https://example.org/appl"] + }, + }, + "refresh": { + "class": "idpyoidc.server.token.jwt_token.JWTToken", + "kwargs": { + "lifetime": 3600, + "aud": ["https://example.org/appl"], + }, + }, +} +TOKEN_HANDLER_ARGS = { + "jwks_file": "private/token_jwks.json", + "code": {"lifetime": 600, "kwargs": {"crypt_conf": CRYPT_CONFIG}}, + "token": { + "class": "idpyoidc.server.token.jwt_token.JWTToken", + "kwargs": { + "lifetime": 3600, + "add_claims_by_scope": True, + "aud": ["https://example.org/appl"], + "profile": 'idpyoidc.message.oauth2.JWTAccessToken', + "with_jti": True + }, + }, + "refresh": { + "class": "idpyoidc.server.token.jwt_token.JWTToken", + "kwargs": { + "lifetime": 3600, + "aud": ["https://example.org/appl"], + }, + }, +} + +CONTEXT = OidcContext() +CONTEXT.cwd = BASEDIR +CONTEXT.issuer = "https://op.example.com" +CONTEXT.cdb = { + "client_1": {} +} +CONTEXT.keyjar = KeyJar() +CONTEXT.keyjar.import_jwks(CLIENT_KEYJAR.export_jwks(private=True), "client_1") +CONTEXT.keyjar.import_jwks(CLIENT_KEYJAR.export_jwks(private=True), "") + +def server_get(what, *args): + if what == "endpoint_context": + if not args: + return CONTEXT + +def test_def_jwttoken(): + _handler = handler.factory(server_get=server_get, **DEFAULT_TOKEN_HANDLER_ARGS) + token_handler = _handler['access_token'] + token_payload = { + 'sub': 'subject_id', + 'aud': 'resource_1', + 'client_id': 'client_1' + } + value = token_handler(session_id='session_id', **token_payload) + + _jws = factory(value) + msg = JWTAccessToken(**_jws.jwt.payload()) + # test if all required claims are there + msg.verify() + assert True + +def test_jwttoken(): + _handler = handler.factory(server_get=server_get, **TOKEN_HANDLER_ARGS) + token_handler = _handler['access_token'] + token_payload = { + 'sub': 'subject_id', + 'aud': 'resource_1', + 'client_id': 'client_1' + } + value = token_handler(session_id='session_id', **token_payload) + + _jws = factory(value) + msg = JWTAccessToken(**_jws.jwt.payload()) + # test if all required claims are there + msg.verify() + assert True + +class MyAccessToken(Message): + c_param = { + "iss": SINGLE_REQUIRED_STRING, + "exp": SINGLE_REQUIRED_INT, + "aud": REQUIRED_LIST_OF_STRINGS, + "sub": SINGLE_REQUIRED_STRING, + "iat": SINGLE_REQUIRED_INT, + 'usage': SINGLE_REQUIRED_STRING + } + +def test_jwttoken_2(): + _handler = handler.factory(server_get=server_get, **TOKEN_HANDLER_ARGS) + token_handler = _handler['access_token'] + token_payload = { + 'sub': 'subject_id', + 'aud': 'Skiresort', + 'usage': 'skilift' + } + value = token_handler(session_id='session_id', profile=MyAccessToken, **token_payload) + + _jws = factory(value) + msg = MyAccessToken(**_jws.jwt.payload()) + # test if all required claims are there + msg.verify() + assert True \ No newline at end of file From f095dc36cc79b8c802d78b5fc1ac32eed51d4f56 Mon Sep 17 00:00:00 2001 From: Kostis Triantafyllakis Date: Thu, 2 Feb 2023 17:02:16 +0200 Subject: [PATCH 10/76] Enforce aud restrictions Signed-off-by: Kostis Triantafyllakis --- src/idpyoidc/server/oauth2/introspection.py | 7 +++++++ tests/test_server_31_oauth2_introspection.py | 17 +++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/src/idpyoidc/server/oauth2/introspection.py b/src/idpyoidc/server/oauth2/introspection.py index 11b29cca..ce31961c 100644 --- a/src/idpyoidc/server/oauth2/introspection.py +++ b/src/idpyoidc/server/oauth2/introspection.py @@ -110,6 +110,13 @@ def process_request(self, request=None, release: Optional[list] = None, **kwargs grant = _session_info["grant"] _token = grant.get_token(request_token) + aud = _token.resources + if not aud: + aud = grant.resources + + if request["client_id"] not in aud: + return {"response_args": _resp} + _info = self._introspect(_token, _session_info["client_id"], _session_info["grant"]) if _info is None: return {"response_args": _resp} diff --git a/tests/test_server_31_oauth2_introspection.py b/tests/test_server_31_oauth2_introspection.py index 04748917..055caebe 100644 --- a/tests/test_server_31_oauth2_introspection.py +++ b/tests/test_server_31_oauth2_introspection.py @@ -494,6 +494,23 @@ def test_revoked_access_token(self): _resp = self.introspection_endpoint.process_request(_req) assert _resp["response_args"]["active"] is False + def test_wrong_aud(self): + auth_req = AUTH_REQ.copy() + auth_req["client_id"] = "client_2" + access_token = self._get_access_token(auth_req) + + _context = self.introspection_endpoint.server_get("endpoint_context") + + _req = self.introspection_endpoint.parse_request( + { + "token": access_token.value, + "client_id": "client_1", + "client_secret": _context.cdb["client_1"]["client_secret"], + } + ) + _resp = self.introspection_endpoint.process_request(_req) + assert _resp["response_args"]["active"] is False + def test_introspect_id_token(self): session_id = self._create_session(AUTH_REQ) grant = self.token_endpoint.server_get("endpoint_context").authz(session_id, AUTH_REQ) From 3710cf21a47a9bfab11cca1984a6f8f034fd7dfb Mon Sep 17 00:00:00 2001 From: roland Date: Fri, 11 Nov 2022 09:09:39 +0100 Subject: [PATCH 11/76] Spring(?)/Autumn cleaning. --- .../actor/client/oidc/registration.py | 4 +- src/idpyoidc/client/client_auth.py | 2 +- src/idpyoidc/client/entity.py | 152 ++++++++---------- src/idpyoidc/client/oauth2/access_token.py | 2 +- src/idpyoidc/client/oauth2/server_metadata.py | 2 +- src/idpyoidc/client/oidc/end_session.py | 2 +- .../client/oidc/provider_info_discovery.py | 2 +- src/idpyoidc/client/oidc/registration.py | 2 +- src/idpyoidc/client/oidc/userinfo.py | 8 +- src/idpyoidc/client/service.py | 44 ++--- src/idpyoidc/client/service_context.py | 8 +- src/idpyoidc/client/specification/__init__.py | 26 +-- src/idpyoidc/client/specification/oidc.py | 29 ++-- src/idpyoidc/message/oauth2/__init__.py | 2 +- src/idpyoidc/server/oauth2/token.py | 3 +- tests/static/jwks.json | 2 +- tests/test_client_01_service_context.py | 2 +- tests/test_client_02_entity.py | 2 +- tests/test_client_02b_entity_metadata.py | 14 +- 19 files changed, 152 insertions(+), 156 deletions(-) diff --git a/src/idpyoidc/actor/client/oidc/registration.py b/src/idpyoidc/actor/client/oidc/registration.py index 3e83cdc2..0de174cb 100644 --- a/src/idpyoidc/actor/client/oidc/registration.py +++ b/src/idpyoidc/actor/client/oidc/registration.py @@ -103,9 +103,9 @@ def _cmp(a, b): return a == b -def check(entity, attribute, expected): +def check(entity, claim, expected): try: - _usable = entity.get_metadata_attribute(attribute) + _usable = entity.get_metadata_claim(claim) except KeyError: pass else: diff --git a/src/idpyoidc/client/client_auth.py b/src/idpyoidc/client/client_auth.py index 80ecaf25..240efd65 100755 --- a/src/idpyoidc/client/client_auth.py +++ b/src/idpyoidc/client/client_auth.py @@ -470,7 +470,7 @@ def _get_audience_and_algorithm(self, context, entity, **kwargs): if _alg : algorithm = _alg else: - algorithm = entity.get_metadata_value("token_endpoint_auth_signing_alg") + algorithm = entity.get_metadata_claim("token_endpoint_auth_signing_alg") if algorithm is None: _pi = context.provider_info try: diff --git a/src/idpyoidc/client/entity.py b/src/idpyoidc/client/entity.py index 1783ee3c..771884d9 100644 --- a/src/idpyoidc/client/entity.py +++ b/src/idpyoidc/client/entity.py @@ -62,18 +62,18 @@ def set_jwks_uri_or_jwks(entity, service_context, config, jwks_uri, keyjar): keys_args = {k: v for k, v in config.get("key_conf").items() if k != "uri_path"} _keyjar = init_key_jar(**keys_args) entity.set_usage_value("jwks", True) - entity.set_metadata_value("jwks", _keyjar.export_jwks()) + entity.set_metadata_claim("jwks", _keyjar.export_jwks()) return elif keyjar: entity.set_usage_value("jwks", True) - entity.set_metadata_value("jwks", keyjar.export_jwks()) + entity.set_metadata_claim("jwks", keyjar.export_jwks()) return for attr in ["jwks_uri", "jwks"]: if entity.will_use(attr): _val = getattr(service_context, attr) if _val: - entity.set_metadata_value(attr, _val) + entity.set_metadata_claim(attr, _val) return @@ -85,7 +85,7 @@ def __init__( services: Optional[dict] = None, jwks_uri: Optional[str] = "", httpc_params: Optional[dict] = None, - client_type: Optional[str] = "" + client_type: Optional[str] = "oauth2" ): self.extra = {} if httpc_params: @@ -124,7 +124,7 @@ def __init__( self.setup_client_authn_methods(config) - jwks_uri = jwks_uri or self.get_metadata_value("jwks_uri") + jwks_uri = jwks_uri or self.get_metadata_claim("jwks_uri") set_jwks_uri_or_jwks(self, self._service_context, config, jwks_uri, _kj) # Deal with backward compatibility @@ -174,7 +174,7 @@ def collect_metadata(self): res = {} for service in self._service.values(): res.update(service.metadata) - res.update(self._service_context.specs.get_all()) + res.update(self._service_context.specs.get_metadata()) return res def collect_usage(self): @@ -184,133 +184,125 @@ def collect_usage(self): res.update(self._service_context.specs.usage) return res - def get_metadata_value(self, attribute, default=None): + def get_metadata_claim(self, claim, default=None): for service in self._service.values(): - if attribute in service.metadata_attributes: - return service.get_metadata(attribute, default) + if claim in service.metadata_claims: + return service.get_metadata(claim, default) - if attribute in self._service_context.specs.attributes: - return self._service_context.specs.get_metadata(attribute, default) + if claim in self._service_context.specs.attributes: + return self._service_context.specs.get_metadata_claim(claim, default) - raise KeyError(f"Unknown specs attribute: {attribute}") + raise KeyError(f"Unknown specs claim: {claim}") - def get_metadata_attributes(self): - attr = [] + def get_metadata_claims(self): + claims = [] for service in self._service.values(): - attr.extend(list(service.metadata_attributes.keys())) + claims.extend(list(service.metadata_claims.keys())) - attr.extend(list(self._service_context.specs.attributes.keys())) + claims.extend(list(self._service_context.specs.metadata.keys())) - return attr + return claims - def value_in_metadata_attribute(self, attribute, value): - for service in self._service.values(): - if attribute in service.metadata_attributes.keys(): - _val = service.get_metadata(attribute) - if isinstance(_val, list): - if value in _val: - return True - else: - if value == _val: - return True - - if attribute in self._service_context.specs.attributes.keys(): - _val = self._service_context.specs.get_metadata(attribute) - if isinstance(_val, list): - if value in _val: - return True - else: - if value == _val: - return True + def metadata_claim_contains_value(self, claim, value): + _val = self.get_metadata_claim(claim) + if isinstance(_val, list): + if value in _val: + return True + else: + if value == _val: + return True return False - def will_use(self, attribute): + def will_use(self, claim): for service in self._service.values(): - if attribute in service.usage_rules.keys(): - if service.usage.get(attribute): + if claim in service.usage_rules.keys(): + if service.usage.get(claim): return True - if attribute in self._service_context.specs.rules.keys(): - if self._service_context.specs.get_usage(attribute): + if claim in self._service_context.specs.rules.keys(): + if self._service_context.specs.get_usage(claim): return True return False - def set_metadata_value(self, attribute, value): + def set_metadata_claim(self, claim, value): """ Only OK to overwrite a value if the value is the default value """ for service in self._service.values(): - if attribute in service.metadata_attributes: - _def_val = service.metadata_attributes[attribute] + if claim in service.metadata_claims: + _def_val = service.metadata_claims[claim] if _def_val is None: - service.metadata[attribute] = value + service.metadata[claim] = value return True else: - if service.metadata.get(attribute, _def_val) == _def_val: - service.metadata[attribute] = value + if service.metadata.get(claim, _def_val) == _def_val: + service.metadata[claim] = value return True - if attribute in self._service_context.specs.attributes: - _def_val = self._service_context.specs.attributes[attribute] + if claim in self._service_context.specs.attributes: + _def_val = self._service_context.specs.attributes[claim] if _def_val is None: - self._service_context.specs.set_metadata(attribute, value) + self._service_context.specs.set_metadata_claim(claim, value) return True else: - if self._service_context.specs.get_metadata(attribute, _def_val): - self._service_context.specs.set_metadata(attribute, value) + if self._service_context.specs.get_metadata_claim(claim, _def_val): + self._service_context.specs.set_metadata_claim(claim, value) return True return True - logger.info(f"Unknown set specs attribute: {attribute}") + logger.info(f"Unknown set specs claim: {claim}") return False - def set_usage_value(self, attribute, value): + def set_usage_value(self, claim, value): """ Only OK to overwrite a value if the value is the default value """ for service in self._service.values(): - if attribute in service.usage_rules: - _def_val = service.usage_rules[attribute] + if claim in service.usage_rules: + _def_val = service.usage_rules[claim] if _def_val is None: - service.usage[attribute] = value + service.usage[claim] = value return True else: - if service.usage[attribute] == _def_val: - service.usage[attribute] = value + if service.usage[claim] == _def_val: + service.usage[claim] = value return True - if attribute in self._service_context.specs.rules: - _def_val = self._service_context.specs.rules[attribute] + if claim in self._service_context.specs.rules: + _def_val = self._service_context.specs.rules[claim] if _def_val is None: - self._service_context.specs.set_usage(attribute, value) + self._service_context.specs.set_usage(claim, value) return True else: - if self._service_context.specs.usage[attribute] == _def_val: - self._service_context.specs.set_usage(attribute, value) + if self._service_context.specs.usage[claim] == _def_val: + self._service_context.specs.set_usage(claim, value) return True - logger.info(f"Unknown set usage attribute: {attribute}") + logger.info(f"Unknown set usage claim: {claim}") return False - def get_usage_value(self, attribute, default=None): + def get_usage_value(self, claim, default=None): for service in self._service.values(): - if attribute in service.usage_rules: - if attribute in service.usage: - return service.usage[attribute] + if claim in service.usage_rules: + if claim in service.usage: + return service.usage[claim] else: return default - if attribute in self._service_context.specs.rules: - _val = self._service_context.specs.get_usage(attribute) + if claim in self._service_context.specs.rules: + _val = self._service_context.specs.get_usage(claim) if _val: return _val else: return default - logger.info(f"Unknown usage attribute: {attribute}") + logger.info(f"Unknown usage claim: {claim}") - def construct_uris(self, issuer, hash_seed, callback): + def construct_uris(self, + issuer: str, + hash_seed: bytes, + callback: Optional[dict]): _hash = hashlib.sha256() _hash.update(hash_seed) _hash.update(as_bytes(issuer)) @@ -322,15 +314,13 @@ def construct_uris(self, issuer, hash_seed, callback): for service in self._service.values(): service.construct_uris(_base_url, _hex) - if not self._service_context.specs.get_metadata("redirect_uris"): + if not self._service_context.specs.get_metadata_claim("redirect_uris"): self._service_context.specs.construct_redirect_uris(_base_url, _hex, callback) - self._service_context.specs.construct_uris(_base_url, _hex) - def backward_compatibility(self, config): _uris = config.get("redirect_uris") if _uris: - self.set_metadata_value("redirect_uris", _uris) + self.set_metadata_claim("redirect_uris", _uris) _dir = config.conf.get("requests_dir") if _dir: @@ -343,7 +333,7 @@ def backward_compatibility(self, config): _pref = config.get("client_preferences", {}) for key, val in _pref.items(): - if self.set_metadata_value(key, val) is False: + if self.set_metadata_claim(key, val) is False: if self.set_usage_value(key, val) is False: setattr(self, key, val) @@ -361,12 +351,12 @@ def config_args(self): res = {} for id, service in self._service.items(): res[id] = { - "metadata": service.metadata_attributes, + "metadata": service.metadata_claims, "usage": service.usage_rules } res[""] = { - "metadata": self._service_context.specs.attributes, - "usage": self._service_context.specs.rules + "metadata": self._service_context.specs.metadata, + "usage": self._service_context.specs.usage } return res diff --git a/src/idpyoidc/client/oauth2/access_token.py b/src/idpyoidc/client/oauth2/access_token.py index 6979f4a1..666374a2 100644 --- a/src/idpyoidc/client/oauth2/access_token.py +++ b/src/idpyoidc/client/oauth2/access_token.py @@ -24,7 +24,7 @@ class AccessToken(Service): request_body_type = "urlencoded" response_body_type = "json" - metadata_attributes = { + metadata_claims = { "token_endpoint_auth_method": "client_secret_basic", "token_endpoint_auth_signing_alg": "RS256" } diff --git a/src/idpyoidc/client/oauth2/server_metadata.py b/src/idpyoidc/client/oauth2/server_metadata.py index bf32700e..857c7075 100644 --- a/src/idpyoidc/client/oauth2/server_metadata.py +++ b/src/idpyoidc/client/oauth2/server_metadata.py @@ -22,7 +22,7 @@ class ServerMetadata(Service): service_name = "server_metadata" http_method = "GET" - metadata_attributes = {} + metadata_claims = {} def __init__(self, client_get, conf=None): Service.__init__(self, client_get, conf=conf) diff --git a/src/idpyoidc/client/oidc/end_session.py b/src/idpyoidc/client/oidc/end_session.py index 59bf4ee6..2e8d0778 100644 --- a/src/idpyoidc/client/oidc/end_session.py +++ b/src/idpyoidc/client/oidc/end_session.py @@ -20,7 +20,7 @@ class EndSession(Service): service_name = "end_session" response_body_type = "html" - metadata_attributes = { + metadata_claims = { "post_logout_redirect_uris": None, "frontchannel_logout_uri": None, "frontchannel_logout_session_required": None, diff --git a/src/idpyoidc/client/oidc/provider_info_discovery.py b/src/idpyoidc/client/oidc/provider_info_discovery.py index 6caa4370..1b53296d 100644 --- a/src/idpyoidc/client/oidc/provider_info_discovery.py +++ b/src/idpyoidc/client/oidc/provider_info_discovery.py @@ -67,7 +67,7 @@ class ProviderInfoDiscovery(server_metadata.ServerMetadata): error_msg = ResponseMessage service_name = "provider_info" - metadata_attributes = {} + metadata_claims = {} def __init__(self, client_get, conf=None): server_metadata.ServerMetadata.__init__(self, client_get, conf=conf) diff --git a/src/idpyoidc/client/oidc/registration.py b/src/idpyoidc/client/oidc/registration.py index a8c384f6..ba9cc702 100644 --- a/src/idpyoidc/client/oidc/registration.py +++ b/src/idpyoidc/client/oidc/registration.py @@ -40,7 +40,7 @@ def add_client_behaviour_preference(self, request_args=None, **kwargs): try: request_args[prop] = _context.specs.behaviour[prop] except KeyError: - _val = _context.specs.get_metadata(prop) + _val = _context.specs.get_metadata_claim(prop) if _val: request_args[prop] = _val return request_args, {} diff --git a/src/idpyoidc/client/oidc/userinfo.py b/src/idpyoidc/client/oidc/userinfo.py index de18be9d..7a6f3f71 100644 --- a/src/idpyoidc/client/oidc/userinfo.py +++ b/src/idpyoidc/client/oidc/userinfo.py @@ -38,13 +38,7 @@ class UserInfo(Service): default_authn_method = "bearer_header" http_method = "GET" - metadata_attributes = { - "userinfo_signed_response_alg": "", - "userinfo_encrypted_response_alg": "", - "userinfo_encrypted_response_enc": "" - } - - metadata_attributes = { + metadata_claims = { "userinfo_signed_response_alg": None, "userinfo_encrypted_response_alg": None, "userinfo_encrypted_response_enc": None diff --git a/src/idpyoidc/client/service.py b/src/idpyoidc/client/service.py index d024793c..de143b68 100644 --- a/src/idpyoidc/client/service.py +++ b/src/idpyoidc/client/service.py @@ -13,8 +13,8 @@ from idpyoidc.impexp import ImpExp from idpyoidc.item import DLDict from idpyoidc.message import Message -from idpyoidc.message.oauth2 import ResponseMessage from idpyoidc.message.oauth2 import is_error_message +from idpyoidc.message.oauth2 import ResponseMessage from idpyoidc.util import importer from .configure import Configuration from .exception import ResponseError @@ -63,7 +63,7 @@ class Service(ImpExp): init_args = ["client_get"] - metadata_attributes = {} + metadata_claims = {} usage_rules = {} usage_to_uri_map = {} callback_path = {} @@ -99,7 +99,7 @@ def __init__( md_conf = conf.get("metadata", {}) if md_conf: - for param, def_val in self.metadata_attributes.items(): + for param, def_val in self.metadata_claims.items(): if param in md_conf: self.metadata[param] = md_conf[param] elif def_val is not None: @@ -587,19 +587,22 @@ def parse_response( LOGGER.debug("response format: %s", sformat) - if sformat in ["jose", "jws", "jwe"]: - resp = self.post_parse_response(info, state=state) - - if not resp: - LOGGER.error("Missing or faulty response") - raise ResponseError("Missing or faulty response") - - return resp + resp = None + if sformat == "jose": + try: + self._do_jwt(info) + sformat = "dict" + except Exception: + _context = self.client_get("service_context") + resp = self.response_cls().from_jwe(info, keys=_context.keyjar) + elif sformat == "jwe": + _context = self.client_get("service_context") + resp = self.response_cls().from_jwe(info, keys=_context.keyjar) # If format is urlencoded 'info' may be a URL # in which case I have to get at the query/fragment part elif sformat == "urlencoded": info = self.get_urlinfo(info) - elif sformat == "jwt": + elif sformat in ["jwt", "jws"]: info = self._do_jwt(info) sformat = "dict" elif sformat == "json": @@ -608,7 +611,12 @@ def parse_response( LOGGER.debug("response_cls: %s", self.response_cls.__name__) - resp = self._do_response(info, sformat, **kwargs) + if resp is None: + if not info: + LOGGER.error("Missing or faulty response") + raise ResponseError("Missing or faulty response") + + resp = self._do_response(info, sformat, **kwargs) LOGGER.debug('Initial response parsing => "%s"', resp.to_dict()) @@ -635,7 +643,7 @@ def parse_response( def get_conf_attr(self, attr, default=None): """ - Get the value of a attribute in the configuration + Get the value of an attribute in the configuration :param attr: The attribute :param default: If the attribute doesn't appear in the configuration @@ -666,13 +674,13 @@ def construct_uris(self, base_url, hex): self.metadata[uri] = self.get_uri(base_url, self.callback_path[uri], hex) - def get_metadata(self, attribute, default=None): + def get_metadata_claim(self, claim, default=None): try: - return self.metadata[attribute] + return self.metadata[claim] except KeyError: return default - def set_metadata(self, key, value): + def set_metadata_claim(self, key, value): self.metadata[key] = value @@ -703,7 +711,7 @@ def init_services(service_definitions, client_get, metadata, usage): _srv = service_configuration["class"](**kwargs) for key, val in metadata.items(): - if key in _srv.metadata_attributes and key not in _srv.metadata: + if key in _srv.metadata_claims and key not in _srv.metadata: _srv.metadata[key] = val for key, val in usage.items(): diff --git a/src/idpyoidc/client/service_context.py b/src/idpyoidc/client/service_context.py index e52a7504..82aee548 100644 --- a/src/idpyoidc/client/service_context.py +++ b/src/idpyoidc/client/service_context.py @@ -116,10 +116,11 @@ def __init__(self, keyjar: Optional[KeyJar] = None, config: Optional[Union[dict, Configuration]] = None, state: Optional[StateInterface] = None, - client_type: Optional[str] = None, + client_type: Optional[str] = 'oauth2', **kwargs): config = get_configuration(config) self.config = config + if not client_type or client_type == "oidc": self.specs = OIDC_Specs() elif client_type == "oauth2": @@ -153,8 +154,7 @@ def __init__(self, for param in [ "client_secret", - "provider_info", - "behaviour" + "provider_info" ]: _val = config.conf.get(param, _def_value[param]) self.set(param, _val) @@ -270,4 +270,4 @@ def set(self, key, value): setattr(self, key, value) def get_client_id(self): - return self.specs.get_metadata("client_id") + return self.specs.get_metadata_claim("client_id") diff --git a/src/idpyoidc/client/specification/__init__.py b/src/idpyoidc/client/specification/__init__.py index 5a413883..27e0eb63 100644 --- a/src/idpyoidc/client/specification/__init__.py +++ b/src/idpyoidc/client/specification/__init__.py @@ -82,10 +82,10 @@ def __init__(self, self.callback = {} self._local = {} - def get_all(self): + def get_metadata(self): return self.metadata - def get_metadata(self, key, default=None): + def get_metadata_claim(self, key, default=None): if key in self.metadata: return self.metadata[key] else: @@ -97,7 +97,7 @@ def get_usage(self, key, default=None): else: return default - def set_metadata(self, key, value): + def set_metadata_claim(self, key, value): self.metadata[key] = value def set_usage(self, key, value): @@ -105,7 +105,7 @@ def set_usage(self, key, value): def _callback_uris(self, base_url, hex): _red = {} - for type in self.get_metadata("response_types", ["code"]): + for type in self.get_metadata_claim("response_types", ["code"]): if "code" in type: _red['code'] = True elif type in ["id_token", "id_token token"]: @@ -120,12 +120,15 @@ def _callback_uris(self, base_url, hex): callback_uri[key] = _uri return callback_uri - def construct_redirect_uris(self, base_url, hex, callbacks): + def construct_redirect_uris(self, + base_url: str, + hex: str, + callbacks: Optional[dict] = None): if not callbacks: callbacks = self._callback_uris(base_url, hex) if callbacks: - self.set_metadata("redirect_uris", [v for k, v in callbacks.items()]) + self.set_metadata_claim("redirect_uris", [v for k, v in callbacks.items()]) self.callback = callbacks @@ -144,18 +147,18 @@ def load_conf(self, info): elif attr == "metadata": for k, v in val.items(): if k in self.attributes: - self.set_metadata(k, v) + self.set_metadata_claim(k, v) elif attr == "behaviour": self.behaviour = val elif attr in self.attributes: - self.set_metadata(attr, val) + self.set_metadata_claim(attr, val) elif attr in self.rules: self.set_usage(attr, val) - # defaults is nothing else is given + # defaults if nothing else is given for key, val in self.attributes.items(): if val and key not in self.metadata: - self.set_metadata(key, val) + self.set_metadata_claim(key, val) for key, val in self.rules.items(): if val and key not in self.usage: @@ -163,6 +166,7 @@ def load_conf(self, info): self.locals(info) self.verify_rules() + return self def bm_get(self, key, default=None): if key in self.behaviour: @@ -182,4 +186,4 @@ def set(self, key, val): self._local[key] = val def construct_uris(self, *args): - pass \ No newline at end of file + pass diff --git a/src/idpyoidc/client/specification/oidc.py b/src/idpyoidc/client/specification/oidc.py index 38bef174..dc84aaf9 100644 --- a/src/idpyoidc/client/specification/oidc.py +++ b/src/idpyoidc/client/specification/oidc.py @@ -1,12 +1,11 @@ import os from typing import Optional -from idpyoidc.client import specification as sp -from idpyoidc.client.service import Service +from idpyoidc.client import specification -class Specification(sp.Specification): - parameter = sp.Specification.parameter.copy() +class Specification(specification.Specification): + parameter = specification.Specification.parameter.copy() parameter.update({ "requests_dir": None }) @@ -64,17 +63,18 @@ def __init__(self, usage: Optional[dict] = None, behaviour: Optional[dict] = None, ): - sp.Specification.__init__(self, metadata=metadata, usage=usage, behaviour=behaviour) + specification.Specification.__init__(self, metadata=metadata, usage=usage, + behaviour=behaviour) - def construct_uris(self, base_url, hex): - if "request_uri" in self.usage: - if self.usage["request_uri"]: - _dir = self.get("requests_dir") - if _dir: - self.set_metadata("request_uris", Service.get_uri(base_url, _dir, hex)) - else: - self.set_metadata("request_uris", - Service.get_uri(base_url, self.callback_path["requests"], hex)) + # def construct_uris(self, base_url, hex): + # if "request_uri" in self.usage: + # if self.usage["request_uri"]: + # _dir = self.get("requests_dir") + # if _dir: + # self.set_metadata("request_uris", Service.get_uri(base_url, _dir, hex)) + # else: + # self.set_metadata("request_uris", + # Service.get_uri(base_url, self.callback_path["requests"], hex)) def verify_rules(self): if self.get_usage("request_parameter") and self.get_usage("request_uri"): @@ -91,4 +91,3 @@ def locals(self, info): os.makedirs(requests_dir) self.set("requests_dir", requests_dir) - diff --git a/src/idpyoidc/message/oauth2/__init__.py b/src/idpyoidc/message/oauth2/__init__.py index 86494e68..ea2e5702 100644 --- a/src/idpyoidc/message/oauth2/__init__.py +++ b/src/idpyoidc/message/oauth2/__init__.py @@ -45,7 +45,7 @@ class ResponseMessage(Message): def verify(self, **kwargs): super(ResponseMessage, self).verify(**kwargs) if "error_description" in self: - # Verify that the characters used are within the allow ranges + # Verify that the characters used are within the allowed ranges # %x20-21 / %x23-5B / %x5D-7E if not all(x in error_chars for x in self["error_description"]): raise ValueError("Characters outside allowed set") diff --git a/src/idpyoidc/server/oauth2/token.py b/src/idpyoidc/server/oauth2/token.py index b23fa03c..e7c4fe85 100755 --- a/src/idpyoidc/server/oauth2/token.py +++ b/src/idpyoidc/server/oauth2/token.py @@ -44,7 +44,8 @@ def __init__(self, server_get, new_refresh_token=False, **kwargs): self.allow_refresh = False self.new_refresh_token = new_refresh_token self.configure_grant_types(kwargs.get("grant_types_helpers")) - self.grant_types_supported = kwargs.get("grant_types_supported", list(self.helper.keys())) + self.grant_types_supported = kwargs.get("grant_types_supported", + list(self.helper_by_grant_type.keys())) self.revoke_refresh_on_issue = kwargs.get("revoke_refresh_on_issue", False) def configure_grant_types(self, grant_types_helpers): diff --git a/tests/static/jwks.json b/tests/static/jwks.json index 161a407b..8322d976 100644 --- a/tests/static/jwks.json +++ b/tests/static/jwks.json @@ -1 +1 @@ -{"keys": [{"kty": "RSA", "use": "sig", "kid": "YnNESFhyQjloMnYzV2VqRGR2a3VCblFLX2h4VGl3TDVlY3FUNkViUE90bw", "e": "AQAB", "n": "2iMaDALTQolz4UaT--GhjriLMyNbrDGlIXxSmgRh17Cm3cuHiyPOIQv1pjZVg4ATU1aafxmFyTfrmtf56tPuJ8yqcNNZC8XadYPAw7PTW9g8GJgLtC8GURJ9GQZD6FYIE6YCou8fYo6yd4b99y2y_vsl06cm9xQnstfp6eyMkcgQyrmdmlbyeuXwvcxsxtGX61MTJtCp4VELmDctJiYP_bD7HNRPV7uqXDMNmWSY0TYL-tg0As4y8-w3wSwmtcfWhnQEraFT0-m4hBpEWHlouuFNXRQIrXbamKxeh6kJNO0wJN8fZ4Ovygf8sE4kEwBPfWO59wxDF7camTpDUqg29Q"}, {"kty": "EC", "use": "sig", "kid": "aWhtalRSTDZmNmRTd1ZDNWZmY3ZGMTNqM1dnLVA2RjQyMi1CNGdOSUNKVQ", "crv": "P-256", "x": "Ww5XVT3CxYN88BpJDZGodRiar0qr8UvPFaRoqzyD1Io", "y": "w23EDFAvwe03NjL5NKtUXwxuVMFmEn3ecJOPbljiDkg"}]} \ No newline at end of file +{"keys": [{"kty": "RSA", "use": "sig", "kid": "YnNESFhyQjloMnYzV2VqRGR2a3VCblFLX2h4VGl3TDVlY3FUNkViUE90bw", "n": "2iMaDALTQolz4UaT--GhjriLMyNbrDGlIXxSmgRh17Cm3cuHiyPOIQv1pjZVg4ATU1aafxmFyTfrmtf56tPuJ8yqcNNZC8XadYPAw7PTW9g8GJgLtC8GURJ9GQZD6FYIE6YCou8fYo6yd4b99y2y_vsl06cm9xQnstfp6eyMkcgQyrmdmlbyeuXwvcxsxtGX61MTJtCp4VELmDctJiYP_bD7HNRPV7uqXDMNmWSY0TYL-tg0As4y8-w3wSwmtcfWhnQEraFT0-m4hBpEWHlouuFNXRQIrXbamKxeh6kJNO0wJN8fZ4Ovygf8sE4kEwBPfWO59wxDF7camTpDUqg29Q", "e": "AQAB"}, {"kty": "EC", "use": "sig", "kid": "aWhtalRSTDZmNmRTd1ZDNWZmY3ZGMTNqM1dnLVA2RjQyMi1CNGdOSUNKVQ", "crv": "P-256", "x": "Ww5XVT3CxYN88BpJDZGodRiar0qr8UvPFaRoqzyD1Io", "y": "w23EDFAvwe03NjL5NKtUXwxuVMFmEn3ecJOPbljiDkg"}]} \ No newline at end of file diff --git a/tests/test_client_01_service_context.py b/tests/test_client_01_service_context.py index e2848728..251070c3 100644 --- a/tests/test_client_01_service_context.py +++ b/tests/test_client_01_service_context.py @@ -36,7 +36,7 @@ def test_create_callback_uris(self): base_url = "https://example.com/cli" hex = "0123456789" self.service_context.specs.construct_redirect_uris(base_url, hex, []) - _uris = self.service_context.specs.get_metadata("redirect_uris") + _uris = self.service_context.specs.get_metadata_claim("redirect_uris") assert len(_uris) == 1 assert _uris == [f"https://example.com/cli/authz_cb/{hex}"] diff --git a/tests/test_client_02_entity.py b/tests/test_client_02_entity.py index 405aa732..ba97b7da 100644 --- a/tests/test_client_02_entity.py +++ b/tests/test_client_02_entity.py @@ -39,7 +39,7 @@ def test_get_service_unsupported(self): assert _srv is None def test_get_client_id(self): - assert self.entity.get_metadata_value("client_id") == "Number5" + assert self.entity.get_metadata_claim("client_id") == "Number5" assert self.entity.client_get("client_id") == "Number5" def test_get_service_by_endpoint_name(self): diff --git a/tests/test_client_02b_entity_metadata.py b/tests/test_client_02b_entity_metadata.py index ffe31a0a..510e2fb8 100644 --- a/tests/test_client_02b_entity_metadata.py +++ b/tests/test_client_02b_entity_metadata.py @@ -77,7 +77,7 @@ def test_create_client(): - client = Entity(config=CLIENT_CONFIG) + client = Entity(config=CLIENT_CONFIG, client_type='oidc') _md = client.collect_metadata() assert set(_md.keys()) == {'application_type', 'backchannel_logout_uri', @@ -96,10 +96,10 @@ def test_create_client(): 'userinfo_signed_response_alg'} # What's in service configuration has higher priority then metadata. - assert client.get_metadata_value("contacts") == 'support@example.com' + assert client.get_metadata_claim("contacts") == 'support@example.com' # Two ways of looking at things - assert client.get_metadata_value("userinfo_signed_response_alg") == "ES256" - assert client.value_in_metadata_attribute("userinfo_signed_response_alg", "ES256") + assert client.get_metadata_claim("userinfo_signed_response_alg") == "ES256" + assert client.metadata_claim_contains_value("userinfo_signed_response_alg", "ES256") # How to act assert client.get_usage_value("request_uri") is True @@ -121,7 +121,7 @@ def test_create_client_key_conf(): client_config.update({"key_conf": KEY_CONF}) client = Entity(config=client_config) - _jwks = client.get_metadata_value("jwks") + _jwks = client.get_metadata_claim("jwks") assert _jwks @@ -130,11 +130,11 @@ def test_create_client_keyjar(): client_config = CLIENT_CONFIG.copy() client = Entity(config=client_config, keyjar=_keyjar) - _jwks = client.get_metadata_value("jwks") + _jwks = client.get_metadata_claim("jwks") assert _jwks def test_create_client_jwks_uri(): client_config = CLIENT_CONFIG.copy() client = Entity(config=client_config, jwks_uri="https://rp.example.com/jwks_uri.json") - assert client.get_metadata_value("jwks_uri") + assert client.get_metadata_claim("jwks_uri") From 7bf470c44d9b0e745f4c2fda8c14097b6e75da46 Mon Sep 17 00:00:00 2001 From: roland Date: Sat, 12 Nov 2022 08:28:03 +0100 Subject: [PATCH 12/76] Merged --- src/idpyoidc/__init__.py | 3 - src/idpyoidc/actor/__init__.py | 2 +- src/idpyoidc/actor/client/oidc/__init__.py | 7 +- .../actor/client/oidc/registration.py | 110 ++++++++-------- src/idpyoidc/actor/server/__init__.py | 2 +- src/idpyoidc/client/client_auth.py | 14 +- src/idpyoidc/client/configure.py | 2 +- src/idpyoidc/client/entity.py | 120 +++++++++--------- src/idpyoidc/client/http.py | 2 +- src/idpyoidc/client/oauth2/__init__.py | 2 +- src/idpyoidc/client/oauth2/authorization.py | 2 +- src/idpyoidc/client/oauth2/token_exchange.py | 1 - src/idpyoidc/client/oauth2/utils.py | 13 +- src/idpyoidc/client/oidc/access_token.py | 6 +- src/idpyoidc/client/oidc/authorization.py | 29 +++-- src/idpyoidc/client/oidc/end_session.py | 10 +- .../client/oidc/provider_info_discovery.py | 20 +-- .../client/oidc/refresh_access_token.py | 4 +- src/idpyoidc/client/oidc/registration.py | 6 +- src/idpyoidc/client/oidc/userinfo.py | 22 ++-- src/idpyoidc/client/oidc/utils.py | 4 +- src/idpyoidc/client/rp_handler.py | 9 +- src/idpyoidc/client/service.py | 44 +++---- src/idpyoidc/client/service_context.py | 29 ++--- src/idpyoidc/client/util.py | 14 +- .../__init__.py | 72 +++++------ .../oauth2.py | 11 +- .../{specification => work_condition}/oidc.py | 52 +++----- src/idpyoidc/impexp.py | 2 - src/idpyoidc/logging.py | 2 - src/idpyoidc/message/oidc/__init__.py | 2 +- src/idpyoidc/message/oidc/session.py | 16 +-- src/idpyoidc/server/client_authn.py | 1 - src/idpyoidc/server/client_configure.py | 6 - src/idpyoidc/server/oauth2/authorization.py | 6 +- src/idpyoidc/server/oauth2/server_metadata.py | 2 - src/idpyoidc/server/oauth2/token.py | 5 +- src/idpyoidc/server/oauth2/token_helper.py | 8 +- src/idpyoidc/server/oidc/authorization.py | 3 + src/idpyoidc/server/token/__init__.py | 1 - src/idpyoidc/server/util.py | 5 +- src/idpyoidc/time_util.py | 41 +++--- src/idpyoidc/util.py | 4 + tests/request123456.jwt | 2 +- tests/test_client_01_service_context.py | 14 +- tests/test_client_02b_entity_metadata.py | 8 +- tests/test_client_04_service.py | 4 +- tests/test_client_06_client_authn.py | 7 +- tests/test_client_12_client_auth.py | 2 +- .../test_client_14_service_context_impexp.py | 12 +- tests/test_client_21_oidc_service.py | 33 +++-- tests/test_client_24_oic_utils.py | 4 +- tests/test_client_28_rp_handler_oidc.py | 6 +- tests/test_client_30_rph_defaults.py | 2 +- tests/test_client_41_rp_handler_persistent.py | 2 +- tests/test_client_51_identity_assurance.py | 2 +- 56 files changed, 394 insertions(+), 420 deletions(-) rename src/idpyoidc/client/{specification => work_condition}/__init__.py (75%) rename src/idpyoidc/client/{specification => work_condition}/oauth2.py (74%) rename src/idpyoidc/client/{specification => work_condition}/oidc.py (64%) diff --git a/src/idpyoidc/__init__.py b/src/idpyoidc/__init__.py index 5b03c94b..691c3c91 100644 --- a/src/idpyoidc/__init__.py +++ b/src/idpyoidc/__init__.py @@ -1,9 +1,6 @@ __author__ = "Roland Hedberg" __version__ = "1.4.0" -import os -from typing import Dict - VERIFIED_CLAIM_PREFIX = "__verified" diff --git a/src/idpyoidc/actor/__init__.py b/src/idpyoidc/actor/__init__.py index 4287ca86..792d6005 100644 --- a/src/idpyoidc/actor/__init__.py +++ b/src/idpyoidc/actor/__init__.py @@ -1 +1 @@ -# \ No newline at end of file +# diff --git a/src/idpyoidc/actor/client/oidc/__init__.py b/src/idpyoidc/actor/client/oidc/__init__.py index d439462b..ce1ab9d8 100644 --- a/src/idpyoidc/actor/client/oidc/__init__.py +++ b/src/idpyoidc/actor/client/oidc/__init__.py @@ -49,12 +49,11 @@ def do_client_notification(self, msg, http_info): _nreq = _notification_endpoint.parse_request( msg, http_info, get_client_id_from_token=self.get_client_id_from_token ) - _ninfo = _notification_endpoint.process_request(_nreq) + _notification_endpoint.process_request(_nreq) def construct_metadata(self): _reg_serv = self.client.client_get("service", "registration") - _info_c = _reg_serv.construct_request() - _reg_endp = self.server.server_get("endpoint", "discovery") - _info_e = _reg_endp.provider_info + _reg_serv.construct_request() + # _reg_endp = self.server.server_get("endpoint", "discovery") return {} diff --git a/src/idpyoidc/actor/client/oidc/registration.py b/src/idpyoidc/actor/client/oidc/registration.py index 0de174cb..abfe9be9 100644 --- a/src/idpyoidc/actor/client/oidc/registration.py +++ b/src/idpyoidc/actor/client/oidc/registration.py @@ -1,8 +1,4 @@ -import hashlib import logging -from typing import Optional - -from cryptojwt.utils import as_bytes from idpyoidc.client.service import Service from idpyoidc.message import oidc @@ -38,58 +34,58 @@ def response_types_to_grant_types(response_types): return list(_res) -def create_callbacks( - issuer: str, - hash_seed: str, - base_url: str, - code: Optional[bool] = False, - implicit: Optional[bool] = False, - form_post: Optional[bool] = False, - request_uris: Optional[bool] = False, - backchannel_logout_uri: Optional[bool] = False, - frontchannel_logout_uri: Optional[bool] = False, -): - """ - To mitigate some security issues the redirect_uris should be OP/AS - specific. This method creates a set of redirect_uris unique to the - OP/AS. - - :param frontchannel_logout_uri: Whether a front-channel logout uri should be constructed - :param backchannel_logout_uri: Whether a back-channel logout uri should be constructed - :param request_uri: Whether a request_uri should be constructed - :param issuer: Issuer ID - :return: A set of redirect_uris - """ - _hash = hashlib.sha256() - _hash.update(hash_seed) - _hash.update(as_bytes(issuer)) - _hex = _hash.hexdigest() - - res = {"__hex": _hex} - - if code: - res["code"] = f"{base_url}/authz_cb/{_hex}" - - if implicit: - res["implicit"] = f"{base_url}/authz_im_cb/{_hex}" - - if form_post: - res["form_post"] = f"{base_url}/authz_fp_cb/{_hex}" - - if request_uris: - res["request_uris"] = f"{base_url}/req_uri/{_hex}" - - if backchannel_logout_uri or frontchannel_logout_uri: - res["post_logout_redirect_uris"] = [f"{base_url}/session_logout/{_hex}"] - - if backchannel_logout_uri: - res["backchannel_logout_uri"] = f"{base_url}/bc_logout/{_hex}" - - if frontchannel_logout_uri: - res["frontchannel_logout_uri"] = f"{base_url}/fc_logout/{_hex}" - - logger.debug(f"Created callback URIs: {res}") - return res +# def create_callbacks( +# issuer: str, +# hash_seed: str, +# base_url: str, +# code: Optional[bool] = False, +# implicit: Optional[bool] = False, +# form_post: Optional[bool] = False, +# request_uris: Optional[bool] = False, +# backchannel_logout_uri: Optional[bool] = False, +# frontchannel_logout_uri: Optional[bool] = False, +# ): +# """ +# To mitigate some security issues the redirect_uris should be OP/AS +# specific. This method creates a set of redirect_uris unique to the +# OP/AS. +# +# :param frontchannel_logout_uri: Whether a front-channel logout uri should be constructed +# :param backchannel_logout_uri: Whether a back-channel logout uri should be constructed +# :param request_uri: Whether a request_uri should be constructed +# :param issuer: Issuer ID +# :return: A set of redirect_uris +# """ +# _hash = hashlib.sha256() +# _hash.update(hash_seed) +# _hash.update(as_bytes(issuer)) +# _hex = _hash.hexdigest() +# +# res = {"__hex": _hex} +# +# if code: +# res["code"] = f"{base_url}/authz_cb/{_hex}" +# +# if implicit: +# res["implicit"] = f"{base_url}/authz_im_cb/{_hex}" +# +# if form_post: +# res["form_post"] = f"{base_url}/authz_fp_cb/{_hex}" +# +# if request_uris: +# res["request_uris"] = f"{base_url}/req_uri/{_hex}" +# +# if backchannel_logout_uri or frontchannel_logout_uri: +# res["post_logout_redirect_uris"] = [f"{base_url}/session_logout/{_hex}"] +# +# if backchannel_logout_uri: +# res["backchannel_logout_uri"] = f"{base_url}/bc_logout/{_hex}" +# +# if frontchannel_logout_uri: +# res["frontchannel_logout_uri"] = f"{base_url}/fc_logout/{_hex}" +# +# logger.debug(f"Created callback URIs: {res}") +# return res def _cmp(a, b): @@ -166,7 +162,7 @@ def add_client_behaviour_preference(self, request_args=None, **kwargs): continue try: - request_args[prop] = _context.specs.behaviour[prop] + request_args[prop] = _context.work_condition.behaviour[prop] except KeyError: try: request_args[prop] = _context.client_preferences[prop] diff --git a/src/idpyoidc/actor/server/__init__.py b/src/idpyoidc/actor/server/__init__.py index 4287ca86..792d6005 100644 --- a/src/idpyoidc/actor/server/__init__.py +++ b/src/idpyoidc/actor/server/__init__.py @@ -1 +1 @@ -# \ No newline at end of file +# diff --git a/src/idpyoidc/client/client_auth.py b/src/idpyoidc/client/client_auth.py index 240efd65..de7bb357 100755 --- a/src/idpyoidc/client/client_auth.py +++ b/src/idpyoidc/client/client_auth.py @@ -1,7 +1,6 @@ """Implementation of a number of client authentication methods.""" import base64 import logging -from urllib.parse import quote_plus from cryptojwt.exception import MissingKey from cryptojwt.exception import UnsupportedAlgorithm @@ -11,14 +10,13 @@ from idpyoidc.defaults import DEF_SIGN_ALG from idpyoidc.defaults import JWT_BEARER -from idpyoidc.message.oauth2 import SINGLE_OPTIONAL_STRING from idpyoidc.message.oauth2 import AccessTokenRequest +from idpyoidc.message.oauth2 import SINGLE_OPTIONAL_STRING from idpyoidc.message.oidc import AuthnToken from idpyoidc.time_util import utc_time_sans_frac from idpyoidc.util import rndstr - -from ..message import VREQUIRED from .util import sanitize +from ..message import VREQUIRED # from idpyoidc.oidc.backchannel_authentication import ClientNotificationAuthn @@ -135,8 +133,8 @@ def _with_or_without_client_id(request, service): :param service: A :py:class:`idpyoidc.client.service.Service` instance """ if ( - isinstance(request, AccessTokenRequest) - and request["grant_type"] == "authorization_code" + isinstance(request, AccessTokenRequest) + and request["grant_type"] == "authorization_code" ): if "client_id" not in request: try: @@ -467,7 +465,7 @@ def _get_audience_and_algorithm(self, context, entity, **kwargs): # we're talking to. if "authn_endpoint" in kwargs and kwargs["authn_endpoint"] in ["token_endpoint"]: _alg = context.registration_response.get("token_endpoint_auth_signing_alg") - if _alg : + if _alg: algorithm = _alg else: algorithm = entity.get_metadata_claim("token_endpoint_auth_signing_alg") @@ -480,7 +478,7 @@ def _get_audience_and_algorithm(self, context, entity, **kwargs): else: for alg in algs: # pick the first one I support and have keys for if alg in SIGNER_ALGS and self.get_signing_key_from_keyjar( - alg, context + alg, context ): algorithm = alg break diff --git a/src/idpyoidc/client/configure.py b/src/idpyoidc/client/configure.py index cfdd30f0..966e7f6c 100755 --- a/src/idpyoidc/client/configure.py +++ b/src/idpyoidc/client/configure.py @@ -7,7 +7,6 @@ from idpyoidc.configure import Base from idpyoidc.logging import configure_logging -from idpyoidc.message.oidc import RegistrationResponse from .util import lower_or_upper try: @@ -26,6 +25,7 @@ class RPHConfiguration(Base): + def __init__( self, conf: Dict, diff --git a/src/idpyoidc/client/entity.py b/src/idpyoidc/client/entity.py index 771884d9..cf1415fc 100644 --- a/src/idpyoidc/client/entity.py +++ b/src/idpyoidc/client/entity.py @@ -47,25 +47,25 @@ def response_types_to_grant_types(response_types): def set_jwks_uri_or_jwks(entity, service_context, config, jwks_uri, keyjar): # lots of different ways to configure the RP's keys if jwks_uri: - entity.set_usage_value("jwks_uri", True) - entity.set_metadata_value("jwks_uri", jwks_uri) + entity.set_support("jwks_uri", True) + entity.set_metadata_claim("jwks_uri", jwks_uri) else: if config.get("jwks_uri"): - entity.set_usage_value("jwks_uri", True) - entity.set_usage_value("jwks", False) + entity.set_support("jwks_uri", True) + entity.set_support("jwks", False) elif config.get("jwks"): - entity.set_usage_value("jwks", True) - entity.set_usage_value("jwks_uri", False) + entity.set_support("jwks", True) + entity.set_support("jwks_uri", False) else: - entity.set_usage_value("jwks_uri", False) + entity.set_support("jwks_uri", False) if config.get("key_conf"): keys_args = {k: v for k, v in config.get("key_conf").items() if k != "uri_path"} _keyjar = init_key_jar(**keys_args) - entity.set_usage_value("jwks", True) + entity.set_support("jwks", True) entity.set_metadata_claim("jwks", _keyjar.export_jwks()) return elif keyjar: - entity.set_usage_value("jwks", True) + entity.set_support("jwks", True) entity.set_metadata_claim("jwks", keyjar.export_jwks()) return @@ -120,7 +120,7 @@ def __init__( self._service = init_services(service_definitions=_srvs, client_get=self.client_get, metadata=config.conf.get("metadata", {}), - usage=config.conf.get("usage", {})) + support=config.conf.get("support", {})) self.setup_client_authn_methods(config) @@ -174,23 +174,23 @@ def collect_metadata(self): res = {} for service in self._service.values(): res.update(service.metadata) - res.update(self._service_context.specs.get_metadata()) + res.update(self._service_context.work_condition.get_metadata()) return res - def collect_usage(self): + def collect_support(self): res = {} for service in self._service.values(): - res.update(service.usage) - res.update(self._service_context.specs.usage) + res.update(service.support) + res.update(self._service_context.work_condition.support) return res def get_metadata_claim(self, claim, default=None): for service in self._service.values(): if claim in service.metadata_claims: - return service.get_metadata(claim, default) + return service.get_metadata_claim(claim, default) - if claim in self._service_context.specs.attributes: - return self._service_context.specs.get_metadata_claim(claim, default) + if claim in self._service_context.work_condition.metadata_claims: + return self._service_context.work_condition.get_metadata_claim(claim, default) raise KeyError(f"Unknown specs claim: {claim}") @@ -199,7 +199,14 @@ def get_metadata_claims(self): for service in self._service.values(): claims.extend(list(service.metadata_claims.keys())) - claims.extend(list(self._service_context.specs.metadata.keys())) + claims.extend(list(self._service_context.work_condition.metadata_claims.keys())) + + return claims + + def get_claim_sources(self): + claims = {'': list(self._service_context.work_condition.metadata_claims.keys())} + for service in self._service.values(): + claims[service.endpoint_name] = list(service.metadata_claims.keys()) return claims @@ -214,14 +221,14 @@ def metadata_claim_contains_value(self, claim, value): return False - def will_use(self, claim): + def will_use(self, facet): for service in self._service.values(): - if claim in service.usage_rules.keys(): - if service.usage.get(claim): + if facet in service.can_support.keys(): + if service.support.get(facet): return True - if claim in self._service_context.specs.rules.keys(): - if self._service_context.specs.get_usage(claim): + if facet in self._service_context.work_condition.can_support.keys(): + if self._service_context.work_condition.get_support(facet): return True return False @@ -240,64 +247,61 @@ def set_metadata_claim(self, claim, value): service.metadata[claim] = value return True - if claim in self._service_context.specs.attributes: - _def_val = self._service_context.specs.attributes[claim] + if claim in self._service_context.work_condition.metadata_claims: + _def_val = self._service_context.work_condition.metadata_claims[claim] if _def_val is None: - self._service_context.specs.set_metadata_claim(claim, value) + self._service_context.work_condition.set_metadata_claim(claim, value) return True else: - if self._service_context.specs.get_metadata_claim(claim, _def_val): - self._service_context.specs.set_metadata_claim(claim, value) + if self._service_context.work_condition.get_metadata_claim(claim, _def_val): + self._service_context.work_condition.set_metadata_claim(claim, value) return True return True logger.info(f"Unknown set specs claim: {claim}") return False - def set_usage_value(self, claim, value): + def set_support(self, claim, value): """ Only OK to overwrite a value if the value is the default value """ for service in self._service.values(): - if claim in service.usage_rules: - _def_val = service.usage_rules[claim] + if claim in service.can_support: + _def_val = service.can_support[claim] if _def_val is None: - service.usage[claim] = value + service.support[claim] = value return True else: - if service.usage[claim] == _def_val: - service.usage[claim] = value + if service.support[claim] == _def_val: + service.support[claim] = value return True - if claim in self._service_context.specs.rules: - _def_val = self._service_context.specs.rules[claim] + if claim in self._service_context.work_condition.can_support: + _def_val = self._service_context.work_condition.can_support[claim] if _def_val is None: - self._service_context.specs.set_usage(claim, value) + self._service_context.work_condition.set_support(claim, value) return True else: - if self._service_context.specs.usage[claim] == _def_val: - self._service_context.specs.set_usage(claim, value) + if self._service_context.work_condition.can_support[claim] == _def_val: + self._service_context.work_condition.set_support(claim, value) return True - logger.info(f"Unknown set usage claim: {claim}") + logger.info(f"Unknown set support claim: {claim}") return False - def get_usage_value(self, claim, default=None): + def get_support(self, claim, default=None): for service in self._service.values(): - if claim in service.usage_rules: - if claim in service.usage: - return service.usage[claim] - else: - return default + if claim in service.can_support.keys(): + return service.support.get(claim, default) - if claim in self._service_context.specs.rules: - _val = self._service_context.specs.get_usage(claim) + if claim in self._service_context.work_condition.can_support: + _val = self._service_context.work_condition.get_support(claim) if _val: return _val else: return default - logger.info(f"Unknown usage claim: {claim}") + logger.info(f"Unknown support claim: {claim}") def construct_uris(self, issuer: str, @@ -314,8 +318,8 @@ def construct_uris(self, for service in self._service.values(): service.construct_uris(_base_url, _hex) - if not self._service_context.specs.get_metadata_claim("redirect_uris"): - self._service_context.specs.construct_redirect_uris(_base_url, _hex, callback) + if not self._service_context.work_condition.get_metadata_claim("redirect_uris"): + self._service_context.work_condition.construct_redirect_uris(_base_url, _hex, callback) def backward_compatibility(self, config): _uris = config.get("redirect_uris") @@ -326,7 +330,7 @@ def backward_compatibility(self, config): if _dir: authz_serv = self.get_service('authorization') if authz_serv: # If this isn't true that's weird. Tests perhaps ? - self.set_usage_value("request_uri", True) + self.set_support("request_uri", True) if not os.path.isdir(_dir): os.makedirs(_dir) authz_serv.callback_path["request_uris"] = _dir @@ -334,12 +338,12 @@ def backward_compatibility(self, config): _pref = config.get("client_preferences", {}) for key, val in _pref.items(): if self.set_metadata_claim(key, val) is False: - if self.set_usage_value(key, val) is False: + if self.set_support(key, val) is False: setattr(self, key, val) for key, val in config.conf.items(): if key not in ["port", "domain", "httpc_params", "metadata", "client_preferences", - "usage", "services", "add_ons"]: + "support", "services", "add_ons"]: self.extra[key] = val auth_request_args = config.conf.get("request_args", {}) @@ -352,11 +356,11 @@ def config_args(self): for id, service in self._service.items(): res[id] = { "metadata": service.metadata_claims, - "usage": service.usage_rules + "support": service.can_support } res[""] = { - "metadata": self._service_context.specs.metadata, - "usage": self._service_context.specs.usage + "metadata": self._service_context.work_condition.metadata_claims, + "support": self._service_context.work_condition.can_support } return res @@ -364,5 +368,5 @@ def get_callback_uris(self): res = [] for service in self._service.values(): res.extend(service.callback_uris) - res.extend(self._service_context.specs.callback_uris) + res.extend(self._service_context.work_condition.callback_uris) return res diff --git a/src/idpyoidc/client/http.py b/src/idpyoidc/client/http.py index d7825787..e846ceed 100644 --- a/src/idpyoidc/client/http.py +++ b/src/idpyoidc/client/http.py @@ -67,7 +67,7 @@ def set_cookie(self, response): except CookieError as err: logger.error(err) raise NonFatalException(response, "{}".format(err)) - except (AttributeError, KeyError) as err: + except (AttributeError, KeyError): pass def __call__(self, url, method="GET", **kwargs): diff --git a/src/idpyoidc/client/oauth2/__init__.py b/src/idpyoidc/client/oauth2/__init__.py index d49fefea..2a2a7125 100755 --- a/src/idpyoidc/client/oauth2/__init__.py +++ b/src/idpyoidc/client/oauth2/__init__.py @@ -103,7 +103,7 @@ def do_request( try: _state = kwargs["state"] - except: + except Exception: _state = "" return self.service_request( _srv, response_body_type=response_body_type, state=_state, **_info diff --git a/src/idpyoidc/client/oauth2/authorization.py b/src/idpyoidc/client/oauth2/authorization.py index 59665964..c55cdbb2 100644 --- a/src/idpyoidc/client/oauth2/authorization.py +++ b/src/idpyoidc/client/oauth2/authorization.py @@ -46,7 +46,7 @@ def gather_request_args(self, **kwargs): if "redirect_uri" not in ar_args: try: # ar_args["redirect_uri"] = self.client_get("service_context").redirect_uris[0] - ar_args["redirect_uri"] = self.client_get("entity").get_metadata_value( + ar_args["redirect_uri"] = self.client_get("entity").get_metadata_claim( "redirect_uris")[0] except (KeyError, AttributeError): raise MissingParameter("redirect_uri") diff --git a/src/idpyoidc/client/oauth2/token_exchange.py b/src/idpyoidc/client/oauth2/token_exchange.py index f583ac7a..ed6390af 100644 --- a/src/idpyoidc/client/oauth2/token_exchange.py +++ b/src/idpyoidc/client/oauth2/token_exchange.py @@ -26,7 +26,6 @@ class TokenExchange(Service): request_body_type = "urlencoded" response_body_type = "json" - def __init__(self, client_get, conf=None): Service.__init__(self, client_get, conf=conf) self.pre_construct.append(self.oauth_pre_construct) diff --git a/src/idpyoidc/client/oauth2/utils.py b/src/idpyoidc/client/oauth2/utils.py index b9a693c0..c32deb89 100644 --- a/src/idpyoidc/client/oauth2/utils.py +++ b/src/idpyoidc/client/oauth2/utils.py @@ -1,5 +1,4 @@ import logging -from typing import List from typing import Optional from typing import Union @@ -36,9 +35,9 @@ def pick_redirect_uri( if "redirect_uri" in request_args: return request_args["redirect_uri"] - if context.specs.callback: + if context.work_condition.callback: if not response_type: - _conf_resp_types = context.specs.behaviour.get("response_types", []) + _conf_resp_types = context.work_condition.behaviour.get("response_types", []) response_type = request_args.get("response_type") if not response_type and _conf_resp_types: response_type = _conf_resp_types[0] @@ -46,18 +45,18 @@ def pick_redirect_uri( _response_mode = request_args.get("response_mode") if _response_mode == "form_post" or response_type == ["form_post"]: - redirect_uri = context.specs.callback["form_post"] + redirect_uri = context.work_condition.callback["form_post"] elif response_type == "code" or response_type == ["code"]: - redirect_uri = context.specs.callback["code"] + redirect_uri = context.work_condition.callback["code"] else: - redirect_uri = context.specs.callback["implicit"] + redirect_uri = context.work_condition.callback["implicit"] logger.debug( f"pick_redirect_uris: response_type={response_type}, response_mode={_response_mode}, " f"redirect_uri={redirect_uri}" ) else: - redirect_uris = entity.get_metadata_value("redirect_uris", []) + redirect_uris = entity.get_metadata_claim("redirect_uris", []) if redirect_uris: redirect_uri = redirect_uris[0] else: diff --git a/src/idpyoidc/client/oidc/access_token.py b/src/idpyoidc/client/oidc/access_token.py index 122cbf22..1292d1a4 100644 --- a/src/idpyoidc/client/oidc/access_token.py +++ b/src/idpyoidc/client/oidc/access_token.py @@ -55,7 +55,7 @@ def gather_verify_arguments( except KeyError: pass - _verify_args = _context.specs.behaviour.get("verify_args") + _verify_args = _context.work_condition.behaviour.get("verify_args") if _verify_args: if _verify_args: kwargs.update(_verify_args) @@ -83,8 +83,8 @@ def update_service_context(self, resp, key="", **kwargs): _state_interface.store_item(resp, "token_response", key) def get_authn_method(self): - _specs = self.client_get("service_context").specs + _work_condition = self.client_get("service_context").work_condition try: - return _specs.behaviour["token_endpoint_auth_method"] + return _work_condition.behaviour["token_endpoint_auth_method"] except KeyError: return self.default_authn_method diff --git a/src/idpyoidc/client/oidc/authorization.py b/src/idpyoidc/client/oidc/authorization.py index 9cea8b7f..2792702a 100644 --- a/src/idpyoidc/client/oidc/authorization.py +++ b/src/idpyoidc/client/oidc/authorization.py @@ -27,7 +27,16 @@ class Authorization(authorization.Authorization): response_cls = oidc.AuthorizationResponse error_msg = oidc.ResponseMessage - usage_rules = { + can_support = { + "request_uris": None + } + + callback_path = { + "request_uris": "request", + } + + support_to_uri = { + "request_uris": "request_uris", } def __init__(self, client_get, conf=None): @@ -93,7 +102,7 @@ def oidc_pre_construct(self, request_args=None, post_args=None, **kwargs): try: _response_types = [request_args["response_type"]] except KeyError: - _response_types = _context.specs.behaviour.get("response_types") + _response_types = _context.work_condition.behaviour.get("response_types") if _response_types: request_args["response_type"] = _response_types[0] else: @@ -101,7 +110,7 @@ def oidc_pre_construct(self, request_args=None, post_args=None, **kwargs): # For OIDC 'openid' is required in scope if "scope" not in request_args: - _scope = self.client_get("entity").get_usage_value("scope") + _scope = self.client_get("entity").get_support("scope") if _scope: request_args["scope"] = _scope else: @@ -133,9 +142,9 @@ def oidc_pre_construct(self, request_args=None, post_args=None, **kwargs): post_args["request_param"] = "request" del kwargs["request_method"] else: - if _entity.get_usage_value("request_uri"): + if _entity.get_support("request_uri"): post_args["request_param"] = "request_uri" - elif _entity.get_usage_value("request_parameter"): + elif _entity.get_support("request_parameter"): post_args["request_param"] = "request" return request_args, post_args @@ -153,7 +162,7 @@ def get_request_object_signing_alg(self, **kwargs): if not alg: _context = self.client_get("service_context") try: - alg = _context.specs.behaviour["request_object_signing_alg"] + alg = _context.work_condition.behaviour["request_object_signing_alg"] except KeyError: # Use default alg = "RS256" return alg @@ -255,15 +264,15 @@ def oidc_post_construct(self, req, **kwargs): if _request_param: del kwargs["request_param"] else: - if _context.specs.get_usage("request_uri"): + if _context.work_condition.get_support("request_uri"): _request_param = "request_uri" - elif _context.specs.get_usage("request_parameter"): + elif _context.work_condition.get_support("request_parameter"): _request_param = "request" _req = None # just a flag if _request_param == "request_uri": kwargs["base_path"] = _context.get("base_url") + "/" + "requests" - kwargs["local_dir"] = _context.specs.get("requests_dir", "./requests") + kwargs["local_dir"] = _context.work_condition.get("requests_dir", "./requests") _req = self.construct_request_parameter(req, _request_param, **kwargs) req["request_uri"] = self.store_request_on_file(_req, **kwargs) elif _request_param == "request": @@ -313,7 +322,7 @@ def gather_verify_arguments( except KeyError: pass - _verify_args = _context.specs.behaviour.get("verify_args") + _verify_args = _context.work_condition.behaviour.get("verify_args") if _verify_args: kwargs.update(_verify_args) diff --git a/src/idpyoidc/client/oidc/end_session.py b/src/idpyoidc/client/oidc/end_session.py index 2e8d0778..8df64fab 100644 --- a/src/idpyoidc/client/oidc/end_session.py +++ b/src/idpyoidc/client/oidc/end_session.py @@ -28,7 +28,7 @@ class EndSession(Service): "backchannel_logout_session_required": None } - usage_rules = { + can_support = { "frontchannel_logout": None, "backchannel_logout": None, "post_logout_redirects": None @@ -40,18 +40,12 @@ class EndSession(Service): "post_logout_redirect_uris": "session_logout" } - usage_to_uri_map = { + support_to_uri = { "frontchannel_logout": "frontchannel_logout_uri", "backchannel_logout": "backchannel_logout_uri", "post_logout_redirect": "post_logout_redirect_uris" } - callback_uris = [ - "frontchannel_logout_uri", - "backchannel_logout_uri", - "post_logout_redirect_uris" - ] - def __init__(self, client_get, conf=None): Service.__init__(self, client_get, conf=conf) self.pre_construct = [ diff --git a/src/idpyoidc/client/oidc/provider_info_discovery.py b/src/idpyoidc/client/oidc/provider_info_discovery.py index 1b53296d..08a16c7c 100644 --- a/src/idpyoidc/client/oidc/provider_info_discovery.py +++ b/src/idpyoidc/client/oidc/provider_info_discovery.py @@ -74,7 +74,7 @@ def __init__(self, client_get, conf=None): def update_service_context(self, resp, **kwargs): _context = self.client_get("service_context") - self._update_service_context(resp) + self._update_service_context(resp) # set endpoints and import keys self.match_preferences(resp, _context.issuer) if "pre_load_keys" in self.conf and self.conf["pre_load_keys"]: _jwks = _context.keyjar.export_jwks_as_json(issuer=resp["issuer"]) @@ -101,14 +101,14 @@ def match_preferences(self, pcr=None, issuer=None): regreq = oidc.RegistrationRequest - _behaviour = _context.specs.behaviour + _behaviour = _context.work_condition.behaviour for _pref, _prov in PREFERENCE2PROVIDER.items(): if _pref in ["scope"]: - vals = _entity.get_usage_value(_pref) + vals = _entity.get_support(_pref) else: try: - vals = _entity.get_metadata_value(_pref) + vals = _entity.get_metadata_claim(_pref) except KeyError: continue @@ -134,10 +134,14 @@ def match_preferences(self, pcr=None, issuer=None): vtyp = regreq.c_param[_pref] except KeyError: # Allow non standard claims - if isinstance(vals, list): + if isinstance(vals, list) and isinstance(_pvals, list): _behaviour[_pref] = [v for v in vals if v in _pvals] - elif vals in _pvals: - _behaviour[_pref] = vals + elif isinstance(_pvals, list): + if vals in _pvals: + _behaviour[_pref] = vals + elif type(vals) == type(_pvals): + if vals == _pvals: + _behaviour[_pref] = vals else: if isinstance(vtyp[0], list): _behaviour[_pref] = [] @@ -170,5 +174,5 @@ def match_preferences(self, pcr=None, issuer=None): if key not in PREFERENCE2PROVIDER: _behaviour[key] = val - _context.specs.behaviour = _behaviour + _context.work_condition.behaviour = _behaviour logger.debug("service_context behaviour: {}".format(_behaviour)) diff --git a/src/idpyoidc/client/oidc/refresh_access_token.py b/src/idpyoidc/client/oidc/refresh_access_token.py index ddc38837..30d9d3c0 100644 --- a/src/idpyoidc/client/oidc/refresh_access_token.py +++ b/src/idpyoidc/client/oidc/refresh_access_token.py @@ -8,8 +8,8 @@ class RefreshAccessToken(refresh_access_token.RefreshAccessToken): error_msg = oidc.ResponseMessage def get_authn_method(self): - _specs = self.client_get("service_context").specs + _work_condition = self.client_get("service_context").work_condition try: - return _specs.behaviour["token_endpoint_auth_method"] + return _work_condition.behaviour["token_endpoint_auth_method"] except KeyError: return self.default_authn_method diff --git a/src/idpyoidc/client/oidc/registration.py b/src/idpyoidc/client/oidc/registration.py index ba9cc702..0df84cc8 100644 --- a/src/idpyoidc/client/oidc/registration.py +++ b/src/idpyoidc/client/oidc/registration.py @@ -38,9 +38,9 @@ def add_client_behaviour_preference(self, request_args=None, **kwargs): continue try: - request_args[prop] = _context.specs.behaviour[prop] + request_args[prop] = _context.work_condition.behaviour[prop] except KeyError: - _val = _context.specs.get_metadata_claim(prop) + _val = _context.work_condition.get_metadata_claim(prop) if _val: request_args[prop] = _val return request_args, {} @@ -67,7 +67,7 @@ def update_service_context(self, resp, key="", **kwargs): _context.registration_response = resp _client_id = resp.get("client_id") if _client_id: - _context.specs.set_metadata("client_id", _client_id) + _context.work_condition.set_metadata("client_id", _client_id) if _client_id not in _context.keyjar: _context.keyjar.import_jwks( _context.keyjar.export_jwks(True, ""), issuer_id=_client_id diff --git a/src/idpyoidc/client/oidc/userinfo.py b/src/idpyoidc/client/oidc/userinfo.py index 7a6f3f71..4ace120c 100644 --- a/src/idpyoidc/client/oidc/userinfo.py +++ b/src/idpyoidc/client/oidc/userinfo.py @@ -95,7 +95,7 @@ def post_parse_response(self, response, **kwargs): ) except MissingSigningKey as err: logger.warning( - "Error encountered while unpacking aggregated " "claims".format(err) + f"Error encountered while unpacking aggregated claims: {err}" ) else: claims = [ @@ -104,16 +104,16 @@ def post_parse_response(self, response, **kwargs): for key in claims: response[key] = aggregated_claims[key] - elif "endpoint" in spec: - _info = { - "headers": self.get_authn_header( - {}, - self.default_authn_method, - authn_endpoint=self.endpoint_name, - key=kwargs["state"], - ), - "url": spec["endpoint"], - } + # elif "endpoint" in spec: + # _info = { + # "headers": self.get_authn_header( + # {}, + # self.default_authn_method, + # authn_endpoint=self.endpoint_name, + # key=kwargs["state"], + # ), + # "url": spec["endpoint"], + # } # Extension point for meth in self.post_parse_process: diff --git a/src/idpyoidc/client/oidc/utils.py b/src/idpyoidc/client/oidc/utils.py index 887ac00d..7fd075ff 100644 --- a/src/idpyoidc/client/oidc/utils.py +++ b/src/idpyoidc/client/oidc/utils.py @@ -20,7 +20,7 @@ def request_object_encryption(msg, service_context, **kwargs): encalg = kwargs["request_object_encryption_alg"] except KeyError: try: - encalg = service_context.specs.behaviour["request_object_encryption_alg"] + encalg = service_context.work_condition.behaviour["request_object_encryption_alg"] except KeyError: return msg @@ -31,7 +31,7 @@ def request_object_encryption(msg, service_context, **kwargs): encenc = kwargs["request_object_encryption_enc"] except KeyError: try: - encenc = service_context.specs.behaviour["request_object_encryption_enc"] + encenc = service_context.work_condition.behaviour["request_object_encryption_enc"] except KeyError: raise MissingRequiredAttribute("No request_object_encryption_enc specified") diff --git a/src/idpyoidc/client/rp_handler.py b/src/idpyoidc/client/rp_handler.py index 4cc4df5c..8ef27bd0 100644 --- a/src/idpyoidc/client/rp_handler.py +++ b/src/idpyoidc/client/rp_handler.py @@ -383,9 +383,9 @@ def client_setup( def _get_response_type(self, context, req_args: Optional[dict] = None): if req_args: - return req_args.get("response_type", context.specs.behaviour["response_types"][0]) + return req_args.get("response_type", context.work_condition.behaviour["response_types"][0]) else: - return context.specs.behaviour["response_types"][0] + return context.work_condition.behaviour["response_types"][0] def init_authorization( self, @@ -422,7 +422,7 @@ def init_authorization( "redirect_uri": pick_redirect_uri( _context, _entity, request_args=req_args, response_type=_response_type ), - "scope": _context.specs.behaviour["scope"], + "scope": _context.work_condition.behaviour["scope"], "response_type": _response_type, "nonce": _nonce, } @@ -631,6 +631,7 @@ def get_user_info(self, state, client=None, access_token="", **kwargs): ["access_token"], ["auth_response", "token_response", "refresh_token_response"], ) + access_token = _arg["access_token"] request_args = {"access_token": access_token} @@ -842,7 +843,7 @@ def finalize(self, issuer, response, behaviour_args: Optional[dict] = None): _sid_support = _context.get("provider_info")[ "frontchannel_logout_session_required" ] - except: + except Exception: _sid_support = False if _sid_support and _id_token: diff --git a/src/idpyoidc/client/service.py b/src/idpyoidc/client/service.py index de143b68..6cbb0330 100644 --- a/src/idpyoidc/client/service.py +++ b/src/idpyoidc/client/service.py @@ -7,7 +7,6 @@ from urllib.parse import urlparse from cryptojwt.jwt import JWT -from cryptojwt.utils import qualified_name from idpyoidc.client.exception import Unsupported from idpyoidc.impexp import ImpExp @@ -64,10 +63,10 @@ class Service(ImpExp): init_args = ["client_get"] metadata_claims = {} - usage_rules = {} - usage_to_uri_map = {} + can_support = {} + support = {} + support_to_uri = {} callback_path = {} - callback_uris = [] def __init__( self, @@ -80,7 +79,7 @@ def __init__( self.client_get = client_get self.default_request_args = {} self.metadata = {} - self.usage = {} + self.support = {} self.callback_uri = {} if conf: @@ -105,13 +104,13 @@ def __init__( elif def_val is not None: self.metadata[param] = def_val - usage_conf = conf.get("usage", {}) - if usage_conf: - for param, def_val in self.usage_rules.items(): - if param in usage_conf: - self.usage[param] = usage_conf[param] + support_conf = conf.get("support", {}) + if support_conf: + for facet, def_val in self.can_support.items(): + if facet in support_conf: + self.support[facet] = support_conf[facet] elif def_val is not None: - self.usage[param] = def_val + self.support[facet] = def_val _default_request_args = conf.get("request_args", {}) if _default_request_args: @@ -163,7 +162,7 @@ def gather_request_args(self, **kwargs): if not val: val = self.default_request_args.get(prop) if not val: - val = _context.specs.behaviour.get(prop) + val = _context.work_condition.behaviour.get(prop) if not val: val = md.get(prop) if val: @@ -539,6 +538,7 @@ def _do_response(self, info, sformat, **kwargs): try: resp = self.response_cls().deserialize(info, sformat, iss=_context.issuer, **kwargs) except Exception as err: + LOGGER.error("Error while deserializing: %s (1 pass)", err) resp = None if sformat == "json": # Could be JWS or JWE but wrongly tagged @@ -667,12 +667,11 @@ def get_uri(base_url, path, hex): return f"{base_url}/{path}/{hex}" def construct_uris(self, base_url, hex): - for usage in self.usage_rules.keys(): - if usage in self.usage: - uri = self.usage_to_uri_map.get(usage) + for activity, _support in self.support.items(): + if _support: + uri = self.support_to_uri.get(activity) if uri and uri not in self.metadata: - self.metadata[uri] = self.get_uri(base_url, self.callback_path[uri], - hex) + self.metadata[uri] = self.get_uri(base_url, self.callback_path[uri], hex) def get_metadata_claim(self, claim, default=None): try: @@ -684,12 +683,13 @@ def set_metadata_claim(self, key, value): self.metadata[key] = value -def init_services(service_definitions, client_get, metadata, usage): +def init_services(service_definitions, client_get, metadata, support): """ Initiates a set of services :param service_definitions: A dictionary containing service definitions :param client_get: A function that returns different things from the base entity. + :param support: What facets of the service that can be used :return: A dictionary, with service name as key and the service instance as value. """ @@ -703,20 +703,18 @@ def init_services(service_definitions, client_get, metadata, usage): kwargs.update({"client_get": client_get}) if isinstance(service_configuration["class"], str): - _value_cls = service_configuration["class"] _cls = importer(service_configuration["class"]) _srv = _cls(**kwargs) else: - _value_cls = qualified_name(service_configuration["class"]) _srv = service_configuration["class"](**kwargs) for key, val in metadata.items(): if key in _srv.metadata_claims and key not in _srv.metadata: _srv.metadata[key] = val - for key, val in usage.items(): - if key in _srv.usage_rules and key not in _srv.usage: - _srv.usage[key] = val + for key, val in support.items(): + if key in _srv.can_support and key not in _srv.support: + _srv.support[key] = val service[_srv.service_name] = _srv diff --git a/src/idpyoidc/client/service_context.py b/src/idpyoidc/client/service_context.py index 82aee548..fdeaa12e 100644 --- a/src/idpyoidc/client/service_context.py +++ b/src/idpyoidc/client/service_context.py @@ -6,22 +6,22 @@ from typing import Optional from typing import Union -from cryptojwt.jwk.rsa import RSAKey from cryptojwt.jwk.rsa import import_private_rsa_key_from_file +from cryptojwt.jwk.rsa import RSAKey from cryptojwt.key_bundle import KeyBundle from cryptojwt.key_jar import KeyJar from cryptojwt.utils import as_bytes from idpyoidc.client.configure import Configuration -from idpyoidc.client.specification.oauth2 import Specification as OAUTH2_Specs -from idpyoidc.client.specification.oidc import Specification as OIDC_Specs +from idpyoidc.client.work_condition.oauth2 import WorkCondition as OAUTH2_Specs +from idpyoidc.client.work_condition.oidc import WorkCondition as OIDC_Specs from idpyoidc.context import OidcContext from idpyoidc.util import rndstr from .configure import get_configuration -from .specification import Specification -from .specification import specification_dump -from .specification import specification_load from .state_interface import StateInterface +from .work_condition import work_condition_dump +from .work_condition import work_condition_load +from .work_condition import WorkCondition CLI_REG_MAP = { "userinfo": { @@ -96,7 +96,7 @@ class ServiceContext(OidcContext): "httpc_params": None, "iss_hash": None, "issuer": None, - "specs": Specification, + "work_condition": WorkCondition, "provider_info": None, "requests_dir": None, "registration_response": None, @@ -107,10 +107,9 @@ class ServiceContext(OidcContext): ) special_load_dump = { - "specs": {"load": specification_load, "dump": specification_dump}, + "specs": {"load": work_condition_load, "dump": work_condition_dump}, } - def __init__(self, base_url: Optional[str] = "", keyjar: Optional[KeyJar] = None, @@ -122,9 +121,9 @@ def __init__(self, self.config = config if not client_type or client_type == "oidc": - self.specs = OIDC_Specs() + self.work_condition = OIDC_Specs() elif client_type == "oauth2": - self.specs = OAUTH2_Specs() + self.work_condition = OAUTH2_Specs() else: raise ValueError(f"Unknown client type: {client_type}") @@ -175,7 +174,7 @@ def __init__(self, for key, val in kwargs.items(): setattr(self, key, val) - self.specs.load_conf(config.conf) + self.work_condition.load_conf(config.conf) def __setitem__(self, key, value): setattr(self, key, value) @@ -233,7 +232,7 @@ def get_sign_alg(self, typ): """ try: - return self.specs.behaviour[CLI_REG_MAP[typ]["sign"]] + return self.work_condition.behaviour[CLI_REG_MAP[typ]["sign"]] except KeyError: try: return self.provider_info[PROVIDER_INFO_MAP[typ]["sign"]] @@ -252,7 +251,7 @@ def get_enc_alg_enc(self, typ): res = {} for attr in ["enc", "alg"]: try: - _alg = self.specs.behaviour[CLI_REG_MAP[typ][attr]] + _alg = self.work_condition.behaviour[CLI_REG_MAP[typ][attr]] except KeyError: try: _alg = self.provider_info[PROVIDER_INFO_MAP[typ][attr]] @@ -270,4 +269,4 @@ def set(self, key, value): setattr(self, key, value) def get_client_id(self): - return self.specs.get_metadata_claim("client_id") + return self.work_condition.get_metadata_claim("client_id") diff --git a/src/idpyoidc/client/util.py b/src/idpyoidc/client/util.py index 4c7425e2..4fcfde75 100755 --- a/src/idpyoidc/client/util.py +++ b/src/idpyoidc/client/util.py @@ -11,17 +11,9 @@ from idpyoidc.constant import JOSE_ENCODED from idpyoidc.constant import JSON_ENCODED from idpyoidc.constant import URL_ENCODED +from idpyoidc.defaults import BASECHR from idpyoidc.exception import UnSupported from idpyoidc.util import importer - -# Since SystemRandom is not available on all systems -try: - import SystemRandom as rnd -except ImportError: - import random as rnd - -from idpyoidc.defaults import BASECHR - from .exception import TimeFormatError from .exception import WrongContentType @@ -268,9 +260,9 @@ def get_deserialization_method(reqresp): if not _ctype: # let's try to detect the format try: - content = reqresp.json() + reqresp.json() return "json" - except: + except Exception: return "urlencoded" # reasonable default ?? if match_to_("application/json", _ctype) or match_to_("application/jrd+json", _ctype): diff --git a/src/idpyoidc/client/specification/__init__.py b/src/idpyoidc/client/work_condition/__init__.py similarity index 75% rename from src/idpyoidc/client/specification/__init__.py rename to src/idpyoidc/client/work_condition/__init__.py index 27e0eb63..75c26150 100644 --- a/src/idpyoidc/client/specification/__init__.py +++ b/src/idpyoidc/client/work_condition/__init__.py @@ -7,42 +7,42 @@ from idpyoidc.util import qualified_name -def specification_dump(info, exclude_attributes): +def work_condition_dump(info, exclude_attributes): return {qualified_name(info.__class__): info.dump(exclude_attributes=exclude_attributes)} -def specification_load(item: dict, **kwargs): +def work_condition_load(item: dict, **kwargs): _class_name = list(item.keys())[0] # there is only one _cls = importer(_class_name) _cls = _cls().load(item[_class_name]) return _cls -class Specification(ImpExp): +class WorkCondition(ImpExp): parameter = { "metadata": None, - "usage": None, + "support": None, "behaviour": None, "callback": None, "_local": None } - attributes = { + metadata_claims = { "redirect_uris": None, - "grant_types": ["authorization_code", "implicit", "refresh_token"], "response_types": ["code"], + "grant_types": ["authorization_code", "implicit", "refresh_token"], + "application_type": "web", + "contacts": None, "client_name": None, - "client_uri": None, "logo_uri": None, - "contacts": None, - "scope": None, - "tos_uri": None, + "client_uri": None, "policy_uri": None, + "tos_uri": None, "jwks_uri": None, "jwks": None, } - rules = { + can_support = { "jwks": None, "jwks_uri": None, "scope": ["openid"], @@ -59,23 +59,23 @@ class Specification(ImpExp): def __init__(self, metadata: Optional[dict] = None, - usage: Optional[dict] = None, + support: Optional[dict] = None, behaviour: Optional[dict] = None ): ImpExp.__init__(self) if isinstance(metadata, dict): - self.metadata = {k: v for k, v in metadata.items() if k in self.attributes} + self.metadata = {k: v for k, v in metadata.items() if k in self.metadata_claims} else: self.metadata = {} - if isinstance(usage, dict): - self.usage = {k: v for k, v in usage.items() if k in self.rules} + if isinstance(support, dict): + self.support = {k: v for k, v in support.items() if k in self.can_support} else: - self.usage = {} + self.support = {} if isinstance(behaviour, dict): - self.behaviour = {k: v for k, v in behaviour.items() if k in self.attributes} + self.behaviour = {k: v for k, v in behaviour.items() if k in self.metadata_claims} else: self.behaviour = {} @@ -91,17 +91,17 @@ def get_metadata_claim(self, key, default=None): else: return default - def get_usage(self, key, default=None): - if key in self.usage: - return self.usage[key] + def get_support(self, key, default=None): + if key in self.support: + return self.support[key] else: return default def set_metadata_claim(self, key, value): self.metadata[key] = value - def set_usage(self, key, value): - self.usage[key] = value + def set_support(self, key, value): + self.support[key] = value def _callback_uris(self, base_url, hex): _red = {} @@ -111,7 +111,7 @@ def _callback_uris(self, base_url, hex): elif type in ["id_token", "id_token token"]: _red['implicit'] = True - if "form_post" in self.usage: + if "form_post" in self.support: _red["form_post"] = True callback_uri = {} @@ -140,29 +140,29 @@ def locals(self, info): def load_conf(self, info): for attr, val in info.items(): - if attr == "usage": + if attr == "support": for k, v in val.items(): - if k in self.rules: - self.set_usage(k, v) + if k in self.can_support: + self.set_support(k, v) elif attr == "metadata": for k, v in val.items(): - if k in self.attributes: + if k in self.metadata_claims: self.set_metadata_claim(k, v) elif attr == "behaviour": self.behaviour = val - elif attr in self.attributes: + elif attr in self.metadata_claims: self.set_metadata_claim(attr, val) - elif attr in self.rules: - self.set_usage(attr, val) + elif attr in self.can_support: + self.set_support(attr, val) # defaults if nothing else is given - for key, val in self.attributes.items(): - if val and key not in self.metadata: - self.set_metadata_claim(key, val) + for key, default in self.metadata_claims.items(): + if default and key not in self.metadata: + self.set_metadata_claim(key, default) - for key, val in self.rules.items(): - if val and key not in self.usage: - self.set_usage(key, val) + for key, default in self.can_support.items(): + if default and key not in self.support: + self.set_support(key, default) self.locals(info) self.verify_rules() diff --git a/src/idpyoidc/client/specification/oauth2.py b/src/idpyoidc/client/work_condition/oauth2.py similarity index 74% rename from src/idpyoidc/client/specification/oauth2.py rename to src/idpyoidc/client/work_condition/oauth2.py index 99e502f3..9e95145b 100644 --- a/src/idpyoidc/client/specification/oauth2.py +++ b/src/idpyoidc/client/work_condition/oauth2.py @@ -1,10 +1,10 @@ from typing import Optional -from idpyoidc.client import specification as sp +from idpyoidc.client import work_condition -class Specification(sp.Specification): - attributes = { +class WorkCondition(work_condition.WorkCondition): + metadata_claims = { "redirect_uris": None, "grant_types": ["authorization_code", "implicit", "refresh_token"], "response_types": ["code"], @@ -39,7 +39,8 @@ class Specification(sp.Specification): def __init__(self, metadata: Optional[dict] = None, - usage: Optional[dict] = None, + support: Optional[dict] = None, behaviour: Optional[dict] = None ): - sp.Specification.__init__(self, metadata=metadata, usage=usage, behaviour=behaviour) + work_condition.WorkCondition.__init__(self, metadata=metadata, support=support, + behaviour=behaviour) diff --git a/src/idpyoidc/client/specification/oidc.py b/src/idpyoidc/client/work_condition/oidc.py similarity index 64% rename from src/idpyoidc/client/specification/oidc.py rename to src/idpyoidc/client/work_condition/oidc.py index dc84aaf9..197b5d5f 100644 --- a/src/idpyoidc/client/specification/oidc.py +++ b/src/idpyoidc/client/work_condition/oidc.py @@ -1,45 +1,45 @@ import os from typing import Optional -from idpyoidc.client import specification +from idpyoidc.client import work_condition -class Specification(specification.Specification): - parameter = specification.Specification.parameter.copy() +class WorkCondition(work_condition.WorkCondition): + parameter = work_condition.WorkCondition.parameter.copy() parameter.update({ "requests_dir": None }) - attributes = { + metadata_claims = { + "redirect_uris": None, + "response_types": ["code"], + "grant_types": ["authorization_code", "implicit", "refresh_token"], "application_type": "web", "contacts": None, "client_name": None, - "client_id": None, "logo_uri": None, "client_uri": None, "policy_uri": None, "tos_uri": None, - "jwks_uri": None, "jwks": None, + "jwks_uri": None, "sector_identifier_uri": None, - "grant_types": ["authorization_code", "implicit", "refresh_token"], - "default_max_age": None, + "subject_type": None, "id_token_signed_response_alg": "RS256", "id_token_encrypted_response_alg": None, "id_token_encrypted_response_enc": None, - "initiate_login_uri": None, - "subject_type": None, - "default_acr_values": None, - "require_auth_time": None, - "redirect_uris": None, "request_object_signing_alg": None, "request_object_encryption_alg": None, "request_object_encryption_enc": None, + "default_max_age": None, + "require_auth_time": None, + "initiate_login_uri": None, + "default_acr_values": None, "request_uris": None, - "response_types": ["code"] + "client_id": None, } - rules = { + can_support = { "form_post": None, "jwks": None, "jwks_uri": None, @@ -60,28 +60,18 @@ class Specification(specification.Specification): def __init__(self, metadata: Optional[dict] = None, - usage: Optional[dict] = None, + support: Optional[dict] = None, behaviour: Optional[dict] = None, ): - specification.Specification.__init__(self, metadata=metadata, usage=usage, - behaviour=behaviour) - - # def construct_uris(self, base_url, hex): - # if "request_uri" in self.usage: - # if self.usage["request_uri"]: - # _dir = self.get("requests_dir") - # if _dir: - # self.set_metadata("request_uris", Service.get_uri(base_url, _dir, hex)) - # else: - # self.set_metadata("request_uris", - # Service.get_uri(base_url, self.callback_path["requests"], hex)) + work_condition.WorkCondition.__init__(self, metadata=metadata, support=support, + behaviour=behaviour) def verify_rules(self): - if self.get_usage("request_parameter") and self.get_usage("request_uri"): + if self.get_support("request_parameter") and self.get_support("request_uri"): raise ValueError("You have to chose one of 'request_parameter' and 'request_uri'.") # default is jwks_uri - if not self.get_usage("jwks") and not self.get_usage('jwks_uri'): - self.set_usage('jwks_uri', True) + if not self.get_support("jwks") and not self.get_support('jwks_uri'): + self.set_support('jwks_uri', True) def locals(self, info): requests_dir = info.get("requests_dir") diff --git a/src/idpyoidc/impexp.py b/src/idpyoidc/impexp.py index 587d4f19..297282fd 100644 --- a/src/idpyoidc/impexp.py +++ b/src/idpyoidc/impexp.py @@ -86,10 +86,8 @@ def load_attr( ) -> Any: if load_args: _kwargs = {"load_args": load_args} - _load_args = load_args else: _kwargs = {} - _load_args = {} if init_args: _kwargs["init_args"] = init_args diff --git a/src/idpyoidc/logging.py b/src/idpyoidc/logging.py index 998ad7f5..5d7e302a 100755 --- a/src/idpyoidc/logging.py +++ b/src/idpyoidc/logging.py @@ -4,8 +4,6 @@ from logging.config import dictConfig from typing import Optional -import yaml - from idpyoidc.util import load_config_file LOGGING_CONF = "logging.yaml" diff --git a/src/idpyoidc/message/oidc/__init__.py b/src/idpyoidc/message/oidc/__init__.py index 67410091..639c7f93 100644 --- a/src/idpyoidc/message/oidc/__init__.py +++ b/src/idpyoidc/message/oidc/__init__.py @@ -288,7 +288,7 @@ def verify_id_token(msg, check_hash=False, claim="id_token", **kwargs): _signed = False _sign_alg = kwargs.get("sigalg") if _sign_alg == "none": - _allowed = True + pass else: # There might or might not be a specified signing alg if kwargs.get("allow_sign_alg_none", False) is False: logger.info("Signing algorithm None not allowed") diff --git a/src/idpyoidc/message/oidc/session.py b/src/idpyoidc/message/oidc/session.py index 73026c0b..9ac4cd5f 100644 --- a/src/idpyoidc/message/oidc/session.py +++ b/src/idpyoidc/message/oidc/session.py @@ -4,21 +4,20 @@ from idpyoidc.exception import MessageException from idpyoidc.exception import NotForMe +from idpyoidc.message import Message from idpyoidc.message import OPTIONAL_LIST_OF_SP_SEP_STRINGS from idpyoidc.message import REQUIRED_LIST_OF_STRINGS from idpyoidc.message import SINGLE_OPTIONAL_STRING from idpyoidc.message import SINGLE_REQUIRED_INT from idpyoidc.message import SINGLE_REQUIRED_JSON from idpyoidc.message import SINGLE_REQUIRED_STRING -from idpyoidc.message import Message from idpyoidc.time_util import utc_time_sans_frac - from ..oauth2 import ResponseMessage +from ..oidc import clear_verified_claims from ..oidc import ID_TOKEN_VERIFY_ARGS -from ..oidc import SINGLE_OPTIONAL_IDTOKEN from ..oidc import IdToken from ..oidc import MessageWithIdToken -from ..oidc import clear_verified_claims +from ..oidc import SINGLE_OPTIONAL_IDTOKEN from ..oidc import verified_claim_name from ..oidc import verify_id_token @@ -136,13 +135,8 @@ def verify(self, **kwargs): except KeyError: _skew = 0 - try: - _exp = self["iat"] - except KeyError: - pass - else: - if self["iat"] > (_now + _skew): - raise ValueError("Invalid issued_at time") + if 'iat' in self and self["iat"] > (_now + _skew): + raise ValueError("Invalid issued_at time") _allowed = kwargs.get("allowed_sign_alg") if _allowed and self.jws_header["alg"] != _allowed: diff --git a/src/idpyoidc/server/client_authn.py b/src/idpyoidc/server/client_authn.py index cd3e2c2d..65caac50 100755 --- a/src/idpyoidc/server/client_authn.py +++ b/src/idpyoidc/server/client_authn.py @@ -4,7 +4,6 @@ from typing import Dict from typing import Optional from typing import Union -from urllib.parse import unquote_plus from cryptojwt.exception import BadSignature from cryptojwt.exception import Invalid diff --git a/src/idpyoidc/server/client_configure.py b/src/idpyoidc/server/client_configure.py index a8157354..b2eb1acd 100644 --- a/src/idpyoidc/server/client_configure.py +++ b/src/idpyoidc/server/client_configure.py @@ -35,12 +35,6 @@ class ClientConfiguration(RegistrationResponse): def verify(self, **kwargs): RegistrationResponse.verify(self, **kwargs) - _server_get = kwargs.get("server_get") - if _server_get: - _endpoint_context = _server_get("endpoint_context") - else: - _endpoint_context = None - if "add_claims" in self: if not set(self["add_claims"].keys()).issubset({"always", "by_scope"}): _diff = set(self["add_claims"].keys()).difference({"always", "by_scope"}) diff --git a/src/idpyoidc/server/oauth2/authorization.py b/src/idpyoidc/server/oauth2/authorization.py index 10166742..3b2cf50c 100755 --- a/src/idpyoidc/server/oauth2/authorization.py +++ b/src/idpyoidc/server/oauth2/authorization.py @@ -335,7 +335,7 @@ class Authorization(Endpoint): response_placement = "url" endpoint_name = "authorization_endpoint" name = "authorization" - provider_info_attributes = { + metadata_claims = { "claims_parameter_supported": True, "request_parameter_supported": True, "request_uri_parameter_supported": True, @@ -615,12 +615,12 @@ def _unwrap_identity(self, identity): _uid = as_unicode(identity['uid']) try: _id = b64d(as_bytes(_uid)) - except Exception as err: + except Exception: return identity else: try: _id = b64d(as_bytes(identity)) - except Exception as err: + except Exception: return identity try: diff --git a/src/idpyoidc/server/oauth2/server_metadata.py b/src/idpyoidc/server/oauth2/server_metadata.py index ccc1922c..3e0230d8 100755 --- a/src/idpyoidc/server/oauth2/server_metadata.py +++ b/src/idpyoidc/server/oauth2/server_metadata.py @@ -1,8 +1,6 @@ import logging from idpyoidc.message import oauth2 - -from idpyoidc.message import oidc from idpyoidc.server.endpoint import Endpoint logger = logging.getLogger(__name__) diff --git a/src/idpyoidc/server/oauth2/token.py b/src/idpyoidc/server/oauth2/token.py index e7c4fe85..431ae7ad 100755 --- a/src/idpyoidc/server/oauth2/token.py +++ b/src/idpyoidc/server/oauth2/token.py @@ -7,16 +7,13 @@ from idpyoidc.message import Message from idpyoidc.message.oauth2 import AccessTokenResponse from idpyoidc.message.oauth2 import ResponseMessage -from idpyoidc.message.oauth2 import TokenExchangeRequest from idpyoidc.message.oidc import TokenErrorResponse -from idpyoidc.server.constant import DEFAULT_REQUESTED_TOKEN_TYPE from idpyoidc.server.endpoint import Endpoint from idpyoidc.server.exception import ProcessError from idpyoidc.server.oauth2.token_helper import AccessTokenHelper from idpyoidc.server.oauth2.token_helper import RefreshTokenHelper from idpyoidc.server.oauth2.token_helper import TokenExchangeHelper from idpyoidc.server.session import MintingNotAllowed -from idpyoidc.server.session.token import TOKEN_TYPES_MAPPING from idpyoidc.util import importer logger = logging.getLogger(__name__) @@ -82,7 +79,7 @@ def configure_grant_types(self, grant_types_helpers): raise ProcessError(f"Failed to initialize class {grant_class}: {e}") def _post_parse_request( - self, request: Union[Message, dict], client_id: Optional[str] = "", **kwargs + self, request: Union[Message, dict], client_id: Optional[str] = "", **kwargs ): grant_type = request["grant_type"] _helper = self.helper.get(grant_type) diff --git a/src/idpyoidc/server/oauth2/token_helper.py b/src/idpyoidc/server/oauth2/token_helper.py index 6a8e05aa..8b17850f 100755 --- a/src/idpyoidc/server/oauth2/token_helper.py +++ b/src/idpyoidc/server/oauth2/token_helper.py @@ -728,17 +728,17 @@ def process_request(self, request, **kwargs): def _validate_configuration(self, config): if "requested_token_types_supported" not in config: raise ImproperlyConfigured( - f"Missing 'requested_token_types_supported' from Token Exchange configuration" + "Missing 'requested_token_types_supported' from Token Exchange configuration" ) if "policy" not in config: - raise ImproperlyConfigured(f"Missing 'policy' from Token Exchange configuration") + raise ImproperlyConfigured("Missing 'policy' from Token Exchange configuration") if "" not in config["policy"]: raise ImproperlyConfigured( - f"Default Token Exchange policy configuration is not defined" + "Default Token Exchange policy configuration is not defined" ) if "callable" not in config["policy"][""]: raise ImproperlyConfigured( - f"Missing 'callable' from default Token Exchange policy configuration" + "Missing 'callable' from default Token Exchange policy configuration" ) _default_requested_token_type = config.get("default_requested_token_type", diff --git a/src/idpyoidc/server/oidc/authorization.py b/src/idpyoidc/server/oidc/authorization.py index 653628f8..ef77ace9 100755 --- a/src/idpyoidc/server/oidc/authorization.py +++ b/src/idpyoidc/server/oidc/authorization.py @@ -94,6 +94,9 @@ class Authorization(authorization.Authorization): "request_object_encryption_enc_values_supported": None, "grant_types_supported": ["authorization_code", "implicit"], "claim_types_supported": ["normal", "aggregated", "distributed"], + } + metadata_claims = { + } default_capabilities = { "client_authn_method": ["request_param", "public"], diff --git a/src/idpyoidc/server/token/__init__.py b/src/idpyoidc/server/token/__init__.py index c01dbd4c..8c92e562 100755 --- a/src/idpyoidc/server/token/__init__.py +++ b/src/idpyoidc/server/token/__init__.py @@ -3,7 +3,6 @@ from typing import Optional from cryptojwt import as_unicode -from cryptojwt.jwe.fernet import FernetEncrypter from idpyoidc.encrypter import init_encrypter from idpyoidc.server.util import lv_pack diff --git a/src/idpyoidc/server/util.py b/src/idpyoidc/server/util.py index 2e0e5023..241bcebf 100755 --- a/src/idpyoidc/server/util.py +++ b/src/idpyoidc/server/util.py @@ -2,7 +2,6 @@ import logging from idpyoidc.util import importer - from .exception import OidcEndpointError logger = logging.getLogger(__name__) @@ -64,6 +63,7 @@ def build_endpoints(conf, server_get, issuer): class JSONDictDB(object): + def __init__(self, filename): with open(filename, "r") as f: self._db = json.load(f) @@ -100,7 +100,7 @@ def lv_unpack(txt): while txt: l, v = txt.split(":", 1) res.append(v[: int(l)]) - txt = v[int(l) :] + txt = v[int(l):] return res @@ -176,7 +176,6 @@ def execute(spec, **kwargs): else: return kwargs - # def sector_id_from_redirect_uris(uris): # if not uris: # return "" diff --git a/src/idpyoidc/time_util.py b/src/idpyoidc/time_util.py index 294ca62b..3ff0838b 100644 --- a/src/idpyoidc/time_util.py +++ b/src/idpyoidc/time_util.py @@ -29,7 +29,6 @@ TIME_FORMAT = "%Y-%m-%dT%H:%M:%SZ" TIME_FORMAT_WITH_FRAGMENT = re.compile("^(\d{4,4}-\d{2,2}-\d{2,2}T\d{2,2}:\d{2,2}:\d{2,2})\.\d*Z$") - logger = logging.getLogger(__name__) @@ -105,11 +104,11 @@ def parse_duration(duration): try: mod = duration[index:].index(code) try: - dic[typ] = int(duration[index : index + mod]) + dic[typ] = int(duration[index: index + mod]) except ValueError: if code == "S": try: - dic[typ] = float(duration[index : index + mod]) + dic[typ] = float(duration[index: index + mod]) except ValueError: raise TimeUtilError("Not a float") else: @@ -186,7 +185,7 @@ def time_in_a_while(days=0, seconds=0, microseconds=0, milliseconds=0, minutes=0 def time_a_while_ago( - days=0, seconds=0, microseconds=0, milliseconds=0, minutes=0, hours=0, weeks=0 + days=0, seconds=0, microseconds=0, milliseconds=0, minutes=0, hours=0, weeks=0 ): """ Will return a time specification for a time sometime in the past. @@ -206,14 +205,14 @@ def time_a_while_ago( def in_a_while( - days=0, - seconds=0, - microseconds=0, - milliseconds=0, - minutes=0, - hours=0, - weeks=0, - time_format=TIME_FORMAT, + days=0, + seconds=0, + microseconds=0, + milliseconds=0, + minutes=0, + hours=0, + weeks=0, + time_format=TIME_FORMAT, ): """ :param days: @@ -235,14 +234,14 @@ def in_a_while( def a_while_ago( - days=0, - seconds=0, - microseconds=0, - milliseconds=0, - minutes=0, - hours=0, - weeks=0, - time_format=TIME_FORMAT, + days=0, + seconds=0, + microseconds=0, + milliseconds=0, + minutes=0, + hours=0, + weeks=0, + time_format=TIME_FORMAT, ): """ @@ -362,7 +361,7 @@ def time_sans_frac(): def epoch_in_a_while( - days=0, seconds=0, microseconds=0, milliseconds=0, minutes=0, hours=0, weeks=0 + days=0, seconds=0, microseconds=0, milliseconds=0, minutes=0, hours=0, weeks=0 ): """ Return the number of seconds since epoch a while from now. diff --git a/src/idpyoidc/util.py b/src/idpyoidc/util.py index 4dcf2d67..ab515d00 100644 --- a/src/idpyoidc/util.py +++ b/src/idpyoidc/util.py @@ -84,6 +84,7 @@ def split_uri(uri: str) -> [str, Union[dict, None]]: class QPKey: + def serialize(self, str): return quote_plus(str) @@ -92,6 +93,7 @@ def deserialize(self, str): class JSON: + def serialize(self, str): return json.dumps(str) @@ -100,6 +102,7 @@ def deserialize(self, str): class PassThru: + def serialize(self, str): return str @@ -139,6 +142,7 @@ def add_path(url, path): else: return "{}/{}".format(url, path) + def qualified_name(cls): """Does both classes and class instances diff --git a/tests/request123456.jwt b/tests/request123456.jwt index ce7ace37..7c63cbe7 100644 --- a/tests/request123456.jwt +++ b/tests/request123456.jwt @@ -1 +1 @@ -eyJhbGciOiJSUzI1NiIsImtpZCI6IlNIRXlZV2N3TlZrMExUZFJPVFp6WjJGVVduZElWWGRhY2sweFdVTTVTRXB3Y1MwM2RWVXhXVTR6UlEifQ.eyJyZXNwb25zZV90eXBlIjogImNvZGUiLCAic3RhdGUiOiAic3RhdGUiLCAicmVkaXJlY3RfdXJpIjogImh0dHBzOi8vZXhhbXBsZS5jb20vY2xpL2F1dGh6X2NiIiwgInNjb3BlIjogIm9wZW5pZCIsICJub25jZSI6ICJwaXlzeUhFaXBnRzhmRW8tS1E3MG1sbmRvRkw1QllndU9vdXBhQ0VCdW8wIiwgImNsaWVudF9pZCI6ICJjbGllbnRfaWQiLCAiaXNzIjogImNsaWVudF9pZCIsICJpYXQiOiAxNjUyMTc1MjcwLCAiYXVkIjogWyJodHRwczovL2V4YW1wbGUuY29tIl19.ALM1q_Zg03PXMKdGMJ_Lnd6_H15m9eqKBMlNaMIBBo15JPvBmwXq_ZGRVdiPg2-MarDTacx59G40qHL8L6C7oWC8MTk8UlkQ2Nbk5zZWtywu8jUiogbllgWJXalt3vczio5pKiZZB36qMo2CRot0BAGjgyewnO0e4wXY_rKoZOMVo1clejAboZJ3tfpIWmY1xnr4wsYR9hDIz9pcHAYmJ4n-j4KDhVXFCWbJ4X8gSE0ezyPRn8snr-_U0by2PRENSEN7_tGRhq3fXtDomzo7dG3pEUeov6lDdtt70EL5c9a_vo0XifraPVmQpiDqD2az6iBTm7wNxhH5KpLJoND2qw \ No newline at end of file +eyJhbGciOiJSUzI1NiIsImtpZCI6IlNIRXlZV2N3TlZrMExUZFJPVFp6WjJGVVduZElWWGRhY2sweFdVTTVTRXB3Y1MwM2RWVXhXVTR6UlEifQ.eyJyZXNwb25zZV90eXBlIjogImNvZGUiLCAic3RhdGUiOiAic3RhdGUiLCAicmVkaXJlY3RfdXJpIjogImh0dHBzOi8vZXhhbXBsZS5jb20vY2xpL2F1dGh6X2NiIiwgInNjb3BlIjogIm9wZW5pZCIsICJub25jZSI6ICJHb2UzRnZxMkpnM3hsMHVYZ3dWdUFVWmFyUk8wVTFmSk05c0pTMl9sTHI4IiwgImNsaWVudF9pZCI6ICJjbGllbnRfaWQiLCAiaXNzIjogImNsaWVudF9pZCIsICJpYXQiOiAxNjY4MTkxNTEzLCAiYXVkIjogWyJodHRwczovL2V4YW1wbGUuY29tIl19.Wl-8ajDc1j5wBcfuZmUX8Zmp7_-tdWExFmD9LGzXJY0VhiR4XrzuVsu_1Im3ytLlsVpYcYhBGmgaO46B-eXWXQR12hqVXeImcQSBaNy6w8gqydN9IFuN0jAqfQbMUehZrgiyZWR1T59C4hFjePG_3xmg6Cu0bALmwfaisAmD32inumFPSwA8j9yUK9rUBmO_YXvkX3i_PAgyxfTSYBzChsLMkgGcTK7Q-ulJnJ9LIO3ylFq-wh9TI3lTsbZkKunPEN3BWX_LnqAtU8dWa2y9jLRxQ88TOSF8tyQa4VdimOy2_Guy3rVOB-0ZbIcbX6tlNO-NA8nMmiez7fdu44sTWw \ No newline at end of file diff --git a/tests/test_client_01_service_context.py b/tests/test_client_01_service_context.py index 251070c3..10c518c2 100644 --- a/tests/test_client_01_service_context.py +++ b/tests/test_client_01_service_context.py @@ -35,8 +35,8 @@ def test_filename_from_webname(self): def test_create_callback_uris(self): base_url = "https://example.com/cli" hex = "0123456789" - self.service_context.specs.construct_redirect_uris(base_url, hex, []) - _uris = self.service_context.specs.get_metadata_claim("redirect_uris") + self.service_context.work_condition.construct_redirect_uris(base_url, hex, []) + _uris = self.service_context.work_condition.get_metadata_claim("redirect_uris") assert len(_uris) == 1 assert _uris == [f"https://example.com/cli/authz_cb/{hex}"] @@ -44,11 +44,11 @@ def test_get_sign_alg(self): _alg = self.service_context.get_sign_alg("id_token") assert _alg is None - self.service_context.specs.behaviour["id_token_signed_response_alg"] = "RS384" + self.service_context.work_condition.behaviour["id_token_signed_response_alg"] = "RS384" _alg = self.service_context.get_sign_alg("id_token") assert _alg == "RS384" - self.service_context.specs.behaviour = {} + self.service_context.work_condition.behaviour = {} self.service_context.provider_info["id_token_signing_alg_values_supported"] = [ "RS256", "ES256", @@ -60,13 +60,13 @@ def test_get_enc_alg_enc(self): _alg_enc = self.service_context.get_enc_alg_enc("userinfo") assert _alg_enc == {"alg": None, "enc": None} - self.service_context.specs.behaviour["userinfo_encrypted_response_alg"] = "RSA1_5" - self.service_context.specs.behaviour["userinfo_encrypted_response_enc"] = "A128CBC+HS256" + self.service_context.work_condition.behaviour["userinfo_encrypted_response_alg"] = "RSA1_5" + self.service_context.work_condition.behaviour["userinfo_encrypted_response_enc"] = "A128CBC+HS256" _alg_enc = self.service_context.get_enc_alg_enc("userinfo") assert _alg_enc == {"alg": "RSA1_5", "enc": "A128CBC+HS256"} - self.service_context.specs.behaviour = {} + self.service_context.work_condition.behaviour = {} self.service_context.provider_info["userinfo_encryption_alg_values_supported"] = [ "RSA1_5", "A128KW", diff --git a/tests/test_client_02b_entity_metadata.py b/tests/test_client_02b_entity_metadata.py index 510e2fb8..6c9027b3 100644 --- a/tests/test_client_02b_entity_metadata.py +++ b/tests/test_client_02b_entity_metadata.py @@ -33,7 +33,9 @@ }, "authorization": { "class": "idpyoidc.client.oidc.authorization.Authorization", - "kwargs": {} + "kwargs": { + "support": {"request_uris": True} + } }, "accesstoken": { "class": "idpyoidc.client.oidc.access_token.AccessToken", @@ -60,7 +62,7 @@ "backchannel_logout_uri": "https://rp.example.com/back", "backchannel_logout_session_required": True }, - "usage": { + "support": { "backchannel_logout": True } } @@ -101,7 +103,7 @@ def test_create_client(): assert client.get_metadata_claim("userinfo_signed_response_alg") == "ES256" assert client.metadata_claim_contains_value("userinfo_signed_response_alg", "ES256") # How to act - assert client.get_usage_value("request_uri") is True + assert client.get_support("request_uri") is True _conf_args = client.config_args() assert _conf_args diff --git a/tests/test_client_04_service.py b/tests/test_client_04_service.py index 4e144823..3a6377ec 100644 --- a/tests/test_client_04_service.py +++ b/tests/test_client_04_service.py @@ -50,7 +50,7 @@ def test_gather_request_args(self): assert args == {"response_type": "code", "state": "state", 'redirect_uri': 'https://example.com/cli/authz_cb', 'scope': ['openid']} - self.entity.set_metadata_value("client_id", "client") + self.entity.set_metadata_claim("client_id", "client") args = self.service.gather_request_args(state="state") assert args == {"client_id": "client", "response_type": "code", "state": "state", 'redirect_uri': 'https://example.com/cli/authz_cb', 'scope': ['openid']} @@ -65,7 +65,7 @@ def test_gather_request_args(self): 'redirect_uri': 'https://example.com/cli/authz_cb', } - self.entity.set_metadata_value("redirect_uris", ["https://rp.example.com"]) + self.entity.set_metadata_claim("redirect_uris", ["https://rp.example.com"]) args = self.service.gather_request_args(state="state") assert args == { "client_id": "client", diff --git a/tests/test_client_06_client_authn.py b/tests/test_client_06_client_authn.py index cf55b0cf..5a472863 100644 --- a/tests/test_client_06_client_authn.py +++ b/tests/test_client_06_client_authn.py @@ -22,7 +22,7 @@ from idpyoidc.client.client_auth import bearer_auth from idpyoidc.client.client_auth import valid_service_context from idpyoidc.client.entity import Entity -from idpyoidc.client.specification import Specification +from idpyoidc.client.work_condition import WorkCondition from idpyoidc.defaults import JWT_BEARER from idpyoidc.message import Message @@ -439,7 +439,8 @@ def test_get_audience_and_algorithm_default_alg(self, entity): # By client preferences request = AccessTokenRequest() - _service_context.specs.set_metadata("token_endpoint_auth_signing_alg", "RS512") + _service_context.work_condition.set_metadata_claim("token_endpoint_auth_signing_alg", + "RS512") csj.construct(request, service=token_service, authn_endpoint="token_endpoint") _jws = factory(request["client_assertion"]) @@ -448,7 +449,7 @@ def test_get_audience_and_algorithm_default_alg(self, entity): # Use provider information is everything else fails request = AccessTokenRequest() - _service_context.specs = Specification() + _service_context.work_condition = WorkCondition() _service_context.provider_info["token_endpoint_auth_signing_alg_values_supported"] = [ "ES256", "RS256", diff --git a/tests/test_client_12_client_auth.py b/tests/test_client_12_client_auth.py index 5d181695..1d2f9654 100755 --- a/tests/test_client_12_client_auth.py +++ b/tests/test_client_12_client_auth.py @@ -429,7 +429,7 @@ def test_get_audience_and_algorithm_default_alg(self, entity): # By client preferences request = AccessTokenRequest() - entity.set_metadata_value("token_endpoint_auth_signing_alg", "RS512") + entity.set_metadata_claim("token_endpoint_auth_signing_alg", "RS512") csj.construct(request, service=token_service, authn_endpoint="token_endpoint") _jws = factory(request["client_assertion"]) diff --git a/tests/test_client_14_service_context_impexp.py b/tests/test_client_14_service_context_impexp.py index f44bb2a3..8af96908 100644 --- a/tests/test_client_14_service_context_impexp.py +++ b/tests/test_client_14_service_context_impexp.py @@ -19,7 +19,7 @@ def test_client_info_init(): "base_url": BASE_URL, "requests_dir": "requests", } - ci = ServiceContext(config=config) + ci = ServiceContext(config=config,client_type='oidc') srvcnx = ServiceContext(base_url=BASE_URL).load(ci.dump()) @@ -27,7 +27,7 @@ def test_client_info_init(): if attr == "client_id": assert srvcnx.get_client_id() == config[attr] elif attr == "requests_dir": - assert srvcnx.specs.get("requests_dir") == config[attr] + assert srvcnx.work_condition.get("requests_dir") == config[attr] else: try: val = getattr(srvcnx, attr) @@ -48,7 +48,7 @@ def test_set_and_get_client_secret(): def test_set_and_get_client_id(): service_context = ServiceContext(base_url=BASE_URL) - service_context.specs.set_metadata("client_id", "myself") + service_context.work_condition.set_metadata_claim("client_id", "myself") srvcnx2 = ServiceContext(base_url=BASE_URL).load(service_context.dump()) assert srvcnx2.get_client_id() == "myself" @@ -108,7 +108,7 @@ def create_client_info_instance(self): self.service_context = ServiceContext(config=config) def test_registration_userinfo_sign_enc_algs(self): - self.service_context.specs.behaviour = { + self.service_context.work_condition.behaviour = { "application_type": "web", "redirect_uris": [ "https://client.example.org/callback", @@ -128,7 +128,7 @@ def test_registration_userinfo_sign_enc_algs(self): assert srvcntx.get_enc_alg_enc("userinfo") == {"alg": "RSA1_5", "enc": "A128CBC-HS256"} def test_registration_request_object_sign_enc_algs(self): - self.service_context.specs.behaviour = { + self.service_context.work_condition.behaviour = { "application_type": "web", "redirect_uris": [ "https://client.example.org/callback", @@ -150,7 +150,7 @@ def test_registration_request_object_sign_enc_algs(self): assert srvcntx.get_sign_alg("request_object") == "RS384" def test_registration_id_token_sign_enc_algs(self): - self.service_context.specs.behaviour = { + self.service_context.work_condition.behaviour = { "application_type": "web", "redirect_uris": [ "https://client.example.org/callback", diff --git a/tests/test_client_21_oidc_service.py b/tests/test_client_21_oidc_service.py index 1b80ca9e..e63b01b0 100644 --- a/tests/test_client_21_oidc_service.py +++ b/tests/test_client_21_oidc_service.py @@ -1,14 +1,14 @@ import json import os +import pytest +import responses from cryptojwt.exception import UnsupportedAlgorithm from cryptojwt.jws import jws from cryptojwt.jws.utils import left_hash from cryptojwt.jwt import JWT from cryptojwt.key_jar import build_keyjar from cryptojwt.key_jar import init_key_jar -import pytest -import responses from idpyoidc.client.defaults import DEFAULT_OIDC_SERVICES from idpyoidc.client.entity import Entity @@ -30,6 +30,7 @@ class Response(object): + def __init__(self, status_code, text, headers=None): self.status_code = status_code self.text = text @@ -70,6 +71,7 @@ def make_keyjar(): class TestAuthorization(object): + @pytest.fixture(autouse=True) def create_request(self): client_config = { @@ -284,7 +286,7 @@ def test_allow_unsigned_idtoken(self, allow_sign_alg_none): idt = JWT(ISS_KEY, iss=ISS, lifetime=3600, sign_alg="none") payload = {"sub": "123456789", "aud": ["client_id"], "nonce": req_args["nonce"]} _idt = idt.pack(payload) - self.service.client_get("service_context").specs.behaviour["verify_args"] = { + self.service.client_get("service_context").work_condition.behaviour["verify_args"] = { "allow_sign_alg_none": allow_sign_alg_none } resp = AuthorizationResponse(state="state", code="code", id_token=_idt) @@ -296,6 +298,7 @@ def test_allow_unsigned_idtoken(self, allow_sign_alg_none): class TestAuthorizationCallback(object): + @pytest.fixture(autouse=True) def create_request(self): client_config = { @@ -370,6 +373,7 @@ def test_construct_form_post(self): class TestAccessTokenRequest(object): + @pytest.fixture(autouse=True) def create_request(self): client_config = { @@ -504,6 +508,7 @@ def test_id_token_nonce_match(self): class TestProviderInfo(object): + @pytest.fixture(autouse=True) def create_service(self): self._iss = ISS @@ -513,7 +518,7 @@ def create_service(self): "redirect_uris": ["https://example.com/cli/authz_cb"], "issuer": self._iss, "application_name": "rphandler", - "usage": { + "support": { "scope": ["openid", "profile", "email", "address", "phone"], }, "services": { @@ -568,7 +573,7 @@ def create_service(self): } } } - entity = Entity(keyjar=make_keyjar(), config=client_config) + entity = Entity(keyjar=make_keyjar(), config=client_config, client_type='oidc') entity.client_get("service_context").issuer = "https://example.com" self.service = entity.client_get("service", "provider_info") @@ -643,7 +648,6 @@ def test_post_parse(self): "address", "phone", "offline_access", - "openid", ], "userinfo_signing_alg_values_supported": [ "RS256", @@ -773,7 +777,7 @@ def test_post_parse(self): "registration_endpoint": "{}/registration".format(OP_BASEURL), "end_session_endpoint": "{}/end_session".format(OP_BASEURL), } - assert self.service.client_get("service_context").specs.behaviour == {} + assert self.service.client_get("service_context").work_condition.behaviour == {} resp = self.service.post_parse_response(provider_info_response) iss_jwks = ISS_KEY.export_jwks_as_json(issuer_id=ISS) @@ -782,7 +786,7 @@ def test_post_parse(self): self.service.update_service_context(resp) - assert self.service.client_get("service_context").specs.behaviour == { + assert self.service.client_get("service_context").work_condition.behaviour == { 'application_type': 'web', 'backchannel_logout_session_required': True, 'backchannel_logout_uri': 'https://rp.example.com/back', @@ -817,7 +821,7 @@ def test_post_parse_2(self): "registration_endpoint": "{}/registration".format(OP_BASEURL), "end_session_endpoint": "{}/end_session".format(OP_BASEURL), } - assert self.service.client_get("service_context").specs.behaviour == {} + assert self.service.client_get("service_context").work_condition.behaviour == {} resp = self.service.post_parse_response(provider_info_response) iss_jwks = ISS_KEY.export_jwks_as_json(issuer_id=ISS) @@ -826,7 +830,7 @@ def test_post_parse_2(self): self.service.update_service_context(resp) - assert self.service.client_get("service_context").specs.behaviour == { + assert self.service.client_get("service_context").work_condition.behaviour == { 'application_type': 'web', 'backchannel_logout_session_required': True, 'backchannel_logout_uri': 'https://rp.example.com/back', @@ -863,6 +867,7 @@ def create_jws(val): class TestRegistration(object): + @pytest.fixture(autouse=True) def create_request(self): self._iss = ISS @@ -884,7 +889,7 @@ def test_construct(self): assert len(_req) == 7 def test_config_with_post_logout(self): - self.service.client_get("service_context").specs.set_metadata( + self.service.client_get("service_context").work_condition.set_metadata( "post_logout_redirect_uri", "https://example.com/post_logout") _req = self.service.construct() @@ -945,6 +950,7 @@ def test_config_logout_uri(): class TestUserInfo(object): + @pytest.fixture(autouse=True) def create_request(self): self._iss = ISS @@ -960,7 +966,7 @@ def create_request(self): entity.client_get("service_context").issuer = "https://example.com" self.service = entity.client_get("service", "userinfo") - entity.client_get("service_context").specs.behaviour = { + entity.client_get("service_context").work_condition.behaviour = { "userinfo_signed_response_alg": "RS256", "userinfo_encrypted_response_alg": "RSA-OAEP", "userinfo_encrypted_response_enc": "A256GCM", @@ -1088,6 +1094,7 @@ def test_unpack_encrypted_response(self): class TestCheckSession(object): + @pytest.fixture(autouse=True) def create_request(self): self._iss = ISS @@ -1117,6 +1124,7 @@ def test_construct(self): class TestCheckID(object): + @pytest.fixture(autouse=True) def create_request(self): self._iss = ISS @@ -1146,6 +1154,7 @@ def test_construct(self): class TestEndSession(object): + @pytest.fixture(autouse=True) def create_request(self): self._iss = ISS diff --git a/tests/test_client_24_oic_utils.py b/tests/test_client_24_oic_utils.py index 63128578..603e4f56 100644 --- a/tests/test_client_24_oic_utils.py +++ b/tests/test_client_24_oic_utils.py @@ -27,10 +27,10 @@ def test_request_object_encryption(): "client_secret": "abcdefghijklmnop", } service_context = ServiceContext(keyjar=KEYJAR, config=conf) - _behav = service_context.specs.behaviour + _behav = service_context.work_condition.behaviour _behav["request_object_encryption_alg"] = "RSA1_5" _behav["request_object_encryption_enc"] = "A128CBC-HS256" - service_context.specs.behaviour = _behav + service_context.work_condition.behaviour = _behav _jwe = request_object_encryption(msg.to_json(), service_context, target=RECEIVER) assert _jwe diff --git a/tests/test_client_28_rp_handler_oidc.py b/tests/test_client_28_rp_handler_oidc.py index 4feb32ba..dd0501e5 100644 --- a/tests/test_client_28_rp_handler_oidc.py +++ b/tests/test_client_28_rp_handler_oidc.py @@ -297,7 +297,7 @@ def test_do_client_registration(self): assert self.rph.hash2issuer["github"] == issuer assert ( - client.client_get("service_context").specs.callback.get("post_logout_redirect_uris") is None + client.client_get("service_context").work_condition.callback.get("post_logout_redirect_uris") is None ) def test_do_client_setup(self): @@ -328,7 +328,7 @@ def test_create_callbacks(self): _context = _srv.client_get("service_context") # add_callbacks(_context, []) - cb = _srv.client_get("service_context").specs.callback + cb = _srv.client_get("service_context").work_condition.callback assert set(cb.keys()) == {"code", "implicit"} _hash = _context.iss_hash @@ -383,7 +383,7 @@ def test_get_client_from_session_key(self): # redo self.rph.do_provider_info(state=res["state"]) # get new redirect_uris - cli2.client_get("service_context").specs.metadata["redirect_uris"] = [] + cli2.client_get("service_context").work_condition.metadata["redirect_uris"] = [] self.rph.do_client_registration(state=res["state"]) def test_finalize_auth(self): diff --git a/tests/test_client_30_rph_defaults.py b/tests/test_client_30_rph_defaults.py index a138871a..23f11161 100644 --- a/tests/test_client_30_rph_defaults.py +++ b/tests/test_client_30_rph_defaults.py @@ -91,7 +91,7 @@ def test_begin(self): self.rph.issuer2rp[issuer] = client - assert set(_context.specs.behaviour.keys()) == { + assert set(_context.work_condition.behaviour.keys()) == { "token_endpoint_auth_method", "response_types", "scope", diff --git a/tests/test_client_41_rp_handler_persistent.py b/tests/test_client_41_rp_handler_persistent.py index 4d2c882f..8e5f5e05 100644 --- a/tests/test_client_41_rp_handler_persistent.py +++ b/tests/test_client_41_rp_handler_persistent.py @@ -313,7 +313,7 @@ def test_get_client_from_session_key(self): # redo rph_1.do_provider_info(state=res["state"]) # get new redirect_uris - cli2.client_get("service_context").specs.metadata["redirect_uris"] = [] + cli2.client_get("service_context").work_condition.metadata["redirect_uris"] = [] rph_1.do_client_registration(state=res["state"]) def test_finalize_auth(self): diff --git a/tests/test_client_51_identity_assurance.py b/tests/test_client_51_identity_assurance.py index 2970c906..61cb3d5f 100644 --- a/tests/test_client_51_identity_assurance.py +++ b/tests/test_client_51_identity_assurance.py @@ -36,7 +36,7 @@ def create_request(self): entity.client_get("service_context").issuer = "https://server.otherop.com" self.service = entity.client_get("service", "userinfo") - entity.client_get("service_context").specs.behaviour = { + entity.client_get("service_context").work_condition.behaviour = { "userinfo_signed_response_alg": "RS256", "userinfo_encrypted_response_alg": "RSA-OAEP", "userinfo_encrypted_response_enc": "A256GCM", From 2afaee0c1f3f9050e804b82c0262c04a08c66d4b Mon Sep 17 00:00:00 2001 From: Roland Hedberg Date: Mon, 14 Nov 2022 08:40:12 +0100 Subject: [PATCH 13/76] Reworking the work condition system. This is about going from what the software can do and what the admin wants it to do to what is actually used. --- .../actor/client/oidc/registration.py | 4 +- src/idpyoidc/client/entity.py | 155 +--------------- src/idpyoidc/client/oidc/access_token.py | 6 + src/idpyoidc/client/oidc/authorization.py | 23 ++- .../client/oidc/provider_info_discovery.py | 104 ++++++----- src/idpyoidc/client/oidc/registration.py | 48 ++--- src/idpyoidc/client/service.py | 70 +++---- src/idpyoidc/client/service_context.py | 72 ++++++-- .../client/work_condition/__init__.py | 174 +++++++----------- src/idpyoidc/client/work_condition/oidc.py | 52 ++---- tests/request123456.jwt | 2 +- tests/test_client_02b_entity_metadata.py | 2 +- tests/test_client_21_oidc_service.py | 14 +- 13 files changed, 281 insertions(+), 445 deletions(-) diff --git a/src/idpyoidc/actor/client/oidc/registration.py b/src/idpyoidc/actor/client/oidc/registration.py index abfe9be9..31cf5a6e 100644 --- a/src/idpyoidc/actor/client/oidc/registration.py +++ b/src/idpyoidc/actor/client/oidc/registration.py @@ -148,14 +148,14 @@ class Registration(Service): def __init__(self, client_get, conf=None): Service.__init__(self, client_get, conf=conf) self.pre_construct = [ - self.add_client_behaviour_preference, + self.add_client_preference, # add_redirect_uris, # add_callback_uris, add_jwks_uri_or_jwks, ] self.post_construct = [self.oidc_post_construct] - def add_client_behaviour_preference(self, request_args=None, **kwargs): + def add_client_preference(self, request_args=None, **kwargs): _context = self.client_get("service_context") for prop in self.msg_type.c_param.keys(): if prop in request_args: diff --git a/src/idpyoidc/client/entity.py b/src/idpyoidc/client/entity.py index cf1415fc..784693cf 100644 --- a/src/idpyoidc/client/entity.py +++ b/src/idpyoidc/client/entity.py @@ -102,7 +102,7 @@ def __init__( self._service_context = ServiceContext( keyjar=keyjar, config=config, jwks_uri=jwks_uri, httpc_params=self.httpc_params, - client_type=client_type + client_type=client_type, client_get=self.client_get ) if config: @@ -163,164 +163,13 @@ def get_entity(self): return self def get_client_id(self): - return self._service_context.get_client_id() + return self._service_context.work_condition.get_usage_claim('client_id') def setup_client_authn_methods(self, config): self._service_context.client_authn_method = client_auth_setup( config.get("client_authn_methods") ) - def collect_metadata(self): - res = {} - for service in self._service.values(): - res.update(service.metadata) - res.update(self._service_context.work_condition.get_metadata()) - return res - - def collect_support(self): - res = {} - for service in self._service.values(): - res.update(service.support) - res.update(self._service_context.work_condition.support) - return res - - def get_metadata_claim(self, claim, default=None): - for service in self._service.values(): - if claim in service.metadata_claims: - return service.get_metadata_claim(claim, default) - - if claim in self._service_context.work_condition.metadata_claims: - return self._service_context.work_condition.get_metadata_claim(claim, default) - - raise KeyError(f"Unknown specs claim: {claim}") - - def get_metadata_claims(self): - claims = [] - for service in self._service.values(): - claims.extend(list(service.metadata_claims.keys())) - - claims.extend(list(self._service_context.work_condition.metadata_claims.keys())) - - return claims - - def get_claim_sources(self): - claims = {'': list(self._service_context.work_condition.metadata_claims.keys())} - for service in self._service.values(): - claims[service.endpoint_name] = list(service.metadata_claims.keys()) - - return claims - - def metadata_claim_contains_value(self, claim, value): - _val = self.get_metadata_claim(claim) - if isinstance(_val, list): - if value in _val: - return True - else: - if value == _val: - return True - - return False - - def will_use(self, facet): - for service in self._service.values(): - if facet in service.can_support.keys(): - if service.support.get(facet): - return True - - if facet in self._service_context.work_condition.can_support.keys(): - if self._service_context.work_condition.get_support(facet): - return True - return False - - def set_metadata_claim(self, claim, value): - """ - Only OK to overwrite a value if the value is the default value - """ - for service in self._service.values(): - if claim in service.metadata_claims: - _def_val = service.metadata_claims[claim] - if _def_val is None: - service.metadata[claim] = value - return True - else: - if service.metadata.get(claim, _def_val) == _def_val: - service.metadata[claim] = value - return True - - if claim in self._service_context.work_condition.metadata_claims: - _def_val = self._service_context.work_condition.metadata_claims[claim] - if _def_val is None: - self._service_context.work_condition.set_metadata_claim(claim, value) - return True - else: - if self._service_context.work_condition.get_metadata_claim(claim, _def_val): - self._service_context.work_condition.set_metadata_claim(claim, value) - return True - return True - - logger.info(f"Unknown set specs claim: {claim}") - return False - - def set_support(self, claim, value): - """ - Only OK to overwrite a value if the value is the default value - """ - for service in self._service.values(): - if claim in service.can_support: - _def_val = service.can_support[claim] - if _def_val is None: - service.support[claim] = value - return True - else: - if service.support[claim] == _def_val: - service.support[claim] = value - return True - - if claim in self._service_context.work_condition.can_support: - _def_val = self._service_context.work_condition.can_support[claim] - if _def_val is None: - self._service_context.work_condition.set_support(claim, value) - return True - else: - if self._service_context.work_condition.can_support[claim] == _def_val: - self._service_context.work_condition.set_support(claim, value) - return True - - logger.info(f"Unknown set support claim: {claim}") - return False - - def get_support(self, claim, default=None): - for service in self._service.values(): - if claim in service.can_support.keys(): - return service.support.get(claim, default) - - if claim in self._service_context.work_condition.can_support: - _val = self._service_context.work_condition.get_support(claim) - if _val: - return _val - else: - return default - - logger.info(f"Unknown support claim: {claim}") - - def construct_uris(self, - issuer: str, - hash_seed: bytes, - callback: Optional[dict]): - _hash = hashlib.sha256() - _hash.update(hash_seed) - _hash.update(as_bytes(issuer)) - _hex = _hash.hexdigest() - - self._service_context.iss_hash = _hex - - _base_url = self._service_context.get("base_url") - for service in self._service.values(): - service.construct_uris(_base_url, _hex) - - if not self._service_context.work_condition.get_metadata_claim("redirect_uris"): - self._service_context.work_condition.construct_redirect_uris(_base_url, _hex, callback) - def backward_compatibility(self, config): _uris = config.get("redirect_uris") if _uris: diff --git a/src/idpyoidc/client/oidc/access_token.py b/src/idpyoidc/client/oidc/access_token.py index 1292d1a4..9bb41b55 100644 --- a/src/idpyoidc/client/oidc/access_token.py +++ b/src/idpyoidc/client/oidc/access_token.py @@ -5,6 +5,7 @@ from idpyoidc.client.exception import ParameterError from idpyoidc.client.oauth2 import access_token from idpyoidc.client.oidc import IDT2REG +from idpyoidc.client.work_condition import get_signing_algs from idpyoidc.message import Message from idpyoidc.message import oidc from idpyoidc.message.oidc import verified_claim_name @@ -20,6 +21,11 @@ class AccessToken(access_token.AccessToken): response_cls = oidc.AccessTokenResponse error_msg = oidc.ResponseMessage + supports = { + "token_endpoint_auth_method": '', + "token_endpoint_auth_signing_alg": get_signing_algs + } + def __init__(self, client_get, conf: Optional[dict] = None): access_token.AccessToken.__init__(self, client_get, conf=conf) diff --git a/src/idpyoidc/client/oidc/authorization.py b/src/idpyoidc/client/oidc/authorization.py index 2792702a..f35c461a 100644 --- a/src/idpyoidc/client/oidc/authorization.py +++ b/src/idpyoidc/client/oidc/authorization.py @@ -2,6 +2,7 @@ from typing import Optional from typing import Union +from idpyoidc.client import work_condition from idpyoidc.client.oauth2 import authorization from idpyoidc.client.oauth2.utils import pre_construct_pick_redirect_uri from idpyoidc.client.oidc import IDT2REG @@ -27,16 +28,24 @@ class Authorization(authorization.Authorization): response_cls = oidc.AuthorizationResponse error_msg = oidc.ResponseMessage - can_support = { - "request_uris": None + supports = { + "request_object_signing_alg": work_condition.get_signing_algs, + "request_object_encryption_alg": work_condition.get_encryption_algs, + "request_object_encryption_enc": work_condition.get_encryption_encs, + "request_uris": None, + "request_parameter": None, + "redirect_uris": None, + "response_types": ["code"], + "form_post": None, } callback_path = { - "request_uris": "request", - } - - support_to_uri = { - "request_uris": "request_uris", + "request_uris": "req", + "redirect_uris": { # based on response_types + "code": "authz_cb", + "implicit": "authz_im_cb", + "form_post": "form" + } } def __init__(self, client_get, conf=None): diff --git a/src/idpyoidc/client/oidc/provider_info_discovery.py b/src/idpyoidc/client/oidc/provider_info_discovery.py index 08a16c7c..fe8af871 100644 --- a/src/idpyoidc/client/oidc/provider_info_discovery.py +++ b/src/idpyoidc/client/oidc/provider_info_discovery.py @@ -41,22 +41,23 @@ def add_redirect_uris(request_args, service=None, **kwargs): """ Add redirect_uris to the request arguments. - :param request_args: Incomming request arguments + :param request_args: Incoming request arguments :param service: A link to the service :param kwargs: Possible extra keyword arguments :return: A possibly augmented set of request arguments. """ - _context = service.client_get("service_context") + _work_condition = service.client_get("service_context").work_condition if "redirect_uris" not in request_args: # Callbacks is a dictionary with callback type 'code', 'implicit', # 'form_post' as keys. - _cbs = _context.callback - if _cbs: + _callback = _work_condition.get_preference('callback') + if _callback: # Filter out local additions. - _uris = [v for k, v in _cbs.items() if not k.startswith("__")] + _uris = [v for k, v in _callback.items() if not k.startswith("__")] request_args["redirect_uris"] = _uris else: - request_args["redirect_uris"] = _context.metadata["redirect_uris"] + request_args["redirect_uris"] = _work_condition.get_preference( + "redirect_uris", _work_condition.supports.get('redirect_uris')) return request_args, {} @@ -67,7 +68,7 @@ class ProviderInfoDiscovery(server_metadata.ServerMetadata): error_msg = ResponseMessage service_name = "provider_info" - metadata_claims = {} + _supports = {} def __init__(self, client_get, conf=None): server_metadata.ServerMetadata.__init__(self, client_get, conf=conf) @@ -82,8 +83,8 @@ def update_service_context(self, resp, **kwargs): def match_preferences(self, pcr=None, issuer=None): """ - Match the clients preferences against what the provider can do. - This is to prepare for later client registration and or what + Match the clients supports against what the provider can do. + This is to prepare for later client registration and/or what functionality the client actually will use. In the client configuration the client preferences are expressed. These are then compared with the Provider Configuration information. @@ -95,70 +96,74 @@ def match_preferences(self, pcr=None, issuer=None): """ _context = self.client_get("service_context") _entity = self.client_get("entity") + _work_condition = _context.work_condition + + _supports = _context.supports() + _prefers = _context.prefers() if not pcr: pcr = _context.provider_info regreq = oidc.RegistrationRequest - - _behaviour = _context.work_condition.behaviour + prefers = {} for _pref, _prov in PREFERENCE2PROVIDER.items(): - if _pref in ["scope"]: - vals = _entity.get_support(_pref) - else: - try: - vals = _entity.get_metadata_claim(_pref) - except KeyError: - continue + _supported_values = _supports.get(_pref) + _preferred_value = _prefers.get(_pref) - if not vals: - continue + if not _preferred_value: + if not _supported_values: + continue + else: + _supported_values = _preferred_value try: - _pvals = pcr[_prov] + _provider_vals = pcr[_prov] except KeyError: try: # If the provider have not specified use what the # standard says is mandatory if at all. - _pvals = PROVIDER_DEFAULT[_pref] + _provider_vals = PROVIDER_DEFAULT[_pref] except KeyError: logger.info("No info from provider on {} and no default".format(_pref)) - _pvals = vals - - if isinstance(vals, str): - if vals in _pvals: - _behaviour[_pref] = vals - else: + _provider_vals = _supported_values + + if not isinstance(_supported_values, list): + if isinstance(_provider_vals, list): + if _supported_values in _provider_vals: + prefers[_pref] = _supported_values + elif _provider_vals == _supported_values: + prefers[_pref] = _supported_values + else: # _supported_values is a list try: vtyp = regreq.c_param[_pref] except KeyError: # Allow non standard claims - if isinstance(vals, list) and isinstance(_pvals, list): - _behaviour[_pref] = [v for v in vals if v in _pvals] - elif isinstance(_pvals, list): - if vals in _pvals: - _behaviour[_pref] = vals - elif type(vals) == type(_pvals): - if vals == _pvals: - _behaviour[_pref] = vals + if isinstance(_supported_values, list) and isinstance(_provider_vals, list): + prefers[_pref] = [v for v in _supported_values if v in _provider_vals] + elif isinstance(_provider_vals, list): + if _supported_values in _provider_vals: + prefers[_pref] = _supported_values + elif type(_supported_values) == type(_provider_vals): + if _supported_values == _provider_vals: + prefers[_pref] = _supported_values else: if isinstance(vtyp[0], list): - _behaviour[_pref] = [] - for val in vals: - if val in _pvals: - _behaviour[_pref].append(val) + prefers[_pref] = [] + for val in _supported_values: + if val in _provider_vals: + prefers[_pref].append(_supported_values) else: - for val in vals: - if val in _pvals: - _behaviour[_pref] = val + for val in _supported_values: + if val in _provider_vals: + prefers[_pref] = val break - if _pref not in _behaviour: + if _pref not in prefers: raise ConfigurationError("OP couldn't match preference:%s" % _pref, pcr) - for key, val in _entity.collect_metadata().items(): - if key in _behaviour: + for key, val in _supports: + if key in prefers: continue if key in ["jwks", "jwks_uri"]: continue @@ -172,7 +177,8 @@ def match_preferences(self, pcr=None, issuer=None): except KeyError: pass if key not in PREFERENCE2PROVIDER: - _behaviour[key] = val + prefers[key] = val - _context.work_condition.behaviour = _behaviour - logger.debug("service_context behaviour: {}".format(_behaviour)) + # stores it all in one place + _context.work_condition.prefer = prefers + logger.debug("Entity prefers: {}".format(prefers)) diff --git a/src/idpyoidc/client/oidc/registration.py b/src/idpyoidc/client/oidc/registration.py index 0df84cc8..239bdd35 100644 --- a/src/idpyoidc/client/oidc/registration.py +++ b/src/idpyoidc/client/oidc/registration.py @@ -25,23 +25,23 @@ class Registration(Service): def __init__(self, client_get, conf=None): Service.__init__(self, client_get, conf=conf) - self.pre_construct = [ - self.add_client_behaviour_preference, - # add_redirect_uris, - ] + self.pre_construct = [self.add_client_preference] self.post_construct = [self.oidc_post_construct] - def add_client_behaviour_preference(self, request_args=None, **kwargs): - _context = self.client_get("service_context") - for prop in self.msg_type.c_param.keys(): + def add_client_preference(self, request_args=None, **kwargs): + _work_condition = self.client_get("service_context") + for prop, spec in self.msg_type.c_param.items(): if prop in request_args: continue - try: - request_args[prop] = _context.work_condition.behaviour[prop] - except KeyError: - _val = _context.work_condition.get_metadata_claim(prop) - if _val: + _val = _work_condition.get_preference(prop) + if _val: + if isinstance(_val, list): + if isinstance(spec[0], list): + request_args[prop] = _val + else: + request_args[prop] = _val[0] # get the first one + else: request_args[prop] = _val return request_args, {} @@ -64,25 +64,29 @@ def update_service_context(self, resp, key="", **kwargs): resp["token_endpoint_auth_method"] = "client_secret_basic" _context = self.client_get("service_context") + _work_condition = _context.work_condition + _keyjar = _context.keyjar + _context.registration_response = resp _client_id = resp.get("client_id") if _client_id: - _context.work_condition.set_metadata("client_id", _client_id) - if _client_id not in _context.keyjar: - _context.keyjar.import_jwks( - _context.keyjar.export_jwks(True, ""), issuer_id=_client_id - ) + _context.work_condition.set_usage_claim("client_id", _client_id) + if _client_id not in _keyjar: + _keyjar.import_jwks(_keyjar.export_jwks(True, ""), issuer_id=_client_id) _client_secret = resp.get("client_secret") if _client_secret: - _context.client_secret = _client_secret - _context.keyjar.add_symmetric("", _client_secret) - _context.keyjar.add_symmetric(_client_id, _client_secret) + _work_condition.set_usage_claim("client_secret", _client_secret) + # _context.client_secret = _client_secret + _keyjar.add_symmetric("", _client_secret) + _keyjar.add_symmetric(_client_id, _client_secret) try: - _context.client_secret_expires_at = resp["client_secret_expires_at"] + _work_condition.set_usage_claim("client_secret_expires_at", + resp["client_secret_expires_at"]) except KeyError: pass try: - _context.registration_access_token = resp["registration_access_token"] + _work_condition.set_usage_claim("registration_access_token", + resp["registration_access_token"]) except KeyError: pass diff --git a/src/idpyoidc/client/service.py b/src/idpyoidc/client/service.py index 6cbb0330..4a852ee0 100644 --- a/src/idpyoidc/client/service.py +++ b/src/idpyoidc/client/service.py @@ -12,8 +12,8 @@ from idpyoidc.impexp import ImpExp from idpyoidc.item import DLDict from idpyoidc.message import Message -from idpyoidc.message.oauth2 import is_error_message from idpyoidc.message.oauth2 import ResponseMessage +from idpyoidc.message.oauth2 import is_error_message from idpyoidc.util import importer from .configure import Configuration from .exception import ResponseError @@ -62,11 +62,8 @@ class Service(ImpExp): init_args = ["client_get"] - metadata_claims = {} - can_support = {} - support = {} - support_to_uri = {} - callback_path = {} + _supports = {} + _callback_path = {} def __init__( self, @@ -78,9 +75,8 @@ def __init__( self.client_get = client_get self.default_request_args = {} - self.metadata = {} - self.support = {} - self.callback_uri = {} + self.prefer = conf.get("prefer", {}) + self.use = {} if conf: self.conf = conf @@ -96,22 +92,6 @@ def __init__( if param in conf: setattr(self, param, conf[param]) - md_conf = conf.get("metadata", {}) - if md_conf: - for param, def_val in self.metadata_claims.items(): - if param in md_conf: - self.metadata[param] = md_conf[param] - elif def_val is not None: - self.metadata[param] = def_val - - support_conf = conf.get("support", {}) - if support_conf: - for facet, def_val in self.can_support.items(): - if facet in support_conf: - self.support[facet] = support_conf[facet] - elif def_val is not None: - self.support[facet] = def_val - _default_request_args = conf.get("request_args", {}) if _default_request_args: self.default_request_args = _default_request_args @@ -641,23 +621,29 @@ def parse_response( return resp - def get_conf_attr(self, attr, default=None): - """ - Get the value of an attribute in the configuration - - :param attr: The attribute - :param default: If the attribute doesn't appear in the configuration - return this value - :return: The value of attribute in the configuration or the default - value - """ - if attr in self.conf: - return self.conf[attr] - - return default - - def usage_to_uri(self, usage): - return self.usage_to_uri_map.get(usage) + def supports(self): + res = {} + for key, val in self._supports.items(): + if isinstance(val, Callable): + res[key] = val() + else: + res[key] = val + return res + + # def get_conf_attr(self, attr, default=None): + # """ + # Get the value of an attribute in the configuration + # + # :param attr: The attribute + # :param default: If the attribute doesn't appear in the configuration + # return this value + # :return: The value of attribute in the configuration or the default + # value + # """ + # if attr in self.conf: + # return self.conf[attr] + # + # return default def get_callback_path(self, callback): return self.callback_path.get(callback) diff --git a/src/idpyoidc/client/service_context.py b/src/idpyoidc/client/service_context.py index fdeaa12e..8b69f833 100644 --- a/src/idpyoidc/client/service_context.py +++ b/src/idpyoidc/client/service_context.py @@ -3,6 +3,8 @@ common between all the services that are used by OAuth2 client or OpenID Connect Relying Party. """ import copy +import hashlib +from typing import Callable from typing import Optional from typing import Union @@ -111,6 +113,7 @@ class ServiceContext(OidcContext): } def __init__(self, + client_get: Callable, base_url: Optional[str] = "", keyjar: Optional[KeyJar] = None, config: Optional[Union[dict, Configuration]] = None, @@ -119,6 +122,7 @@ def __init__(self, **kwargs): config = get_configuration(config) self.config = config + self.client_get = client_get if not client_type or client_type == "oidc": self.work_condition = OIDC_Specs() @@ -231,15 +235,11 @@ def get_sign_alg(self, typ): :return: """ - try: - return self.work_condition.behaviour[CLI_REG_MAP[typ]["sign"]] - except KeyError: - try: - return self.provider_info[PROVIDER_INFO_MAP[typ]["sign"]] - except (KeyError, TypeError): - pass + _alg = self.work_condition.get_usage_claim(CLI_REG_MAP[typ]["sign"]) + if not _alg: + _alg = self.provider_info.get(PROVIDER_INFO_MAP[typ]["sign"]) - return None + return _alg def get_enc_alg_enc(self, typ): """ @@ -250,14 +250,9 @@ def get_enc_alg_enc(self, typ): res = {} for attr in ["enc", "alg"]: - try: - _alg = self.work_condition.behaviour[CLI_REG_MAP[typ][attr]] - except KeyError: - try: - _alg = self.provider_info[PROVIDER_INFO_MAP[typ][attr]] - except KeyError: - _alg = None - + _alg = self.work_condition.get_usage_claim(CLI_REG_MAP[typ][attr]) + if not _alg: + _alg = self.provider_info.get(PROVIDER_INFO_MAP[typ][attr]) res[attr] = _alg return res @@ -269,4 +264,47 @@ def set(self, key, value): setattr(self, key, value) def get_client_id(self): - return self.work_condition.get_metadata_claim("client_id") + return self.work_condition.get_usage_claim("client_id") + + def collect_usage(self): + services = self. client_get('services') + res = {} + for service in services.values(): + res.update(service.use) + res.update(self.work_condition.use) + return res + + def supports(self): + services = self.client_get('services') + res = {} + for service in services.values(): + res.update(service.supports()) + res.update(self.work_condition.supports()) + return res + + def prefers(self): + services = self.client_get('services') + res = {} + for service in services.values(): + res.update(service.prefer) + res.update(self.work_condition.prefer) + return res + + def construct_uris(self, + issuer: str, + hash_seed: bytes, + callback: Optional[dict]): + _hash = hashlib.sha256() + _hash.update(hash_seed) + _hash.update(as_bytes(issuer)) + _hex = _hash.hexdigest() + + self.iss_hash = _hex + + _base_url = self.get("base_url") + services = self.client_get('services') + for service in services.values(): + service.construct_uris(_base_url, _hex) + + if not self.work_condition.get_usage_claim("redirect_uris"): + self.work_condition.construct_redirect_uris(_base_url, _hex, callback) diff --git a/src/idpyoidc/client/work_condition/__init__.py b/src/idpyoidc/client/work_condition/__init__.py index 75c26150..b59d7ca7 100644 --- a/src/idpyoidc/client/work_condition/__init__.py +++ b/src/idpyoidc/client/work_condition/__init__.py @@ -1,5 +1,8 @@ +from typing import Callable from typing import Optional +from cryptojwt.jwe import SUPPORTED +from cryptojwt.jws.jws import SIGNER_ALGS from cryptojwt.utils import importer from idpyoidc.client.service import Service @@ -20,104 +23,59 @@ def work_condition_load(item: dict, **kwargs): class WorkCondition(ImpExp): parameter = { - "metadata": None, - "support": None, - "behaviour": None, - "callback": None, + "prefer": None, + "use": None, + "callback_path": None, "_local": None } - metadata_claims = { - "redirect_uris": None, - "response_types": ["code"], - "grant_types": ["authorization_code", "implicit", "refresh_token"], - "application_type": "web", - "contacts": None, - "client_name": None, - "logo_uri": None, - "client_uri": None, - "policy_uri": None, - "tos_uri": None, - "jwks_uri": None, - "jwks": None, - } - - can_support = { - "jwks": None, - "jwks_uri": None, - "scope": ["openid"], - "verify_args": None, - } - - callback_path = { - "requests": "req", - "code": "authz_cb", - "implicit": "authz_im_cb", - } - - callback_uris = ["redirect_uris"] + _supports = {} def __init__(self, - metadata: Optional[dict] = None, - support: Optional[dict] = None, - behaviour: Optional[dict] = None - ): + prefer: Optional[dict] = None, + callback_path: Optional[dict] = None): ImpExp.__init__(self) - if isinstance(metadata, dict): - self.metadata = {k: v for k, v in metadata.items() if k in self.metadata_claims} - else: - self.metadata = {} - - if isinstance(support, dict): - self.support = {k: v for k, v in support.items() if k in self.can_support} - else: - self.support = {} - - if isinstance(behaviour, dict): - self.behaviour = {k: v for k, v in behaviour.items() if k in self.metadata_claims} + if isinstance(prefer, dict): + self.prefer = {k: v for k, v in prefer.items() if k in self.supports} else: - self.behaviour = {} + self.prefer = {} - self.callback = {} + self.callback_path = callback_path or {} + self.use = {} self._local = {} + self.callback = {} - def get_metadata(self): - return self.metadata + def get_usage(self): + return self.use - def get_metadata_claim(self, key, default=None): - if key in self.metadata: - return self.metadata[key] - else: - return default + def set_usage_claim(self, key, value): + self.use[key] = value - def get_support(self, key, default=None): - if key in self.support: - return self.support[key] - else: - return default + def get_usage_claim(self, key, default=None): + return self.use.get(key, default) - def set_metadata_claim(self, key, value): - self.metadata[key] = value + def get_preference(self, key, default=None): + return self.prefer.get(key, default) - def set_support(self, key, value): - self.support[key] = value + def set_preference(self, key, value): + self.prefer[key] = value def _callback_uris(self, base_url, hex): - _red = {} - for type in self.get_metadata_claim("response_types", ["code"]): + _uri = [] + for type in self.get_usage_claim("response_types", + self._supports['response_types']): if "code" in type: - _red['code'] = True + _uri.append('code') elif type in ["id_token", "id_token token"]: - _red['implicit'] = True + _uri.append('implicit') - if "form_post" in self.support: - _red["form_post"] = True + if "form_post" in self.supports: + _uri.append("form_post") callback_uri = {} - for key in _red.keys(): - _uri = Service.get_uri(base_url, self.callback_path[key], hex) - callback_uri[key] = _uri + for key in _uri: + callback_uri[key] = Service.get_uri(base_url, self.callback_path[key], hex) return callback_uri def construct_redirect_uris(self, @@ -128,7 +86,7 @@ def construct_redirect_uris(self, callbacks = self._callback_uris(base_url, hex) if callbacks: - self.set_metadata_claim("redirect_uris", [v for k, v in callbacks.items()]) + self.set_preference("redirect_uris", [v for k, v in callbacks.items()]) self.callback = callbacks @@ -140,42 +98,22 @@ def locals(self, info): def load_conf(self, info): for attr, val in info.items(): - if attr == "support": + if attr == "preference": for k, v in val.items(): - if k in self.can_support: - self.set_support(k, v) - elif attr == "metadata": - for k, v in val.items(): - if k in self.metadata_claims: - self.set_metadata_claim(k, v) - elif attr == "behaviour": - self.behaviour = val - elif attr in self.metadata_claims: - self.set_metadata_claim(attr, val) - elif attr in self.can_support: - self.set_support(attr, val) - - # defaults if nothing else is given - for key, default in self.metadata_claims.items(): - if default and key not in self.metadata: - self.set_metadata_claim(key, default) - - for key, default in self.can_support.items(): - if default and key not in self.support: - self.set_support(key, default) + if k in self._supports: + self.set_preference(k, v) + elif attr in self._supports: + self.set_preference(attr, val) + + # # defaults if nothing else is given + # for key, default in self._supports.items(): + # if default and key not in self.prefer: + # self.set_preference(key, default) self.locals(info) self.verify_rules() return self - def bm_get(self, key, default=None): - if key in self.behaviour: - return self.behaviour[key] - elif key in self.metadata: - return self.metadata[key] - - return default - def get(self, key, default=None): if key in self._local: return self._local[key] @@ -187,3 +125,25 @@ def set(self, key, val): def construct_uris(self, *args): pass + + def supports(self): + res = {} + for key, val in self._supports.items(): + if isinstance(val, Callable): + res[key] = val() + else: + res[key] = val + return res + + +def get_signing_algs(): + # Assumes Cryptojwt + return list(SIGNER_ALGS.keys()) + + +def get_encryption_algs(): + return SUPPORTED['alg'] + + +def get_encryption_encs(): + return SUPPORTED['enc'] diff --git a/src/idpyoidc/client/work_condition/oidc.py b/src/idpyoidc/client/work_condition/oidc.py index 197b5d5f..7f538448 100644 --- a/src/idpyoidc/client/work_condition/oidc.py +++ b/src/idpyoidc/client/work_condition/oidc.py @@ -10,9 +10,7 @@ class WorkCondition(work_condition.WorkCondition): "requests_dir": None }) - metadata_claims = { - "redirect_uris": None, - "response_types": ["code"], + supports = { "grant_types": ["authorization_code", "implicit", "refresh_token"], "application_type": "web", "contacts": None, @@ -25,53 +23,33 @@ class WorkCondition(work_condition.WorkCondition): "jwks_uri": None, "sector_identifier_uri": None, "subject_type": None, - "id_token_signed_response_alg": "RS256", - "id_token_encrypted_response_alg": None, - "id_token_encrypted_response_enc": None, - "request_object_signing_alg": None, - "request_object_encryption_alg": None, - "request_object_encryption_enc": None, + "id_token_signed_response_alg": work_condition.get_signing_algs, + "id_token_encrypted_response_alg": work_condition.get_encryption_algs, + "id_token_encrypted_response_enc": work_condition.get_encryption_encs, "default_max_age": None, "require_auth_time": None, "initiate_login_uri": None, "default_acr_values": None, - "request_uris": None, "client_id": None, - } - - can_support = { - "form_post": None, - "jwks": None, - "jwks_uri": None, - "request_parameter": None, - "request_uri": None, + "client_secret": None, "scope": ["openid"], - "verify_args": None, + # "verify_args": None, } - callback_path = { - "requests": "req", - "code": "authz_cb", - "implicit": "authz_im_cb", - "form_post": "form" - } - - callback_uris = ["redirect_uris"] - def __init__(self, - metadata: Optional[dict] = None, - support: Optional[dict] = None, - behaviour: Optional[dict] = None, + prefer: Optional[dict] = None, + callback_path: Optional[dict] = None ): - work_condition.WorkCondition.__init__(self, metadata=metadata, support=support, - behaviour=behaviour) + work_condition.WorkCondition.__init__(self, prefer=prefer, callback_path=callback_path) def verify_rules(self): - if self.get_support("request_parameter") and self.get_support("request_uri"): - raise ValueError("You have to chose one of 'request_parameter' and 'request_uri'.") + if self.get_preference("request_parameter") and self.get_preference("request_uri"): + raise ValueError("You have to chose one of 'request_parameter' and 'request_uri'." + " you can't have both.") + # default is jwks_uri - if not self.get_support("jwks") and not self.get_support('jwks_uri'): - self.set_support('jwks_uri', True) + if not self.get_preference("jwks") and not self.get_preference('jwks_uri'): + self.set_preference('jwks_uri', True) def locals(self, info): requests_dir = info.get("requests_dir") diff --git a/tests/request123456.jwt b/tests/request123456.jwt index 7c63cbe7..88690753 100644 --- a/tests/request123456.jwt +++ b/tests/request123456.jwt @@ -1 +1 @@ -eyJhbGciOiJSUzI1NiIsImtpZCI6IlNIRXlZV2N3TlZrMExUZFJPVFp6WjJGVVduZElWWGRhY2sweFdVTTVTRXB3Y1MwM2RWVXhXVTR6UlEifQ.eyJyZXNwb25zZV90eXBlIjogImNvZGUiLCAic3RhdGUiOiAic3RhdGUiLCAicmVkaXJlY3RfdXJpIjogImh0dHBzOi8vZXhhbXBsZS5jb20vY2xpL2F1dGh6X2NiIiwgInNjb3BlIjogIm9wZW5pZCIsICJub25jZSI6ICJHb2UzRnZxMkpnM3hsMHVYZ3dWdUFVWmFyUk8wVTFmSk05c0pTMl9sTHI4IiwgImNsaWVudF9pZCI6ICJjbGllbnRfaWQiLCAiaXNzIjogImNsaWVudF9pZCIsICJpYXQiOiAxNjY4MTkxNTEzLCAiYXVkIjogWyJodHRwczovL2V4YW1wbGUuY29tIl19.Wl-8ajDc1j5wBcfuZmUX8Zmp7_-tdWExFmD9LGzXJY0VhiR4XrzuVsu_1Im3ytLlsVpYcYhBGmgaO46B-eXWXQR12hqVXeImcQSBaNy6w8gqydN9IFuN0jAqfQbMUehZrgiyZWR1T59C4hFjePG_3xmg6Cu0bALmwfaisAmD32inumFPSwA8j9yUK9rUBmO_YXvkX3i_PAgyxfTSYBzChsLMkgGcTK7Q-ulJnJ9LIO3ylFq-wh9TI3lTsbZkKunPEN3BWX_LnqAtU8dWa2y9jLRxQ88TOSF8tyQa4VdimOy2_Guy3rVOB-0ZbIcbX6tlNO-NA8nMmiez7fdu44sTWw \ No newline at end of file +eyJhbGciOiJSUzI1NiIsImtpZCI6IlNIRXlZV2N3TlZrMExUZFJPVFp6WjJGVVduZElWWGRhY2sweFdVTTVTRXB3Y1MwM2RWVXhXVTR6UlEifQ.eyJyZXNwb25zZV90eXBlIjogImNvZGUiLCAic3RhdGUiOiAic3RhdGUiLCAicmVkaXJlY3RfdXJpIjogImh0dHBzOi8vZXhhbXBsZS5jb20vY2xpL2F1dGh6X2NiIiwgInNjb3BlIjogIm9wZW5pZCIsICJub25jZSI6ICJzcG1VT0V3Z01PS29TdkNXUzJjLVVpcFg5cUlxNHA4UC0wVTBnTW93NjBRIiwgImNsaWVudF9pZCI6ICJjbGllbnRfaWQiLCAiaXNzIjogImNsaWVudF9pZCIsICJpYXQiOiAxNjY4MjU3MTgwLCAiYXVkIjogWyJodHRwczovL2V4YW1wbGUuY29tIl19.b0VYiEKj4WcZ48Bcj4mZHxrWGeZUyuGTOqwiznJB7qquohHlKv9ZtJ179uedRI-SKuSsduo6-KVRKHWOT8yDzPYZZFkVemR75GjV8ciMJLL4zOOB6a15tYzGCx0UpAHkvcYj1jAKyfOBDPRa-YFElxzK2dbvEWiBYEhuy6B5oQZxTJagftPUhO1UT9go3NA3H_Ck-nHnpR5QET0ctprTkp8LETp_rGkuGp-ESlwdMj0a-mCDK0iVhv9xP4fXX47gI1XPxTdceRxrda3EWYWfBDn95ykl2L8FbDznBZ6c2yvc6h0DZJGdlvDpoMWjiBtA_IaoKWBKbNbU4PplyiLR8A \ No newline at end of file diff --git a/tests/test_client_02b_entity_metadata.py b/tests/test_client_02b_entity_metadata.py index 6c9027b3..351d170e 100644 --- a/tests/test_client_02b_entity_metadata.py +++ b/tests/test_client_02b_entity_metadata.py @@ -103,7 +103,7 @@ def test_create_client(): assert client.get_metadata_claim("userinfo_signed_response_alg") == "ES256" assert client.metadata_claim_contains_value("userinfo_signed_response_alg", "ES256") # How to act - assert client.get_support("request_uri") is True + assert client.get_support("request_uris") is True _conf_args = client.config_args() assert _conf_args diff --git a/tests/test_client_21_oidc_service.py b/tests/test_client_21_oidc_service.py index e63b01b0..1b778c13 100644 --- a/tests/test_client_21_oidc_service.py +++ b/tests/test_client_21_oidc_service.py @@ -872,29 +872,28 @@ class TestRegistration(object): def create_request(self): self._iss = ISS client_config = { - "client_id": "client_id", - "client_secret": "a longesh password", "redirect_uris": ["https://example.com/cli/authz_cb"], "issuer": self._iss, "requests_dir": "requests", "base_url": "https://example.com/cli/", } - entity = Entity(keyjar=make_keyjar(), config=client_config, services=DEFAULT_OIDC_SERVICES) + entity = Entity(keyjar=make_keyjar(), config=client_config, services=DEFAULT_OIDC_SERVICES, + client_type='oidc') entity.client_get("service_context").issuer = "https://example.com" self.service = entity.client_get("service", "registration") def test_construct(self): _req = self.service.construct() assert isinstance(_req, RegistrationRequest) - assert len(_req) == 7 + assert len(_req) == 6 def test_config_with_post_logout(self): - self.service.client_get("service_context").work_condition.set_metadata( + self.service.client_get("service_context").work_condition.set_metadata_claim( "post_logout_redirect_uri", "https://example.com/post_logout") _req = self.service.construct() assert isinstance(_req, RegistrationRequest) - assert len(_req) == 8 + assert len(_req) == 7 assert "post_logout_redirect_uri" in _req @@ -907,7 +906,8 @@ def test_config_with_required_request_uri(): "requests_dir": "requests", "base_url": "https://example.com/cli", } - entity = Entity(keyjar=make_keyjar(), config=client_config, services=DEFAULT_OIDC_SERVICES) + entity = Entity(keyjar=make_keyjar(), config=client_config, services=DEFAULT_OIDC_SERVICES, + client_type='oidc') entity.client_get("service_context").issuer = "https://example.com" pi_service = entity.client_get("service", "provider_info") From 17b9eb3aa1fdac77f74e029297a1a90f58af386c Mon Sep 17 00:00:00 2001 From: roland Date: Thu, 17 Nov 2022 08:42:46 +0100 Subject: [PATCH 14/76] Refactoring and putting better names on things. --- .../actor/client/oidc/registration.py | 6 +- src/idpyoidc/client/client_auth.py | 47 +++---- src/idpyoidc/client/entity.py | 109 +++++++-------- src/idpyoidc/client/oauth2/access_token.py | 12 +- src/idpyoidc/client/oauth2/authorization.py | 48 ++++++- src/idpyoidc/client/oauth2/server_metadata.py | 2 +- src/idpyoidc/client/oauth2/utils.py | 4 +- src/idpyoidc/client/oidc/__init__.py | 20 --- src/idpyoidc/client/oidc/access_token.py | 11 +- src/idpyoidc/client/oidc/authorization.py | 42 +++--- src/idpyoidc/client/oidc/end_session.py | 14 +- .../client/oidc/provider_info_discovery.py | 22 --- .../client/oidc/refresh_access_token.py | 2 +- src/idpyoidc/client/oidc/registration.py | 3 +- src/idpyoidc/client/oidc/userinfo.py | 12 +- src/idpyoidc/client/oidc/utils.py | 4 +- src/idpyoidc/client/rp_handler.py | 12 +- src/idpyoidc/client/service.py | 88 ++++-------- src/idpyoidc/client/service_context.py | 130 +++++++++++------- .../client/work_condition/__init__.py | 57 ++++++-- src/idpyoidc/client/work_condition/oauth2.py | 24 +--- src/idpyoidc/client/work_condition/oidc.py | 20 +-- .../client/work_condition/transform.py | 115 ++++++++++++++++ src/idpyoidc/defaults.py | 1 + src/idpyoidc/server/oidc/registration.py | 84 +++++------ tests/request123456.jwt | 2 +- tests/test_client_01_service_context.py | 31 +++-- tests/test_client_02_entity.py | 2 +- tests/test_client_02b_entity_metadata.py | 128 +++++++++-------- tests/test_client_04_service.py | 36 +++-- tests/test_client_06_client_authn.py | 69 ++++++---- tests/test_client_12_client_auth.py | 49 ++++--- .../test_client_14_service_context_impexp.py | 44 +++--- tests/test_client_18_service.py | 2 +- tests/test_client_19_webfinger.py | 2 - tests/test_client_21_oidc_service.py | 27 ++-- tests/test_client_23_pkce.py | 2 +- tests/test_client_24_oic_utils.py | 7 +- tests/test_client_28_rp_handler_oidc.py | 10 +- tests/test_client_29_pushed_auth.py | 2 +- tests/test_client_30_rph_defaults.py | 2 +- tests/test_client_40_dpop.py | 4 +- tests/test_client_41_rp_handler_persistent.py | 6 +- tests/test_client_50_ciba.py | 2 +- tests/test_client_51_identity_assurance.py | 2 +- ...ent_11_base.py => xtest_client_11_base.py} | 2 +- 46 files changed, 733 insertions(+), 587 deletions(-) create mode 100644 src/idpyoidc/client/work_condition/transform.py rename tests/{test_client_11_base.py => xtest_client_11_base.py} (85%) diff --git a/src/idpyoidc/actor/client/oidc/registration.py b/src/idpyoidc/actor/client/oidc/registration.py index 31cf5a6e..8e196559 100644 --- a/src/idpyoidc/actor/client/oidc/registration.py +++ b/src/idpyoidc/actor/client/oidc/registration.py @@ -101,7 +101,7 @@ def _cmp(a, b): def check(entity, claim, expected): try: - _usable = entity.get_metadata_claim(claim) + _usable = entity.get_service_context().get_usage(claim) except KeyError: pass else: @@ -162,10 +162,10 @@ def add_client_preference(self, request_args=None, **kwargs): continue try: - request_args[prop] = _context.work_condition.behaviour[prop] + request_args[prop] = _context.work_condition.get_usage(prop) except KeyError: try: - request_args[prop] = _context.client_preferences[prop] + request_args[prop] = _context.work_condition.get_preference[prop] except KeyError: pass return request_args, {} diff --git a/src/idpyoidc/client/client_auth.py b/src/idpyoidc/client/client_auth.py index de7bb357..58f73ede 100755 --- a/src/idpyoidc/client/client_auth.py +++ b/src/idpyoidc/client/client_auth.py @@ -95,7 +95,7 @@ def _get_passwd(request, service, **kwargs): try: passwd = request["client_secret"] except KeyError: - passwd = service.client_get("service_context").client_secret + passwd = service.client_get("service_context").get_usage('client_secret') return passwd @staticmethod @@ -220,9 +220,8 @@ def modify_request(self, request, service, **kwargs): try: request["client_secret"] = kwargs["client_secret"] except (KeyError, TypeError): - if _context.client_secret: - request["client_secret"] = _context.client_secret - else: + request["client_secret"] = _context.get_usage('client_secret') + if not request["client_secret"]: raise AuthnFailure("Missing client secret") # Set the client_id in the the request @@ -458,30 +457,26 @@ def _get_signing_key(self, algorithm, context, kid=None): return signing_key - def _get_audience_and_algorithm(self, context, entity, **kwargs): + def _get_audience_and_algorithm(self, context, **kwargs): algorithm = None # audience for the signed JWT depends on which endpoint # we're talking to. if "authn_endpoint" in kwargs and kwargs["authn_endpoint"] in ["token_endpoint"]: - _alg = context.registration_response.get("token_endpoint_auth_signing_alg") - if _alg: - algorithm = _alg - else: - algorithm = entity.get_metadata_claim("token_endpoint_auth_signing_alg") - if algorithm is None: - _pi = context.provider_info - try: - algs = _pi["token_endpoint_auth_signing_alg_values_supported"] - except KeyError: - algorithm = "RS256" # default - else: - for alg in algs: # pick the first one I support and have keys for - if alg in SIGNER_ALGS and self.get_signing_key_from_keyjar( - alg, context - ): - algorithm = alg - break + algorithm = context.get_usage("token_endpoint_auth_signing_alg") + if algorithm is None: + _pi = context.provider_info + try: + algs = _pi["token_endpoint_auth_signing_alg_values_supported"] + except KeyError: + algorithm = "RS256" # default + else: + for alg in algs: # pick the first one I support and have keys for + if alg in SIGNER_ALGS and self.get_signing_key_from_keyjar( + alg, context + ): + algorithm = alg + break audience = context.provider_info["token_endpoint"] else: @@ -493,8 +488,7 @@ def _get_audience_and_algorithm(self, context, entity, **kwargs): def _construct_client_assertion(self, service, **kwargs): _context = service.client_get("service_context") - _entity = service.client_get("entity") - audience, algorithm = self._get_audience_and_algorithm(_context, _entity, **kwargs) + audience, algorithm = self._get_audience_and_algorithm(_context, **kwargs) if "kid" in kwargs: signing_key = self._get_signing_key(algorithm, _context, kid=kwargs["kid"]) @@ -511,7 +505,8 @@ def _construct_client_assertion(self, service, **kwargs): # construct the signed JWT with the assertions and add # it as value to the 'client_assertion' claim of the request - return assertion_jwt(_entity.get_client_id(), signing_key, audience, algorithm, **_args) + return assertion_jwt(_context.get_usage('client_id'), signing_key, audience, algorithm, + **_args) def modify_request(self, request, service, **kwargs): """ diff --git a/src/idpyoidc/client/entity.py b/src/idpyoidc/client/entity.py index 784693cf..8aa657ea 100644 --- a/src/idpyoidc/client/entity.py +++ b/src/idpyoidc/client/entity.py @@ -1,4 +1,3 @@ -import hashlib import logging import os from typing import Optional @@ -6,7 +5,6 @@ from cryptojwt import KeyJar from cryptojwt.key_jar import init_key_jar -from cryptojwt.utils import as_bytes from idpyoidc.client.client_auth import client_auth_setup from idpyoidc.client.configure import Configuration @@ -18,7 +16,7 @@ logger = logging.getLogger(__name__) -rt2gt = { +RESPONSE_TYPES2GRANT_TYPES = { "code": ["authorization_code"], "id_token": ["implicit"], "id_token token": ["implicit"], @@ -35,7 +33,7 @@ def response_types_to_grant_types(response_types): _rt = response_type.split(" ") _rt.sort() try: - _gt = rt2gt[" ".join(_rt)] + _gt = RESPONSE_TYPES2GRANT_TYPES[" ".join(_rt)] except KeyError: logger.warning("No such response type combination: {}".format(response_types)) else: @@ -44,40 +42,30 @@ def response_types_to_grant_types(response_types): return list(_res) -def set_jwks_uri_or_jwks(entity, service_context, config, jwks_uri, keyjar): +def _set_jwks(service_context, config: Configuration, keyjar: Optional[KeyJar]): + _key_conf = config.get("key_conf") or config.conf.get('key_conf') + + if _key_conf: + keys_args = {k: v for k, v in _key_conf.items() if k != "uri_path"} + _keyjar = init_key_jar(**keys_args) + service_context.set_preference("jwks", _keyjar.export_jwks()) + elif keyjar: + service_context.set_preference("jwks", keyjar.export_jwks()) + + +def set_jwks_uri_or_jwks(service_context, config, jwks_uri, keyjar): # lots of different ways to configure the RP's keys if jwks_uri: - entity.set_support("jwks_uri", True) - entity.set_metadata_claim("jwks_uri", jwks_uri) + service_context.set_preference("jwks_uri", jwks_uri) else: if config.get("jwks_uri"): - entity.set_support("jwks_uri", True) - entity.set_support("jwks", False) - elif config.get("jwks"): - entity.set_support("jwks", True) - entity.set_support("jwks_uri", False) + service_context.set_preference("jwks_uri", jwks_uri) else: - entity.set_support("jwks_uri", False) - if config.get("key_conf"): - keys_args = {k: v for k, v in config.get("key_conf").items() if k != "uri_path"} - _keyjar = init_key_jar(**keys_args) - entity.set_support("jwks", True) - entity.set_metadata_claim("jwks", _keyjar.export_jwks()) - return - elif keyjar: - entity.set_support("jwks", True) - entity.set_metadata_claim("jwks", keyjar.export_jwks()) - return - - for attr in ["jwks_uri", "jwks"]: - if entity.will_use(attr): - _val = getattr(service_context, attr) - if _val: - entity.set_metadata_claim(attr, _val) - return + _set_jwks(service_context, config, keyjar) class Entity(object): + def __init__( self, keyjar: Optional[KeyJar] = None, @@ -118,21 +106,22 @@ def __init__( else: _srvs = DEFAULT_OIDC_SERVICES - self._service = init_services(service_definitions=_srvs, client_get=self.client_get, - metadata=config.conf.get("metadata", {}), - support=config.conf.get("support", {})) + self._service = init_services(service_definitions=_srvs, client_get=self.client_get) self.setup_client_authn_methods(config) - jwks_uri = jwks_uri or self.get_metadata_claim("jwks_uri") - set_jwks_uri_or_jwks(self, self._service_context, config, jwks_uri, _kj) + jwks_uri = jwks_uri or self._service_context.get("jwks_uri") + set_jwks_uri_or_jwks(self._service_context, config, jwks_uri, self._service_context.keyjar) # Deal with backward compatibility self.backward_compatibility(config) - self.construct_uris(self._service_context.issuer, - self._service_context.hash_seed, - config.conf.get("callback")) + self._service_context.work_condition.load_conf(config.conf, + supports=self._service_context.supports()) + + self._service_context.construct_uris(self._service_context.issuer, + self._service_context.hash_seed, + config.conf.get("callback")) def client_get(self, what, *arg): _func = getattr(self, "get_{}".format(what), None) @@ -163,7 +152,11 @@ def get_entity(self): return self def get_client_id(self): - return self._service_context.work_condition.get_usage_claim('client_id') + _val = self._service_context.work_condition.get_usage('client_id') + if _val: + return _val + else: + return self._service_context.work_condition.get_preference('client_id') def setup_client_authn_methods(self, config): self._service_context.client_authn_method = client_auth_setup( @@ -171,29 +164,18 @@ def setup_client_authn_methods(self, config): ) def backward_compatibility(self, config): + _work_condition = self._service_context.work_condition _uris = config.get("redirect_uris") if _uris: - self.set_metadata_claim("redirect_uris", _uris) + _work_condition.set_preference("redirect_uris", _uris) _dir = config.conf.get("requests_dir") if _dir: - authz_serv = self.get_service('authorization') - if authz_serv: # If this isn't true that's weird. Tests perhaps ? - self.set_support("request_uri", True) - if not os.path.isdir(_dir): - os.makedirs(_dir) - authz_serv.callback_path["request_uris"] = _dir + _work_condition.set_preference('requests_dir', _dir) _pref = config.get("client_preferences", {}) for key, val in _pref.items(): - if self.set_metadata_claim(key, val) is False: - if self.set_support(key, val) is False: - setattr(self, key, val) - - for key, val in config.conf.items(): - if key not in ["port", "domain", "httpc_params", "metadata", "client_preferences", - "support", "services", "add_ons"]: - self.extra[key] = val + _work_condition.set_preference(key, val) auth_request_args = config.conf.get("request_args", {}) if auth_request_args: @@ -204,18 +186,25 @@ def config_args(self): res = {} for id, service in self._service.items(): res[id] = { - "metadata": service.metadata_claims, - "support": service.can_support + "preference": service.supports(), } res[""] = { - "metadata": self._service_context.work_condition.metadata_claims, - "support": self._service_context.work_condition.can_support + "preference": self._service_context.work_condition.supports, } return res def get_callback_uris(self): res = [] for service in self._service.values(): - res.extend(service.callback_uris) - res.extend(self._service_context.work_condition.callback_uris) + for _callback in service.callback_uris(): + _uri = self._service_context.work_condition.get_preference(_callback) + if _uri: + res[_callback] = _uri + # res.extend(self._service_context.work_condition.callback_uris) return res + + def prefers(self): + return self._service_context.work_condition.prefers() + + def use(self): + return self._service_context.work_condition.get_use() \ No newline at end of file diff --git a/src/idpyoidc/client/oauth2/access_token.py b/src/idpyoidc/client/oauth2/access_token.py index 666374a2..67c51517 100644 --- a/src/idpyoidc/client/oauth2/access_token.py +++ b/src/idpyoidc/client/oauth2/access_token.py @@ -3,6 +3,8 @@ from idpyoidc.client.oauth2.utils import get_state_parameter from idpyoidc.client.service import Service +from idpyoidc.client.work_condition import get_client_authn_methods +from idpyoidc.client.work_condition import get_signing_algs from idpyoidc.message import oauth2 from idpyoidc.message.oauth2 import ResponseMessage from idpyoidc.time_util import time_sans_frac @@ -24,13 +26,9 @@ class AccessToken(Service): request_body_type = "urlencoded" response_body_type = "json" - metadata_claims = { - "token_endpoint_auth_method": "client_secret_basic", - "token_endpoint_auth_signing_alg": "RS256" - } - - usage_rules = { - "token_endpoint_auth_methods": None + _supports = { + "token_endpoint_auth_method": get_client_authn_methods, + "token_endpoint_auth_signing_alg": get_signing_algs, } def __init__(self, client_get, conf=None): diff --git a/src/idpyoidc/client/oauth2/authorization.py b/src/idpyoidc/client/oauth2/authorization.py index c55cdbb2..ec3788a3 100644 --- a/src/idpyoidc/client/oauth2/authorization.py +++ b/src/idpyoidc/client/oauth2/authorization.py @@ -1,5 +1,7 @@ """The service that talks to the OAuth2 Authorization endpoint.""" import logging +from typing import List +from typing import Optional from idpyoidc.client.oauth2.utils import get_state_parameter from idpyoidc.client.oauth2.utils import pre_construct_pick_redirect_uri @@ -24,6 +26,18 @@ class Authorization(Service): service_name = "authorization" response_body_type = "urlencoded" + _supports = { + "response_types": ["code"] + } + + _callback_path = { + "redirect_uris": { # based on response_types + "code": "authz_cb", + "implicit": "authz_im_cb", + # "form_post": "form" + } + } + def __init__(self, client_get, conf=None): Service.__init__(self, client_get, conf=conf) self.pre_construct.extend([pre_construct_pick_redirect_uri, set_state_parameter]) @@ -46,7 +60,7 @@ def gather_request_args(self, **kwargs): if "redirect_uri" not in ar_args: try: # ar_args["redirect_uri"] = self.client_get("service_context").redirect_uris[0] - ar_args["redirect_uri"] = self.client_get("entity").get_metadata_claim( + ar_args["redirect_uri"] = self.client_get("service_context").get_usage( "redirect_uris")[0] except (KeyError, AttributeError): raise MissingParameter("redirect_uri") @@ -78,3 +92,35 @@ def post_parse_response(self, response, **kwargs): except KeyError: pass return response + + def construct_uris(self, base_url: str, hex: bytes, + targets: Optional[List[str]] = None, + preference: Optional[dict] = None): + if not targets: + targets = list(self._callback_path.keys()) + + res = {} + for uri_name in targets: + spec = self._callback_path.get(uri_name) + if spec: + if uri_name == "redirect_uris": # another layer + _uris = [] + for typ, path in spec.items(): + add = False + if 'response_type' in preference: + if typ in preference['response_type']: + add = True + elif typ in preference: + add = True + elif 'response_type' in self._supports: + if typ in self._supports['response_type']: + add = True + elif typ in self._supports: + add = True + + if add: + _uris.append(self.get_uri(base_url, path, hex)) + res[uri_name] = _uris + elif uri_name in preference or uri_name in self._supports: + res[uri_name] = self.get_uri(base_url, spec, hex) + return res \ No newline at end of file diff --git a/src/idpyoidc/client/oauth2/server_metadata.py b/src/idpyoidc/client/oauth2/server_metadata.py index 857c7075..8f2b8929 100644 --- a/src/idpyoidc/client/oauth2/server_metadata.py +++ b/src/idpyoidc/client/oauth2/server_metadata.py @@ -22,7 +22,7 @@ class ServerMetadata(Service): service_name = "server_metadata" http_method = "GET" - metadata_claims = {} + _supports = {} def __init__(self, client_get, conf=None): Service.__init__(self, client_get, conf=conf) diff --git a/src/idpyoidc/client/oauth2/utils.py b/src/idpyoidc/client/oauth2/utils.py index c32deb89..66d176a8 100644 --- a/src/idpyoidc/client/oauth2/utils.py +++ b/src/idpyoidc/client/oauth2/utils.py @@ -37,7 +37,7 @@ def pick_redirect_uri( if context.work_condition.callback: if not response_type: - _conf_resp_types = context.work_condition.behaviour.get("response_types", []) + _conf_resp_types = context.work_condition.get_usage("response_types", []) response_type = request_args.get("response_type") if not response_type and _conf_resp_types: response_type = _conf_resp_types[0] @@ -56,7 +56,7 @@ def pick_redirect_uri( f"redirect_uri={redirect_uri}" ) else: - redirect_uris = entity.get_metadata_claim("redirect_uris", []) + redirect_uris = context.get_usage("redirect_uris", []) if redirect_uris: redirect_uri = redirect_uris[0] else: diff --git a/src/idpyoidc/client/oidc/__init__.py b/src/idpyoidc/client/oidc/__init__.py index fdab6050..33e0d86f 100755 --- a/src/idpyoidc/client/oidc/__init__.py +++ b/src/idpyoidc/client/oidc/__init__.py @@ -40,26 +40,6 @@ # This should probably be part of the configuration MAX_AUTHENTICATION_AGE = 86400 -PREFERENCE2PROVIDER = { - # "require_signed_request_object": "request_object_algs_supported", - "request_object_signing_alg": "request_object_signing_alg_values_supported", - "request_object_encryption_alg": "request_object_encryption_alg_values_supported", - "request_object_encryption_enc": "request_object_encryption_enc_values_supported", - "userinfo_signed_response_alg": "userinfo_signing_alg_values_supported", - "userinfo_encrypted_response_alg": "userinfo_encryption_alg_values_supported", - "userinfo_encrypted_response_enc": "userinfo_encryption_enc_values_supported", - "id_token_signed_response_alg": "id_token_signing_alg_values_supported", - "id_token_encrypted_response_alg": "id_token_encryption_alg_values_supported", - "id_token_encrypted_response_enc": "id_token_encryption_enc_values_supported", - "default_acr_values": "acr_values_supported", - "subject_type": "subject_types_supported", - "token_endpoint_auth_method": "token_endpoint_auth_methods_supported", - "token_endpoint_auth_signing_alg": "token_endpoint_auth_signing_alg_values_supported", - "response_types": "response_types_supported", - "grant_types": "grant_types_supported", -} - -PROVIDER2PREFERENCE = dict([(v, k) for k, v in PREFERENCE2PROVIDER.items()]) PROVIDER_DEFAULT = { "token_endpoint_auth_method": "client_secret_basic", diff --git a/src/idpyoidc/client/oidc/access_token.py b/src/idpyoidc/client/oidc/access_token.py index 9bb41b55..6fc9c783 100644 --- a/src/idpyoidc/client/oidc/access_token.py +++ b/src/idpyoidc/client/oidc/access_token.py @@ -5,6 +5,7 @@ from idpyoidc.client.exception import ParameterError from idpyoidc.client.oauth2 import access_token from idpyoidc.client.oidc import IDT2REG +from idpyoidc.client.work_condition import get_client_authn_methods from idpyoidc.client.work_condition import get_signing_algs from idpyoidc.message import Message from idpyoidc.message import oidc @@ -21,9 +22,9 @@ class AccessToken(access_token.AccessToken): response_cls = oidc.AccessTokenResponse error_msg = oidc.ResponseMessage - supports = { - "token_endpoint_auth_method": '', - "token_endpoint_auth_signing_alg": get_signing_algs + _supports = { + "token_endpoint_auth_methods_supported": get_client_authn_methods, + "token_endpoint_auth_signing_alg_values_supported": get_signing_algs } def __init__(self, client_get, conf: Optional[dict] = None): @@ -61,7 +62,7 @@ def gather_verify_arguments( except KeyError: pass - _verify_args = _context.work_condition.behaviour.get("verify_args") + _verify_args = _context.work_condition.get_usage("verify_args") if _verify_args: if _verify_args: kwargs.update(_verify_args) @@ -91,6 +92,6 @@ def update_service_context(self, resp, key="", **kwargs): def get_authn_method(self): _work_condition = self.client_get("service_context").work_condition try: - return _work_condition.behaviour["token_endpoint_auth_method"] + return _work_condition.get_usage("token_endpoint_auth_method") except KeyError: return self.default_authn_method diff --git a/src/idpyoidc/client/oidc/authorization.py b/src/idpyoidc/client/oidc/authorization.py index f35c461a..8b080e34 100644 --- a/src/idpyoidc/client/oidc/authorization.py +++ b/src/idpyoidc/client/oidc/authorization.py @@ -28,18 +28,18 @@ class Authorization(authorization.Authorization): response_cls = oidc.AuthorizationResponse error_msg = oidc.ResponseMessage - supports = { - "request_object_signing_alg": work_condition.get_signing_algs, - "request_object_encryption_alg": work_condition.get_encryption_algs, - "request_object_encryption_enc": work_condition.get_encryption_encs, + _supports = { + "request_object_signing_alg_values_supported": work_condition.get_signing_algs, + "request_object_encryption_alg_values_supported": work_condition.get_encryption_algs, + "request_object_encryption_enc_values_supported": work_condition.get_encryption_encs, + "response_types_supported": ["code", "form_post"], "request_uris": None, "request_parameter": None, + "encrypt_request_object": None, "redirect_uris": None, - "response_types": ["code"], - "form_post": None, } - callback_path = { + _callback_path = { "request_uris": "req", "redirect_uris": { # based on response_types "code": "authz_cb", @@ -50,7 +50,6 @@ class Authorization(authorization.Authorization): def __init__(self, client_get, conf=None): authorization.Authorization.__init__(self, client_get, conf=conf) - self.default_request_args.update({"scope": ["openid"]}) self.pre_construct = [ self.set_state, pre_construct_pick_redirect_uri, @@ -103,7 +102,6 @@ def post_parse_response(self, response, **kwargs): def oidc_pre_construct(self, request_args=None, post_args=None, **kwargs): _context = self.client_get("service_context") - _entity = self.client_get("entity") if request_args is None: request_args = {} @@ -111,7 +109,7 @@ def oidc_pre_construct(self, request_args=None, post_args=None, **kwargs): try: _response_types = [request_args["response_type"]] except KeyError: - _response_types = _context.work_condition.behaviour.get("response_types") + _response_types = _context.get_usage("response_types") if _response_types: request_args["response_type"] = _response_types[0] else: @@ -119,7 +117,7 @@ def oidc_pre_construct(self, request_args=None, post_args=None, **kwargs): # For OIDC 'openid' is required in scope if "scope" not in request_args: - _scope = self.client_get("entity").get_support("scope") + _scope = _context.get_usage("scope") if _scope: request_args["scope"] = _scope else: @@ -151,9 +149,9 @@ def oidc_pre_construct(self, request_args=None, post_args=None, **kwargs): post_args["request_param"] = "request" del kwargs["request_method"] else: - if _entity.get_support("request_uri"): + if _context.get_usage("request_uri"): post_args["request_param"] = "request_uri" - elif _entity.get_support("request_parameter"): + elif _context.get_usage("request_parameter"): post_args["request_param"] = "request" return request_args, post_args @@ -171,7 +169,7 @@ def get_request_object_signing_alg(self, **kwargs): if not alg: _context = self.client_get("service_context") try: - alg = _context.work_condition.behaviour["request_object_signing_alg"] + alg = _context.work_condition.get_usage("request_object_signing_alg") except KeyError: # Use default alg = "RS256" return alg @@ -209,18 +207,16 @@ def construct_request_parameter( if alg == "none": kwargs["keys"] = [] - _srv_cntx = _context - # This is the issuer of the JWT, that is me ! _issuer = kwargs.get("issuer") if _issuer is None: - kwargs["issuer"] = _srv_cntx.get_client_id() + kwargs["issuer"] = _context.get_client_id() if kwargs.get("recv") is None: try: - kwargs["recv"] = _srv_cntx.provider_info["issuer"] + kwargs["recv"] = _context.provider_info["issuer"] except KeyError: - kwargs["recv"] = _srv_cntx.issuer + kwargs["recv"] = _context.issuer try: del kwargs["service"] @@ -273,15 +269,15 @@ def oidc_post_construct(self, req, **kwargs): if _request_param: del kwargs["request_param"] else: - if _context.work_condition.get_support("request_uri"): + if _context.get_usage("request_uri"): _request_param = "request_uri" - elif _context.work_condition.get_support("request_parameter"): + elif _context.get_usage("request_parameter"): _request_param = "request" _req = None # just a flag if _request_param == "request_uri": kwargs["base_path"] = _context.get("base_url") + "/" + "requests" - kwargs["local_dir"] = _context.work_condition.get("requests_dir", "./requests") + kwargs["local_dir"] = _context.get_usage("requests_dir", "./requests") _req = self.construct_request_parameter(req, _request_param, **kwargs) req["request_uri"] = self.store_request_on_file(_req, **kwargs) elif _request_param == "request": @@ -331,7 +327,7 @@ def gather_verify_arguments( except KeyError: pass - _verify_args = _context.work_condition.behaviour.get("verify_args") + _verify_args = _context.get_usage("verify_args") if _verify_args: kwargs.update(_verify_args) diff --git a/src/idpyoidc/client/oidc/end_session.py b/src/idpyoidc/client/oidc/end_session.py index 8df64fab..8d8901f7 100644 --- a/src/idpyoidc/client/oidc/end_session.py +++ b/src/idpyoidc/client/oidc/end_session.py @@ -20,7 +20,7 @@ class EndSession(Service): service_name = "end_session" response_body_type = "html" - metadata_claims = { + _supports = { "post_logout_redirect_uris": None, "frontchannel_logout_uri": None, "frontchannel_logout_session_required": None, @@ -28,24 +28,12 @@ class EndSession(Service): "backchannel_logout_session_required": None } - can_support = { - "frontchannel_logout": None, - "backchannel_logout": None, - "post_logout_redirects": None - } - callback_path = { "frontchannel_logout_uri": "fc_logout", "backchannel_logout_uri": "bc_logout", "post_logout_redirect_uris": "session_logout" } - support_to_uri = { - "frontchannel_logout": "frontchannel_logout_uri", - "backchannel_logout": "backchannel_logout_uri", - "post_logout_redirect": "post_logout_redirect_uris" - } - def __init__(self, client_get, conf=None): Service.__init__(self, client_get, conf=conf) self.pre_construct = [ diff --git a/src/idpyoidc/client/oidc/provider_info_discovery.py b/src/idpyoidc/client/oidc/provider_info_discovery.py index fe8af871..c4d20482 100644 --- a/src/idpyoidc/client/oidc/provider_info_discovery.py +++ b/src/idpyoidc/client/oidc/provider_info_discovery.py @@ -9,28 +9,6 @@ logger = logging.getLogger(__name__) -PREFERENCE2PROVIDER = { - # "require_signed_request_object": "request_object_algs_supported", - "request_object_signing_alg": "request_object_signing_alg_values_supported", - "request_object_encryption_alg": "request_object_encryption_alg_values_supported", - "request_object_encryption_enc": "request_object_encryption_enc_values_supported", - "userinfo_signed_response_alg": "userinfo_signing_alg_values_supported", - "userinfo_encrypted_response_alg": "userinfo_encryption_alg_values_supported", - "userinfo_encrypted_response_enc": "userinfo_encryption_enc_values_supported", - "id_token_signed_response_alg": "id_token_signing_alg_values_supported", - "id_token_encrypted_response_alg": "id_token_encryption_alg_values_supported", - "id_token_encrypted_response_enc": "id_token_encryption_enc_values_supported", - "default_acr_values": "acr_values_supported", - "subject_type": "subject_types_supported", - "token_endpoint_auth_method": "token_endpoint_auth_methods_supported", - "token_endpoint_auth_signing_alg": "token_endpoint_auth_signing_alg_values_supported", - "response_types": "response_types_supported", - "grant_types": "grant_types_supported", - "scope": "scopes_supported", -} - -PROVIDER2PREFERENCE = dict([(v, k) for k, v in PREFERENCE2PROVIDER.items()]) - PROVIDER_DEFAULT = { "token_endpoint_auth_method": "client_secret_basic", "id_token_signed_response_alg": "RS256", diff --git a/src/idpyoidc/client/oidc/refresh_access_token.py b/src/idpyoidc/client/oidc/refresh_access_token.py index 30d9d3c0..b6e0ef71 100644 --- a/src/idpyoidc/client/oidc/refresh_access_token.py +++ b/src/idpyoidc/client/oidc/refresh_access_token.py @@ -10,6 +10,6 @@ class RefreshAccessToken(refresh_access_token.RefreshAccessToken): def get_authn_method(self): _work_condition = self.client_get("service_context").work_condition try: - return _work_condition.behaviour["token_endpoint_auth_method"] + return _work_condition.get_usage("token_endpoint_auth_method") except KeyError: return self.default_authn_method diff --git a/src/idpyoidc/client/oidc/registration.py b/src/idpyoidc/client/oidc/registration.py index 239bdd35..a96974ab 100644 --- a/src/idpyoidc/client/oidc/registration.py +++ b/src/idpyoidc/client/oidc/registration.py @@ -20,7 +20,6 @@ class Registration(Service): request_body_type = "json" http_method = "POST" - usage_to_uri_map = {} callback_path = {} def __init__(self, client_get, conf=None): @@ -70,7 +69,7 @@ def update_service_context(self, resp, key="", **kwargs): _context.registration_response = resp _client_id = resp.get("client_id") if _client_id: - _context.work_condition.set_usage_claim("client_id", _client_id) + _context.work_condition.set_usage("client_id", _client_id) if _client_id not in _keyjar: _keyjar.import_jwks(_keyjar.export_jwks(True, ""), issuer_id=_client_id) _client_secret = resp.get("client_secret") diff --git a/src/idpyoidc/client/oidc/userinfo.py b/src/idpyoidc/client/oidc/userinfo.py index 4ace120c..96bcd42a 100644 --- a/src/idpyoidc/client/oidc/userinfo.py +++ b/src/idpyoidc/client/oidc/userinfo.py @@ -4,6 +4,9 @@ from idpyoidc.client.oauth2.utils import get_state_parameter from idpyoidc.client.service import Service +from idpyoidc.client.work_condition import get_encryption_algs +from idpyoidc.client.work_condition import get_encryption_encs +from idpyoidc.client.work_condition import get_signing_algs from idpyoidc.exception import MissingSigningKey from idpyoidc.message import Message from idpyoidc.message import oidc @@ -38,10 +41,11 @@ class UserInfo(Service): default_authn_method = "bearer_header" http_method = "GET" - metadata_claims = { - "userinfo_signed_response_alg": None, - "userinfo_encrypted_response_alg": None, - "userinfo_encrypted_response_enc": None + _supports = { + "userinfo_signing_alg_values_supported": get_signing_algs, + "userinfo_encryption_alg_values_supported": get_encryption_algs, + "userinfo_encryption_enc_values_supported": get_encryption_encs, + "encrypt_userinfo": None } def __init__(self, client_get, conf=None): diff --git a/src/idpyoidc/client/oidc/utils.py b/src/idpyoidc/client/oidc/utils.py index 7fd075ff..097f6f9f 100644 --- a/src/idpyoidc/client/oidc/utils.py +++ b/src/idpyoidc/client/oidc/utils.py @@ -20,7 +20,7 @@ def request_object_encryption(msg, service_context, **kwargs): encalg = kwargs["request_object_encryption_alg"] except KeyError: try: - encalg = service_context.work_condition.behaviour["request_object_encryption_alg"] + encalg = service_context.get_usage("request_object_encryption_alg") except KeyError: return msg @@ -31,7 +31,7 @@ def request_object_encryption(msg, service_context, **kwargs): encenc = kwargs["request_object_encryption_enc"] except KeyError: try: - encenc = service_context.work_condition.behaviour["request_object_encryption_enc"] + encenc = service_context.get_usage("request_object_encryption_enc") except KeyError: raise MissingRequiredAttribute("No request_object_encryption_enc specified") diff --git a/src/idpyoidc/client/rp_handler.py b/src/idpyoidc/client/rp_handler.py index 8ef27bd0..8849a710 100644 --- a/src/idpyoidc/client/rp_handler.py +++ b/src/idpyoidc/client/rp_handler.py @@ -383,9 +383,9 @@ def client_setup( def _get_response_type(self, context, req_args: Optional[dict] = None): if req_args: - return req_args.get("response_type", context.work_condition.behaviour["response_types"][0]) + return req_args.get("response_type", context.work_condition.get_usage("response_types")[0]) else: - return context.work_condition.behaviour["response_types"][0] + return context.work_condition.get_usage("response_types")[0] def init_authorization( self, @@ -422,7 +422,7 @@ def init_authorization( "redirect_uri": pick_redirect_uri( _context, _entity, request_args=req_args, response_type=_response_type ), - "scope": _context.work_condition.behaviour["scope"], + "scope": _context.work_condition.get_usage("scope"), "response_type": _response_type, "nonce": _nonce, } @@ -496,7 +496,7 @@ def get_response_type(client): :param client: A Client instance :return: The response_type """ - return client.service_context.get("behaviour")["response_types"][0] + return client.service_context.work_condition.get_usage("response_types")[0] @staticmethod def get_client_authn_method(client, endpoint): @@ -510,9 +510,9 @@ def get_client_authn_method(client, endpoint): """ if endpoint == "token_endpoint": try: - am = client.client_get("service_context").get("behaviour")[ + am = client.client_get("service_context").work_condition.get_usage( "token_endpoint_auth_method" - ] + ) except KeyError: return "" else: diff --git a/src/idpyoidc/client/service.py b/src/idpyoidc/client/service.py index 4a852ee0..ffbe2f6d 100644 --- a/src/idpyoidc/client/service.py +++ b/src/idpyoidc/client/service.py @@ -2,6 +2,7 @@ import json import logging from typing import Callable +from typing import List from typing import Optional from typing import Union from urllib.parse import urlparse @@ -12,8 +13,8 @@ from idpyoidc.impexp import ImpExp from idpyoidc.item import DLDict from idpyoidc.message import Message -from idpyoidc.message.oauth2 import ResponseMessage from idpyoidc.message.oauth2 import is_error_message +from idpyoidc.message.oauth2 import ResponseMessage from idpyoidc.util import importer from .configure import Configuration from .exception import ResponseError @@ -75,8 +76,6 @@ def __init__( self.client_get = client_get self.default_request_args = {} - self.prefer = conf.get("prefer", {}) - self.use = {} if conf: self.conf = conf @@ -117,10 +116,14 @@ def gather_request_args(self, **kwargs): """ ar_args = kwargs.copy() - _entity = self.client_get("entity") - md = _entity.collect_metadata() - _context = self.client_get("service_context") + _use = _context.collect_usage() + if not _use: + _use = _context.map_preferred_to_register() + + if "request_args" in self.conf: + ar_args.update(self.conf["request_args"]) + # Go through the list of claims defined for the message class. # There are a couple of places where information can be found. # Access them in the order of priority @@ -131,20 +134,10 @@ def gather_request_args(self, **kwargs): if prop in ar_args: continue - if prop != "state": - val = _context.get(prop) - else: - val = "" - + val = self.default_request_args.get(prop) if not val: - if "request_args" in self.conf: - val = self.conf["request_args"].get(prop) - if not val: - val = self.default_request_args.get(prop) - if not val: - val = _context.work_condition.behaviour.get(prop) - if not val: - val = md.get(prop) + val = _use.get(prop) + if val: ar_args[prop] = val @@ -630,52 +623,39 @@ def supports(self): res[key] = val return res - # def get_conf_attr(self, attr, default=None): - # """ - # Get the value of an attribute in the configuration - # - # :param attr: The attribute - # :param default: If the attribute doesn't appear in the configuration - # return this value - # :return: The value of attribute in the configuration or the default - # value - # """ - # if attr in self.conf: - # return self.conf[attr] - # - # return default - def get_callback_path(self, callback): - return self.callback_path.get(callback) + return self._callback_path.get(callback) @staticmethod def get_uri(base_url, path, hex): return f"{base_url}/{path}/{hex}" - def construct_uris(self, base_url, hex): - for activity, _support in self.support.items(): - if _support: - uri = self.support_to_uri.get(activity) - if uri and uri not in self.metadata: - self.metadata[uri] = self.get_uri(base_url, self.callback_path[uri], hex) + def construct_uris(self, base_url: str, hex: bytes, + targets: Optional[List[str]] = None, + preference: Optional[dict] = None): + if not targets: + targets = self._callback_path.keys() + res = {} + for uri in targets: + _path = self._callback_path.get(uri) + if _path: + res[uri] = self.get_uri(base_url, _path, hex) + return res - def get_metadata_claim(self, claim, default=None): - try: - return self.metadata[claim] - except KeyError: - return default + def supported(self, claim): + return claim in self._supports + + def callback_uris(self): + return list(self._callback_path.keys()) - def set_metadata_claim(self, key, value): - self.metadata[key] = value -def init_services(service_definitions, client_get, metadata, support): +def init_services(service_definitions, client_get): """ Initiates a set of services :param service_definitions: A dictionary containing service definitions :param client_get: A function that returns different things from the base entity. - :param support: What facets of the service that can be used :return: A dictionary, with service name as key and the service instance as value. """ @@ -694,14 +674,6 @@ def init_services(service_definitions, client_get, metadata, support): else: _srv = service_configuration["class"](**kwargs) - for key, val in metadata.items(): - if key in _srv.metadata_claims and key not in _srv.metadata: - _srv.metadata[key] = val - - for key, val in support.items(): - if key in _srv.can_support and key not in _srv.support: - _srv.support[key] = val - service[_srv.service_name] = _srv return service diff --git a/src/idpyoidc/client/service_context.py b/src/idpyoidc/client/service_context.py index 8b69f833..6afac781 100644 --- a/src/idpyoidc/client/service_context.py +++ b/src/idpyoidc/client/service_context.py @@ -4,6 +4,7 @@ """ import copy import hashlib +import logging from typing import Callable from typing import Optional from typing import Union @@ -24,6 +25,10 @@ from .work_condition import work_condition_dump from .work_condition import work_condition_load from .work_condition import WorkCondition +from .work_condition.transform import preferred_to_register +from .work_condition.transform import supported_to_preferred + +logger = logging.getLogger(__name__) CLI_REG_MAP = { "userinfo": { @@ -67,7 +72,6 @@ "client_id": "", "redirect_uris": [], "provider_info": {}, - "behaviour": {}, "callback": {}, "issuer": "" } @@ -113,7 +117,7 @@ class ServiceContext(OidcContext): } def __init__(self, - client_get: Callable, + client_get: Optional[Callable] = None, base_url: Optional[str] = "", keyjar: Optional[KeyJar] = None, config: Optional[Union[dict, Configuration]] = None, @@ -145,24 +149,20 @@ def __init__(self, self.issuer = "" self.httpc_params = {} self.callback = {} - self.client_secret = "" self.client_secret_expires_at = 0 self.provider_info = {} # self.post_logout_redirect_uri = "" # self.redirect_uris = [] self.registration_response = {} - self.requests_dir = "" + # self.requests_dir = "" _def_value = copy.deepcopy(DEFAULT_VALUE) - for param in [ - "client_secret", - "provider_info" - ]: - _val = config.conf.get(param, _def_value[param]) - self.set(param, _val) - if param == "client_secret" and _val: - self.keyjar.add_symmetric("", _val) + _val = config.conf.get("client_secret") + if _val: + self.keyjar.add_symmetric("", _val) + + self.provider_info = config.conf.get("provider_info", {}) _issuer = config.get("issuer") if _issuer: @@ -178,8 +178,6 @@ def __init__(self, for key, val in kwargs.items(): setattr(self, key, val) - self.work_condition.load_conf(config.conf) - def __setitem__(self, key, value): setattr(self, key, value) @@ -228,18 +226,28 @@ def import_keys(self, keyspec): _bundle = KeyBundle(source=url) self.keyjar.add_kb(iss, _bundle) + def _get_crypt(self, typ, attr): + _item_typ = CLI_REG_MAP.get(typ) + _alg = '' + if _item_typ: + _alg = self.work_condition.get_usage(_item_typ[attr]) + if not _alg: + _alg = self.work_condition.get_preference(_item_typ[attr]) + + if not _alg: + _item_typ = PROVIDER_INFO_MAP.get(typ) + if _item_typ: + _alg = self.provider_info.get(_item_typ[attr]) + + return _alg + def get_sign_alg(self, typ): """ :param typ: ['id_token', 'userinfo', 'request_object'] - :return: + :return: signing algorithm """ - - _alg = self.work_condition.get_usage_claim(CLI_REG_MAP[typ]["sign"]) - if not _alg: - _alg = self.provider_info.get(PROVIDER_INFO_MAP[typ]["sign"]) - - return _alg + return self._get_crypt(typ, 'sign') def get_enc_alg_enc(self, typ): """ @@ -250,10 +258,7 @@ def get_enc_alg_enc(self, typ): res = {} for attr in ["enc", "alg"]: - _alg = self.work_condition.get_usage_claim(CLI_REG_MAP[typ][attr]) - if not _alg: - _alg = self.provider_info.get(PROVIDER_INFO_MAP[typ][attr]) - res[attr] = _alg + res[attr] = self._get_crypt(typ, attr) return res @@ -264,36 +269,39 @@ def set(self, key, value): setattr(self, key, value) def get_client_id(self): - return self.work_condition.get_usage_claim("client_id") + return self.work_condition.get_usage("client_id") def collect_usage(self): - services = self. client_get('services') - res = {} - for service in services.values(): - res.update(service.use) - res.update(self.work_condition.use) - return res + return self.work_condition.use def supports(self): - services = self.client_get('services') res = {} - for service in services.values(): - res.update(service.supports()) + if self.client_get: + services = self.client_get('services') + for service in services.values(): + res.update(service.supports()) res.update(self.work_condition.supports()) return res def prefers(self): - services = self.client_get('services') - res = {} - for service in services.values(): - res.update(service.prefer) - res.update(self.work_condition.prefer) - return res + return self.work_condition.prefer + + def get_preference(self, claim, default=None): + return self.work_condition.get_preference(claim) + + def set_preference(self, key, value): + self.work_condition.set_preference(key, value) + + def get_usage(self, claim, default: Optional[str] = None): + return self.work_condition.get_usage(claim, default) + + def set_usage(self, claim, value): + return self.work_condition.set_usage(claim, value) def construct_uris(self, issuer: str, hash_seed: bytes, - callback: Optional[dict]): + callback: Optional[dict] = None): _hash = hashlib.sha256() _hash.update(hash_seed) _hash.update(as_bytes(issuer)) @@ -302,9 +310,35 @@ def construct_uris(self, self.iss_hash = _hex _base_url = self.get("base_url") - services = self.client_get('services') - for service in services.values(): - service.construct_uris(_base_url, _hex) - - if not self.work_condition.get_usage_claim("redirect_uris"): - self.work_condition.construct_redirect_uris(_base_url, _hex, callback) + if self.client_get: + services = self.client_get('services') + for service in services.values(): + service.construct_uris(base_url=_base_url, hex=_hex, + preference=self.work_condition.prefer) + + # if not self.work_condition.get_usage("redirect_uris"): + # self.work_condition.construct_redirect_uris(_base_url, _hex, callback) + + def prefer_or_support(self, claim): + if claim in self.work_condition.prefer: + return 'prefer' + else: + for service in self.client_get('services').values(): + _res = service.prefer_or_support(claim) + if _res: + return _res + + if claim in self.work_condition.supported(claim): + return 'support' + return None + + def map_supported_to_preferred(self, info: Optional[dict] = None): + self.work_condition.prefer = supported_to_preferred(self.supports(), + self.work_condition.prefer, + info) + return self.work_condition.prefer + + def map_preferred_to_register(self): + self.work_condition.use = preferred_to_register(self.work_condition.prefer, + self.work_condition.use) + return self.work_condition.use \ No newline at end of file diff --git a/src/idpyoidc/client/work_condition/__init__.py b/src/idpyoidc/client/work_condition/__init__.py index b59d7ca7..71c4fcf6 100644 --- a/src/idpyoidc/client/work_condition/__init__.py +++ b/src/idpyoidc/client/work_condition/__init__.py @@ -1,3 +1,4 @@ +from functools import cmp_to_key from typing import Callable from typing import Optional @@ -5,6 +6,7 @@ from cryptojwt.jws.jws import SIGNER_ALGS from cryptojwt.utils import importer +from idpyoidc.client.client_auth import CLIENT_AUTHN_METHOD from idpyoidc.client.service import Service from idpyoidc.impexp import ImpExp from idpyoidc.util import qualified_name @@ -46,13 +48,13 @@ def __init__(self, self._local = {} self.callback = {} - def get_usage(self): + def get_use(self): return self.use - def set_usage_claim(self, key, value): + def set_usage(self, key, value): self.use[key] = value - def get_usage_claim(self, key, default=None): + def get_usage(self, key, default=None): return self.use.get(key, default) def get_preference(self, key, default=None): @@ -63,7 +65,7 @@ def set_preference(self, key, value): def _callback_uris(self, base_url, hex): _uri = [] - for type in self.get_usage_claim("response_types", + for type in self.get_usage("response_types", self._supports['response_types']): if "code" in type: _uri.append('code') @@ -86,6 +88,7 @@ def construct_redirect_uris(self, callbacks = self._callback_uris(base_url, hex) if callbacks: + self.set_preference('callbacks', callbacks) self.set_preference("redirect_uris", [v for k, v in callbacks.items()]) self.callback = callbacks @@ -96,20 +99,15 @@ def verify_rules(self): def locals(self, info): pass - def load_conf(self, info): + def load_conf(self, info, supports): for attr, val in info.items(): if attr == "preference": for k, v in val.items(): - if k in self._supports: + if k in supports: self.set_preference(k, v) - elif attr in self._supports: + elif attr in supports: self.set_preference(attr, val) - # # defaults if nothing else is given - # for key, default in self._supports.items(): - # if default and key not in self.prefer: - # self.set_preference(key, default) - self.locals(info) self.verify_rules() return self @@ -135,10 +133,39 @@ def supports(self): res[key] = val return res + def supported(self, claim): + return claim in self._supports + + def prefers(self): + return self.prefer + + +SIGNING_ALGORITHM_SORT_ORDER = ['RS', 'ES', 'PS', 'HS'] + + +def cmp(a, b): + return (a > b) - (a < b) + + +def alg_cmp(a, b): + if a == 'none': + return 1 + elif b == 'none': + return -1 + + _pos1 = SIGNING_ALGORITHM_SORT_ORDER.index(a[0:2]) + _pos2 = SIGNING_ALGORITHM_SORT_ORDER.index(b[0:2]) + if _pos1 == _pos2: + return (a > b) - (a < b) + elif _pos1 > _pos2: + return 1 + else: + return -1 + def get_signing_algs(): # Assumes Cryptojwt - return list(SIGNER_ALGS.keys()) + return sorted(list(SIGNER_ALGS.keys()), key=cmp_to_key(alg_cmp)) def get_encryption_algs(): @@ -147,3 +174,7 @@ def get_encryption_algs(): def get_encryption_encs(): return SUPPORTED['enc'] + + +def get_client_authn_methods(): + return list(CLIENT_AUTHN_METHOD.keys()) diff --git a/src/idpyoidc/client/work_condition/oauth2.py b/src/idpyoidc/client/work_condition/oauth2.py index 9e95145b..c4d25440 100644 --- a/src/idpyoidc/client/work_condition/oauth2.py +++ b/src/idpyoidc/client/work_condition/oauth2.py @@ -4,11 +4,12 @@ class WorkCondition(work_condition.WorkCondition): - metadata_claims = { + _supports = { "redirect_uris": None, "grant_types": ["authorization_code", "implicit", "refresh_token"], "response_types": ["code"], "client_id": None, + 'client_secret': None, "client_name": None, "client_uri": None, "logo_uri": None, @@ -22,25 +23,12 @@ class WorkCondition(work_condition.WorkCondition): "software_version": None } - rules = { - "jwks": None, - "jwks_uri": None, - "scope": ["openid"], - "verify_args": None, - } - callback_path = { - "requests": "req", - "code": "authz_cb", - "implicit": "authz_im_cb", - } + callback_path = {} callback_uris = ["redirect_uris"] def __init__(self, - metadata: Optional[dict] = None, - support: Optional[dict] = None, - behaviour: Optional[dict] = None - ): - work_condition.WorkCondition.__init__(self, metadata=metadata, support=support, - behaviour=behaviour) + prefer: Optional[dict] = None, + callback_path: Optional[dict] = None): + work_condition.WorkCondition.__init__(self, prefer=prefer, callback_path=callback_path) diff --git a/src/idpyoidc/client/work_condition/oidc.py b/src/idpyoidc/client/work_condition/oidc.py index 7f538448..21f56295 100644 --- a/src/idpyoidc/client/work_condition/oidc.py +++ b/src/idpyoidc/client/work_condition/oidc.py @@ -10,8 +10,13 @@ class WorkCondition(work_condition.WorkCondition): "requests_dir": None }) - supports = { - "grant_types": ["authorization_code", "implicit", "refresh_token"], + _supports = { + "grant_types_supported": ["authorization_code", "implicit", "refresh_token"], + "id_token_signing_alg_values_supported": work_condition.get_signing_algs, + "id_token_encryption_alg_values_supported": work_condition.get_encryption_algs, + "id_token_encryption_enc_values_supported": work_condition.get_encryption_encs, + "acr_values_supported": None, + "subject_types_supported": ["public", "pairwise", "ephemeral"], "application_type": "web", "contacts": None, "client_name": None, @@ -23,17 +28,15 @@ class WorkCondition(work_condition.WorkCondition): "jwks_uri": None, "sector_identifier_uri": None, "subject_type": None, - "id_token_signed_response_alg": work_condition.get_signing_algs, - "id_token_encrypted_response_alg": work_condition.get_encryption_algs, - "id_token_encrypted_response_enc": work_condition.get_encryption_encs, "default_max_age": None, "require_auth_time": None, "initiate_login_uri": None, - "default_acr_values": None, "client_id": None, "client_secret": None, "scope": ["openid"], # "verify_args": None, + "requests_dir": None, + "encrypt_id_token": None } def __init__(self, @@ -47,10 +50,6 @@ def verify_rules(self): raise ValueError("You have to chose one of 'request_parameter' and 'request_uri'." " you can't have both.") - # default is jwks_uri - if not self.get_preference("jwks") and not self.get_preference('jwks_uri'): - self.set_preference('jwks_uri', True) - def locals(self, info): requests_dir = info.get("requests_dir") if requests_dir: @@ -59,3 +58,4 @@ def locals(self, info): os.makedirs(requests_dir) self.set("requests_dir", requests_dir) + diff --git a/src/idpyoidc/client/work_condition/transform.py b/src/idpyoidc/client/work_condition/transform.py new file mode 100644 index 00000000..6d0c8220 --- /dev/null +++ b/src/idpyoidc/client/work_condition/transform.py @@ -0,0 +1,115 @@ +import logging +from typing import Optional + +from idpyoidc.message.oidc import RegistrationResponse + +logger = logging.getLogger(__name__) + +REGISTER2PREFERRED = { + # "require_signed_request_object": "request_object_algs_supported", + "request_object_signing_alg": "request_object_signing_alg_values_supported", + "request_object_encryption_alg": "request_object_encryption_alg_values_supported", + "request_object_encryption_enc": "request_object_encryption_enc_values_supported", + "userinfo_signed_response_alg": "userinfo_signing_alg_values_supported", + "userinfo_encrypted_response_alg": "userinfo_encryption_alg_values_supported", + "userinfo_encrypted_response_enc": "userinfo_encryption_enc_values_supported", + "id_token_signed_response_alg": "id_token_signing_alg_values_supported", + "id_token_encrypted_response_alg": "id_token_encryption_alg_values_supported", + "id_token_encrypted_response_enc": "id_token_encryption_enc_values_supported", + "default_acr_values": "acr_values_supported", + "subject_type": "subject_types_supported", + "token_endpoint_auth_method": "token_endpoint_auth_methods_supported", + "token_endpoint_auth_signing_alg": "token_endpoint_auth_signing_alg_values_supported", + "response_types": "response_types_supported", + "grant_types": "grant_types_supported", + "scope": "scopes_supported", + "display": "display_values_supported", + "claims": "claims_supported", + "request": "request_parameter_supported", + "request_uri": "request_uri_parameter_supported", + 'claims_locales': 'claims_locales_supported', + 'ui_locales': 'ui_locales_supported', +} + +PREFERRED2REGISTER = dict([(v, k) for k, v in REGISTER2PREFERRED.items()]) + +REQUEST2REGISTER = { + 'client_id': "client_id", + "client_secret": "client_secret", + # 'acr_values': "default_acr_values" , + # 'max_age': "default_max_age", + 'redirect_uri': "redirect_uris", + 'response_type': "response_types", + 'request_uri': "request_uris", + 'grant_type': "grant_types" +} + + +# AUTHORIZATION_REQUEST = [ +# "acr_values", +# "claims", +# "claims_locales", +# "client_id", +# "display", +# "id_token_hint", +# "login_hint", +# "max_age", +# "nonce", +# "prompt", +# "redirect_uri", +# "registration", +# "request", +# "request_uri", +# "response_mode" +# "response_type", +# "scope", +# "state", +# "ui_locales", +# ] + + +def supported_to_preferred(supported: dict, preference: dict, info: Optional[dict] = None): + for key, val in supported.items(): + if info and key in info: + preference[key] = info[key] + continue + + if val is None: + continue + + if key not in preference: + preference[key] = val + + return preference + + +def preferred_to_register(prefers: dict, use: Optional[dict] = None): + if not use: + use = {} + + for key, spec in RegistrationResponse.c_param.items(): + _pref_key = REGISTER2PREFERRED.get(key, key) + + _preferred_values = prefers.get(_pref_key) + if not _preferred_values: + continue + + if isinstance(spec[0], list): + if _preferred_values: + use[key] = _preferred_values + else: + if _preferred_values: + if isinstance(_preferred_values, list): + use[key] = _preferred_values[0] + else: + use[key] = _preferred_values + + _rr_keys = list(RegistrationResponse.c_param.keys()) + for key, val in prefers.items(): + if PREFERRED2REGISTER.get(key): + continue + if key not in _rr_keys: + use[key] = val + + logger.debug(f"Entity uses: {use}") + return use diff --git a/src/idpyoidc/defaults.py b/src/idpyoidc/defaults.py index 5e8e65c3..83341e84 100644 --- a/src/idpyoidc/defaults.py +++ b/src/idpyoidc/defaults.py @@ -12,3 +12,4 @@ JWT_BEARER = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" BASECHR = string.ascii_letters + string.digits + diff --git a/src/idpyoidc/server/oidc/registration.py b/src/idpyoidc/server/oidc/registration.py index 9db22afa..8076f618 100755 --- a/src/idpyoidc/server/oidc/registration.py +++ b/src/idpyoidc/server/oidc/registration.py @@ -7,9 +7,9 @@ from urllib.parse import urlencode from urllib.parse import urlparse -from cryptojwt.jws.utils import alg2keytype from cryptojwt.utils import as_bytes +# from idpyoidc.defaults import PREFERENCE2SUPPORTED from idpyoidc.exception import MessageException from idpyoidc.message.oauth2 import ResponseMessage from idpyoidc.message.oidc import ClientRegistrationErrorResponse @@ -25,25 +25,6 @@ from idpyoidc.util import sanitize from idpyoidc.util import split_uri -PREFERENCE2PROVIDER = { - # "require_signed_request_object": "request_object_algs_supported", - "request_object_signing_alg": "request_object_signing_alg_values_supported", - "request_object_encryption_alg": "request_object_encryption_alg_values_supported", - "request_object_encryption_enc": "request_object_encryption_enc_values_supported", - "userinfo_signed_response_alg": "userinfo_signing_alg_values_supported", - "userinfo_encrypted_response_alg": "userinfo_encryption_alg_values_supported", - "userinfo_encrypted_response_enc": "userinfo_encryption_enc_values_supported", - "id_token_signed_response_alg": "id_token_signing_alg_values_supported", - "id_token_encrypted_response_alg": "id_token_encryption_alg_values_supported", - "id_token_encrypted_response_enc": "id_token_encryption_enc_values_supported", - "default_acr_values": "acr_values_supported", - "subject_type": "subject_types_supported", - "token_endpoint_auth_method": "token_endpoint_auth_methods_supported", - "token_endpoint_auth_signing_alg": "token_endpoint_auth_signing_alg_values_supported", - "response_types": "response_types_supported", - "grant_types": "grant_types_supported", -} - logger = logging.getLogger(__name__) @@ -124,12 +105,13 @@ def comb_uri(args): args["request_uris"] = val -def random_client_id(length: int = 16, reserved: list = [], **kwargs): +def random_client_id(length: int = 16, reserved: list = None, **kwargs): # create new id och secret client_id = rndstr(16) # cdb client_id MUST be unique! - while client_id in reserved: - client_id = rndstr(16) + if reserved: + while client_id in reserved: + client_id = rndstr(16) return client_id @@ -156,18 +138,18 @@ def __init__(self, *args, **kwargs): def match_client_request(self, request): _context = self.server_get("endpoint_context") - for _pref, _prov in PREFERENCE2PROVIDER.items(): - if _pref in request: - if _pref in ["response_types", "default_acr_values"]: - if not match_sp_sep(request[_pref], _context.provider_info[_prov]): - raise CapabilitiesMisMatch(_pref) - else: - if isinstance(request[_pref], str): - if request[_pref] not in _context.provider_info[_prov]: - raise CapabilitiesMisMatch(_pref) - else: - if not set(request[_pref]).issubset(set(_context.provider_info[_prov])): - raise CapabilitiesMisMatch(_pref) + # for _pref, _prov in PREFERENCE2SUPPORTED.items(): + # if _pref in request: + # if _pref in ["response_types", "default_acr_values"]: + # if not match_sp_sep(request[_pref], _context.provider_info[_prov]): + # raise CapabilitiesMisMatch(_pref) + # else: + # if isinstance(request[_pref], str): + # if request[_pref] not in _context.provider_info[_prov]: + # raise CapabilitiesMisMatch(_pref) + # else: + # if not set(request[_pref]).issubset(set(_context.provider_info[_prov])): + # raise CapabilitiesMisMatch(_pref) def do_client_registration(self, request, client_id, ignore=None): if ignore is None: @@ -236,22 +218,22 @@ def do_client_registration(self, request, client_id, ignore=None): ) # Do I have the necessary keys - for item in ["id_token_signed_response_alg", "userinfo_signed_response_alg"]: - if item in request: - if request[item] in _context.provider_info[PREFERENCE2PROVIDER[item]]: - ktyp = alg2keytype(request[item]) - # do I have this ktyp and for EC type keys the curve - if ktyp not in ["none", "oct"]: - _k = [] - for iss in ["", _context.issuer]: - _k.extend( - _context.keyjar.get_signing_key( - ktyp, alg=request[item], issuer_id=iss - ) - ) - if not _k: - logger.warning('Lacking support for "{}"'.format(request[item])) - del _cinfo[item] + # for item in ["id_token_signed_response_alg", "userinfo_signed_response_alg"]: + # if item in request: + # if request[item] in _context.provider_info[PREFERENCE2SUPPORTED[item]]: + # ktyp = alg2keytype(request[item]) + # # do I have this ktyp and for EC type keys the curve + # if ktyp not in ["none", "oct"]: + # _k = [] + # for iss in ["", _context.issuer]: + # _k.extend( + # _context.keyjar.get_signing_key( + # ktyp, alg=request[item], issuer_id=iss + # ) + # ) + # if not _k: + # logger.warning('Lacking support for "{}"'.format(request[item])) + # del _cinfo[item] t = {"jwks_uri": "", "jwks": None} diff --git a/tests/request123456.jwt b/tests/request123456.jwt index 88690753..b462cac9 100644 --- a/tests/request123456.jwt +++ b/tests/request123456.jwt @@ -1 +1 @@ -eyJhbGciOiJSUzI1NiIsImtpZCI6IlNIRXlZV2N3TlZrMExUZFJPVFp6WjJGVVduZElWWGRhY2sweFdVTTVTRXB3Y1MwM2RWVXhXVTR6UlEifQ.eyJyZXNwb25zZV90eXBlIjogImNvZGUiLCAic3RhdGUiOiAic3RhdGUiLCAicmVkaXJlY3RfdXJpIjogImh0dHBzOi8vZXhhbXBsZS5jb20vY2xpL2F1dGh6X2NiIiwgInNjb3BlIjogIm9wZW5pZCIsICJub25jZSI6ICJzcG1VT0V3Z01PS29TdkNXUzJjLVVpcFg5cUlxNHA4UC0wVTBnTW93NjBRIiwgImNsaWVudF9pZCI6ICJjbGllbnRfaWQiLCAiaXNzIjogImNsaWVudF9pZCIsICJpYXQiOiAxNjY4MjU3MTgwLCAiYXVkIjogWyJodHRwczovL2V4YW1wbGUuY29tIl19.b0VYiEKj4WcZ48Bcj4mZHxrWGeZUyuGTOqwiznJB7qquohHlKv9ZtJ179uedRI-SKuSsduo6-KVRKHWOT8yDzPYZZFkVemR75GjV8ciMJLL4zOOB6a15tYzGCx0UpAHkvcYj1jAKyfOBDPRa-YFElxzK2dbvEWiBYEhuy6B5oQZxTJagftPUhO1UT9go3NA3H_Ck-nHnpR5QET0ctprTkp8LETp_rGkuGp-ESlwdMj0a-mCDK0iVhv9xP4fXX47gI1XPxTdceRxrda3EWYWfBDn95ykl2L8FbDznBZ6c2yvc6h0DZJGdlvDpoMWjiBtA_IaoKWBKbNbU4PplyiLR8A \ No newline at end of file +eyJhbGciOiJSUzI1NiIsImtpZCI6IlNIRXlZV2N3TlZrMExUZFJPVFp6WjJGVVduZElWWGRhY2sweFdVTTVTRXB3Y1MwM2RWVXhXVTR6UlEifQ.eyJyZXNwb25zZV90eXBlIjogImNvZGUiLCAic3RhdGUiOiAic3RhdGUiLCAicmVkaXJlY3RfdXJpIjogImh0dHBzOi8vZXhhbXBsZS5jb20vY2xpL2F1dGh6X2NiIiwgInNjb3BlIjogIm9wZW5pZCIsICJub25jZSI6ICJtdEk0TWk5WkFweTR4TWNLSkF0c3BXVFRwa1RqcUFTTHpLWHg1Y0VhNEt3IiwgImNsaWVudF9pZCI6ICJjbGllbnRfaWQiLCAiaXNzIjogImNsaWVudF9pZCIsICJpYXQiOiAxNjY4NjIwMDA1LCAiYXVkIjogWyJodHRwczovL2V4YW1wbGUuY29tIl19.cSFmCUFZh6jCHiAC7n6EkC_gIkLfnlH2GXCVTUV2SfF19k2wHVH2L8hLj4SmjreoVYKNkhJdt6qpxpmmAP4dWZorUhFJc4j4vp0rIFflffVSg5db1bCvG4_H_XqJxhQdpcUlqfTTkKiqQ9v4fnbh_mTtDJc8ZHLjHaPrRFsSNTvsGeR366PL8bbSrY7F5CX_Ox86B5gIMKDCNt6Cqywd0TcfN5PFrLAKPe3rH1md3dg85dN64xFupSqKhqXlQ3QggrZDQLbGAUnf3YUeqSn2dGb8Of_hVgzfVN33P2uT6x7kkNRmizXEUlsGZ3IiFsPRC59ZF_rObnsRZrGa_9-uLg \ No newline at end of file diff --git a/tests/test_client_01_service_context.py b/tests/test_client_01_service_context.py index 10c518c2..7d6a8baa 100644 --- a/tests/test_client_01_service_context.py +++ b/tests/test_client_01_service_context.py @@ -21,6 +21,7 @@ class TestServiceContext: + @pytest.fixture(autouse=True) def setup(self): self.service_context = ServiceContext(config=MINI_CONFIG) @@ -32,23 +33,23 @@ def test_filename_from_webname(self): _filename = self.service_context.filename_from_webname("https://example.com/cli/jwks.json") assert _filename == "jwks.json" - def test_create_callback_uris(self): - base_url = "https://example.com/cli" - hex = "0123456789" - self.service_context.work_condition.construct_redirect_uris(base_url, hex, []) - _uris = self.service_context.work_condition.get_metadata_claim("redirect_uris") - assert len(_uris) == 1 - assert _uris == [f"https://example.com/cli/authz_cb/{hex}"] + # def test_create_callback_uris(self): + # base_url = "https://example.com/cli" + # hex = "0123456789" + # self.service_context.work_condition.construct_redirect_uris(base_url, hex, []) + # _uris = self.service_context.work_condition.get_metadata_claim("redirect_uris") + # assert len(_uris) == 1 + # assert _uris == [f"https://example.com/cli/authz_cb/{hex}"] def test_get_sign_alg(self): _alg = self.service_context.get_sign_alg("id_token") assert _alg is None - self.service_context.work_condition.behaviour["id_token_signed_response_alg"] = "RS384" + self.service_context.work_condition.set_preference("id_token_signed_response_alg", "RS384") _alg = self.service_context.get_sign_alg("id_token") assert _alg == "RS384" - self.service_context.work_condition.behaviour = {} + self.service_context.work_condition.prefer = {} self.service_context.provider_info["id_token_signing_alg_values_supported"] = [ "RS256", "ES256", @@ -60,13 +61,15 @@ def test_get_enc_alg_enc(self): _alg_enc = self.service_context.get_enc_alg_enc("userinfo") assert _alg_enc == {"alg": None, "enc": None} - self.service_context.work_condition.behaviour["userinfo_encrypted_response_alg"] = "RSA1_5" - self.service_context.work_condition.behaviour["userinfo_encrypted_response_enc"] = "A128CBC+HS256" + self.service_context.work_condition.set_preference("userinfo_encrypted_response_alg", + "RSA1_5") + self.service_context.work_condition.set_preference("userinfo_encrypted_response_enc", + "A128CBC+HS256") _alg_enc = self.service_context.get_enc_alg_enc("userinfo") assert _alg_enc == {"alg": "RSA1_5", "enc": "A128CBC+HS256"} - self.service_context.work_condition.behaviour = {} + self.service_context.work_condition.prefer = {} self.service_context.provider_info["userinfo_encryption_alg_values_supported"] = [ "RSA1_5", "A128KW", @@ -83,5 +86,5 @@ def test_get(self): assert self.service_context.get("base_url") == MINI_CONFIG["base_url"] def test_set(self): - self.service_context.set("client_id", "number5") - assert self.service_context.get("client_id") == "number5" + self.service_context.set_preference("client_id", "number5") + assert self.service_context.get_preference("client_id") == "number5" diff --git a/tests/test_client_02_entity.py b/tests/test_client_02_entity.py index ba97b7da..b3152e55 100644 --- a/tests/test_client_02_entity.py +++ b/tests/test_client_02_entity.py @@ -39,7 +39,7 @@ def test_get_service_unsupported(self): assert _srv is None def test_get_client_id(self): - assert self.entity.get_metadata_claim("client_id") == "Number5" + assert self.entity.get_service_context().get_preference("client_id") == "Number5" assert self.entity.client_get("client_id") == "Number5" def test_get_service_by_endpoint_name(self): diff --git a/tests/test_client_02b_entity_metadata.py b/tests/test_client_02b_entity_metadata.py index 351d170e..583faf0c 100644 --- a/tests/test_client_02b_entity_metadata.py +++ b/tests/test_client_02b_entity_metadata.py @@ -9,17 +9,21 @@ "client_secret": "a longesh password", "issuer": ISS, "application_name": "rphandler", - "metadata": { + "preference": { "application_type": "web", "contacts": "support@example.com", "response_types": ["code"], "client_id": "client_id", "redirect_uris": ["https://example.com/cli/authz_cb"], - "request_object_signing_alg": "ES256" - }, - "usage": { + 'request_parameter': True, + "request_object_signing_alg_values_supported": ["ES256"], "scope": ["openid", "profile", "email", "address", "phone"], - "request_uri": True + "token_endpoint_auth_methods_supported": ["private_key_jwt"], + "token_endpoint_auth_signing_alg_values_supported": ["ES256"], + "userinfo_signing_alg_values_supported": ["ES256"], + "post_logout_redirect_uris": ["https://rp.example.com/post"], + "backchannel_logout_uri": "https://rp.example.com/back", + "backchannel_logout_session_required": True }, "services": { @@ -33,39 +37,19 @@ }, "authorization": { "class": "idpyoidc.client.oidc.authorization.Authorization", - "kwargs": { - "support": {"request_uris": True} - } + "kwargs": {} }, "accesstoken": { "class": "idpyoidc.client.oidc.access_token.AccessToken", - "kwargs": { - "metadata": { - "token_endpoint_auth_method": "private_key_jwt", - "token_endpoint_auth_signing_alg": "ES256" - } - } + "kwargs": {} }, "userinfo": { "class": "idpyoidc.client.oidc.userinfo.UserInfo", - "kwargs": { - "metadata": { - "userinfo_signed_response_alg": "ES256" - }, - } + "kwargs": {} }, "end_session": { "class": "idpyoidc.client.oidc.end_session.EndSession", - "kwargs": { - "metadata": { - "post_logout_redirect_uris": ["https://rp.example.com/post"], - "backchannel_logout_uri": "https://rp.example.com/back", - "backchannel_logout_session_required": True - }, - "support": { - "backchannel_logout": True - } - } + "kwargs": {} } } } @@ -80,50 +64,62 @@ def test_create_client(): client = Entity(config=CLIENT_CONFIG, client_type='oidc') - _md = client.collect_metadata() - assert set(_md.keys()) == {'application_type', - 'backchannel_logout_uri', - "backchannel_logout_session_required", - 'client_id', - 'contacts', - 'grant_types', - 'id_token_signed_response_alg', - 'post_logout_redirect_uris', - 'redirect_uris', - 'request_object_signing_alg', - 'request_uris', - 'response_types', - 'token_endpoint_auth_method', - 'token_endpoint_auth_signing_alg', - 'userinfo_signed_response_alg'} - - # What's in service configuration has higher priority then metadata. - assert client.get_metadata_claim("contacts") == 'support@example.com' - # Two ways of looking at things - assert client.get_metadata_claim("userinfo_signed_response_alg") == "ES256" - assert client.metadata_claim_contains_value("userinfo_signed_response_alg", "ES256") + client.get_service_context().map_supported_to_preferred() + _pref = client.prefers() + assert set(_pref.keys()) == {'application_type', + 'backchannel_logout_session_required', + 'backchannel_logout_uri', + 'client_id', + 'client_secret', + 'contacts', + 'grant_types_supported', + 'id_token_encryption_alg_values_supported', + 'id_token_encryption_enc_values_supported', + 'id_token_signing_alg_values_supported', + 'jwks', + 'post_logout_redirect_uris', + 'redirect_uris', + 'request_object_encryption_alg_values_supported', + 'request_object_encryption_enc_values_supported', + 'request_object_signing_alg_values_supported', + 'request_parameter', + 'response_types_supported', + 'scope', + 'subject_types_supported', + 'token_endpoint_auth_methods_supported', + 'token_endpoint_auth_signing_alg_values_supported', + 'userinfo_encryption_alg_values_supported', + 'userinfo_encryption_enc_values_supported', + 'userinfo_signing_alg_values_supported'} + + # What's in service configuration has higher priority then what's just supported. + _context = client.get_service_context() + assert _context.get_preference("contacts") == 'support@example.com' + # + assert _context.get_preference("userinfo_signing_alg_values_supported") == ['ES256'] # How to act - assert client.get_support("request_uris") is True + _context.work_condition.use = _context.map_preferred_to_register() + + assert _context.get_usage("request_uris") is None - _conf_args = client.config_args() + _conf_args = _context.collect_usage() assert _conf_args - total_metadata_args = {} - for key, val in _conf_args.items(): - total_metadata_args.update(val["metadata"]) - ma = list(total_metadata_args.keys()) - ma.sort() - assert len(ma) == 36 + assert len(_conf_args) == 25 rr = set(RegistrationRequest.c_param.keys()) - d = rr.difference(set(ma)) - assert d == {'federation_type', 'organization_name', 'post_logout_redirect_uri'} + d = rr.difference(set(_conf_args)) + assert d == {'initiate_login_uri', 'client_name', 'post_logout_redirect_uri', 'tos_uri', + 'logo_uri', 'jwks_uri', 'federation_type', 'frontchannel_logout_session_required', + 'require_auth_time', 'client_uri', 'frontchannel_logout_uri', 'request_uris', + 'sector_identifier_uri', 'default_max_age', 'organization_name', 'policy_uri', + 'default_acr_values'} def test_create_client_key_conf(): client_config = CLIENT_CONFIG.copy() client_config.update({"key_conf": KEY_CONF}) - client = Entity(config=client_config) - _jwks = client.get_metadata_claim("jwks") + client = Entity(config=client_config, client_type='oidc') + _jwks = client.get_service_context().get_preference("jwks") assert _jwks @@ -131,12 +127,12 @@ def test_create_client_keyjar(): _keyjar = init_key_jar(**KEY_CONF) client_config = CLIENT_CONFIG.copy() - client = Entity(config=client_config, keyjar=_keyjar) - _jwks = client.get_metadata_claim("jwks") + client = Entity(config=client_config, keyjar=_keyjar, client_type='oidc') + _jwks = client.get_service_context().get_preference("jwks") assert _jwks def test_create_client_jwks_uri(): client_config = CLIENT_CONFIG.copy() client = Entity(config=client_config, jwks_uri="https://rp.example.com/jwks_uri.json") - assert client.get_metadata_claim("jwks_uri") + assert client.get_service_context().get_preference("jwks_uri") diff --git a/tests/test_client_04_service.py b/tests/test_client_04_service.py index 3a6377ec..9101d377 100644 --- a/tests/test_client_04_service.py +++ b/tests/test_client_04_service.py @@ -6,6 +6,7 @@ class Response(object): + def __init__(self, status_code, text, headers=None): self.status_code = status_code self.text = text @@ -19,23 +20,25 @@ def __init__(self, status_code, text, headers=None): CLIENT_CONF = { "redirect_uris": ["https://example.com/cli/authz_cb"], - "behaviour": {"response_types": ["code"]}, + "preference": {"response_types": ["code"]}, "key_conf": {"key_defs": KEYDEFS}, + "client_id": 'CLIENT' } class TestService: + @pytest.fixture(autouse=True) def create_service(self): self.entity = Entity( config=CLIENT_CONF, - services={ - "authz": {"class": "idpyoidc.client.oidc.authorization.Authorization"}, - } + services={"authz": {"class": "idpyoidc.client.oidc.authorization.Authorization"}}, + client_type='oidc' ) self.service = self.entity.get_service("authorization") self.service_context = self.entity.get_service_context() + self.service_context.map_supported_to_preferred() def client_get(self, *args): if args[0] == "service_context": @@ -44,34 +47,46 @@ def client_get(self, *args): def test_1(self): assert self.service + def test_use(self): + use = self.service_context.map_preferred_to_register() + + assert set(use.keys()) == {'client_id', 'redirect_uris', 'response_types', + 'grant_types', 'application_type', 'jwks', 'subject_type', + 'id_token_signed_response_alg', + 'id_token_encrypted_response_alg', + 'id_token_encrypted_response_enc', + 'request_object_signing_alg', + 'request_object_encryption_alg', + 'request_object_encryption_enc', 'scope'} + def test_gather_request_args(self): self.service.conf["request_args"] = {"response_type": "code"} args = self.service.gather_request_args(state="state") - assert args == {"response_type": "code", "state": "state", + assert args == {"response_type": "code", "state": "state", 'client_id': 'CLIENT', 'redirect_uri': 'https://example.com/cli/authz_cb', 'scope': ['openid']} - self.entity.set_metadata_claim("client_id", "client") + self.service_context.set_usage("client_id", "client") args = self.service.gather_request_args(state="state") assert args == {"client_id": "client", "response_type": "code", "state": "state", 'redirect_uri': 'https://example.com/cli/authz_cb', 'scope': ['openid']} - self.service.default_request_args = {"scope": ["openid"]} + self.service_context.set_usage("scope", ["openid", "foo"]) args = self.service.gather_request_args(state="state") assert args == { "client_id": "client", "response_type": "code", - "scope": ["openid"], + "scope": ["openid", "foo"], "state": "state", 'redirect_uri': 'https://example.com/cli/authz_cb', } - self.entity.set_metadata_claim("redirect_uris", ["https://rp.example.com"]) + self.service_context.set_usage("redirect_uri", "https://rp.example.com") args = self.service.gather_request_args(state="state") assert args == { "client_id": "client", "redirect_uri": "https://rp.example.com", "response_type": "code", - "scope": ["openid"], + "scope": ["openid", "foo"], "state": "state", } @@ -126,6 +141,7 @@ def test_parse_response_err(self): class TestAuthorization(object): + @pytest.fixture(autouse=True) def create_service(self): self.entity = Entity( diff --git a/tests/test_client_06_client_authn.py b/tests/test_client_06_client_authn.py index 5a472863..e5628d08 100644 --- a/tests/test_client_06_client_authn.py +++ b/tests/test_client_06_client_authn.py @@ -1,29 +1,27 @@ import base64 import os -from urllib.parse import quote_plus +import pytest from cryptojwt.exception import MissingKey -from cryptojwt.jws.jws import JWS from cryptojwt.jws.jws import factory +from cryptojwt.jws.jws import JWS from cryptojwt.jwt import JWT from cryptojwt.key_bundle import KeyBundle -from cryptojwt.key_jar import KeyJar from cryptojwt.key_jar import init_key_jar -import pytest +from cryptojwt.key_jar import KeyJar +from idpyoidc.client.client_auth import assertion_jwt from idpyoidc.client.client_auth import AuthnFailure +from idpyoidc.client.client_auth import bearer_auth from idpyoidc.client.client_auth import BearerBody from idpyoidc.client.client_auth import BearerHeader from idpyoidc.client.client_auth import ClientSecretBasic from idpyoidc.client.client_auth import ClientSecretJWT from idpyoidc.client.client_auth import ClientSecretPost from idpyoidc.client.client_auth import PrivateKeyJWT -from idpyoidc.client.client_auth import assertion_jwt -from idpyoidc.client.client_auth import bearer_auth from idpyoidc.client.client_auth import valid_service_context from idpyoidc.client.entity import Entity from idpyoidc.client.work_condition import WorkCondition - from idpyoidc.defaults import JWT_BEARER from idpyoidc.message import Message from idpyoidc.message.oauth2 import AccessTokenRequest @@ -38,7 +36,7 @@ CLIENT_CONF = { "issuer": "https://example.com/as", - "redirect_uris": ["https://example.com/cli/authz_cb"], + # "redirect_uris": ["https://example.com/cli/authz_cb"], "client_secret": "white boarding pass", "client_id": CLIENT_ID, } @@ -57,7 +55,7 @@ def _eq(l1, l2): @pytest.fixture def entity(): keyjar = init_key_jar(**KEY_CONF) - return Entity( + _entity = Entity( config=CLIENT_CONF, services={ "base": {"class": "idpyoidc.client.service.Service"}, @@ -67,8 +65,14 @@ def entity(): } } }, - keyjar=keyjar + keyjar=keyjar, + client_type='oidc' ) + # The following two lines is necessary since they replace provider info collection and + # client registration. + _entity.get_service_context().map_supported_to_preferred() + _entity.get_service_context().map_preferred_to_register() + return _entity def test_quote(): @@ -86,16 +90,18 @@ def test_quote(): class TestClientSecretBasic(object): + def test_construct(self, entity): _service = entity.client_get("service", "") - request = _service.construct(redirect_uri="http://example.com", state="ABCDE") + request = _service.construct( + request_args={'redirect_uri': "http://example.com", 'state': "ABCDE"}) csb = ClientSecretBasic() http_args = csb.construct(request, _service) _authz = http_args["headers"]["Authorization"] assert _authz.startswith("Basic ") - _token = _authz.split(" ",1)[1] + _token = _authz.split(" ", 1)[1] assert base64.urlsafe_b64decode(_token) == b'A:white boarding pass' def test_does_not_remove_padding(self): @@ -117,6 +123,7 @@ def test_construct_cc(self): class TestBearerHeader(object): + def test_construct(self, entity): request = ResourceRequest(access_token="Sesame") bh = BearerHeader() @@ -189,6 +196,7 @@ def test_construct_with_token(self, entity): class TestBearerBody(object): + def test_construct(self, entity): _token_service = entity.client_get("service", "") request = ResourceRequest(access_token="Sesame") @@ -245,9 +253,11 @@ def test_construct_with_request(self, entity): class TestClientSecretPost(object): + def test_construct(self, entity): _token_service = entity.client_get("service", "") - request = _token_service.construct(redirect_uri="http://example.com", state="ABCDE") + request = _token_service.construct(request_args={'redirect_uri': "http://example.com", + 'state': "ABCDE"}) csp = ClientSecretPost() http_args = csp.construct(request, service=_token_service) @@ -263,22 +273,25 @@ def test_construct(self, entity): def test_modify_1(self, entity): token_service = entity.client_get("service", "") - request = token_service.construct(redirect_uri="http://example.com", state="ABCDE") + request = token_service.construct(request_args={'redirect_uri': "http://example.com", + 'state': "ABCDE"}) csp = ClientSecretPost() http_args = csp.construct(request, service=token_service) assert "client_secret" in request def test_modify_2(self, entity): _service = entity.client_get("service", "") - request = _service.construct(redirect_uri="http://example.com", state="ABCDE") + request = _service.construct(request_args={'redirect_uri': "http://example.com", + 'state': "ABCDE"}) csp = ClientSecretPost() - _service.client_get("service_context").client_secret = "" + _service.client_get("service_context").set_usage('client_secret', "") # this will fail with pytest.raises(AuthnFailure): http_args = csp.construct(request, service=_service) class TestPrivateKeyJWT(object): + def test_construct(self, entity): token_service = entity.client_get("service", "") kb_rsa = KeyBundle( @@ -336,6 +349,7 @@ def test_construct_client_assertion(self, entity): class TestClientSecretJWT_TE(object): + def test_client_secret_jwt(self, entity): _service_context = entity.client_get("service_context") _service_context.token_endpoint = "https://example.com/token" @@ -345,7 +359,8 @@ def test_client_secret_jwt(self, entity): "token_endpoint": "https://example.com/token", } - _service_context.registration_response = {"token_endpoint_auth_signing_alg": "HS256"} + # This is not the default + _service_context.set_usage("token_endpoint_auth_signing_alg", "HS256") csj = ClientSecretJWT() request = AccessTokenRequest() @@ -359,7 +374,7 @@ def test_client_secret_jwt(self, entity): _kj = KeyJar() _kj.add_symmetric(_service_context.get_client_id(), - _service_context.client_secret, ["sig"]) + _service_context.get_usage('client_secret'), ["sig"]) jso = JWT(key_jar=_kj, sign_alg="HS256").unpack(cas) assert _eq(jso.keys(), ["aud", "iss", "sub", "exp", "iat", "jti"]) @@ -379,7 +394,7 @@ def test_get_key_by_kid(self, entity): "token_endpoint": "https://example.com/token", } - _service_context.registration_response = {"token_endpoint_auth_signing_alg": "HS256"} + _service_context.set_usage("token_endpoint_auth_signing_alg", "HS256") csj = ClientSecretJWT() request = AccessTokenRequest() @@ -401,7 +416,7 @@ def test_get_key_by_kid_fail(self, entity): "token_endpoint": "https://example.com/token", } - _service_context.registration_response = {"token_endpoint_auth_signing_alg": "HS256"} + _service_context.set_usage("token_endpoint_auth_signing_alg", "HS256") csj = ClientSecretJWT() request = AccessTokenRequest() @@ -420,7 +435,8 @@ def test_get_audience_and_algorithm_default_alg(self, entity): "token_endpoint": "https://example.com/token", } - _service_context.registration_response = {"token_endpoint_auth_signing_alg": "HS256"} + # This is the default so this line is unnecessary + # _service_context.set_usage("token_endpoint_auth_signing_alg", "RS256") csj = ClientSecretJWT() request = AccessTokenRequest() @@ -429,7 +445,7 @@ def test_get_audience_and_algorithm_default_alg(self, entity): token_service = entity.client_get("service", "") - # Since I have a RSA key this doesn't fail + # Since I have an RSA key this doesn't fail csj.construct(request, service=token_service, authn_endpoint="token_endpoint") _jws = factory(request["client_assertion"]) @@ -439,12 +455,11 @@ def test_get_audience_and_algorithm_default_alg(self, entity): # By client preferences request = AccessTokenRequest() - _service_context.work_condition.set_metadata_claim("token_endpoint_auth_signing_alg", - "RS512") + _service_context.set_usage("token_endpoint_auth_signing_alg", "RS512") csj.construct(request, service=token_service, authn_endpoint="token_endpoint") _jws = factory(request["client_assertion"]) - assert _jws.jwt.headers["alg"] == "RS256" + assert _jws.jwt.headers["alg"] == "RS512" assert _jws.jwt.headers["kid"] == _rsa_key.kid # Use provider information is everything else fails @@ -464,6 +479,7 @@ def test_get_audience_and_algorithm_default_alg(self, entity): class TestClientSecretJWT_UI(object): + def test_client_secret_jwt(self, entity): access_token_service = entity.client_get("service", "") @@ -486,7 +502,7 @@ def test_client_secret_jwt(self, entity): _kj = KeyJar() _kj.add_symmetric(_service_context.get_client_id(), - _service_context.client_secret, usage=["sig"]) + _service_context.get_usage('client_secret'), usage=["sig"]) jso = JWT(key_jar=_kj, sign_alg="HS256").unpack(cas) assert _eq(jso.keys(), ["aud", "iss", "sub", "jti", "exp", "iat"]) @@ -499,6 +515,7 @@ def test_client_secret_jwt(self, entity): class TestValidClientInfo(object): + def test_valid_service_context(self, entity): _service_context = entity.client_get("service_context") diff --git a/tests/test_client_12_client_auth.py b/tests/test_client_12_client_auth.py index 1d2f9654..484bf304 100755 --- a/tests/test_client_12_client_auth.py +++ b/tests/test_client_12_client_auth.py @@ -1,25 +1,24 @@ import base64 import os -from urllib.parse import quote_plus import pytest from cryptojwt.exception import MissingKey from cryptojwt.jwk.rsa import new_rsa_key -from cryptojwt.jws.jws import JWS from cryptojwt.jws.jws import factory +from cryptojwt.jws.jws import JWS from cryptojwt.jwt import JWT from cryptojwt.key_bundle import KeyBundle from cryptojwt.key_jar import KeyJar +from idpyoidc.client.client_auth import assertion_jwt from idpyoidc.client.client_auth import AuthnFailure +from idpyoidc.client.client_auth import bearer_auth from idpyoidc.client.client_auth import BearerBody from idpyoidc.client.client_auth import BearerHeader from idpyoidc.client.client_auth import ClientSecretBasic from idpyoidc.client.client_auth import ClientSecretJWT from idpyoidc.client.client_auth import ClientSecretPost from idpyoidc.client.client_auth import PrivateKeyJWT -from idpyoidc.client.client_auth import assertion_jwt -from idpyoidc.client.client_auth import bearer_auth from idpyoidc.client.client_auth import valid_service_context from idpyoidc.client.entity import Entity from idpyoidc.defaults import JWT_BEARER @@ -48,7 +47,12 @@ def _eq(l1, l2): @pytest.fixture def entity(): - return Entity(config=CLIENT_CONF) + entity = Entity(config=CLIENT_CONF, client_type='oidc') + # The following two lines is necessary since they replace provider info collection and + # client registration. + entity.get_service_context().map_supported_to_preferred() + entity.get_service_context().map_preferred_to_register() + return entity def test_quote(): @@ -60,15 +64,17 @@ def test_quote(): ) assert ( - http_args["headers"]["Authorization"] == "Basic " - 'Nzk2ZDhmYWUtYTQyZi00ZTRmLWFiMjUtZDYyMDViNmQ0ZmEyOk1LRU0vQTdQa243SnVVMExBY3h5SFZLdndkY3pzdWdhUFUwQmllTGI0Q2JRQWdRait5cGNhbkZPQ2IwL0ZBNWg=' + http_args["headers"]["Authorization"] == "Basic " + 'Nzk2ZDhmYWUtYTQyZi00ZTRmLWFiMjUtZDYyMDViNmQ0ZmEyOk1LRU0vQTdQa243SnVVMExBY3h5SFZLdndkY3pzdWdhUFUwQmllTGI0Q2JRQWdRait5cGNhbkZPQ2IwL0ZBNWg=' ) class TestClientSecretBasic(object): + def test_construct(self, entity): _token_service = entity.client_get("service", "accesstoken") - request = _token_service.construct(redirect_uri="http://example.com", state="ABCDE") + request = _token_service.construct(request_args={'redirect_uri': "http://example.com", + 'state': "ABCDE"}) csb = ClientSecretBasic() http_args = csb.construct(request, _token_service) @@ -102,6 +108,7 @@ def test_construct_cc(self): class TestBearerHeader(object): + def test_construct(self, entity): request = ResourceRequest(access_token="Sesame") bh = BearerHeader() @@ -173,6 +180,7 @@ def test_construct_with_token(self, entity): class TestBearerBody(object): + def test_construct(self, entity): _token_service = entity.client_get("service", "accesstoken") request = ResourceRequest(access_token="Sesame") @@ -227,6 +235,7 @@ def test_construct_with_request(self, entity): class TestClientSecretPost(object): + def test_construct(self, entity): _token_service = entity.client_get("service", "accesstoken") request = _token_service.construct(redirect_uri="http://example.com", state="ABCDE") @@ -258,13 +267,14 @@ def test_modify_2(self, entity): csp = ClientSecretPost() # client secret not in request or kwargs del request["client_secret"] - token_service.client_get("service_context").client_secret = "" + token_service.client_get("service_context").set_usage('client_secret', "") # this will fail with pytest.raises(AuthnFailure): - http_args = csp.construct(request, service=token_service) + csp.construct(request, service=token_service) class TestPrivateKeyJWT(object): + def test_construct(self, entity): token_service = entity.client_get("service", "accesstoken") kb_rsa = KeyBundle( @@ -320,6 +330,7 @@ def test_construct_client_assertion(self, entity): class TestClientSecretJWT_TE(object): + def test_client_secret_jwt(self, entity): _service_context = entity.client_get("service_context") _service_context.token_endpoint = "https://example.com/token" @@ -329,7 +340,7 @@ def test_client_secret_jwt(self, entity): "token_endpoint": "https://example.com/token", } - _service_context.registration_response = {"token_endpoint_auth_signing_alg": "HS256"} + _service_context.set_usage("token_endpoint_auth_signing_alg", "HS256") csj = ClientSecretJWT() request = AccessTokenRequest() @@ -345,7 +356,7 @@ def test_client_secret_jwt(self, entity): _kj = KeyJar() _kj.add_symmetric(_service_context.get_client_id(), - _service_context.client_secret, ["sig"]) + _service_context.get_usage('client_secret'), ["sig"]) jso = JWT(key_jar=_kj, sign_alg="HS256").unpack(cas) assert _eq(jso.keys(), ["aud", "iss", "sub", "exp", "iat", "jti"]) @@ -365,7 +376,7 @@ def test_get_key_by_kid(self, entity): "token_endpoint": "https://example.com/token", } - _service_context.registration_response = {"token_endpoint_auth_signing_alg": "HS256"} + _service_context.set_usage("token_endpoint_auth_signing_alg", "HS256") csj = ClientSecretJWT() request = AccessTokenRequest() @@ -387,7 +398,7 @@ def test_get_key_by_kid_fail(self, entity): "token_endpoint": "https://example.com/token", } - _service_context.registration_response = {"token_endpoint_auth_signing_alg": "HS256"} + _service_context.set_usage("token_endpoint_auth_signing_alg", "HS256") csj = ClientSecretJWT() request = AccessTokenRequest() @@ -406,7 +417,7 @@ def test_get_audience_and_algorithm_default_alg(self, entity): "token_endpoint": "https://example.com/token", } - _service_context.registration_response = {"token_endpoint_auth_signing_alg": "HS256"} + _service_context.set_usage("token_endpoint_auth_signing_alg", "RS256") csj = ClientSecretJWT() request = AccessTokenRequest() @@ -429,7 +440,7 @@ def test_get_audience_and_algorithm_default_alg(self, entity): # By client preferences request = AccessTokenRequest() - entity.set_metadata_claim("token_endpoint_auth_signing_alg", "RS512") + _service_context.set_usage("token_endpoint_auth_signing_alg", "RS512") csj.construct(request, service=token_service, authn_endpoint="token_endpoint") _jws = factory(request["client_assertion"]) @@ -439,7 +450,7 @@ def test_get_audience_and_algorithm_default_alg(self, entity): # Use provider information is everything else fails request = AccessTokenRequest() # Can't use set_metadata_value since it won't allow me to overwrite a non-default value - token_service.metadata["token_endpoint_auth_signing_alg"] = None + _service_context.set_usage("token_endpoint_auth_signing_alg", None) _service_context.provider_info["token_endpoint_auth_signing_alg_values_supported"] = [ "ES256", "RS256", @@ -453,6 +464,7 @@ def test_get_audience_and_algorithm_default_alg(self, entity): class TestClientSecretJWT_UI(object): + def test_client_secret_jwt(self, entity): access_token_service = entity.client_get("service", "accesstoken") @@ -475,7 +487,7 @@ def test_client_secret_jwt(self, entity): _kj = KeyJar() _kj.add_symmetric(_service_context.get_client_id(), - _service_context.client_secret, usage=["sig"]) + _service_context.get_usage('client_secret'), usage=["sig"]) jso = JWT(key_jar=_kj, sign_alg="HS256").unpack(cas) assert _eq(jso.keys(), ["aud", "iss", "sub", "jti", "exp", "iat"]) @@ -488,6 +500,7 @@ def test_client_secret_jwt(self, entity): class TestValidClientInfo(object): + def test_valid_service_context(self, entity): _service_context = entity.client_get("service_context") diff --git a/tests/test_client_14_service_context_impexp.py b/tests/test_client_14_service_context_impexp.py index 8af96908..ae3526df 100644 --- a/tests/test_client_14_service_context_impexp.py +++ b/tests/test_client_14_service_context_impexp.py @@ -1,6 +1,5 @@ import json import os -from urllib.parse import urlsplit import pytest import responses @@ -19,36 +18,37 @@ def test_client_info_init(): "base_url": BASE_URL, "requests_dir": "requests", } - ci = ServiceContext(config=config,client_type='oidc') + ci = ServiceContext(config=config, client_type='oidc') + ci.work_condition.load_conf(config, supports=ci.supports()) + ci.map_supported_to_preferred() + ci.map_preferred_to_register() srvcnx = ServiceContext(base_url=BASE_URL).load(ci.dump()) for attr in config.keys(): if attr == "client_id": assert srvcnx.get_client_id() == config[attr] - elif attr == "requests_dir": - assert srvcnx.work_condition.get("requests_dir") == config[attr] else: try: val = getattr(srvcnx, attr) except AttributeError: - val = srvcnx.get(attr) + val = srvcnx.get_usage(attr) assert val == config[attr] def test_set_and_get_client_secret(): service_context = ServiceContext(base_url=BASE_URL) - service_context.client_secret = "longenoughsupersecret" + service_context.set_usage('client_secret', "longenoughsupersecret") srvcnx2 = ServiceContext(base_url=BASE_URL).load(service_context.dump()) - assert srvcnx2.client_secret == "longenoughsupersecret" + assert srvcnx2.get_usage('client_secret') == "longenoughsupersecret" def test_set_and_get_client_id(): service_context = ServiceContext(base_url=BASE_URL) - service_context.work_condition.set_metadata_claim("client_id", "myself") + service_context.set_usage("client_id", "myself") srvcnx2 = ServiceContext(base_url=BASE_URL).load(service_context.dump()) assert srvcnx2.get_client_id() == "myself" @@ -96,6 +96,7 @@ def verify_alg_support(service_context, alg, usage, typ): class TestClientInfo(object): + @pytest.fixture(autouse=True) def create_client_info_instance(self): config = { @@ -108,18 +109,17 @@ def create_client_info_instance(self): self.service_context = ServiceContext(config=config) def test_registration_userinfo_sign_enc_algs(self): - self.service_context.work_condition.behaviour = { - "application_type": "web", - "redirect_uris": [ - "https://client.example.org/callback", - "https://client.example.org/callback2", - ], - "token_endpoint_auth_method": "client_secret_basic", - "jwks_uri": "https://client.example.org/my_public_keys.jwks", - "userinfo_encrypted_response_alg": "RSA1_5", - "userinfo_encrypted_response_enc": "A128CBC-HS256", - } - + self.service_context.work_condition.use = { + "application_type": "web", + "redirect_uris": [ + "https://client.example.org/callback", + "https://client.example.org/callback2", + ], + "token_endpoint_auth_method": "client_secret_basic", + "jwks_uri": "https://client.example.org/my_public_keys.jwks", + "userinfo_encrypted_response_alg": "RSA1_5", + "userinfo_encrypted_response_enc": "A128CBC-HS256", + } srvcntx = ServiceContext(base_url=BASE_URL).load( self.service_context.dump(exclude_attributes=["service_context"]) @@ -128,7 +128,7 @@ def test_registration_userinfo_sign_enc_algs(self): assert srvcntx.get_enc_alg_enc("userinfo") == {"alg": "RSA1_5", "enc": "A128CBC-HS256"} def test_registration_request_object_sign_enc_algs(self): - self.service_context.work_condition.behaviour = { + self.service_context.work_condition.use = { "application_type": "web", "redirect_uris": [ "https://client.example.org/callback", @@ -150,7 +150,7 @@ def test_registration_request_object_sign_enc_algs(self): assert srvcntx.get_sign_alg("request_object") == "RS384" def test_registration_id_token_sign_enc_algs(self): - self.service_context.work_condition.behaviour = { + self.service_context.work_condition.use = { "application_type": "web", "redirect_uris": [ "https://client.example.org/callback", diff --git a/tests/test_client_18_service.py b/tests/test_client_18_service.py index 5402064b..d9ee21c0 100644 --- a/tests/test_client_18_service.py +++ b/tests/test_client_18_service.py @@ -35,7 +35,7 @@ def create_service(self): "client_id": "client_id", "client_secret": "a longesh password", "redirect_uris": ["https://example.com/cli/authz_cb"], - "behaviour": {"response_types": ["code"]}, + "preference": {"response_types": ["code"]}, } service = {"dummy": {"class": DummyService}} diff --git a/tests/test_client_19_webfinger.py b/tests/test_client_19_webfinger.py index e92c9254..a1a289c4 100644 --- a/tests/test_client_19_webfinger.py +++ b/tests/test_client_19_webfinger.py @@ -15,8 +15,6 @@ __author__ = "Roland Hedberg" -SERVICE_CONTEXT = ServiceContext(base_url="https://example.com") - ENTITY = Entity(config={"base_url":"https://example.com"}) diff --git a/tests/test_client_21_oidc_service.py b/tests/test_client_21_oidc_service.py index 1b778c13..8ed5a87a 100644 --- a/tests/test_client_21_oidc_service.py +++ b/tests/test_client_21_oidc_service.py @@ -79,8 +79,13 @@ def create_request(self): "client_secret": "a longesh password", "redirect_uris": ["https://example.com/cli/authz_cb"], } - entity = Entity(services=DEFAULT_OIDC_SERVICES, keyjar=make_keyjar(), config=client_config) - entity.client_get("service_context").issuer = "https://example.com" + entity = Entity(services=DEFAULT_OIDC_SERVICES, keyjar=make_keyjar(), config=client_config, + client_type='oidc') + _context = entity.client_get("service_context") + _context.issuer = "https://example.com" + _context.map_supported_to_preferred() + _context.map_preferred_to_register() + self.context = _context self.service = entity.client_get("service", "authorization") def test_construct(self): @@ -173,6 +178,7 @@ def test_request_init(self): def test_request_init_request_method(self): req_args = {"response_type": "code", "state": "state"} self.service.endpoint = "https://example.com/authorize" + self.context.set_usage('request_object_encryption_alg', None) _info = self.service.get_request_parameters(request_args=req_args, request_method="value") assert set(_info.keys()) == {"url", "method", "request"} msg = AuthorizationRequest().from_urlencoded(self.service.get_urlinfo(_info["url"])) @@ -213,6 +219,7 @@ def test_request_param(self): "request_uris": ["https://example.com/request123456.jwt"], } _context.base_url = "https://example.com/" + _context.set_usage('request_object_encryption_alg', None) _info = self.service.get_request_parameters( request_args=req_args, request_method="reference" ) @@ -286,9 +293,9 @@ def test_allow_unsigned_idtoken(self, allow_sign_alg_none): idt = JWT(ISS_KEY, iss=ISS, lifetime=3600, sign_alg="none") payload = {"sub": "123456789", "aud": ["client_id"], "nonce": req_args["nonce"]} _idt = idt.pack(payload) - self.service.client_get("service_context").work_condition.behaviour["verify_args"] = { + self.service.client_get("service_context").work_condition.set_usage("verify_args", { "allow_sign_alg_none": allow_sign_alg_none - } + }) resp = AuthorizationResponse(state="state", code="code", id_token=_idt) if allow_sign_alg_none: self.service.parse_response(resp.to_urlencoded()) @@ -777,7 +784,7 @@ def test_post_parse(self): "registration_endpoint": "{}/registration".format(OP_BASEURL), "end_session_endpoint": "{}/end_session".format(OP_BASEURL), } - assert self.service.client_get("service_context").work_condition.behaviour == {} + assert self.service.client_get("service_context").work_condition.use == {} resp = self.service.post_parse_response(provider_info_response) iss_jwks = ISS_KEY.export_jwks_as_json(issuer_id=ISS) @@ -786,7 +793,7 @@ def test_post_parse(self): self.service.update_service_context(resp) - assert self.service.client_get("service_context").work_condition.behaviour == { + assert self.service.client_get("service_context").work_condition.use == { 'application_type': 'web', 'backchannel_logout_session_required': True, 'backchannel_logout_uri': 'https://rp.example.com/back', @@ -821,7 +828,7 @@ def test_post_parse_2(self): "registration_endpoint": "{}/registration".format(OP_BASEURL), "end_session_endpoint": "{}/end_session".format(OP_BASEURL), } - assert self.service.client_get("service_context").work_condition.behaviour == {} + assert self.service.client_get("service_context").work_condition.use == {} resp = self.service.post_parse_response(provider_info_response) iss_jwks = ISS_KEY.export_jwks_as_json(issuer_id=ISS) @@ -830,7 +837,7 @@ def test_post_parse_2(self): self.service.update_service_context(resp) - assert self.service.client_get("service_context").work_condition.behaviour == { + assert self.service.client_get("service_context").work_condition.use == { 'application_type': 'web', 'backchannel_logout_session_required': True, 'backchannel_logout_uri': 'https://rp.example.com/back', @@ -966,7 +973,7 @@ def create_request(self): entity.client_get("service_context").issuer = "https://example.com" self.service = entity.client_get("service", "userinfo") - entity.client_get("service_context").work_condition.behaviour = { + entity.client_get("service_context").work_condition.use = { "userinfo_signed_response_alg": "RS256", "userinfo_encrypted_response_alg": "RSA-OAEP", "userinfo_encrypted_response_enc": "A256GCM", @@ -1189,7 +1196,7 @@ def test_authz_service_conf(): "client_id": "client_id", "client_secret": "a longesh password", "redirect_uris": ["https://example.com/cli/authz_cb"], - "behaviour": {"response_types": ["code"]}, + "preference": {"response_types": ["code"]}, } services = { diff --git a/tests/test_client_23_pkce.py b/tests/test_client_23_pkce.py index b7294389..6d92ce58 100644 --- a/tests/test_client_23_pkce.py +++ b/tests/test_client_23_pkce.py @@ -48,7 +48,7 @@ def create_client(self): "client_id": "client_id", "client_secret": "a longesh password", "redirect_uris": ["https://example.com/cli/authz_cb"], - "behaviour": {"response_types": ["code"]}, + "preference": {"response_types": ["code"]}, "add_ons": { "pkce": { "function": "idpyoidc.client.oauth2.add_on.pkce.add_support", diff --git a/tests/test_client_24_oic_utils.py b/tests/test_client_24_oic_utils.py index 603e4f56..4be09df0 100644 --- a/tests/test_client_24_oic_utils.py +++ b/tests/test_client_24_oic_utils.py @@ -27,10 +27,9 @@ def test_request_object_encryption(): "client_secret": "abcdefghijklmnop", } service_context = ServiceContext(keyjar=KEYJAR, config=conf) - _behav = service_context.work_condition.behaviour - _behav["request_object_encryption_alg"] = "RSA1_5" - _behav["request_object_encryption_enc"] = "A128CBC-HS256" - service_context.work_condition.behaviour = _behav + _condition = service_context.work_condition + _condition.set_usage("request_object_encryption_alg", "RSA1_5") + _condition.set_usage("request_object_encryption_enc", "A128CBC-HS256") _jwe = request_object_encryption(msg.to_json(), service_context, target=RECEIVER) assert _jwe diff --git a/tests/test_client_28_rp_handler_oidc.py b/tests/test_client_28_rp_handler_oidc.py index dd0501e5..7358c426 100644 --- a/tests/test_client_28_rp_handler_oidc.py +++ b/tests/test_client_28_rp_handler_oidc.py @@ -66,7 +66,7 @@ "client_id": "xxxxxxx", "client_secret": "yyyyyyyyyyyyyyyyyyyy", "redirect_uris": ["{}/authz_cb/linkedin".format(BASE_URL)], - "behaviour": { + "preference": { "response_types": ["code"], "scope": ["r_basicprofile", "r_emailaddress"], "token_endpoint_auth_method": "client_secret_post", @@ -87,7 +87,7 @@ "issuer": "https://www.facebook.com/v2.11/dialog/oauth", "client_id": "ccccccccc", "client_secret": "dddddddddddddd", - "behaviour": { + "preference": { "response_types": ["code"], "scope": ["email", "public_profile"], "token_endpoint_auth_method": "", @@ -115,7 +115,7 @@ "client_id": "eeeeeeeee", "client_secret": "aaaaaaaaaaaaaaaaaaaa", "redirect_uris": ["{}/authz_cb/github".format(BASE_URL)], - "behaviour": { + "preference": { "response_types": ["code"], "scope": ["user", "public_repo"], "token_endpoint_auth_method": "", @@ -145,7 +145,7 @@ "client_id": "eeeeeeeee", "client_secret": "aaaaaaaaaaaaaaaaaaaa", "redirect_uris": ["{}/authz_cb/github".format(BASE_URL)], - "behaviour": { + "preference": { "response_types": ["code"], "scope": ["user", "public_repo"], "token_endpoint_auth_method": "", @@ -256,7 +256,7 @@ def test_init_client(self): "userinfo_endpoint", } - assert _context.get("behaviour") == { + assert _context.get("preference") == { "response_types": ["code"], "scope": ["user", "public_repo"], "token_endpoint_auth_method": "", diff --git a/tests/test_client_29_pushed_auth.py b/tests/test_client_29_pushed_auth.py index 3babbf29..47a0de64 100644 --- a/tests/test_client_29_pushed_auth.py +++ b/tests/test_client_29_pushed_auth.py @@ -32,7 +32,7 @@ def create_client(self): "client_id": "client_id", "client_secret": "a longesh password", "redirect_uris": ["https://example.com/cli/authz_cb"], - "behaviour": {"response_types": ["code"]}, + "preference": {"response_types": ["code"]}, "add_ons": { "pushed_authorization": { "function": "idpyoidc.client.oauth2.add_on.pushed_authorization.add_support", diff --git a/tests/test_client_30_rph_defaults.py b/tests/test_client_30_rph_defaults.py index 23f11161..2fdaa79a 100644 --- a/tests/test_client_30_rph_defaults.py +++ b/tests/test_client_30_rph_defaults.py @@ -91,7 +91,7 @@ def test_begin(self): self.rph.issuer2rp[issuer] = client - assert set(_context.work_condition.behaviour.keys()) == { + assert set(_context.work_condition.use.keys()) == { "token_endpoint_auth_method", "response_types", "scope", diff --git a/tests/test_client_40_dpop.py b/tests/test_client_40_dpop.py index 906aa266..80a9a964 100644 --- a/tests/test_client_40_dpop.py +++ b/tests/test_client_40_dpop.py @@ -29,7 +29,7 @@ def create_client(self): "client_id": "client_id", "client_secret": "a longesh password", "redirect_uris": ["https://example.com/cli/authz_cb"], - "behaviour": {"response_types": ["code"]}, + "preference": {"response_types": ["code"]}, "add_ons": { "dpop": { "function": "idpyoidc.client.oauth2.add_on.dpop.add_support", @@ -77,7 +77,7 @@ def create_client(self): "client_id": "client_id", "client_secret": "a longesh password", "redirect_uris": ["https://example.com/cli/authz_cb"], - "behaviour": {"response_types": ["code"]}, + "preference": {"response_types": ["code"]}, "add_ons": { "dpop": { "function": "idpyoidc.client.oauth2.add_on.dpop.add_support", diff --git a/tests/test_client_41_rp_handler_persistent.py b/tests/test_client_41_rp_handler_persistent.py index 8e5f5e05..3c0b9804 100644 --- a/tests/test_client_41_rp_handler_persistent.py +++ b/tests/test_client_41_rp_handler_persistent.py @@ -55,7 +55,7 @@ "client_id": "xxxxxxx", "client_secret": "yyyyyyyyyyyyyyyyyyyy", "redirect_uris": ["{}/authz_cb/linkedin".format(BASE_URL)], - "behaviour": { + "preference": { "response_types": ["code"], "scope": ["r_basicprofile", "r_emailaddress"], "token_endpoint_auth_method": "client_secret_post", @@ -76,7 +76,7 @@ "issuer": "https://www.facebook.com/v2.11/dialog/oauth", "client_id": "ccccccccc", "client_secret": "dddddddddddddd", - "behaviour": { + "preference": { "response_types": ["code"], "scope": ["email", "public_profile"], "token_endpoint_auth_method": "", @@ -104,7 +104,7 @@ "client_id": "eeeeeeeee", "client_secret": "aaaaaaaaaaaaaaaaaaaa", "redirect_uris": ["{}/authz_cb/github".format(BASE_URL)], - "behaviour": { + "preference": { "response_types": ["code"], "scope": ["user", "public_repo"], "token_endpoint_auth_method": "", diff --git a/tests/test_client_50_ciba.py b/tests/test_client_50_ciba.py index bbce977f..283808c5 100644 --- a/tests/test_client_50_ciba.py +++ b/tests/test_client_50_ciba.py @@ -28,7 +28,7 @@ def create_client(self): "client_id": "client_id", "client_secret": "a longesh password", "redirect_uris": ["https://example.com/cli/authz_cb"], - "behaviour": {"response_types": ["code"]}, + "preference": {"response_types": ["code"]}, "add_ons": { "ciba": { "function": "idpyoidc.client.oidc.add_on.ciba.add_support", diff --git a/tests/test_client_51_identity_assurance.py b/tests/test_client_51_identity_assurance.py index 61cb3d5f..454707ed 100644 --- a/tests/test_client_51_identity_assurance.py +++ b/tests/test_client_51_identity_assurance.py @@ -36,7 +36,7 @@ def create_request(self): entity.client_get("service_context").issuer = "https://server.otherop.com" self.service = entity.client_get("service", "userinfo") - entity.client_get("service_context").work_condition.behaviour = { + entity.client_get("service_context").work_condition.use = { "userinfo_signed_response_alg": "RS256", "userinfo_encrypted_response_alg": "RSA-OAEP", "userinfo_encrypted_response_enc": "A256GCM", diff --git a/tests/test_client_11_base.py b/tests/xtest_client_11_base.py similarity index 85% rename from tests/test_client_11_base.py rename to tests/xtest_client_11_base.py index f68dd1ec..54792c81 100644 --- a/tests/test_client_11_base.py +++ b/tests/xtest_client_11_base.py @@ -7,7 +7,7 @@ def test_load_registration_response(): "redirect_uris": ["https://example.com/cli/authz_cb"], "client_id": "client_1", "client_secret": "abcdefghijklmnop", - "registration_response": {"issuer": "https://example.com"}, + "issuer": "https://example.com", } client = RP(config=conf) From 7ce47522d4c0523afc625930f31e892b3de71fa4 Mon Sep 17 00:00:00 2001 From: roland Date: Fri, 18 Nov 2022 08:50:03 +0100 Subject: [PATCH 15/76] Fixed tests up to client_28 --- src/idpyoidc/client/entity.py | 10 - src/idpyoidc/client/oauth2/utils.py | 11 +- src/idpyoidc/client/oidc/__init__.py | 4 +- src/idpyoidc/client/oidc/access_token.py | 7 +- src/idpyoidc/client/oidc/authorization.py | 13 +- .../client/oidc/provider_info_discovery.py | 85 +------ src/idpyoidc/client/oidc/registration.py | 15 +- src/idpyoidc/client/oidc/userinfo.py | 2 +- src/idpyoidc/client/oidc/webfinger.py | 6 +- src/idpyoidc/client/service_context.py | 6 +- .../client/work_condition/__init__.py | 7 +- src/idpyoidc/client/work_condition/oidc.py | 23 +- .../client/work_condition/transform.py | 57 ++++- tests/request123456.jwt | 2 +- tests/test_client_02b_entity_metadata.py | 17 +- tests/test_client_04_service.py | 8 +- tests/test_client_21_oidc_service.py | 212 ++++++++++-------- tests/test_client_23_pkce.py | 6 +- tests/test_client_26_read_registration.py | 20 +- tests/test_client_27_conversation.py | 92 +++----- tests/test_client_28_rp_handler_oidc.py | 8 +- 21 files changed, 294 insertions(+), 317 deletions(-) diff --git a/src/idpyoidc/client/entity.py b/src/idpyoidc/client/entity.py index 8aa657ea..9331504e 100644 --- a/src/idpyoidc/client/entity.py +++ b/src/idpyoidc/client/entity.py @@ -193,16 +193,6 @@ def config_args(self): } return res - def get_callback_uris(self): - res = [] - for service in self._service.values(): - for _callback in service.callback_uris(): - _uri = self._service_context.work_condition.get_preference(_callback) - if _uri: - res[_callback] = _uri - # res.extend(self._service_context.work_condition.callback_uris) - return res - def prefers(self): return self._service_context.work_condition.prefers() diff --git a/src/idpyoidc/client/oauth2/utils.py b/src/idpyoidc/client/oauth2/utils.py index 66d176a8..0b314057 100644 --- a/src/idpyoidc/client/oauth2/utils.py +++ b/src/idpyoidc/client/oauth2/utils.py @@ -35,9 +35,10 @@ def pick_redirect_uri( if "redirect_uri" in request_args: return request_args["redirect_uri"] - if context.work_condition.callback: + _callback_uris = context.get_preference("callback_uris") + if _callback_uris: if not response_type: - _conf_resp_types = context.work_condition.get_usage("response_types", []) + _conf_resp_types = context.get_usage("response_types", []) response_type = request_args.get("response_type") if not response_type and _conf_resp_types: response_type = _conf_resp_types[0] @@ -45,11 +46,11 @@ def pick_redirect_uri( _response_mode = request_args.get("response_mode") if _response_mode == "form_post" or response_type == ["form_post"]: - redirect_uri = context.work_condition.callback["form_post"] + redirect_uri = _callback_uris["form_post"] elif response_type == "code" or response_type == ["code"]: - redirect_uri = context.work_condition.callback["code"] + redirect_uri = _callback_uris["code"] else: - redirect_uri = context.work_condition.callback["implicit"] + redirect_uri = _callback_uris["implicit"] logger.debug( f"pick_redirect_uris: response_type={response_type}, response_mode={_response_mode}, " diff --git a/src/idpyoidc/client/oidc/__init__.py b/src/idpyoidc/client/oidc/__init__.py index 33e0d86f..042a71de 100755 --- a/src/idpyoidc/client/oidc/__init__.py +++ b/src/idpyoidc/client/oidc/__init__.py @@ -79,8 +79,8 @@ def __init__( ) _context = self.get_service_context() - if _context.callback is None: - _context.callback = {} + if _context.get_preference('callback_uris') is None: + _context.set_preference('callback_uris', {}) def fetch_distributed_claims(self, userinfo, callback=None): """ diff --git a/src/idpyoidc/client/oidc/access_token.py b/src/idpyoidc/client/oidc/access_token.py index 6fc9c783..d35698a5 100644 --- a/src/idpyoidc/client/oidc/access_token.py +++ b/src/idpyoidc/client/oidc/access_token.py @@ -90,8 +90,5 @@ def update_service_context(self, resp, key="", **kwargs): _state_interface.store_item(resp, "token_response", key) def get_authn_method(self): - _work_condition = self.client_get("service_context").work_condition - try: - return _work_condition.get_usage("token_endpoint_auth_method") - except KeyError: - return self.default_authn_method + _context = self.client_get("service_context") + return _context.get_usage("token_endpoint_auth_method", self.default_authn_method) diff --git a/src/idpyoidc/client/oidc/authorization.py b/src/idpyoidc/client/oidc/authorization.py index 8b080e34..34fd8620 100644 --- a/src/idpyoidc/client/oidc/authorization.py +++ b/src/idpyoidc/client/oidc/authorization.py @@ -35,7 +35,7 @@ class Authorization(authorization.Authorization): "response_types_supported": ["code", "form_post"], "request_uris": None, "request_parameter": None, - "encrypt_request_object": None, + "encrypt_request_object_supported": None, "redirect_uris": None, } @@ -182,11 +182,14 @@ def store_request_on_file(self, req, **kwargs): :return: The URL the OP should use to access the file """ _context = self.client_get("service_context") - try: - _webname = _context.registration_response["request_uris"][0] - filename = _context.filename_from_webname(_webname) - except KeyError: + + _webname = _context.get_usage("request_uris") + if _webname is None: filename, _webname = construct_request_uri(**kwargs) + else: + # webname should be a list + _webname = _webname[0] + filename = _context.filename_from_webname(_webname) fid = open(filename, mode="w") fid.write(req) diff --git a/src/idpyoidc/client/oidc/provider_info_discovery.py b/src/idpyoidc/client/oidc/provider_info_discovery.py index c4d20482..3b52766a 100644 --- a/src/idpyoidc/client/oidc/provider_info_discovery.py +++ b/src/idpyoidc/client/oidc/provider_info_discovery.py @@ -73,90 +73,9 @@ def match_preferences(self, pcr=None, issuer=None): :param issuer: The issuer identifier """ _context = self.client_get("service_context") - _entity = self.client_get("entity") - _work_condition = _context.work_condition - - _supports = _context.supports() - _prefers = _context.prefers() - if not pcr: pcr = _context.provider_info - regreq = oidc.RegistrationRequest - prefers = {} - - for _pref, _prov in PREFERENCE2PROVIDER.items(): - _supported_values = _supports.get(_pref) - _preferred_value = _prefers.get(_pref) - - if not _preferred_value: - if not _supported_values: - continue - else: - _supported_values = _preferred_value - - try: - _provider_vals = pcr[_prov] - except KeyError: - try: - # If the provider have not specified use what the - # standard says is mandatory if at all. - _provider_vals = PROVIDER_DEFAULT[_pref] - except KeyError: - logger.info("No info from provider on {} and no default".format(_pref)) - _provider_vals = _supported_values - - if not isinstance(_supported_values, list): - if isinstance(_provider_vals, list): - if _supported_values in _provider_vals: - prefers[_pref] = _supported_values - elif _provider_vals == _supported_values: - prefers[_pref] = _supported_values - else: # _supported_values is a list - try: - vtyp = regreq.c_param[_pref] - except KeyError: - # Allow non standard claims - if isinstance(_supported_values, list) and isinstance(_provider_vals, list): - prefers[_pref] = [v for v in _supported_values if v in _provider_vals] - elif isinstance(_provider_vals, list): - if _supported_values in _provider_vals: - prefers[_pref] = _supported_values - elif type(_supported_values) == type(_provider_vals): - if _supported_values == _provider_vals: - prefers[_pref] = _supported_values - else: - if isinstance(vtyp[0], list): - prefers[_pref] = [] - for val in _supported_values: - if val in _provider_vals: - prefers[_pref].append(_supported_values) - else: - for val in _supported_values: - if val in _provider_vals: - prefers[_pref] = val - break - - if _pref not in prefers: - raise ConfigurationError("OP couldn't match preference:%s" % _pref, pcr) - - for key, val in _supports: - if key in prefers: - continue - if key in ["jwks", "jwks_uri"]: - continue - - try: - vtyp = regreq.c_param[key] - if isinstance(vtyp[0], list): - pass - elif isinstance(val, list) and not isinstance(val, str): - val = val[0] - except KeyError: - pass - if key not in PREFERENCE2PROVIDER: - prefers[key] = val - - # stores it all in one place - _context.work_condition.prefer = prefers + prefers = _context.map_supported_to_preferred(pcr) + logger.debug("Entity prefers: {}".format(prefers)) diff --git a/src/idpyoidc/client/oidc/registration.py b/src/idpyoidc/client/oidc/registration.py index a96974ab..6fdf1b65 100644 --- a/src/idpyoidc/client/oidc/registration.py +++ b/src/idpyoidc/client/oidc/registration.py @@ -28,12 +28,13 @@ def __init__(self, client_get, conf=None): self.post_construct = [self.oidc_post_construct] def add_client_preference(self, request_args=None, **kwargs): - _work_condition = self.client_get("service_context") + _context = self.client_get("service_context") + _use = _context.map_preferred_to_register() for prop, spec in self.msg_type.c_param.items(): if prop in request_args: continue - _val = _work_condition.get_preference(prop) + _val = _use.get(prop) if _val: if isinstance(_val, list): if isinstance(spec[0], list): @@ -63,7 +64,6 @@ def update_service_context(self, resp, key="", **kwargs): resp["token_endpoint_auth_method"] = "client_secret_basic" _context = self.client_get("service_context") - _work_condition = _context.work_condition _keyjar = _context.keyjar _context.registration_response = resp @@ -74,18 +74,17 @@ def update_service_context(self, resp, key="", **kwargs): _keyjar.import_jwks(_keyjar.export_jwks(True, ""), issuer_id=_client_id) _client_secret = resp.get("client_secret") if _client_secret: - _work_condition.set_usage_claim("client_secret", _client_secret) + _context.set_usage("client_secret", _client_secret) # _context.client_secret = _client_secret _keyjar.add_symmetric("", _client_secret) _keyjar.add_symmetric(_client_id, _client_secret) try: - _work_condition.set_usage_claim("client_secret_expires_at", - resp["client_secret_expires_at"]) + _context.set_usage("client_secret_expires_at", + resp["client_secret_expires_at"]) except KeyError: pass try: - _work_condition.set_usage_claim("registration_access_token", - resp["registration_access_token"]) + _context.set_usage("registration_access_token", resp["registration_access_token"]) except KeyError: pass diff --git a/src/idpyoidc/client/oidc/userinfo.py b/src/idpyoidc/client/oidc/userinfo.py index 96bcd42a..436bafd0 100644 --- a/src/idpyoidc/client/oidc/userinfo.py +++ b/src/idpyoidc/client/oidc/userinfo.py @@ -45,7 +45,7 @@ class UserInfo(Service): "userinfo_signing_alg_values_supported": get_signing_algs, "userinfo_encryption_alg_values_supported": get_encryption_algs, "userinfo_encryption_enc_values_supported": get_encryption_encs, - "encrypt_userinfo": None + "encrypt_userinfo_supported": None } def __init__(self, client_get, conf=None): diff --git a/src/idpyoidc/client/oidc/webfinger.py b/src/idpyoidc/client/oidc/webfinger.py index b048ccf4..ddfba9ee 100644 --- a/src/idpyoidc/client/oidc/webfinger.py +++ b/src/idpyoidc/client/oidc/webfinger.py @@ -49,10 +49,8 @@ def update_service_context(self, resp, key="", **kwargs): for link in links: if link["rel"] == self.rel: _href = link["href"] - try: - _http_allowed = self.get_conf_attr("allow", default={})["http_links"] - except KeyError: - _http_allowed = False + _context = self.client_get('service_context') + _http_allowed = 'http_links' in _context.get("allow", default={}) if _href.startswith("http://") and not _http_allowed: raise ValueError("http link not allowed ({})".format(_href)) diff --git a/src/idpyoidc/client/service_context.py b/src/idpyoidc/client/service_context.py index 6afac781..389c0d33 100644 --- a/src/idpyoidc/client/service_context.py +++ b/src/idpyoidc/client/service_context.py @@ -142,13 +142,12 @@ def __init__(self, self.base_url = base_url or config.get("base_url") or config.conf.get('base_url', '') # Below so my IDE won't complain - self.allow = {} + self.allow = config.conf.get('allow') self.args = {} self.add_on = {} self.iss_hash = "" self.issuer = "" self.httpc_params = {} - self.callback = {} self.client_secret_expires_at = 0 self.provider_info = {} # self.post_logout_redirect_uri = "" @@ -335,7 +334,8 @@ def prefer_or_support(self, claim): def map_supported_to_preferred(self, info: Optional[dict] = None): self.work_condition.prefer = supported_to_preferred(self.supports(), self.work_condition.prefer, - info) + base_url=self.base_url, + info=info) return self.work_condition.prefer def map_preferred_to_register(self): diff --git a/src/idpyoidc/client/work_condition/__init__.py b/src/idpyoidc/client/work_condition/__init__.py index 71c4fcf6..4d9b001b 100644 --- a/src/idpyoidc/client/work_condition/__init__.py +++ b/src/idpyoidc/client/work_condition/__init__.py @@ -63,10 +63,13 @@ def get_preference(self, key, default=None): def set_preference(self, key, value): self.prefer[key] = value + def remove_preference(self, key): + if key in self.prefer: + del self.prefer[key] + def _callback_uris(self, base_url, hex): _uri = [] - for type in self.get_usage("response_types", - self._supports['response_types']): + for type in self.get_usage("response_types", self._supports['response_types']): if "code" in type: _uri.append('code') elif type in ["id_token", "id_token token"]: diff --git a/src/idpyoidc/client/work_condition/oidc.py b/src/idpyoidc/client/work_condition/oidc.py index 21f56295..f5c31ba6 100644 --- a/src/idpyoidc/client/work_condition/oidc.py +++ b/src/idpyoidc/client/work_condition/oidc.py @@ -27,8 +27,7 @@ class WorkCondition(work_condition.WorkCondition): "jwks": None, "jwks_uri": None, "sector_identifier_uri": None, - "subject_type": None, - "default_max_age": None, + "default_max_age": 86400, "require_auth_time": None, "initiate_login_uri": None, "client_id": None, @@ -36,7 +35,8 @@ class WorkCondition(work_condition.WorkCondition): "scope": ["openid"], # "verify_args": None, "requests_dir": None, - "encrypt_id_token": None + "encrypt_id_token_supported": None, + "callback_uris": None } def __init__(self, @@ -50,6 +50,22 @@ def verify_rules(self): raise ValueError("You have to chose one of 'request_parameter' and 'request_uri'." " you can't have both.") + _cb_uris = self.get_preference('callback_uris') + if _cb_uris: + self.set_preference('redirect_uris', list(_cb_uris.values())) # just overwrite + + if not self.get_preference('encrypt_userinfo_supported'): + self.set_preference('userinfo_encryption_alg_values_supported', []) + self.set_preference('userinfo_encryption_enc_values_supported', []) + + if not self.get_preference('encrypt_request_object_supported'): + self.set_preference('request_object_encryption_alg_values_supported', []) + self.set_preference('request_object_encryption_enc_values_supported', []) + + if not self.get_preference('encrypt_id_token_supported'): + self.set_preference('id_token_encryption_alg_values_supported', []) + self.set_preference('id_token_encryption_enc_values_supported', []) + def locals(self, info): requests_dir = info.get("requests_dir") if requests_dir: @@ -58,4 +74,3 @@ def locals(self, info): os.makedirs(requests_dir) self.set("requests_dir", requests_dir) - diff --git a/src/idpyoidc/client/work_condition/transform.py b/src/idpyoidc/client/work_condition/transform.py index 6d0c8220..ad4d68a6 100644 --- a/src/idpyoidc/client/work_condition/transform.py +++ b/src/idpyoidc/client/work_condition/transform.py @@ -68,17 +68,51 @@ # ] -def supported_to_preferred(supported: dict, preference: dict, info: Optional[dict] = None): - for key, val in supported.items(): - if info and key in info: - preference[key] = info[key] - continue - - if val is None: - continue - - if key not in preference: - preference[key] = val +def supported_to_preferred(supported: dict, + preference: dict, + base_url: str, + info: Optional[dict] = None, + ): + if info: # The provider info + for key, val in supported.items(): + if key in preference: + _pref_val = preference.get(key) # defined in configuration + _info_val = info.get(key) + if _info_val: + # Only use provider setting if less or equal to what I support + if key.endswith('supported'): # list + preference[key] = [x for x in _pref_val if x in _info_val] + else: + pass + elif val is None: # No default + # if key not in ['jwks_uri', 'jwks']: + pass + else: + # there is a default + _info_val = info.get(key) + if _info_val: # The OP has an opinion + if key.endswith('supported'): # list + preference[key] = [x for x in val if x in _info_val] + else: + pass + else: + preference[key] = val + + # special case -> must have a request_uris value + if 'require_request_uri_registration' in info: + # only makes sense if I want to use request_uri + if preference.get('request_parameter') == 'request_uri': + if 'request_uri' not in preference: + preference['request_uris'] = [f'{base_url}/requests'] + else: # just ignore + logger.info('Asked for "request_uri" which it did not plan to use') + else: + # Add defaults + for key, val in supported.items(): + if val is None: + continue + if key not in preference: + preference[key] = val return preference @@ -104,6 +138,7 @@ def preferred_to_register(prefers: dict, use: Optional[dict] = None): else: use[key] = _preferred_values + # transfer those claims that are not part of the registration request _rr_keys = list(RegistrationResponse.c_param.keys()) for key, val in prefers.items(): if PREFERRED2REGISTER.get(key): diff --git a/tests/request123456.jwt b/tests/request123456.jwt index b462cac9..697ab2af 100644 --- a/tests/request123456.jwt +++ b/tests/request123456.jwt @@ -1 +1 @@ -eyJhbGciOiJSUzI1NiIsImtpZCI6IlNIRXlZV2N3TlZrMExUZFJPVFp6WjJGVVduZElWWGRhY2sweFdVTTVTRXB3Y1MwM2RWVXhXVTR6UlEifQ.eyJyZXNwb25zZV90eXBlIjogImNvZGUiLCAic3RhdGUiOiAic3RhdGUiLCAicmVkaXJlY3RfdXJpIjogImh0dHBzOi8vZXhhbXBsZS5jb20vY2xpL2F1dGh6X2NiIiwgInNjb3BlIjogIm9wZW5pZCIsICJub25jZSI6ICJtdEk0TWk5WkFweTR4TWNLSkF0c3BXVFRwa1RqcUFTTHpLWHg1Y0VhNEt3IiwgImNsaWVudF9pZCI6ICJjbGllbnRfaWQiLCAiaXNzIjogImNsaWVudF9pZCIsICJpYXQiOiAxNjY4NjIwMDA1LCAiYXVkIjogWyJodHRwczovL2V4YW1wbGUuY29tIl19.cSFmCUFZh6jCHiAC7n6EkC_gIkLfnlH2GXCVTUV2SfF19k2wHVH2L8hLj4SmjreoVYKNkhJdt6qpxpmmAP4dWZorUhFJc4j4vp0rIFflffVSg5db1bCvG4_H_XqJxhQdpcUlqfTTkKiqQ9v4fnbh_mTtDJc8ZHLjHaPrRFsSNTvsGeR366PL8bbSrY7F5CX_Ox86B5gIMKDCNt6Cqywd0TcfN5PFrLAKPe3rH1md3dg85dN64xFupSqKhqXlQ3QggrZDQLbGAUnf3YUeqSn2dGb8Of_hVgzfVN33P2uT6x7kkNRmizXEUlsGZ3IiFsPRC59ZF_rObnsRZrGa_9-uLg \ No newline at end of file +eyJhbGciOiJSUzI1NiIsImtpZCI6IlNIRXlZV2N3TlZrMExUZFJPVFp6WjJGVVduZElWWGRhY2sweFdVTTVTRXB3Y1MwM2RWVXhXVTR6UlEifQ.eyJyZXNwb25zZV90eXBlIjogImNvZGUiLCAic3RhdGUiOiAic3RhdGUiLCAicmVkaXJlY3RfdXJpIjogImh0dHBzOi8vZXhhbXBsZS5jb20vY2IiLCAic2NvcGUiOiAib3BlbmlkIiwgIm5vbmNlIjogIjhhaXBhMEszdmZwYVd0SlhXbEZCdzdTcEVHUzZlMFBON0dubEJFMW5qRlkiLCAiY2xpZW50X2lkIjogImNsaWVudF9pZCIsICJpc3MiOiAiY2xpZW50X2lkIiwgImlhdCI6IDE2Njg2ODc1OTQsICJhdWQiOiBbImh0dHBzOi8vZXhhbXBsZS5jb20iXX0.e8VXl40EMAqA_ZpLA_NZ3KJsVMLs8bzqFX81IWQWmUiUOj8pnaKVidKOS9ddGQW0Vt4wuA_iJzrLnYu-MS683RHjocddZhJbEhJmRztBjOZqSgYsP1hQtpOU9U3FDINgn8d6U_eDke1mC46xQhFL8OZQlqgFXaQ5lk2-XHWOsGzQdqiDiwo8aydwwIhXOqNsZFA8StEGnl7iWy-jOq52fRHKBpFy6vDsS7N8tg7-QaoPBOoWl2kTBLWDIWQq3Nu5bIkvk0Qq0ydhCmfJRzeeelhCgtPSQTQY8ulkzTdhsKSBEOv1hAVAP707-jmHjPcVCOMZS_NWGwPO7g94P76n0A \ No newline at end of file diff --git a/tests/test_client_02b_entity_metadata.py b/tests/test_client_02b_entity_metadata.py index 583faf0c..2cc9c356 100644 --- a/tests/test_client_02b_entity_metadata.py +++ b/tests/test_client_02b_entity_metadata.py @@ -15,7 +15,7 @@ "response_types": ["code"], "client_id": "client_id", "redirect_uris": ["https://example.com/cli/authz_cb"], - 'request_parameter': True, + 'request_parameter': "request_uri", "request_object_signing_alg_values_supported": ["ES256"], "scope": ["openid", "profile", "email", "address", "phone"], "token_endpoint_auth_methods_supported": ["private_key_jwt"], @@ -72,6 +72,7 @@ def test_create_client(): 'client_id', 'client_secret', 'contacts', + 'default_max_age', 'grant_types_supported', 'id_token_encryption_alg_values_supported', 'id_token_encryption_enc_values_supported', @@ -98,20 +99,24 @@ def test_create_client(): # assert _context.get_preference("userinfo_signing_alg_values_supported") == ['ES256'] # How to act - _context.work_condition.use = _context.map_preferred_to_register() + _context.map_preferred_to_register() assert _context.get_usage("request_uris") is None - _conf_args = _context.collect_usage() + _conf_args = list(_context.collect_usage().keys()) assert _conf_args - assert len(_conf_args) == 25 + assert len(_conf_args) == 20 rr = set(RegistrationRequest.c_param.keys()) + # The ones that are not defined d = rr.difference(set(_conf_args)) assert d == {'initiate_login_uri', 'client_name', 'post_logout_redirect_uri', 'tos_uri', 'logo_uri', 'jwks_uri', 'federation_type', 'frontchannel_logout_session_required', 'require_auth_time', 'client_uri', 'frontchannel_logout_uri', 'request_uris', - 'sector_identifier_uri', 'default_max_age', 'organization_name', 'policy_uri', - 'default_acr_values'} + 'sector_identifier_uri', 'organization_name', 'policy_uri', + 'default_acr_values', 'userinfo_encrypted_response_alg', + 'id_token_encrypted_response_alg', 'request_object_encryption_alg', + 'userinfo_encrypted_response_enc', 'request_object_encryption_enc', + 'id_token_encrypted_response_enc'} def test_create_client_key_conf(): diff --git a/tests/test_client_04_service.py b/tests/test_client_04_service.py index 9101d377..17871060 100644 --- a/tests/test_client_04_service.py +++ b/tests/test_client_04_service.py @@ -52,12 +52,8 @@ def test_use(self): assert set(use.keys()) == {'client_id', 'redirect_uris', 'response_types', 'grant_types', 'application_type', 'jwks', 'subject_type', - 'id_token_signed_response_alg', - 'id_token_encrypted_response_alg', - 'id_token_encrypted_response_enc', - 'request_object_signing_alg', - 'request_object_encryption_alg', - 'request_object_encryption_enc', 'scope'} + 'id_token_signed_response_alg', 'default_max_age', + 'request_object_signing_alg', 'scope'} def test_gather_request_args(self): self.service.conf["request_args"] = {"response_type": "code"} diff --git a/tests/test_client_21_oidc_service.py b/tests/test_client_21_oidc_service.py index 8ed5a87a..c8973ddd 100644 --- a/tests/test_client_21_oidc_service.py +++ b/tests/test_client_21_oidc_service.py @@ -214,12 +214,10 @@ def test_request_param(self): assert os.path.isfile(os.path.join(_dirname, "request123456.jwt")) _context = self.service.client_get("service_context") - _context.registration_response = { - "redirect_uris": ["https://example.com/cb"], - "request_uris": ["https://example.com/request123456.jwt"], - } + _context.set_usage("redirect_uris", ["https://example.com/cb"]) + _context.set_usage("request_uris", ["https://example.com/request123456.jwt"]) _context.base_url = "https://example.com/" - _context.set_usage('request_object_encryption_alg', None) + # _context.set_usage('request_object_encryption_alg', None) _info = self.service.get_request_parameters( request_args=req_args, request_method="reference" ) @@ -311,14 +309,19 @@ def create_request(self): client_config = { "client_id": "client_id", "client_secret": "a longesh password", - "callback": { + "callback_uris": { "code": "https://example.com/cli/authz_cb", "implicit": "https://example.com/cli/authz_im_cb", "form_post": "https://example.com/cli/authz_fp_cb", }, } - entity = Entity(keyjar=make_keyjar(), config=client_config, services=DEFAULT_OIDC_SERVICES) - entity.client_get("service_context").issuer = "https://example.com" + entity = Entity(keyjar=make_keyjar(), config=client_config, services=DEFAULT_OIDC_SERVICES, + client_type='oidc') + _context = entity.client_get("service_context") + _context.issuer = "https://example.com" + _context.map_supported_to_preferred() + _context.map_preferred_to_register() + self.service = entity.client_get("service", "authorization") def test_construct_code(self): @@ -525,8 +528,19 @@ def create_service(self): "redirect_uris": ["https://example.com/cli/authz_cb"], "issuer": self._iss, "application_name": "rphandler", - "support": { + "application_type": "web", + "contacts": ["ops@example.org"], + "preference": { "scope": ["openid", "profile", "email", "address", "phone"], + "response_types_supported": ["code"], + "request_object_signing_alg_values_supported": ["ES256"], + "encrypt_id_token_supported": False, # default + "token_endpoint_auth_methods_supported": ["private_key_jwt"], + "token_endpoint_auth_signing_alg_values_supported": ["ES256"], + "userinfo_signing_alg_values_supported": ["ES256"], + "post_logout_redirect_uris": ["https://rp.example.com/post"], + "backchannel_logout_uri": "https://rp.example.com/back", + "backchannel_logout_session_required": True }, "services": { "web_finger": {"class": "idpyoidc.client.oidc.webfinger.WebFinger"}, @@ -535,48 +549,23 @@ def create_service(self): }, "registration": { "class": "idpyoidc.client.oidc.registration.Registration", - "kwargs": { - "metadata": { - "application_type": "web", - "contacts": ["ops@example.org"], - "response_types": ["code"] - } - } + "kwargs": {} }, "authorization": { "class": "idpyoidc.client.oidc.authorization.Authorization", - "kwargs": { - "metadata": { - "request_object_signing_alg": "ES256", - } - } + "kwargs": {} }, "accesstoken": { "class": "idpyoidc.client.oidc.access_token.AccessToken", - "kwargs": { - "metadata": { - "token_endpoint_auth_method": "private_key_jwt", - "token_endpoint_auth_signing_alg": "ES256" - } - } + "kwargs": {} }, "userinfo": { "class": "idpyoidc.client.oidc.userinfo.UserInfo", - "kwargs": { - "metadata": { - "userinfo_signed_response_alg": "ES256" - }, - } + "kwargs": {} }, "end_session": { "class": "idpyoidc.client.oidc.end_session.EndSession", - "kwargs": { - "metadata": { - "post_logout_redirect_uris": ["https://rp.example.com/post"], - "backchannel_logout_uri": "https://rp.example.com/back", - "backchannel_logout_session_required": True - } - } + "kwargs": {} } } } @@ -608,7 +597,7 @@ def test_post_parse(self): "claims_parameter_supported": True, "request_parameter_supported": True, "request_uri_parameter_supported": True, - "require_request_uri_registration": True, + # "require_request_uri_registration": True, "grant_types_supported": [ "authorization_code", "implicit", @@ -784,7 +773,8 @@ def test_post_parse(self): "registration_endpoint": "{}/registration".format(OP_BASEURL), "end_session_endpoint": "{}/end_session".format(OP_BASEURL), } - assert self.service.client_get("service_context").work_condition.use == {} + _context = self.service.client_get("service_context") + assert _context.work_condition.use == {} resp = self.service.post_parse_response(provider_info_response) iss_jwks = ISS_KEY.export_jwks_as_json(issuer_id=ISS) @@ -793,7 +783,19 @@ def test_post_parse(self): self.service.update_service_context(resp) - assert self.service.client_get("service_context").work_condition.use == { + # static client registration + _context.map_preferred_to_register() + + use_copy = self.service.client_get("service_context").work_condition.use.copy() + # jwks content will change dynamically between runs + assert 'jwks' in use_copy + del use_copy['jwks'] + + assert use_copy == { + 'client_secret': 'a longesh password', + 'contacts': ['ops@example.org'], + 'default_max_age': 86400, + 'encrypt_id_token_supported': False, 'application_type': 'web', 'backchannel_logout_session_required': True, 'backchannel_logout_uri': 'https://rp.example.com/back', @@ -806,7 +808,9 @@ def test_post_parse(self): 'token_endpoint_auth_method': 'private_key_jwt', 'token_endpoint_auth_signing_alg': 'ES256', 'userinfo_signed_response_alg': 'ES256', - 'scope': ["openid", "profile", "email", "address", "phone"] + 'scope': ["openid", "profile", "email", "address", "phone"], + 'request_object_signing_alg': 'ES256', + 'subject_type': 'public' } def test_post_parse_2(self): @@ -828,7 +832,8 @@ def test_post_parse_2(self): "registration_endpoint": "{}/registration".format(OP_BASEURL), "end_session_endpoint": "{}/end_session".format(OP_BASEURL), } - assert self.service.client_get("service_context").work_condition.use == {} + _context = self.service.client_get("service_context") + assert _context.work_condition.use == {} resp = self.service.post_parse_response(provider_info_response) iss_jwks = ISS_KEY.export_jwks_as_json(issuer_id=ISS) @@ -837,7 +842,15 @@ def test_post_parse_2(self): self.service.update_service_context(resp) - assert self.service.client_get("service_context").work_condition.use == { + # static client registration + _context.map_preferred_to_register() + + use_copy = self.service.client_get("service_context").work_condition.use.copy() + # jwks content will change dynamically between runs + assert 'jwks' in use_copy + del use_copy['jwks'] + + assert use_copy == { 'application_type': 'web', 'backchannel_logout_session_required': True, 'backchannel_logout_uri': 'https://rp.example.com/back', @@ -850,7 +863,13 @@ def test_post_parse_2(self): 'token_endpoint_auth_method': 'private_key_jwt', 'token_endpoint_auth_signing_alg': 'ES256', 'userinfo_signed_response_alg': 'ES256', - 'scope': ["openid", "profile", "email", "address", "phone"] + 'scope': ["openid", "profile", "email", "address", "phone"], + 'client_secret': 'a longesh password', + 'contacts': ['ops@example.org'], + 'default_max_age': 86400, + 'encrypt_id_token_supported': False, + 'request_object_signing_alg': 'ES256', + 'subject_type': 'public' } @@ -892,15 +911,15 @@ def create_request(self): def test_construct(self): _req = self.service.construct() assert isinstance(_req, RegistrationRequest) - assert len(_req) == 6 + assert len(_req) == 5 def test_config_with_post_logout(self): - self.service.client_get("service_context").work_condition.set_metadata_claim( + self.service.client_get("service_context").work_condition.set_usage( "post_logout_redirect_uri", "https://example.com/post_logout") _req = self.service.construct() assert isinstance(_req, RegistrationRequest) - assert len(_req) == 7 + assert len(_req) == 6 assert "post_logout_redirect_uri" in _req @@ -911,6 +930,8 @@ def test_config_with_required_request_uri(): "redirect_uris": ["https://example.com/cli/authz_cb"], "issuer": ISS, "requests_dir": "requests", + "request_parameter": 'request_uri', + 'request_uris': ["https://example.com/cli/requests"], "base_url": "https://example.com/cli", } entity = Entity(keyjar=make_keyjar(), config=client_config, services=DEFAULT_OIDC_SERVICES, @@ -925,7 +946,9 @@ def test_config_with_required_request_uri(): assert isinstance(_req, RegistrationRequest) assert set(_req.keys()) == {"application_type", "response_types", "jwks", "redirect_uris", "grant_types", "id_token_signed_response_alg", - "request_uris"} + "request_uris", 'default_max_age', 'request_object_signing_alg', + 'subject_type', 'token_endpoint_auth_method', + 'token_endpoint_auth_signing_alg', 'userinfo_signed_response_alg'} def test_config_logout_uri(): @@ -935,12 +958,14 @@ def test_config_logout_uri(): "redirect_uris": ["https://example.com/cli/authz_cb"], "issuer": ISS, "requests_dir": "requests", + "request_uris": ["https://example.com/cli/requests"], "base_url": "https://example.com/cli/", - "usage": { - "request_parameter_preference": "request_uri" + "preference": { + "request_parameter": "request_uri" } } - entity = Entity(keyjar=make_keyjar(), config=client_config, services=SERVICES) + entity = Entity(keyjar=make_keyjar(), config=client_config, services=SERVICES, + client_type='oidc') _context = entity.client_get("service_context") _context.issuer = "https://example.com" @@ -951,9 +976,19 @@ def test_config_logout_uri(): reg_service = entity.client_get("service", "registration") _req = reg_service.construct() assert isinstance(_req, RegistrationRequest) - assert len(_req) == 8 - assert "request_uris" in _req - assert "backchannel_logout_uri" in _req + assert set(_req.keys()) == {'application_type', + 'default_max_age', + 'grant_types', + 'id_token_signed_response_alg', + 'jwks', + 'redirect_uris', + 'request_object_signing_alg', + 'request_uris', + 'response_types', + 'subject_type', + 'token_endpoint_auth_method', + 'token_endpoint_auth_signing_alg', + 'userinfo_signed_response_alg'} class TestUserInfo(object): @@ -969,7 +1004,8 @@ def create_request(self): "requests_dir": "requests", "base_url": "https://example.com/cli/", } - entity = Entity(keyjar=make_keyjar(), config=client_config, services=DEFAULT_OIDC_SERVICES) + entity = Entity(keyjar=make_keyjar(), config=client_config, services=DEFAULT_OIDC_SERVICES, + client_type='oidc') entity.client_get("service_context").issuer = "https://example.com" self.service = entity.client_get("service", "userinfo") @@ -1172,13 +1208,14 @@ def create_request(self): "issuer": self._iss, "requests_dir": "requests", "base_url": "https://example.com/cli/", - "metadata": { - "post_logout_redirect_uris": ["https://example.com/post_logout"] - } + "post_logout_redirect_uris": ["https://example.com/post_logout"] } services = {"checksession": {"class": "idpyoidc.client.oidc.end_session.EndSession"}} entity = Entity(keyjar=make_keyjar(), config=client_config, services=services) - entity.client_get("service_context").issuer = "https://example.com" + _context = entity.client_get("service_context") + _context.issuer = "https://example.com" + _context.map_supported_to_preferred() + _context.map_preferred_to_register() self.service = entity.client_get("service", "end_session") def test_construct(self): @@ -1214,8 +1251,12 @@ def test_authz_service_conf(): }, } } - entity = Entity(keyjar=make_keyjar(), config=client_config, services=services) - entity.client_get("service_context").issuer = "https://example.com" + entity = Entity(keyjar=make_keyjar(), config=client_config, services=services, + client_type='oidc') + _context = entity.client_get("service_context") + _context.issuer = "https://example.com" + _context.map_supported_to_preferred() + _context.map_preferred_to_register() service = entity.client_get("service", "authorization") req = service.construct() @@ -1227,33 +1268,20 @@ def test_jwks_uri_conf(): client_config = { "client_secret": "a longesh password", "issuer": ISS, - "metadata": { - "client_id": "client_id", - "jwks_uri": "https://example.com/jwks/jwks.json", - "redirect_uris": ["https://example.com/cli/authz_cb"], - "id_token_signed_response_alg": "RS384", - "userinfo_signed_response_alg": "RS384", - }, + "client_id": "client_id", + "jwks_uri": "https://example.com/jwks/jwks.json", + "redirect_uris": ["https://example.com/cli/authz_cb"], + "id_token_signed_response_alg": "RS384", + "userinfo_signed_response_alg": "RS384", } - entity = Entity(keyjar=make_keyjar(), config=client_config, services=DEFAULT_OIDC_SERVICES) - assert entity.will_use("jwks_uri") - + entity = Entity(keyjar=make_keyjar(), config=client_config, services=DEFAULT_OIDC_SERVICES, + client_type='oidc') + _context = entity.client_get("service_context") + _context.issuer = "https://example.com" + _context.map_supported_to_preferred() + _context.map_preferred_to_register() -def test_add_jwks_uri_or_jwks(): - client_config = { - "client_secret": "a longesh password", - "issuer": ISS, - "metadata": { - "client_id": "client_id", - "redirect_uris": ["https://example.com/cli/authz_cb"], - "jwks_uri": "https://example.com/jwks/jwks.json", - "id_token_signed_response_alg": "RS384", - "userinfo_signed_response_alg": "RS384", - }, - } - entity = Entity(keyjar=make_keyjar(), config=client_config, services=DEFAULT_OIDC_SERVICES) - # jwks_uri has higher priority the jwks - assert entity.will_use("jwks_uri") + assert _context.get_usage("jwks_uri") def test_jwks_uri_arg(): @@ -1272,5 +1300,11 @@ def test_jwks_uri_arg(): config=client_config, jwks_uri="https://example.com/jwks/jwks.json", services=DEFAULT_OIDC_SERVICES, + client_type='oidc' ) - assert entity.will_use("jwks_uri") + _context = entity.client_get("service_context") + _context.issuer = "https://example.com" + _context.map_supported_to_preferred() + _context.map_preferred_to_register() + + assert _context.get_usage("jwks_uri") diff --git a/tests/test_client_23_pkce.py b/tests/test_client_23_pkce.py index 6d92ce58..80b81c64 100644 --- a/tests/test_client_23_pkce.py +++ b/tests/test_client_23_pkce.py @@ -56,10 +56,14 @@ def create_client(self): } }, } - self.entity = Entity(keyjar=CLI_KEY, config=config, services=DEFAULT_OAUTH2_SERVICES) + self.entity = Entity(keyjar=CLI_KEY, config=config, services=DEFAULT_OAUTH2_SERVICES, + client_type='oauth2') if "add_ons" in config: do_add_ons(config["add_ons"], self.entity.client_get("services")) + _context = self.entity.get_service_context() + _context.map_supported_to_preferred() + _context.map_preferred_to_register() def test_add_code_challenge_default_values(self): auth_serv = self.entity.client_get("service", "authorization") diff --git a/tests/test_client_26_read_registration.py b/tests/test_client_26_read_registration.py index 21959ddb..ba4fed7d 100644 --- a/tests/test_client_26_read_registration.py +++ b/tests/test_client_26_read_registration.py @@ -18,19 +18,16 @@ class TestRegistrationRead(object): def create_request(self): self._iss = ISS client_config = { - "redirect_uris": ["https://example.com/cli/authz_cb"], "issuer": self._iss, "requests_dir": "requests", "base_url": "https://example.com/cli/", - "metadata": { - "application_type": "web", - "response_types": ["code"], - "contacts": ["ops@example.org"], - "jwks_uri": "https://example.com/rp/static/jwks.json", - "redirect_uris": ["{}/authz_cb".format(RP_BASEURL)], - "token_endpoint_auth_method": "client_secret_basic", - "grant_types": ["authorization_code"], - }, + "application_type": "web", + "response_types": ["code"], + "contacts": ["ops@example.org"], + "jwks_uri": "https://example.com/rp/static/jwks.json", + "redirect_uris": ["{}/authz_cb".format(RP_BASEURL)], + "token_endpoint_auth_method": "client_secret_basic", + "grant_types": ["authorization_code"], } services = { "registration": {"class": "idpyoidc.client.oidc.registration.Registration"}, @@ -40,6 +37,9 @@ def create_request(self): } self.entity = Entity(config=client_config, services=services) + _context = self.entity.get_service_context() + _context.map_supported_to_preferred() + _context.map_preferred_to_register() self.reg_service = self.entity.client_get("service", "registration") self.read_service = self.entity.client_get("service", "registration_read") diff --git a/tests/test_client_27_conversation.py b/tests/test_client_27_conversation.py index 6b2a1852..da698e21 100644 --- a/tests/test_client_27_conversation.py +++ b/tests/test_client_27_conversation.py @@ -114,66 +114,46 @@ }, "authorization": { "class": "idpyoidc.client.oidc.authorization.Authorization", - "kwargs": { - "metadata": { - "request_object_signing_alg": "ES256" - }, - "usage": { - "request_uri": True - } - } + "kwargs": {} }, "accesstoken": { "class": "idpyoidc.client.oidc.access_token.AccessToken", - "kwargs": { - "metadata": { - "token_endpoint_auth_method": "private_key_jwt", - "token_endpoint_auth_signing_alg": "ES256" - } - } + "kwargs": {} }, "refresh_token": { "class": "idpyoidc.client.oidc.refresh_access_token.RefreshAccessToken" }, "userinfo": { "class": "idpyoidc.client.oidc.userinfo.UserInfo", - "kwargs": { - "metadata": { - "userinfo_signed_response_alg": "ES256" - }, - } + "kwargs": {} }, "end_session": { "class": "idpyoidc.client.oidc.end_session.EndSession", - "kwargs": { - "metadata": { - "post_logout_redirect_uri": "https://rp.example.com/post", - "backchannel_logout_uri": "https://rp.example.com/back", - "backchannel_logout_session_required": True - }, - "usage": { - "backchannel_logout": True - } - } + "kwargs": {} } } def test_conversation(): config = { - "metadata": { - "application_type": "web", - "contacts": ["ops@example.org"], - "redirect_uris": [f"{RP_BASEURL}/authz_cb"], - "response_types": ["code"], - }, - "usage": { - "scope": ["openid", "profile", "email", "address", "phone"], - }, + "application_type": "web", + "contacts": ["ops@example.org"], + "redirect_uris": [f"{RP_BASEURL}/authz_cb"], + "response_types": ["code"], + "scope": ["openid", "profile", "email", "address", "phone"], + "request_object_signing_alg": "ES256", + "request_uris": [f"{RP_BASEURL}/requests"], + "token_endpoint_auth_methods_supported": ["private_key_jwt"], + "token_endpoint_auth_signing_alg_values_supported": ["ES256"], + "userinfo_signing_alg_values_supported": ["ES256"], + "post_logout_redirect_uri": "https://rp.example.com/post", + "backchannel_logout_uri": "https://rp.example.com/back", + "backchannel_logout_session_required": True, + 'allow': {'missing_kid': True}, "services": SERVICES } - entity = Entity(config=config, keyjar=RP_KEYJAR) + entity = Entity(config=config, keyjar=RP_KEYJAR, client_type='oidc') assert set(entity.client_get("services").keys()) == { "accesstoken", @@ -437,20 +417,22 @@ def test_conversation(): assert info["url"] == "https://example.org/op/registration" _body = json.loads(info["body"]) - assert set(_body.keys()) == { - "application_type", - 'backchannel_logout_uri', - 'backchannel_logout_session_required', - "contacts", - "grant_types", - 'id_token_signed_response_alg', - 'jwks', - "redirect_uris", - "response_types", - "token_endpoint_auth_method", - 'userinfo_signed_response_alg', - 'token_endpoint_auth_signing_alg', - } + assert set(_body.keys()) == {'application_type', + 'backchannel_logout_session_required', + 'backchannel_logout_uri', + 'contacts', + 'default_max_age', + 'grant_types', + 'id_token_signed_response_alg', + 'jwks', + 'redirect_uris', + 'request_object_signing_alg', + 'request_uris', + 'response_types', + 'subject_type', + 'token_endpoint_auth_method', + 'token_endpoint_auth_signing_alg', + 'userinfo_signed_response_alg'} assert info["headers"] == {"Content-Type": "application/json"} now = int(time.time()) @@ -477,7 +459,7 @@ def test_conversation(): registration_service.update_service_context(response) assert service_context.get_client_id() == "zls2qhN1jO6A" - assert service_context.client_secret == "c8434f28cf9375d9a7" + assert service_context.get_usage('client_secret') == "c8434f28cf9375d9a7" assert set(service_context.registration_response.keys()) == { "client_secret_expires_at", "contacts", @@ -535,7 +517,7 @@ def test_conversation(): # =================== Access token ==================== token_service = entity.client_get("service", "accesstoken") - request_args = {"state": STATE, "redirect_uri": entity.get_metadata_value("redirect_uris")[0]} + request_args = {"state": STATE, "redirect_uri": service_context.get_usage("redirect_uris")[0]} info = token_service.get_request_parameters(request_args=request_args) diff --git a/tests/test_client_28_rp_handler_oidc.py b/tests/test_client_28_rp_handler_oidc.py index 7358c426..9d48394e 100644 --- a/tests/test_client_28_rp_handler_oidc.py +++ b/tests/test_client_28_rp_handler_oidc.py @@ -20,7 +20,7 @@ BASE_URL = "https://example.com/rp" -METADATA = { +PREF = { "application_type": "web", "contacts": ["ops@example.com"], "response_types": [ @@ -32,9 +32,6 @@ "code token", ], "token_endpoint_auth_method": "client_secret_basic", -} - -USAGE = { "scope": ["openid", "profile", "email", "address", "phone"], "verify_args": {"allow_sign_alg_none": True}, "request_uri": True @@ -42,8 +39,7 @@ CLIENT_CONFIG = { "": { - "metadata": METADATA, - "usage": USAGE, + "preference": PREF, "redirect_uris": None, "base_url": BASE_URL, "requests_dir": "requests", From 65f98dcbf0f4046a4d3d08da73c2f8f311275e84 Mon Sep 17 00:00:00 2001 From: Roland Hedberg Date: Fri, 18 Nov 2022 14:54:01 +0100 Subject: [PATCH 16/76] working on tests. --- src/idpyoidc/client/entity.py | 17 +++++--- src/idpyoidc/client/oauth2/authorization.py | 15 +++---- src/idpyoidc/client/oidc/authorization.py | 7 ++- src/idpyoidc/client/rp_handler.py | 5 +++ src/idpyoidc/client/service.py | 3 +- src/idpyoidc/client/service_context.py | 25 ++++++----- tests/pub_client.jwks | 2 +- tests/pub_iss.jwks | 2 +- tests/test_client_28_rp_handler_oidc.py | 48 ++++++++++----------- 9 files changed, 65 insertions(+), 59 deletions(-) diff --git a/src/idpyoidc/client/entity.py b/src/idpyoidc/client/entity.py index 9331504e..2f0f926f 100644 --- a/src/idpyoidc/client/entity.py +++ b/src/idpyoidc/client/entity.py @@ -1,5 +1,4 @@ import logging -import os from typing import Optional from typing import Union @@ -119,9 +118,15 @@ def __init__( self._service_context.work_condition.load_conf(config.conf, supports=self._service_context.supports()) - self._service_context.construct_uris(self._service_context.issuer, - self._service_context.hash_seed, - config.conf.get("callback")) + _response_types = self._service_context.get_preference( + 'response_types_supported', + self._service_context.supports()['response_types_supported']) + + _callback_uris = self._service_context.construct_uris(self._service_context.issuer, + self._service_context.hash_seed, + config.conf.get("callback"), + response_types=_response_types) + self._service_context.set_usage('callback_uris', _callback_uris) def client_get(self, what, *arg): _func = getattr(self, "get_{}".format(what), None) @@ -167,7 +172,7 @@ def backward_compatibility(self, config): _work_condition = self._service_context.work_condition _uris = config.get("redirect_uris") if _uris: - _work_condition.set_preference("redirect_uris", _uris) + _work_condition.set_preference("redirect_uris", _uris) _dir = config.conf.get("requests_dir") if _dir: @@ -197,4 +202,4 @@ def prefers(self): return self._service_context.work_condition.prefers() def use(self): - return self._service_context.work_condition.get_use() \ No newline at end of file + return self._service_context.work_condition.get_use() diff --git a/src/idpyoidc/client/oauth2/authorization.py b/src/idpyoidc/client/oauth2/authorization.py index ec3788a3..79d777df 100644 --- a/src/idpyoidc/client/oauth2/authorization.py +++ b/src/idpyoidc/client/oauth2/authorization.py @@ -95,7 +95,8 @@ def post_parse_response(self, response, **kwargs): def construct_uris(self, base_url: str, hex: bytes, targets: Optional[List[str]] = None, - preference: Optional[dict] = None): + preference: Optional[dict] = None, + response_types: Optional[List[str]] = None): if not targets: targets = list(self._callback_path.keys()) @@ -104,13 +105,10 @@ def construct_uris(self, base_url: str, hex: bytes, spec = self._callback_path.get(uri_name) if spec: if uri_name == "redirect_uris": # another layer - _uris = [] + _uris = {} for typ, path in spec.items(): add = False - if 'response_type' in preference: - if typ in preference['response_type']: - add = True - elif typ in preference: + if typ in response_types: add = True elif 'response_type' in self._supports: if typ in self._supports['response_type']: @@ -119,8 +117,9 @@ def construct_uris(self, base_url: str, hex: bytes, add = True if add: - _uris.append(self.get_uri(base_url, path, hex)) + _uris[typ] = self.get_uri(base_url, path, hex) res[uri_name] = _uris elif uri_name in preference or uri_name in self._supports: res[uri_name] = self.get_uri(base_url, spec, hex) - return res \ No newline at end of file + + return res \ No newline at end of file diff --git a/src/idpyoidc/client/oidc/authorization.py b/src/idpyoidc/client/oidc/authorization.py index 34fd8620..da1e7153 100644 --- a/src/idpyoidc/client/oidc/authorization.py +++ b/src/idpyoidc/client/oidc/authorization.py @@ -325,10 +325,9 @@ def gather_verify_arguments( except KeyError: pass - try: - kwargs["allow_missing_kid"] = _context.allow["missing_kid"] - except KeyError: - pass + _allow = _context.allow.get("missing_kid") + if _allow: + kwargs["allow_missing_kid"] = _allow _verify_args = _context.get_usage("verify_args") if _verify_args: diff --git a/src/idpyoidc/client/rp_handler.py b/src/idpyoidc/client/rp_handler.py index 8849a710..1df984e9 100644 --- a/src/idpyoidc/client/rp_handler.py +++ b/src/idpyoidc/client/rp_handler.py @@ -264,6 +264,9 @@ def do_provider_info( _kj.add_kb(_context.get("issuer"), _kb) else: raise ValueError("Unknown provider JWKS type: {}".format(typ)) + + _context.map_supported_to_preferred(info=_pi) + try: return _context.get("provider_info")["issuer"] except KeyError: @@ -312,6 +315,8 @@ def do_client_registration( request_args.update({k: v for k, v in behaviour_args.items() if k in _params}) load_registration_response(client, request_args=request_args) + else: + _context.map_preferred_to_register() def do_webfinger(self, user: str) -> Client: """ diff --git a/src/idpyoidc/client/service.py b/src/idpyoidc/client/service.py index ffbe2f6d..3a77c811 100644 --- a/src/idpyoidc/client/service.py +++ b/src/idpyoidc/client/service.py @@ -632,7 +632,8 @@ def get_uri(base_url, path, hex): def construct_uris(self, base_url: str, hex: bytes, targets: Optional[List[str]] = None, - preference: Optional[dict] = None): + preference: Optional[dict] = None, + response_types: Optional[list] = None): if not targets: targets = self._callback_path.keys() res = {} diff --git a/src/idpyoidc/client/service_context.py b/src/idpyoidc/client/service_context.py index 389c0d33..3c83ff97 100644 --- a/src/idpyoidc/client/service_context.py +++ b/src/idpyoidc/client/service_context.py @@ -9,8 +9,8 @@ from typing import Optional from typing import Union -from cryptojwt.jwk.rsa import import_private_rsa_key_from_file from cryptojwt.jwk.rsa import RSAKey +from cryptojwt.jwk.rsa import import_private_rsa_key_from_file from cryptojwt.key_bundle import KeyBundle from cryptojwt.key_jar import KeyJar from cryptojwt.utils import as_bytes @@ -22,9 +22,9 @@ from idpyoidc.util import rndstr from .configure import get_configuration from .state_interface import StateInterface +from .work_condition import WorkCondition from .work_condition import work_condition_dump from .work_condition import work_condition_load -from .work_condition import WorkCondition from .work_condition.transform import preferred_to_register from .work_condition.transform import supported_to_preferred @@ -93,8 +93,6 @@ class ServiceContext(OidcContext): "args": None, "base_url": None, "behaviour": None, - "callback": None, - "client_secret": None, "client_secret_expires_at": 0, "clock_skew": None, "config": None, @@ -142,7 +140,7 @@ def __init__(self, self.base_url = base_url or config.get("base_url") or config.conf.get('base_url', '') # Below so my IDE won't complain - self.allow = config.conf.get('allow') + self.allow = config.conf.get('allow', {}) self.args = {} self.add_on = {} self.iss_hash = "" @@ -286,7 +284,7 @@ def prefers(self): return self.work_condition.prefer def get_preference(self, claim, default=None): - return self.work_condition.get_preference(claim) + return self.work_condition.get_preference(claim, default=default) def set_preference(self, key, value): self.work_condition.set_preference(key, value) @@ -300,7 +298,8 @@ def set_usage(self, claim, value): def construct_uris(self, issuer: str, hash_seed: bytes, - callback: Optional[dict] = None): + callback: Optional[dict] = None, + response_types: Optional[list] = None): _hash = hashlib.sha256() _hash.update(hash_seed) _hash.update(as_bytes(issuer)) @@ -309,14 +308,14 @@ def construct_uris(self, self.iss_hash = _hex _base_url = self.get("base_url") + _uris = {} if self.client_get: services = self.client_get('services') for service in services.values(): - service.construct_uris(base_url=_base_url, hex=_hex, - preference=self.work_condition.prefer) - - # if not self.work_condition.get_usage("redirect_uris"): - # self.work_condition.construct_redirect_uris(_base_url, _hex, callback) + _uris.update(service.construct_uris(base_url=_base_url, hex=_hex, + preference=self.work_condition.prefer, + response_types=response_types)) + return _uris def prefer_or_support(self, claim): if claim in self.work_condition.prefer: @@ -341,4 +340,4 @@ def map_supported_to_preferred(self, info: Optional[dict] = None): def map_preferred_to_register(self): self.work_condition.use = preferred_to_register(self.work_condition.prefer, self.work_condition.use) - return self.work_condition.use \ No newline at end of file + return self.work_condition.use diff --git a/tests/pub_client.jwks b/tests/pub_client.jwks index 84a27042..d5ce25ed 100644 --- a/tests/pub_client.jwks +++ b/tests/pub_client.jwks @@ -1 +1 @@ -{"keys": [{"kty": "EC", "use": "sig", "kid": "azZQQ2FEQjh3QnVZWVdrbHJkMEZSaWR6aVJ0LTBjeUFfeWRlbTRrRFZ5VQ", "crv": "P-256", "x": "2ADe18caWWGp6hpRbfa9HqQHDFNpid9xUmR56Wzm_wc", "y": "HnD_8QBanz4Y-UF8mKQFZXfqkGkXUSm34mLsdDKtSyk"}, {"kty": "RSA", "use": "sig", "kid": "SHEyYWcwNVk0LTdROTZzZ2FUWndIVXdack0xWUM5SEpwcS03dVUxWU4zRQ", "e": "AQAB", "n": "rRz52ddyP9Y2ezSlRsnkt-sjXfV_Ii7vOFX-cStLE3IUlVeSJGEe_kAASLr2r3BE2unjntaxj67NP8D95h_rzG1SpCklTEn-aTe3FOwNyTzUH_oiDVeRoEcf04Y43ciRGYRB5PhI6ii-2lYuig6hyUr776Qxiu6-0zw-M_ay2MgGSy5CEj55dDSvcUyxStUObxGpPWnEvybO1vnE7iJEWGNe0L5uPe5nLidOiR-JwjxSWEx1xZYtIjxaf2Ulu-qu4hwgwBUQdx4bNZyBfljKj55skWuHqPMG3xMjnedQC6Ms5bR3rIkbBpvmgI3kJK-4CZikM6ruyLo94-Lk19aYQw"}]} \ No newline at end of file +{"keys": [{"kty": "EC", "use": "sig", "kid": "azZQQ2FEQjh3QnVZWVdrbHJkMEZSaWR6aVJ0LTBjeUFfeWRlbTRrRFZ5VQ", "crv": "P-256", "x": "2ADe18caWWGp6hpRbfa9HqQHDFNpid9xUmR56Wzm_wc", "y": "HnD_8QBanz4Y-UF8mKQFZXfqkGkXUSm34mLsdDKtSyk"}, {"kty": "RSA", "use": "sig", "kid": "SHEyYWcwNVk0LTdROTZzZ2FUWndIVXdack0xWUM5SEpwcS03dVUxWU4zRQ", "n": "rRz52ddyP9Y2ezSlRsnkt-sjXfV_Ii7vOFX-cStLE3IUlVeSJGEe_kAASLr2r3BE2unjntaxj67NP8D95h_rzG1SpCklTEn-aTe3FOwNyTzUH_oiDVeRoEcf04Y43ciRGYRB5PhI6ii-2lYuig6hyUr776Qxiu6-0zw-M_ay2MgGSy5CEj55dDSvcUyxStUObxGpPWnEvybO1vnE7iJEWGNe0L5uPe5nLidOiR-JwjxSWEx1xZYtIjxaf2Ulu-qu4hwgwBUQdx4bNZyBfljKj55skWuHqPMG3xMjnedQC6Ms5bR3rIkbBpvmgI3kJK-4CZikM6ruyLo94-Lk19aYQw", "e": "AQAB"}]} \ No newline at end of file diff --git a/tests/pub_iss.jwks b/tests/pub_iss.jwks index 9b062907..77081f40 100644 --- a/tests/pub_iss.jwks +++ b/tests/pub_iss.jwks @@ -1 +1 @@ -{"keys": [{"kty": "EC", "use": "sig", "kid": "SmdKMlVGcG1zMnprdDdXZGpGWEczdHhlZVpGbkx1THpPdUY4d0w4bnZkSQ", "crv": "P-256", "x": "tRHJYm0fsOi0icpGEb33qiDVgt68ltMoYSWdLGhDGz4", "y": "fRpX0i6p5Jigf5I0qwW34PyStosMShwWAWS8x_w5o7E"}, {"kty": "RSA", "use": "sig", "kid": "R0FsaFdqREFaUFp1c0MwbUpsbHVSZ200blBJZWJVMTUtNGsyVlBmdHk5UQ", "e": "AQAB", "n": "2ilgsKVqF92KfhwmosSVeZOaDgb3RF1mbg-pqkmLO6YpOO06LF4V4angF-GhP-ysAm2E75aSIU4tnHVThFlcxTgKFqjYKJQXyVzTVK2r-L2IbvFPaDtvoU6WteybpMlIUVk2po3cFDGObCWYKCm7CUOLlwH0uOpui66P9VSCqdKVKbJRAQBvTSbP10KWPxulfqjWGJtHO5fY7-JVWwOBkG-eHSJIT_uaoPjyvKCZjknq04bLUV9qP78KRQpRyYijBN60w2v8F79baN9CN10TIEjjWKGz0uX0M_YYQzTUoSY5l5ka9RkL3wT4o2iQ1t5nHphX6aA-gqwgCQmi-nvjaw"}]} \ No newline at end of file +{"keys": [{"kty": "EC", "use": "sig", "kid": "SmdKMlVGcG1zMnprdDdXZGpGWEczdHhlZVpGbkx1THpPdUY4d0w4bnZkSQ", "crv": "P-256", "x": "tRHJYm0fsOi0icpGEb33qiDVgt68ltMoYSWdLGhDGz4", "y": "fRpX0i6p5Jigf5I0qwW34PyStosMShwWAWS8x_w5o7E"}, {"kty": "RSA", "use": "sig", "kid": "R0FsaFdqREFaUFp1c0MwbUpsbHVSZ200blBJZWJVMTUtNGsyVlBmdHk5UQ", "n": "2ilgsKVqF92KfhwmosSVeZOaDgb3RF1mbg-pqkmLO6YpOO06LF4V4angF-GhP-ysAm2E75aSIU4tnHVThFlcxTgKFqjYKJQXyVzTVK2r-L2IbvFPaDtvoU6WteybpMlIUVk2po3cFDGObCWYKCm7CUOLlwH0uOpui66P9VSCqdKVKbJRAQBvTSbP10KWPxulfqjWGJtHO5fY7-JVWwOBkG-eHSJIT_uaoPjyvKCZjknq04bLUV9qP78KRQpRyYijBN60w2v8F79baN9CN10TIEjjWKGz0uX0M_YYQzTUoSY5l5ka9RkL3wT4o2iQ1t5nHphX6aA-gqwgCQmi-nvjaw", "e": "AQAB"}]} \ No newline at end of file diff --git a/tests/test_client_28_rp_handler_oidc.py b/tests/test_client_28_rp_handler_oidc.py index 9d48394e..6de2e186 100644 --- a/tests/test_client_28_rp_handler_oidc.py +++ b/tests/test_client_28_rp_handler_oidc.py @@ -4,16 +4,16 @@ from urllib.parse import urlparse from urllib.parse import urlsplit +from cryptojwt.key_jar import init_key_jar import pytest import responses -from cryptojwt.key_jar import init_key_jar from idpyoidc.client.entity import Entity from idpyoidc.client.rp_handler import RPHandler -from idpyoidc.message.oidc import JRD from idpyoidc.message.oidc import AccessTokenResponse from idpyoidc.message.oidc import AuthorizationResponse from idpyoidc.message.oidc import IdToken +from idpyoidc.message.oidc import JRD from idpyoidc.message.oidc import Link from idpyoidc.message.oidc import OpenIDSchema from idpyoidc.message.oidc import ProviderConfigurationResponse @@ -112,9 +112,9 @@ "client_secret": "aaaaaaaaaaaaaaaaaaaa", "redirect_uris": ["{}/authz_cb/github".format(BASE_URL)], "preference": { - "response_types": ["code"], + "response_types_supported": ["code"], "scope": ["user", "public_repo"], - "token_endpoint_auth_method": "", + "token_endpoint_auth_methods_supported": [], "verify_args": {"allow_sign_alg_none": True}, }, "provider_info": { @@ -127,10 +127,7 @@ "class": "idpyoidc.client.oidc.authorization.Authorization", }, "access_token": {"class": "idpyoidc.client.oidc.access_token.AccessToken"}, - "userinfo": { - "class": "idpyoidc.client.oidc.userinfo.UserInfo", - "kwargs": {"conf": {"default_authn_method": ""}}, - }, + "userinfo": {"class": "idpyoidc.client.oidc.userinfo.UserInfo"}, "refresh_access_token": { "class": "idpyoidc.client.oidc.refresh_access_token.RefreshAccessToken" }, @@ -241,9 +238,11 @@ def test_init_client(self): _context = client.client_get("service_context") - assert _context.get_client_id() == "eeeeeeeee" - assert _context.get("client_secret") == "aaaaaaaaaaaaaaaaaaaa" - assert _context.get("issuer") == "https://github.com/login/oauth/authorize" + # Neither provider info discovery not client registration has been done + # So only preferences so far. + assert _context.get_preference('client_id') == "eeeeeeeee" + assert _context.get_preference("client_secret") == "aaaaaaaaaaaaaaaaaaaa" + assert _context.issuer == "https://github.com/login/oauth/authorize" assert _context.get("provider_info") is not None assert set(_context.get("provider_info").keys()) == { @@ -252,12 +251,8 @@ def test_init_client(self): "userinfo_endpoint", } - assert _context.get("preference") == { - "response_types": ["code"], - "scope": ["user", "public_repo"], - "token_endpoint_auth_method": "", - "verify_args": {"allow_sign_alg_none": True}, - } + _pref = [k for k,v in _context.prefers().items() if v] + assert _pref == ['jwks', 'client_id', 'client_secret', 'redirect_uris', 'scope'] _github_id = iss_id("github") _context.keyjar.import_jwks(GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) @@ -293,7 +288,8 @@ def test_do_client_registration(self): assert self.rph.hash2issuer["github"] == issuer assert ( - client.client_get("service_context").work_condition.callback.get("post_logout_redirect_uris") is None + client.client_get("service_context").work_condition.callback.get( + "post_logout_redirect_uris") is None ) def test_do_client_setup(self): @@ -301,9 +297,11 @@ def test_do_client_setup(self): _github_id = iss_id("github") _context = client.client_get("service_context") - assert _context.get_client_id() == "eeeeeeeee" - assert _context.get("client_secret") == "aaaaaaaaaaaaaaaaaaaa" - assert _context.get("issuer") == _github_id + # Neither provider info discovery not client registration has been done + # So only preferences so far. + assert _context.get_preference('client_id') == "eeeeeeeee" + assert _context.get_preference("client_secret") == "aaaaaaaaaaaaaaaaaaaa" + assert _context.issuer == _github_id _context.keyjar.import_jwks(GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) @@ -322,14 +320,14 @@ def test_create_callbacks(self): client = self.rph.init_client("https://op.example.com/") _srv = client.client_get("service", "registration") _context = _srv.client_get("service_context") - # add_callbacks(_context, []) - cb = _srv.client_get("service_context").work_condition.callback + cb = _srv.client_get("service_context").get_usage('callback_uris') - assert set(cb.keys()) == {"code", "implicit"} + assert set(cb.keys()) == {"request_uris", "redirect_uris"} + assert set(cb['redirect_uris'].keys()) == {'code', 'form_post'} _hash = _context.iss_hash - assert cb["code"] == f"https://example.com/rp/authz_cb/{_hash}" + assert cb['redirect_uris']["code"] == f"https://example.com/rp/authz_cb/{_hash}" assert list(self.rph.hash2issuer.keys()) == [_hash] From 44e37430597a43952789f20554aa60e0b4774e52 Mon Sep 17 00:00:00 2001 From: roland Date: Sun, 20 Nov 2022 09:42:38 +0100 Subject: [PATCH 17/76] Fixed tests up to client_28. Back and forth. --- src/idpyoidc/client/entity.py | 18 +++++++++++------ src/idpyoidc/client/oauth2/authorization.py | 20 +++++++++++++------ src/idpyoidc/client/oauth2/utils.py | 19 +++++++++++++----- src/idpyoidc/client/oidc/authorization.py | 6 ++++-- src/idpyoidc/client/service.py | 9 ++++++--- src/idpyoidc/client/service_context.py | 16 ++++++--------- .../client/work_condition/__init__.py | 1 - src/idpyoidc/client/work_condition/oidc.py | 4 ++++ tests/request123456.jwt | 2 +- tests/test_client_02b_entity_metadata.py | 3 ++- tests/test_client_04_service.py | 7 ++++--- tests/test_client_21_oidc_service.py | 13 +++++++++--- 12 files changed, 77 insertions(+), 41 deletions(-) diff --git a/src/idpyoidc/client/entity.py b/src/idpyoidc/client/entity.py index 2f0f926f..6b8cb091 100644 --- a/src/idpyoidc/client/entity.py +++ b/src/idpyoidc/client/entity.py @@ -63,6 +63,13 @@ def set_jwks_uri_or_jwks(service_context, config, jwks_uri, keyjar): _set_jwks(service_context, config, keyjar) +def redirect_uris_from_callback_uris(callback_uris): + res = [] + for k, v in callback_uris['redirect_uris'].items(): + res.extend(v) + return res + + class Entity(object): def __init__( @@ -120,13 +127,12 @@ def __init__( _response_types = self._service_context.get_preference( 'response_types_supported', - self._service_context.supports()['response_types_supported']) + self._service_context.supports().get('response_types_supported', [])) - _callback_uris = self._service_context.construct_uris(self._service_context.issuer, - self._service_context.hash_seed, - config.conf.get("callback"), - response_types=_response_types) - self._service_context.set_usage('callback_uris', _callback_uris) + _callback_uris = self._service_context.construct_uris(response_types=_response_types) + if _callback_uris: + self._service_context.set_preference('redirect_uris', + redirect_uris_from_callback_uris(_callback_uris)) def client_get(self, what, *arg): _func = getattr(self, "get_{}".format(what), None) diff --git a/src/idpyoidc/client/oauth2/authorization.py b/src/idpyoidc/client/oauth2/authorization.py index 79d777df..eb2ac17d 100644 --- a/src/idpyoidc/client/oauth2/authorization.py +++ b/src/idpyoidc/client/oauth2/authorization.py @@ -7,6 +7,7 @@ from idpyoidc.client.oauth2.utils import pre_construct_pick_redirect_uri from idpyoidc.client.oauth2.utils import set_state_parameter from idpyoidc.client.service import Service +from idpyoidc.client.service_context import ServiceContext from idpyoidc.exception import MissingParameter from idpyoidc.message import oauth2 from idpyoidc.message.oauth2 import ResponseMessage @@ -59,7 +60,7 @@ def gather_request_args(self, **kwargs): if "redirect_uri" not in ar_args: try: - # ar_args["redirect_uri"] = self.client_get("service_context").redirect_uris[0] + # _cb = self.client_get("service_context").get_usage("callback_uris") ar_args["redirect_uri"] = self.client_get("service_context").get_usage( "redirect_uris")[0] except (KeyError, AttributeError): @@ -93,15 +94,19 @@ def post_parse_response(self, response, **kwargs): pass return response - def construct_uris(self, base_url: str, hex: bytes, + def construct_uris(self, + base_url: str, + hex: bytes, + context: ServiceContext, targets: Optional[List[str]] = None, - preference: Optional[dict] = None, response_types: Optional[List[str]] = None): if not targets: targets = list(self._callback_path.keys()) - res = {} + res = context.get_preference('callback_uris', {}) for uri_name in targets: + if uri_name in res: + continue spec = self._callback_path.get(uri_name) if spec: if uri_name == "redirect_uris": # another layer @@ -119,7 +124,10 @@ def construct_uris(self, base_url: str, hex: bytes, if add: _uris[typ] = self.get_uri(base_url, path, hex) res[uri_name] = _uris - elif uri_name in preference or uri_name in self._supports: + elif uri_name == 'request_uris': + if 'request_uri' == context.get_preference('request_parameter'): + res[uri_name] = self.get_uri(base_url, spec, hex) + elif uri_name in context.prefers() or uri_name in self._supports: res[uri_name] = self.get_uri(base_url, spec, hex) - return res \ No newline at end of file + return res diff --git a/src/idpyoidc/client/oauth2/utils.py b/src/idpyoidc/client/oauth2/utils.py index 0b314057..3214d261 100644 --- a/src/idpyoidc/client/oauth2/utils.py +++ b/src/idpyoidc/client/oauth2/utils.py @@ -36,6 +36,9 @@ def pick_redirect_uri( return request_args["redirect_uri"] _callback_uris = context.get_preference("callback_uris") + if _callback_uris: + _callback_uris = _callback_uris.get("redirect_uris") + if _callback_uris: if not response_type: _conf_resp_types = context.get_usage("response_types", []) @@ -45,12 +48,18 @@ def pick_redirect_uri( _response_mode = request_args.get("response_mode") - if _response_mode == "form_post" or response_type == ["form_post"]: - redirect_uri = _callback_uris["form_post"] - elif response_type == "code" or response_type == ["code"]: - redirect_uri = _callback_uris["code"] + if _response_mode: + if _response_mode == "form_post": + redirect_uri = _callback_uris["form_post"][0] + elif response_type == "code" or response_type == ["code"]: + redirect_uri = _callback_uris["code"][0] + else: + redirect_uri = _callback_uris["implicit"][0] else: - redirect_uri = _callback_uris["implicit"] + if 'code' == response_type: + redirect_uri = _callback_uris["code"][0] + else: + redirect_uri = _callback_uris["implicit"][0] logger.debug( f"pick_redirect_uris: response_type={response_type}, response_mode={_response_mode}, " diff --git a/src/idpyoidc/client/oidc/authorization.py b/src/idpyoidc/client/oidc/authorization.py index da1e7153..3c685006 100644 --- a/src/idpyoidc/client/oidc/authorization.py +++ b/src/idpyoidc/client/oidc/authorization.py @@ -32,18 +32,20 @@ class Authorization(authorization.Authorization): "request_object_signing_alg_values_supported": work_condition.get_signing_algs, "request_object_encryption_alg_values_supported": work_condition.get_encryption_algs, "request_object_encryption_enc_values_supported": work_condition.get_encryption_encs, - "response_types_supported": ["code", "form_post"], + "response_types_supported": ["code", "token", "code token", 'id_token', 'id_token token', + 'code id_token', 'code idtoken token'], "request_uris": None, "request_parameter": None, "encrypt_request_object_supported": None, "redirect_uris": None, + "response_modes_supported": ['query', 'fragment', 'form_post'] } _callback_path = { "request_uris": "req", "redirect_uris": { # based on response_types "code": "authz_cb", - "implicit": "authz_im_cb", + "token": "authz_tok_cb", "form_post": "form" } } diff --git a/src/idpyoidc/client/service.py b/src/idpyoidc/client/service.py index 3a77c811..1ac4173e 100644 --- a/src/idpyoidc/client/service.py +++ b/src/idpyoidc/client/service.py @@ -26,6 +26,8 @@ __author__ = "Roland Hedberg" +from ..context import OidcContext + LOGGER = logging.getLogger(__name__) SUCCESSFUL = [200, 201, 202, 203, 204, 205, 206] @@ -630,9 +632,11 @@ def get_callback_path(self, callback): def get_uri(base_url, path, hex): return f"{base_url}/{path}/{hex}" - def construct_uris(self, base_url: str, hex: bytes, + def construct_uris(self, + base_url: str, + hex: bytes, + context: OidcContext, targets: Optional[List[str]] = None, - preference: Optional[dict] = None, response_types: Optional[list] = None): if not targets: targets = self._callback_path.keys() @@ -650,7 +654,6 @@ def callback_uris(self): return list(self._callback_path.keys()) - def init_services(service_definitions, client_get): """ Initiates a set of services diff --git a/src/idpyoidc/client/service_context.py b/src/idpyoidc/client/service_context.py index 3c83ff97..84ece19e 100644 --- a/src/idpyoidc/client/service_context.py +++ b/src/idpyoidc/client/service_context.py @@ -9,8 +9,8 @@ from typing import Optional from typing import Union -from cryptojwt.jwk.rsa import RSAKey from cryptojwt.jwk.rsa import import_private_rsa_key_from_file +from cryptojwt.jwk.rsa import RSAKey from cryptojwt.key_bundle import KeyBundle from cryptojwt.key_jar import KeyJar from cryptojwt.utils import as_bytes @@ -22,9 +22,9 @@ from idpyoidc.util import rndstr from .configure import get_configuration from .state_interface import StateInterface -from .work_condition import WorkCondition from .work_condition import work_condition_dump from .work_condition import work_condition_load +from .work_condition import WorkCondition from .work_condition.transform import preferred_to_register from .work_condition.transform import supported_to_preferred @@ -295,14 +295,10 @@ def get_usage(self, claim, default: Optional[str] = None): def set_usage(self, claim, value): return self.work_condition.set_usage(claim, value) - def construct_uris(self, - issuer: str, - hash_seed: bytes, - callback: Optional[dict] = None, - response_types: Optional[list] = None): + def construct_uris(self, response_types: Optional[list] = None): _hash = hashlib.sha256() - _hash.update(hash_seed) - _hash.update(as_bytes(issuer)) + _hash.update(self.hash_seed) + _hash.update(as_bytes(self.issuer)) _hex = _hash.hexdigest() self.iss_hash = _hex @@ -313,7 +309,7 @@ def construct_uris(self, services = self.client_get('services') for service in services.values(): _uris.update(service.construct_uris(base_url=_base_url, hex=_hex, - preference=self.work_condition.prefer, + context=self, response_types=response_types)) return _uris diff --git a/src/idpyoidc/client/work_condition/__init__.py b/src/idpyoidc/client/work_condition/__init__.py index 4d9b001b..298c952c 100644 --- a/src/idpyoidc/client/work_condition/__init__.py +++ b/src/idpyoidc/client/work_condition/__init__.py @@ -46,7 +46,6 @@ def __init__(self, self.callback_path = callback_path or {} self.use = {} self._local = {} - self.callback = {} def get_use(self): return self.use diff --git a/src/idpyoidc/client/work_condition/oidc.py b/src/idpyoidc/client/work_condition/oidc.py index f5c31ba6..062d8a3c 100644 --- a/src/idpyoidc/client/work_condition/oidc.py +++ b/src/idpyoidc/client/work_condition/oidc.py @@ -53,6 +53,10 @@ def verify_rules(self): _cb_uris = self.get_preference('callback_uris') if _cb_uris: self.set_preference('redirect_uris', list(_cb_uris.values())) # just overwrite + else: + _uris = self.get_preference('redirect_uris') + if _uris: + self.set_preference('callback_uris', {'redirect_uris': {'code': _uris}}) if not self.get_preference('encrypt_userinfo_supported'): self.set_preference('userinfo_encryption_alg_values_supported', []) diff --git a/tests/request123456.jwt b/tests/request123456.jwt index 697ab2af..68ca9f2f 100644 --- a/tests/request123456.jwt +++ b/tests/request123456.jwt @@ -1 +1 @@ -eyJhbGciOiJSUzI1NiIsImtpZCI6IlNIRXlZV2N3TlZrMExUZFJPVFp6WjJGVVduZElWWGRhY2sweFdVTTVTRXB3Y1MwM2RWVXhXVTR6UlEifQ.eyJyZXNwb25zZV90eXBlIjogImNvZGUiLCAic3RhdGUiOiAic3RhdGUiLCAicmVkaXJlY3RfdXJpIjogImh0dHBzOi8vZXhhbXBsZS5jb20vY2IiLCAic2NvcGUiOiAib3BlbmlkIiwgIm5vbmNlIjogIjhhaXBhMEszdmZwYVd0SlhXbEZCdzdTcEVHUzZlMFBON0dubEJFMW5qRlkiLCAiY2xpZW50X2lkIjogImNsaWVudF9pZCIsICJpc3MiOiAiY2xpZW50X2lkIiwgImlhdCI6IDE2Njg2ODc1OTQsICJhdWQiOiBbImh0dHBzOi8vZXhhbXBsZS5jb20iXX0.e8VXl40EMAqA_ZpLA_NZ3KJsVMLs8bzqFX81IWQWmUiUOj8pnaKVidKOS9ddGQW0Vt4wuA_iJzrLnYu-MS683RHjocddZhJbEhJmRztBjOZqSgYsP1hQtpOU9U3FDINgn8d6U_eDke1mC46xQhFL8OZQlqgFXaQ5lk2-XHWOsGzQdqiDiwo8aydwwIhXOqNsZFA8StEGnl7iWy-jOq52fRHKBpFy6vDsS7N8tg7-QaoPBOoWl2kTBLWDIWQq3Nu5bIkvk0Qq0ydhCmfJRzeeelhCgtPSQTQY8ulkzTdhsKSBEOv1hAVAP707-jmHjPcVCOMZS_NWGwPO7g94P76n0A \ No newline at end of file +eyJhbGciOiJSUzI1NiIsImtpZCI6IlNIRXlZV2N3TlZrMExUZFJPVFp6WjJGVVduZElWWGRhY2sweFdVTTVTRXB3Y1MwM2RWVXhXVTR6UlEifQ.eyJyZXNwb25zZV90eXBlIjogImNvZGUiLCAic3RhdGUiOiAic3RhdGUiLCAicmVkaXJlY3RfdXJpIjogImh0dHBzOi8vZXhhbXBsZS5jb20vY2IiLCAic2NvcGUiOiAib3BlbmlkIiwgIm5vbmNlIjogInQtZElTaFg0NWc2TnJtcWtrOWE3RGM3bjJiWGg0WVJhaUwzUWFQOTg2WlEiLCAiY2xpZW50X2lkIjogImNsaWVudF9pZCIsICJpc3MiOiAiY2xpZW50X2lkIiwgImlhdCI6IDE2Njg3OTgwOTUsICJhdWQiOiBbImh0dHBzOi8vZXhhbXBsZS5jb20iXX0.gb_JlFwEXKKlJpPjkE5wyl6hyDFRO6VXES4MmHLsmeyeqvR4tdkReZlFZEss8Zu9re9r6rJAUKri41H8S19FkbvYI4OgY1FdCQz2fjTIvzQ0E0Hd5RfAFx_IXyyC8-wa_KQRGGP3a16m1pP2N19JgsfXudhFygD6RuExROXLRg-z8jeg1mhaRp0EaWg61KNMMK0F6i17M790jeKeIvvevpyBCJA4qvgAu5W9d7_LRm_2sjCh_TdvOaYXkknIiYzGSDuqU1DrYSuMQBB3-n3G6kzXipwfJ1j3Pg7XvhZBPdFIV64CLmwO1eLhB50vJexGduU6128t42JlCEQCY2xXQQ \ No newline at end of file diff --git a/tests/test_client_02b_entity_metadata.py b/tests/test_client_02b_entity_metadata.py index 2cc9c356..364b7427 100644 --- a/tests/test_client_02b_entity_metadata.py +++ b/tests/test_client_02b_entity_metadata.py @@ -69,6 +69,7 @@ def test_create_client(): assert set(_pref.keys()) == {'application_type', 'backchannel_logout_session_required', 'backchannel_logout_uri', + 'callback_uris', 'client_id', 'client_secret', 'contacts', @@ -105,7 +106,7 @@ def test_create_client(): _conf_args = list(_context.collect_usage().keys()) assert _conf_args - assert len(_conf_args) == 20 + assert len(_conf_args) == 21 rr = set(RegistrationRequest.c_param.keys()) # The ones that are not defined d = rr.difference(set(_conf_args)) diff --git a/tests/test_client_04_service.py b/tests/test_client_04_service.py index 17871060..2e3a98dc 100644 --- a/tests/test_client_04_service.py +++ b/tests/test_client_04_service.py @@ -20,9 +20,10 @@ def __init__(self, status_code, text, headers=None): CLIENT_CONF = { "redirect_uris": ["https://example.com/cli/authz_cb"], - "preference": {"response_types": ["code"]}, + "preference": {"response_types_supported": ["code"]}, "key_conf": {"key_defs": KEYDEFS}, - "client_id": 'CLIENT' + "client_id": 'CLIENT', + 'base_url': "https://example.com/cli" } @@ -53,7 +54,7 @@ def test_use(self): assert set(use.keys()) == {'client_id', 'redirect_uris', 'response_types', 'grant_types', 'application_type', 'jwks', 'subject_type', 'id_token_signed_response_alg', 'default_max_age', - 'request_object_signing_alg', 'scope'} + 'request_object_signing_alg', 'scope', 'callback_uris'} def test_gather_request_args(self): self.service.conf["request_args"] = {"response_type": "code"} diff --git a/tests/test_client_21_oidc_service.py b/tests/test_client_21_oidc_service.py index c8973ddd..b8d2f29e 100644 --- a/tests/test_client_21_oidc_service.py +++ b/tests/test_client_21_oidc_service.py @@ -77,7 +77,14 @@ def create_request(self): client_config = { "client_id": "client_id", "client_secret": "a longesh password", - "redirect_uris": ["https://example.com/cli/authz_cb"], + "callbak_uris": { + "redirect_uris": { # different flows + "code": ["https://example.com/cli/authz_cb"], + "implicit": ["https://example.com/cli/imp_cb"], + "form_post": ["https://example.com/cli/form"] + } + }, + "response_types_supported": ['code', 'token'] } entity = Entity(services=DEFAULT_OIDC_SERVICES, keyjar=make_keyjar(), config=client_config, client_type='oidc') @@ -311,7 +318,7 @@ def create_request(self): "client_secret": "a longesh password", "callback_uris": { "code": "https://example.com/cli/authz_cb", - "implicit": "https://example.com/cli/authz_im_cb", + "token": "https://example.com/cli/authz_im_cb", "form_post": "https://example.com/cli/authz_fp_cb", }, } @@ -600,7 +607,7 @@ def test_post_parse(self): # "require_request_uri_registration": True, "grant_types_supported": [ "authorization_code", - "implicit", + "token", "urn:ietf:params:oauth:grant-type:jwt-bearer", "refresh_token", ], From aaa964afc77bb1a149f8ca45c3ae32954e2cbd39 Mon Sep 17 00:00:00 2001 From: Roland Hedberg Date: Tue, 22 Nov 2022 10:30:27 +0100 Subject: [PATCH 18/76] working on tests... --- src/idpyoidc/client/entity.py | 5 +- src/idpyoidc/client/oauth2/authorization.py | 82 +++-- src/idpyoidc/client/oidc/authorization.py | 46 +++ src/idpyoidc/client/oidc/end_session.py | 8 +- src/idpyoidc/client/oidc/registration.py | 2 +- src/idpyoidc/client/rp_handler.py | 9 +- src/idpyoidc/client/service.py | 26 +- src/idpyoidc/client/service_context.py | 35 +- src/idpyoidc/client/util.py | 18 +- src/idpyoidc/client/work_condition/oauth2.py | 2 +- src/idpyoidc/client/work_condition/oidc.py | 2 +- .../client/work_condition/transform.py | 90 +++-- src/idpyoidc/message/oidc/__init__.py | 4 +- tests/pub_client.jwks | 2 +- tests/request123456.jwt | 2 +- tests/test_08_transform.py | 332 ++++++++++++++++++ tests/test_client_02b_entity_metadata.py | 3 +- tests/test_client_04_service.py | 3 +- tests/test_client_21_oidc_service.py | 78 ++-- tests/test_client_23_pkce.py | 13 +- tests/test_client_28_rp_handler_oidc.py | 44 +-- 21 files changed, 623 insertions(+), 183 deletions(-) create mode 100644 tests/test_08_transform.py diff --git a/src/idpyoidc/client/entity.py b/src/idpyoidc/client/entity.py index 6b8cb091..977308b8 100644 --- a/src/idpyoidc/client/entity.py +++ b/src/idpyoidc/client/entity.py @@ -129,10 +129,7 @@ def __init__( 'response_types_supported', self._service_context.supports().get('response_types_supported', [])) - _callback_uris = self._service_context.construct_uris(response_types=_response_types) - if _callback_uris: - self._service_context.set_preference('redirect_uris', - redirect_uris_from_callback_uris(_callback_uris)) + self._service_context.construct_uris(response_types=_response_types) def client_get(self, what, *arg): _func = getattr(self, "get_{}".format(what), None) diff --git a/src/idpyoidc/client/oauth2/authorization.py b/src/idpyoidc/client/oauth2/authorization.py index eb2ac17d..2fd49fdc 100644 --- a/src/idpyoidc/client/oauth2/authorization.py +++ b/src/idpyoidc/client/oauth2/authorization.py @@ -8,6 +8,7 @@ from idpyoidc.client.oauth2.utils import set_state_parameter from idpyoidc.client.service import Service from idpyoidc.client.service_context import ServiceContext +from idpyoidc.client.util import implicit_response_types from idpyoidc.exception import MissingParameter from idpyoidc.message import oauth2 from idpyoidc.message.oauth2 import ResponseMessage @@ -28,7 +29,8 @@ class Authorization(Service): response_body_type = "urlencoded" _supports = { - "response_types": ["code"] + "response_types_supported": ["code", 'token'], + "response_modes_supported": ['query', 'fragment'] } _callback_path = { @@ -94,40 +96,56 @@ def post_parse_response(self, response, **kwargs): pass return response + def _do_flow(self, flow_type, response_types): + if flow_type == 'code' and 'code' in response_types: + return True + elif flow_type == 'implicit': + if implicit_response_types(response_types): + return True + return False + + def _do_redirect_uris(self, base_url, hex, context, callback_uris, response_types): + _redirect_uris = context.get_preference('redirect_uris', []) + if _redirect_uris: + if not callback_uris or 'redirect_uris' not in callback_uris: + # the same redirect_uris for all flow types + callback_uris['redirect_uris'] = {} + for flow_type in self._callback_path['redirect_uris'].keys(): + if self._do_flow(flow_type, response_types): + callback_uris['redirect_uris'][flow_type] = _redirect_uris + elif callback_uris: + if 'redirect_uris' in callback_uris: + pass + else: + callback_uris['redirect_uris'] = {} + for flow_type, path in self._callback_path['redirect_uris'].items(): + if self._do_flow(flow_type, response_types): + callback_uris['redirect_uris'][flow_type] = self.get_uri(base_url, path, hex) + else: + callback_uris['redirect_uris'] = {} + for flow_type, path in self._callback_path['redirect_uris'].items(): + if self._do_flow(flow_type, response_types): + callback_uris['redirect_uris'][flow_type] = self.get_uri(base_url, path, hex) + return callback_uris + def construct_uris(self, base_url: str, hex: bytes, context: ServiceContext, targets: Optional[List[str]] = None, response_types: Optional[List[str]] = None): - if not targets: - targets = list(self._callback_path.keys()) - - res = context.get_preference('callback_uris', {}) - for uri_name in targets: - if uri_name in res: - continue - spec = self._callback_path.get(uri_name) - if spec: - if uri_name == "redirect_uris": # another layer - _uris = {} - for typ, path in spec.items(): - add = False - if typ in response_types: - add = True - elif 'response_type' in self._supports: - if typ in self._supports['response_type']: - add = True - elif typ in self._supports: - add = True - - if add: - _uris[typ] = self.get_uri(base_url, path, hex) - res[uri_name] = _uris - elif uri_name == 'request_uris': - if 'request_uri' == context.get_preference('request_parameter'): - res[uri_name] = self.get_uri(base_url, spec, hex) - elif uri_name in context.prefers() or uri_name in self._supports: - res[uri_name] = self.get_uri(base_url, spec, hex) - - return res + _callback_uris = context.get_preference('callback_uris', {}) + + for uri_name in self._callback_path.keys(): + if uri_name == 'redirect_uris': + _callback_uris = self._do_redirect_uris(base_url, hex, context, _callback_uris, + response_types) + _redirect_uris = set() + for flow, _uris in _callback_uris['redirect_uris'].items(): + _redirect_uris.update(set(_uris)) + context.set_preference('redirect_uris', list(_redirect_uris)) + else: + _callback_uris[uri_name] = self.get_uri(base_url, self._callback_path[uri_name], + hex) + + return _callback_uris diff --git a/src/idpyoidc/client/oidc/authorization.py b/src/idpyoidc/client/oidc/authorization.py index 3c685006..4b243746 100644 --- a/src/idpyoidc/client/oidc/authorization.py +++ b/src/idpyoidc/client/oidc/authorization.py @@ -1,4 +1,5 @@ import logging +from typing import List from typing import Optional from typing import Union @@ -8,6 +9,8 @@ from idpyoidc.client.oidc import IDT2REG from idpyoidc.client.oidc.utils import construct_request_uri from idpyoidc.client.oidc.utils import request_object_encryption +from idpyoidc.client.service_context import ServiceContext +from idpyoidc.client.util import implicit_response_types from idpyoidc.exception import MissingRequiredAttribute from idpyoidc.message import Message from idpyoidc.message import oauth2 @@ -58,6 +61,7 @@ def __init__(self, client_get, conf=None): self.oidc_pre_construct, ] self.post_construct = [self.oidc_post_construct] + self.default_request_args = {'scope': ['openid']} def set_state(self, request_args, **kwargs): try: @@ -336,3 +340,45 @@ def gather_verify_arguments( kwargs.update(_verify_args) return kwargs + + def _do_request_uris(self, base_url, hex, context, callback_uris): + _uri_name = 'request_uris' + if context.get_preference('request_parameter') == _uri_name: + if _uri_name not in callback_uris: + callback_uris[_uri_name] = self.get_uri(base_url, + self._callback_path[_uri_name], + hex) + return callback_uris + + def _do_type(self, context, typ, response_types): + if typ == 'code' and 'code' in response_types: + if typ in context.get_preference('response_modes_supported'): + return True + elif typ == 'implicit': + if typ in context.get_preference('response_modes_supported'): + if implicit_response_types(response_types): + return True + elif typ == 'form_post': + if typ in context.get_preference('response_modes_supported'): + return True + return False + + def construct_uris(self, + base_url: str, + hex: bytes, + context: ServiceContext, + targets: Optional[List[str]] = None, + response_types: Optional[List[str]] = None): + _callback_uris = context.get_preference('callback_uris', {}) + + for uri_name in self._callback_path.keys(): + if uri_name == 'redirect_uris': + _callback_uris = self._do_redirect_uris(base_url, hex, context, _callback_uris, + response_types) + elif uri_name == 'request_uris': + _callback_uris = self._do_request_uris(base_url, hex, context, _callback_uris) + else: + _callback_uris[uri_name] = self.get_uri(base_url, self._callback_path[uri_name], + hex) + + return _callback_uris diff --git a/src/idpyoidc/client/oidc/end_session.py b/src/idpyoidc/client/oidc/end_session.py index 8d8901f7..1f2e6577 100644 --- a/src/idpyoidc/client/oidc/end_session.py +++ b/src/idpyoidc/client/oidc/end_session.py @@ -21,17 +21,19 @@ class EndSession(Service): response_body_type = "html" _supports = { - "post_logout_redirect_uris": None, + "post_logout_redirect_uri": None, + 'frontchannel_logout_supported': None, "frontchannel_logout_uri": None, "frontchannel_logout_session_required": None, + 'backchannel_logout_supported': None, "backchannel_logout_uri": None, "backchannel_logout_session_required": None } - callback_path = { + _callback_path = { "frontchannel_logout_uri": "fc_logout", "backchannel_logout_uri": "bc_logout", - "post_logout_redirect_uris": "session_logout" + "post_logout_redirect_uri": "session_logout" } def __init__(self, client_get, conf=None): diff --git a/src/idpyoidc/client/oidc/registration.py b/src/idpyoidc/client/oidc/registration.py index 6fdf1b65..fd708b12 100644 --- a/src/idpyoidc/client/oidc/registration.py +++ b/src/idpyoidc/client/oidc/registration.py @@ -29,7 +29,7 @@ def __init__(self, client_get, conf=None): def add_client_preference(self, request_args=None, **kwargs): _context = self.client_get("service_context") - _use = _context.map_preferred_to_register() + _use = _context.map_preferred_to_registered() for prop, spec in self.msg_type.c_param.items(): if prop in request_args: continue diff --git a/src/idpyoidc/client/rp_handler.py b/src/idpyoidc/client/rp_handler.py index 1df984e9..2302a6d9 100644 --- a/src/idpyoidc/client/rp_handler.py +++ b/src/idpyoidc/client/rp_handler.py @@ -316,7 +316,7 @@ def do_client_registration( load_registration_response(client, request_args=request_args) else: - _context.map_preferred_to_register() + _context.map_preferred_to_registered() def do_webfinger(self, user: str) -> Client: """ @@ -514,11 +514,8 @@ def get_client_authn_method(client, endpoint): :return: The client authentication method """ if endpoint == "token_endpoint": - try: - am = client.client_get("service_context").work_condition.get_usage( - "token_endpoint_auth_method" - ) - except KeyError: + am = client.client_get("service_context").get_usage("token_endpoint_auth_method") + if not am: return "" else: if isinstance(am, str): diff --git a/src/idpyoidc/client/service.py b/src/idpyoidc/client/service.py index 1ac4173e..7e45f497 100644 --- a/src/idpyoidc/client/service.py +++ b/src/idpyoidc/client/service.py @@ -13,8 +13,8 @@ from idpyoidc.impexp import ImpExp from idpyoidc.item import DLDict from idpyoidc.message import Message -from idpyoidc.message.oauth2 import is_error_message from idpyoidc.message.oauth2 import ResponseMessage +from idpyoidc.message.oauth2 import is_error_message from idpyoidc.util import importer from .configure import Configuration from .exception import ResponseError @@ -121,7 +121,7 @@ def gather_request_args(self, **kwargs): _context = self.client_get("service_context") _use = _context.collect_usage() if not _use: - _use = _context.map_preferred_to_register() + _use = _context.map_preferred_to_registered() if "request_args" in self.conf: ar_args.update(self.conf["request_args"]) @@ -136,9 +136,11 @@ def gather_request_args(self, **kwargs): if prop in ar_args: continue - val = self.default_request_args.get(prop) + val = _use.get(prop) if not val: - val = _use.get(prop) + #val = request_claim(_context, prop) + #if not val: + val = self.default_request_args.get(prop) if val: ar_args[prop] = val @@ -640,12 +642,18 @@ def construct_uris(self, response_types: Optional[list] = None): if not targets: targets = self._callback_path.keys() - res = {} + + if not targets: + return {} + + _callback_uris = context.get_preference('callback_uris', {}) for uri in targets: - _path = self._callback_path.get(uri) - if _path: - res[uri] = self.get_uri(base_url, _path, hex) - return res + if uri in _callback_uris: + pass + else: + _callback_uris[uri] = self.get_uri(base_url, self._callback_path.get(uri), hex) + + return _callback_uris def supported(self, claim): return claim in self._supports diff --git a/src/idpyoidc/client/service_context.py b/src/idpyoidc/client/service_context.py index 84ece19e..7edcac64 100644 --- a/src/idpyoidc/client/service_context.py +++ b/src/idpyoidc/client/service_context.py @@ -9,8 +9,8 @@ from typing import Optional from typing import Union -from cryptojwt.jwk.rsa import import_private_rsa_key_from_file from cryptojwt.jwk.rsa import RSAKey +from cryptojwt.jwk.rsa import import_private_rsa_key_from_file from cryptojwt.key_bundle import KeyBundle from cryptojwt.key_jar import KeyJar from cryptojwt.utils import as_bytes @@ -22,10 +22,10 @@ from idpyoidc.util import rndstr from .configure import get_configuration from .state_interface import StateInterface +from .work_condition import WorkCondition from .work_condition import work_condition_dump from .work_condition import work_condition_load -from .work_condition import WorkCondition -from .work_condition.transform import preferred_to_register +from .work_condition.transform import preferred_to_registered from .work_condition.transform import supported_to_preferred logger = logging.getLogger(__name__) @@ -295,6 +295,14 @@ def get_usage(self, claim, default: Optional[str] = None): def set_usage(self, claim, value): return self.work_condition.set_usage(claim, value) + def _callback_per_service(self): + _cb = {} + for service in self.client_get('services').values(): + _cbs = service._callback_path.keys() + if _cbs: + _cb[service.service_name] = _cbs + return _cb + def construct_uris(self, response_types: Optional[list] = None): _hash = hashlib.sha256() _hash.update(self.hash_seed) @@ -304,14 +312,21 @@ def construct_uris(self, response_types: Optional[list] = None): self.iss_hash = _hex _base_url = self.get("base_url") - _uris = {} + + _callback_uris = self.get_preference('callback_uris', {}) if self.client_get: services = self.client_get('services') for service in services.values(): - _uris.update(service.construct_uris(base_url=_base_url, hex=_hex, - context=self, - response_types=response_types)) - return _uris + _callback_uris.update(service.construct_uris(base_url=_base_url, hex=_hex, + context=self, + response_types=response_types)) + + self.set_preference('callback_uris', _callback_uris) + if 'redirect_uris' in _callback_uris: + _redirect_uris = set() + for flow, _uris in _callback_uris['redirect_uris'].items(): + _redirect_uris.update(set(_uris)) + self.set_preference('redirect_uris', list(_redirect_uris)) def prefer_or_support(self, claim): if claim in self.work_condition.prefer: @@ -333,7 +348,7 @@ def map_supported_to_preferred(self, info: Optional[dict] = None): info=info) return self.work_condition.prefer - def map_preferred_to_register(self): - self.work_condition.use = preferred_to_register(self.work_condition.prefer, + def map_preferred_to_registered(self): + self.work_condition.use = preferred_to_registered(self.work_condition.prefer, self.work_condition.use) return self.work_condition.use diff --git a/src/idpyoidc/client/util.py b/src/idpyoidc/client/util.py index 4fcfde75..13d16965 100755 --- a/src/idpyoidc/client/util.py +++ b/src/idpyoidc/client/util.py @@ -1,8 +1,8 @@ """Utilities""" -import logging -import secrets from http.cookiejar import Cookie from http.cookiejar import http2time +import logging +import secrets from urllib.parse import parse_qs from urllib.parse import urlsplit from urllib.parse import urlunsplit @@ -307,3 +307,17 @@ def lower_or_upper(config, param, default=None): if not res: res = config.get(param.upper(), default) return res + + +IMPLICIT_RESPONSE_TYPES = [ + {'id_token'}, {'id_token', 'token'}, {'code', 'token'}, ['code', 'id_token'], + {'code', 'id_token', 'token'}, {'token'} +] + + +def implicit_response_types(a): + res = [] + for typ in a: + if set(typ.split(' ')) in IMPLICIT_RESPONSE_TYPES: + res.append(typ) + return res diff --git a/src/idpyoidc/client/work_condition/oauth2.py b/src/idpyoidc/client/work_condition/oauth2.py index c4d25440..8b0861b9 100644 --- a/src/idpyoidc/client/work_condition/oauth2.py +++ b/src/idpyoidc/client/work_condition/oauth2.py @@ -14,7 +14,7 @@ class WorkCondition(work_condition.WorkCondition): "client_uri": None, "logo_uri": None, "contacts": None, - "scope": None, + "scopes_supported": [], "tos_uri": None, "policy_uri": None, "jwks_uri": None, diff --git a/src/idpyoidc/client/work_condition/oidc.py b/src/idpyoidc/client/work_condition/oidc.py index 062d8a3c..2e609fb0 100644 --- a/src/idpyoidc/client/work_condition/oidc.py +++ b/src/idpyoidc/client/work_condition/oidc.py @@ -32,7 +32,7 @@ class WorkCondition(work_condition.WorkCondition): "initiate_login_uri": None, "client_id": None, "client_secret": None, - "scope": ["openid"], + "scopes_supported": ["openid"], # "verify_args": None, "requests_dir": None, "encrypt_id_token_supported": None, diff --git a/src/idpyoidc/client/work_condition/transform.py b/src/idpyoidc/client/work_condition/transform.py index ad4d68a6..c35a5252 100644 --- a/src/idpyoidc/client/work_condition/transform.py +++ b/src/idpyoidc/client/work_condition/transform.py @@ -23,12 +23,12 @@ "response_types": "response_types_supported", "grant_types": "grant_types_supported", "scope": "scopes_supported", - "display": "display_values_supported", - "claims": "claims_supported", - "request": "request_parameter_supported", - "request_uri": "request_uri_parameter_supported", - 'claims_locales': 'claims_locales_supported', - 'ui_locales': 'ui_locales_supported', + # "display": "display_values_supported", + # "claims": "claims_supported", + # "request": "request_parameter_supported", + # "request_uri": "request_uri_parameter_supported", + # 'claims_locales': 'claims_locales_supported', + # 'ui_locales': 'ui_locales_supported', } PREFERRED2REGISTER = dict([(v, k) for k, v in REGISTER2PREFERRED.items()]) @@ -41,33 +41,11 @@ 'redirect_uri': "redirect_uris", 'response_type': "response_types", 'request_uri': "request_uris", - 'grant_type': "grant_types" + 'grant_type': "grant_types", + "scope": 'scopes_supported', } -# AUTHORIZATION_REQUEST = [ -# "acr_values", -# "claims", -# "claims_locales", -# "client_id", -# "display", -# "id_token_hint", -# "login_hint", -# "max_age", -# "nonce", -# "prompt", -# "redirect_uri", -# "registration", -# "request", -# "request_uri", -# "response_mode" -# "response_type", -# "scope", -# "state", -# "ui_locales", -# ] - - def supported_to_preferred(supported: dict, preference: dict, base_url: str, @@ -84,7 +62,7 @@ def supported_to_preferred(supported: dict, preference[key] = [x for x in _pref_val if x in _info_val] else: pass - elif val is None: # No default + elif val is None: # No default, means the RP does not have a preference # if key not in ['jwks_uri', 'jwks']: pass else: @@ -117,26 +95,40 @@ def supported_to_preferred(supported: dict, return preference -def preferred_to_register(prefers: dict, use: Optional[dict] = None): - if not use: - use = {} +def array_to_singleton(claim_spec, values): + if isinstance(claim_spec[0], list): + return values + else: + if isinstance(values, list): + return values[0] + else: # singleton + return values + + +def preferred_to_registered(prefers: dict, registration_response: Optional[dict] = None): + """ + The claims with values that are returned from the OP is what goes unless (!!) + the values returned are not within the supported values. + + @param prefers: + @param registration_response: + @return: + """ + registered = {} + + if registration_response: + for key, val in registration_response.items(): + registered[key] = val # Should I just accept with the OP says ?? for key, spec in RegistrationResponse.c_param.items(): + if key in registered: + continue _pref_key = REGISTER2PREFERRED.get(key, key) _preferred_values = prefers.get(_pref_key) if not _preferred_values: continue - - if isinstance(spec[0], list): - if _preferred_values: - use[key] = _preferred_values - else: - if _preferred_values: - if isinstance(_preferred_values, list): - use[key] = _preferred_values[0] - else: - use[key] = _preferred_values + registered[key] = array_to_singleton(spec, _preferred_values) # transfer those claims that are not part of the registration request _rr_keys = list(RegistrationResponse.c_param.keys()) @@ -144,7 +136,11 @@ def preferred_to_register(prefers: dict, use: Optional[dict] = None): if PREFERRED2REGISTER.get(key): continue if key not in _rr_keys: - use[key] = val + registered[key] = val + + logger.debug(f"Entity registered: {registered}") + return registered + - logger.debug(f"Entity uses: {use}") - return use +def register_to_request(prefers, registration_response): + pass \ No newline at end of file diff --git a/src/idpyoidc/message/oidc/__init__.py b/src/idpyoidc/message/oidc/__init__.py index 639c7f93..cf53198e 100644 --- a/src/idpyoidc/message/oidc/__init__.py +++ b/src/idpyoidc/message/oidc/__init__.py @@ -635,8 +635,8 @@ class RegistrationRequest(Message): "frontchannel_logout_session_required": SINGLE_OPTIONAL_BOOLEAN, "backchannel_logout_uri": SINGLE_OPTIONAL_STRING, "backchannel_logout_session_required": SINGLE_OPTIONAL_BOOLEAN, - "federation_type": OPTIONAL_LIST_OF_STRINGS, - "organization_name": SINGLE_OPTIONAL_STRING, + # "federation_type": OPTIONAL_LIST_OF_STRINGS, + # "organization_name": SINGLE_OPTIONAL_STRING, } c_default = {"application_type": "web", "response_types": ["code"]} c_allowed_values = { diff --git a/tests/pub_client.jwks b/tests/pub_client.jwks index d5ce25ed..84a27042 100644 --- a/tests/pub_client.jwks +++ b/tests/pub_client.jwks @@ -1 +1 @@ -{"keys": [{"kty": "EC", "use": "sig", "kid": "azZQQ2FEQjh3QnVZWVdrbHJkMEZSaWR6aVJ0LTBjeUFfeWRlbTRrRFZ5VQ", "crv": "P-256", "x": "2ADe18caWWGp6hpRbfa9HqQHDFNpid9xUmR56Wzm_wc", "y": "HnD_8QBanz4Y-UF8mKQFZXfqkGkXUSm34mLsdDKtSyk"}, {"kty": "RSA", "use": "sig", "kid": "SHEyYWcwNVk0LTdROTZzZ2FUWndIVXdack0xWUM5SEpwcS03dVUxWU4zRQ", "n": "rRz52ddyP9Y2ezSlRsnkt-sjXfV_Ii7vOFX-cStLE3IUlVeSJGEe_kAASLr2r3BE2unjntaxj67NP8D95h_rzG1SpCklTEn-aTe3FOwNyTzUH_oiDVeRoEcf04Y43ciRGYRB5PhI6ii-2lYuig6hyUr776Qxiu6-0zw-M_ay2MgGSy5CEj55dDSvcUyxStUObxGpPWnEvybO1vnE7iJEWGNe0L5uPe5nLidOiR-JwjxSWEx1xZYtIjxaf2Ulu-qu4hwgwBUQdx4bNZyBfljKj55skWuHqPMG3xMjnedQC6Ms5bR3rIkbBpvmgI3kJK-4CZikM6ruyLo94-Lk19aYQw", "e": "AQAB"}]} \ No newline at end of file +{"keys": [{"kty": "EC", "use": "sig", "kid": "azZQQ2FEQjh3QnVZWVdrbHJkMEZSaWR6aVJ0LTBjeUFfeWRlbTRrRFZ5VQ", "crv": "P-256", "x": "2ADe18caWWGp6hpRbfa9HqQHDFNpid9xUmR56Wzm_wc", "y": "HnD_8QBanz4Y-UF8mKQFZXfqkGkXUSm34mLsdDKtSyk"}, {"kty": "RSA", "use": "sig", "kid": "SHEyYWcwNVk0LTdROTZzZ2FUWndIVXdack0xWUM5SEpwcS03dVUxWU4zRQ", "e": "AQAB", "n": "rRz52ddyP9Y2ezSlRsnkt-sjXfV_Ii7vOFX-cStLE3IUlVeSJGEe_kAASLr2r3BE2unjntaxj67NP8D95h_rzG1SpCklTEn-aTe3FOwNyTzUH_oiDVeRoEcf04Y43ciRGYRB5PhI6ii-2lYuig6hyUr776Qxiu6-0zw-M_ay2MgGSy5CEj55dDSvcUyxStUObxGpPWnEvybO1vnE7iJEWGNe0L5uPe5nLidOiR-JwjxSWEx1xZYtIjxaf2Ulu-qu4hwgwBUQdx4bNZyBfljKj55skWuHqPMG3xMjnedQC6Ms5bR3rIkbBpvmgI3kJK-4CZikM6ruyLo94-Lk19aYQw"}]} \ No newline at end of file diff --git a/tests/request123456.jwt b/tests/request123456.jwt index 68ca9f2f..ab46a129 100644 --- a/tests/request123456.jwt +++ b/tests/request123456.jwt @@ -1 +1 @@ -eyJhbGciOiJSUzI1NiIsImtpZCI6IlNIRXlZV2N3TlZrMExUZFJPVFp6WjJGVVduZElWWGRhY2sweFdVTTVTRXB3Y1MwM2RWVXhXVTR6UlEifQ.eyJyZXNwb25zZV90eXBlIjogImNvZGUiLCAic3RhdGUiOiAic3RhdGUiLCAicmVkaXJlY3RfdXJpIjogImh0dHBzOi8vZXhhbXBsZS5jb20vY2IiLCAic2NvcGUiOiAib3BlbmlkIiwgIm5vbmNlIjogInQtZElTaFg0NWc2TnJtcWtrOWE3RGM3bjJiWGg0WVJhaUwzUWFQOTg2WlEiLCAiY2xpZW50X2lkIjogImNsaWVudF9pZCIsICJpc3MiOiAiY2xpZW50X2lkIiwgImlhdCI6IDE2Njg3OTgwOTUsICJhdWQiOiBbImh0dHBzOi8vZXhhbXBsZS5jb20iXX0.gb_JlFwEXKKlJpPjkE5wyl6hyDFRO6VXES4MmHLsmeyeqvR4tdkReZlFZEss8Zu9re9r6rJAUKri41H8S19FkbvYI4OgY1FdCQz2fjTIvzQ0E0Hd5RfAFx_IXyyC8-wa_KQRGGP3a16m1pP2N19JgsfXudhFygD6RuExROXLRg-z8jeg1mhaRp0EaWg61KNMMK0F6i17M790jeKeIvvevpyBCJA4qvgAu5W9d7_LRm_2sjCh_TdvOaYXkknIiYzGSDuqU1DrYSuMQBB3-n3G6kzXipwfJ1j3Pg7XvhZBPdFIV64CLmwO1eLhB50vJexGduU6128t42JlCEQCY2xXQQ \ No newline at end of file +eyJhbGciOiJSUzI1NiIsImtpZCI6IlNIRXlZV2N3TlZrMExUZFJPVFp6WjJGVVduZElWWGRhY2sweFdVTTVTRXB3Y1MwM2RWVXhXVTR6UlEifQ.eyJyZXNwb25zZV90eXBlIjogImNvZGUiLCAic3RhdGUiOiAic3RhdGUiLCAicmVkaXJlY3RfdXJpIjogImh0dHBzOi8vZXhhbXBsZS5jb20vY2IiLCAic2NvcGUiOiAib3BlbmlkIiwgIm5vbmNlIjogIkZ3eG5aTk9Hc3hHUVRpblhRTmpwaUtqclA2WVQ1YUJBeGFmNV9yV1l2Y3ciLCAiY2xpZW50X2lkIjogImNsaWVudF9pZCIsICJpc3MiOiAiY2xpZW50X2lkIiwgImlhdCI6IDE2Njg5MzQzNDYsICJhdWQiOiBbImh0dHBzOi8vZXhhbXBsZS5jb20iXX0.pzfhYGlDcii7YEtbz9PxER3_ILD-3FfqyoVWOIAMuZ8Ryd0Nx4j7qbG43zjLIm0bTTm1VGXbeH7DC8mnVQfEEX-b_QYNjcFbH-1O9xTKgIq_B3RGfHu0TpAbTkAtJZO92pFKFDmrUKGDg0B00J0XKH4rG_gVY8u-X4x63veVus9peKtzOgyN0WiUixwIYgwpfQx8-CR4MozJuJ9Q28QIDft53Vwl2puwvTy2n29IHIO6qIvO-Aho02FY9guC9QVhMHcSlUxilg_iqCNQQ6GE1yB2pDP2sTWpQI97BZ4XNeQ_JC86gQnjVsaWhGAhCFuJDLfIsTPVdLhWGCEIIZkkLQ \ No newline at end of file diff --git a/tests/test_08_transform.py b/tests/test_08_transform.py new file mode 100644 index 00000000..742557ba --- /dev/null +++ b/tests/test_08_transform.py @@ -0,0 +1,332 @@ +from typing import Callable + +from cryptojwt.utils import importer +import pytest + +from idpyoidc.client.work_condition.oidc import WorkCondition as WorkConditionOIDC +from idpyoidc.client.work_condition.transform import REGISTER2PREFERRED +from idpyoidc.client.work_condition.transform import array_to_singleton +from idpyoidc.client.work_condition.transform import preferred_to_registered +from idpyoidc.client.work_condition.transform import supported_to_preferred +from idpyoidc.message.oidc import ProviderConfigurationResponse +from idpyoidc.message.oidc import RegistrationRequest + + +class TestTransform: + @pytest.fixture(autouse=True) + def setup(self): + supported = WorkConditionOIDC._supports.copy() + for service in [ + 'idpyoidc.client.oidc.access_token.AccessToken', + 'idpyoidc.client.oidc.authorization.Authorization', + 'idpyoidc.client.oidc.backchannel_authentication.BackChannelAuthentication', + 'idpyoidc.client.oidc.backchannel_authentication.ClientNotification', + 'idpyoidc.client.oidc.check_id.CheckID', + 'idpyoidc.client.oidc.check_session.CheckSession', + 'idpyoidc.client.oidc.end_session.EndSession', + 'idpyoidc.client.oidc.provider_info_discovery.ProviderInfoDiscovery', + 'idpyoidc.client.oidc.read_registration.RegistrationRead', + 'idpyoidc.client.oidc.refresh_access_token.RefreshAccessToken', + 'idpyoidc.client.oidc.registration.Registration', + 'idpyoidc.client.oidc.userinfo.UserInfo', + 'idpyoidc.client.oidc.webfinger.WebFinger' + ]: + cls = importer(service) + supported.update(cls._supports) + + for key, val in supported.items(): + if isinstance(val, Callable): + supported[key] = val() + self.supported = supported + + def test_supported(self): + # These are all the available configuration parameters + assert set(self.supported.keys()) == { + 'acr_values_supported', + 'application_type', + 'backchannel_logout_session_required', + 'backchannel_logout_supported', + 'backchannel_logout_uri', + 'callback_uris', + 'client_id', + 'client_name', + 'client_secret', + 'client_uri', + 'contacts', + 'default_max_age', + 'encrypt_id_token_supported', + 'encrypt_request_object_supported', + 'encrypt_userinfo_supported', + 'frontchannel_logout_session_required', + 'frontchannel_logout_supported', + 'frontchannel_logout_uri', + 'grant_types_supported', + 'id_token_encryption_alg_values_supported', + 'id_token_encryption_enc_values_supported', + 'id_token_signing_alg_values_supported', + 'initiate_login_uri', + 'jwks', + 'jwks_uri', + 'logo_uri', + 'policy_uri', + 'post_logout_redirect_uri', + 'redirect_uris', + 'request_object_encryption_alg_values_supported', + 'request_object_encryption_enc_values_supported', + 'request_object_signing_alg_values_supported', + 'request_parameter', + 'request_uris', + 'requests_dir', + 'require_auth_time', + 'response_modes_supported', + 'response_types_supported', + 'scopes_supported', + 'sector_identifier_uri', + 'subject_types_supported', + 'token_endpoint_auth_methods_supported', + 'token_endpoint_auth_signing_alg_values_supported', + 'tos_uri', + 'userinfo_encryption_alg_values_supported', + 'userinfo_encryption_enc_values_supported', + 'userinfo_signing_alg_values_supported'} + + def test_oidc_setup(self): + # This is OP specified stuff + assert set(ProviderConfigurationResponse.c_param.keys()).difference( + set(self.supported)) == { + 'authorization_endpoint', + 'check_session_iframe', + 'claim_types_supported', + 'claims_locales_supported', + 'claims_parameter_supported', + 'claims_supported', + 'display_values_supported', + 'end_session_endpoint', + 'error', + 'error_description', + 'error_uri', + 'issuer', + 'op_policy_uri', + 'op_tos_uri', + 'registration_endpoint', + 'request_parameter_supported', + 'request_uri_parameter_supported', + 'require_request_uri_registration', + 'service_documentation', + 'token_endpoint', + 'ui_locales_supported', + 'userinfo_endpoint'} + + # parameters that are not mapped against what the OP's provider info says + assert set(self.supported).difference( + set(ProviderConfigurationResponse.c_param.keys())) == { + 'application_type', + 'backchannel_logout_uri', + 'callback_uris', + 'client_id', + 'client_name', + 'client_secret', + 'client_uri', + 'contacts', + 'default_max_age', + 'encrypt_id_token_supported', + 'encrypt_request_object_supported', + 'encrypt_userinfo_supported', + 'frontchannel_logout_uri', + 'initiate_login_uri', + 'jwks', + 'logo_uri', + 'policy_uri', + 'post_logout_redirect_uri', + 'redirect_uris', + 'request_parameter', + 'request_uris', + 'requests_dir', + 'require_auth_time', + 'sector_identifier_uri', + 'tos_uri'} + + preference = {} + pref = supported_to_preferred(supported=self.supported, preference=preference, + base_url='https://example.com') + + # These are the claims that has default values. A default value may be an empty list. + # This is the case for claims like id_token_encryption_enc_values_supported. + assert set(pref.keys()) == {'application_type', + 'default_max_age', + 'grant_types_supported', + 'id_token_encryption_alg_values_supported', + 'id_token_encryption_enc_values_supported', + 'id_token_signing_alg_values_supported', + 'request_object_encryption_alg_values_supported', + 'request_object_encryption_enc_values_supported', + 'request_object_signing_alg_values_supported', + 'response_modes_supported', + 'response_types_supported', + 'scopes_supported', + 'subject_types_supported', + 'token_endpoint_auth_methods_supported', + 'token_endpoint_auth_signing_alg_values_supported', + 'userinfo_encryption_alg_values_supported', + 'userinfo_encryption_enc_values_supported', + 'userinfo_signing_alg_values_supported'} + + # To verify that I have all the necessary claims to do client registration + reg_claim = [] + for key, spec in RegistrationRequest.c_param.items(): + _pref_key = REGISTER2PREFERRED.get(key, key) + if _pref_key in self.supported: + reg_claim.append(key) + + assert set(RegistrationRequest.c_param.keys()).difference(set(reg_claim)) == set() + + # Which ones are list -> singletons + + l_to_s = [] + non_oidc = [] + for key, pref_key in REGISTER2PREFERRED.items(): + spec = RegistrationRequest.c_param.get(key) + if spec is None: + non_oidc.append(pref_key) + elif isinstance(spec[0], list): + l_to_s.append(key) + + assert set(non_oidc) == {'scopes_supported'} + assert set(l_to_s) == {'response_types', 'grant_types', 'default_acr_values'} + + def test_provider_info(self): + OP_BASEURL = 'https://example.com' + provider_info_response = { + "version": "3.0", + "token_endpoint_auth_methods_supported": [ + "client_secret_post", + "client_secret_basic", + "client_secret_jwt", + "private_key_jwt", + ], + "issuer": OP_BASEURL, + "jwks_uri": f"{OP_BASEURL}/static/jwks_tE2iLbOAqXhe8bqh.json", + "authorization_endpoint": f"{OP_BASEURL}/authorization", + "token_endpoint": f"{OP_BASEURL}/token", + "userinfo_endpoint": f"{OP_BASEURL}/userinfo", + "registration_endpoint": f"{OP_BASEURL}/registration", + "end_session_endpoint": f"{OP_BASEURL}/end_session", + # below are a set which the RP has default values but the OP overwrites + "scopes_supported": ['openid', 'fee', 'faa', 'foo', 'fum'], + "response_types_supported": ['code', 'id_token', 'code id_token'], + "response_modes_supported": ['query', 'form_post', 'new_fangled'], + # this does not have a default value + "acr_values_supported": ['mfa'], + } + + preference = {} + pref = supported_to_preferred(supported=self.supported, preference=preference, + base_url='https://example.com', + info=provider_info_response) + + # These are the claims that has default values + assert set(pref.keys()) == {'application_type', + 'default_max_age', + 'grant_types_supported', + 'id_token_encryption_alg_values_supported', + 'id_token_encryption_enc_values_supported', + 'id_token_signing_alg_values_supported', + 'request_object_encryption_alg_values_supported', + 'request_object_encryption_enc_values_supported', + 'request_object_signing_alg_values_supported', + 'response_modes_supported', + 'response_types_supported', + 'scopes_supported', + 'subject_types_supported', + 'token_endpoint_auth_methods_supported', + 'token_endpoint_auth_signing_alg_values_supported', + 'userinfo_encryption_alg_values_supported', + 'userinfo_encryption_enc_values_supported', + 'userinfo_signing_alg_values_supported'} + + # least common denominator + # The RP supports less than the OP + assert pref['scopes_supported'] == ['openid'] + assert pref["response_modes_supported"] == ['query', 'form_post'] + # The OP supports less than the RP + assert pref["response_types_supported"] == ['code', 'id_token', 'code id_token'] + + def test_registration_response(self): + OP_BASEURL = 'https://example.com' + provider_info_response = { + "version": "3.0", + "token_endpoint_auth_methods_supported": [ + "client_secret_post", + "client_secret_basic", + "client_secret_jwt", + "private_key_jwt", + ], + "issuer": OP_BASEURL, + "jwks_uri": f"{OP_BASEURL}/static/jwks_tE2iLbOAqXhe8bqh.json", + "authorization_endpoint": f"{OP_BASEURL}/authorization", + "token_endpoint": f"{OP_BASEURL}/token", + "userinfo_endpoint": f"{OP_BASEURL}/userinfo", + "registration_endpoint": f"{OP_BASEURL}/registration", + "end_session_endpoint": f"{OP_BASEURL}/end_session", + # below are a set which the RP has default values but the OP overwrites + "scopes_supported": ['openid', 'fee', 'faa', 'foo', 'fum'], + "response_types_supported": ['code', 'id_token', 'code id_token'], + "response_modes_supported": ['query', 'form_post', 'new_fangled'], + # this does not have a default value + "acr_values_supported": ['mfa'], + } + + preference = { + "application_type": "web", + "redirect_uris": + ["https://client.example.org/callback", + "https://client.example.org/callback2"], + "client_name": "My Example", + "client_name#ja-Jpan-JP": + "クライアント名", + "logo_uri": "https://client.example.org/logo.png", + 'contacts': ["ve7jtb@example.org", "mary@example.org"] + } + pref = supported_to_preferred(supported=self.supported, preference=preference, + base_url='https://example.com', + info=provider_info_response) + + registration_request = {} + for key, spec in RegistrationRequest.c_param.items(): + _pref_key = REGISTER2PREFERRED.get(key, key) + if _pref_key in pref: + value = pref[_pref_key] + elif _pref_key in self.supported: + value = self.supported[_pref_key] + else: + value = None + + if value: + #registration_request[key] = array_to_singleton(spec, value) + registration_request[key] = value + + registration_response = { + "application_type": "web", + "redirect_uris": + ["https://client.example.org/callback", + "https://client.example.org/callback2"], + "client_name": "My Example", + "client_name#ja-Jpan-JP": + "クライアント名", + "logo_uri": "https://client.example.org/logo.png", + "subject_type": "pairwise", + "sector_identifier_uri": + "https://other.example.net/file_of_redirect_uris.json", + "token_endpoint_auth_method": "client_secret_basic", + "jwks_uri": "https://client.example.org/my_public_keys.jwks", + "userinfo_encrypted_response_alg": "RSA1_5", + "userinfo_encrypted_response_enc": "A128CBC-HS256", + "contacts": ["ve7jtb@example.org", "mary@example.org"], + "request_uris": [ + "https://client.example.org/rf.txt#qpXaRLh_n93TTR9F252ValdatUQvQiJi5BDub2BeznA"] + } + + to_use = preferred_to_registered(prefers=pref, + registration_response=registration_response) + + assert to_use diff --git a/tests/test_client_02b_entity_metadata.py b/tests/test_client_02b_entity_metadata.py index 364b7427..e4b5d437 100644 --- a/tests/test_client_02b_entity_metadata.py +++ b/tests/test_client_02b_entity_metadata.py @@ -85,8 +85,9 @@ def test_create_client(): 'request_object_encryption_enc_values_supported', 'request_object_signing_alg_values_supported', 'request_parameter', + 'response_modes_supported', 'response_types_supported', - 'scope', + 'scopes_supported', 'subject_types_supported', 'token_endpoint_auth_methods_supported', 'token_endpoint_auth_signing_alg_values_supported', diff --git a/tests/test_client_04_service.py b/tests/test_client_04_service.py index 2e3a98dc..6d5191f9 100644 --- a/tests/test_client_04_service.py +++ b/tests/test_client_04_service.py @@ -54,7 +54,8 @@ def test_use(self): assert set(use.keys()) == {'client_id', 'redirect_uris', 'response_types', 'grant_types', 'application_type', 'jwks', 'subject_type', 'id_token_signed_response_alg', 'default_max_age', - 'request_object_signing_alg', 'scope', 'callback_uris'} + 'request_object_signing_alg', 'callback_uris', + 'response_modes_supported'} def test_gather_request_args(self): self.service.conf["request_args"] = {"response_type": "code"} diff --git a/tests/test_client_21_oidc_service.py b/tests/test_client_21_oidc_service.py index b8d2f29e..f5dab9af 100644 --- a/tests/test_client_21_oidc_service.py +++ b/tests/test_client_21_oidc_service.py @@ -1,14 +1,14 @@ import json import os -import pytest -import responses from cryptojwt.exception import UnsupportedAlgorithm from cryptojwt.jws import jws from cryptojwt.jws.utils import left_hash from cryptojwt.jwt import JWT from cryptojwt.key_jar import build_keyjar from cryptojwt.key_jar import init_key_jar +import pytest +import responses from idpyoidc.client.defaults import DEFAULT_OIDC_SERVICES from idpyoidc.client.entity import Entity @@ -77,7 +77,7 @@ def create_request(self): client_config = { "client_id": "client_id", "client_secret": "a longesh password", - "callbak_uris": { + "callback_uris": { "redirect_uris": { # different flows "code": ["https://example.com/cli/authz_cb"], "implicit": ["https://example.com/cli/imp_cb"], @@ -317,9 +317,11 @@ def create_request(self): "client_id": "client_id", "client_secret": "a longesh password", "callback_uris": { - "code": "https://example.com/cli/authz_cb", - "token": "https://example.com/cli/authz_im_cb", - "form_post": "https://example.com/cli/authz_fp_cb", + "redirect_uris": { + "code": ["https://example.com/cli/authz_cb"], + "implicit": ["https://example.com/cli/authz_im_cb"], + "form_post": ["https://example.com/cli/authz_fp_cb"] + }, }, } entity = Entity(keyjar=make_keyjar(), config=client_config, services=DEFAULT_OIDC_SERVICES, @@ -798,27 +800,29 @@ def test_post_parse(self): assert 'jwks' in use_copy del use_copy['jwks'] - assert use_copy == { - 'client_secret': 'a longesh password', - 'contacts': ['ops@example.org'], - 'default_max_age': 86400, - 'encrypt_id_token_supported': False, - 'application_type': 'web', - 'backchannel_logout_session_required': True, - 'backchannel_logout_uri': 'https://rp.example.com/back', - 'client_id': 'client_id', - 'grant_types': ['authorization_code', 'implicit', 'refresh_token'], - 'id_token_signed_response_alg': 'RS256', - 'post_logout_redirect_uris': ['https://rp.example.com/post'], - 'redirect_uris': ['https://example.com/cli/authz_cb'], - 'response_types': ['code'], - 'token_endpoint_auth_method': 'private_key_jwt', - 'token_endpoint_auth_signing_alg': 'ES256', - 'userinfo_signed_response_alg': 'ES256', - 'scope': ["openid", "profile", "email", "address", "phone"], - 'request_object_signing_alg': 'ES256', - 'subject_type': 'public' - } + assert use_copy == {'application_type': 'web', + 'backchannel_logout_session_required': True, + 'backchannel_logout_uri': 'https://rp.example.com/back', + 'callback_uris': { + 'redirect_uris': {'code': ['https://example.com/cli/authz_cb']}}, + 'client_id': 'client_id', + 'client_secret': 'a longesh password', + 'contacts': ['ops@example.org'], + 'default_max_age': 86400, + 'encrypt_id_token_supported': False, + 'grant_types': ['authorization_code', 'refresh_token'], + 'id_token_signed_response_alg': 'RS256', + 'post_logout_redirect_uris': ['https://rp.example.com/post'], + 'redirect_uris': ['https://example.com/cli/authz_cb'], + 'request_object_signing_alg': 'ES256', + 'response_modes_supported': ['query', 'fragment', 'form_post'], + 'response_types': ['code'], + 'scope': ['openid', 'profile', 'email', 'address', 'phone'], + 'subject_type': 'public', + 'token_endpoint_auth_method': 'private_key_jwt', + 'token_endpoint_auth_signing_alg': 'ES256', + 'userinfo_signed_response_alg': 'ES256' + } def test_post_parse_2(self): OP_BASEURL = ISS @@ -861,23 +865,25 @@ def test_post_parse_2(self): 'application_type': 'web', 'backchannel_logout_session_required': True, 'backchannel_logout_uri': 'https://rp.example.com/back', + 'callback_uris': { + 'redirect_uris': {'code': ['https://example.com/cli/authz_cb']}}, 'client_id': 'client_id', + 'client_secret': 'a longesh password', + 'contacts': ['ops@example.org'], + 'default_max_age': 86400, + 'encrypt_id_token_supported': False, 'grant_types': ['authorization_code', 'implicit', 'refresh_token'], 'id_token_signed_response_alg': 'RS256', 'post_logout_redirect_uris': ['https://rp.example.com/post'], 'redirect_uris': ['https://example.com/cli/authz_cb'], + 'request_object_signing_alg': 'ES256', + 'response_modes_supported': ['query', 'fragment', 'form_post'], 'response_types': ['code'], + 'scope': ['openid', 'profile', 'email', 'address', 'phone'], + 'subject_type': 'public', 'token_endpoint_auth_method': 'private_key_jwt', 'token_endpoint_auth_signing_alg': 'ES256', - 'userinfo_signed_response_alg': 'ES256', - 'scope': ["openid", "profile", "email", "address", "phone"], - 'client_secret': 'a longesh password', - 'contacts': ['ops@example.org'], - 'default_max_age': 86400, - 'encrypt_id_token_supported': False, - 'request_object_signing_alg': 'ES256', - 'subject_type': 'public' - } + 'userinfo_signed_response_alg': 'ES256'} def test_response_types_to_grant_types(): diff --git a/tests/test_client_23_pkce.py b/tests/test_client_23_pkce.py index 80b81c64..a77d9bda 100644 --- a/tests/test_client_23_pkce.py +++ b/tests/test_client_23_pkce.py @@ -48,15 +48,22 @@ def create_client(self): "client_id": "client_id", "client_secret": "a longesh password", "redirect_uris": ["https://example.com/cli/authz_cb"], - "preference": {"response_types": ["code"]}, + "preference": { + "response_types": ["code"] + }, "add_ons": { "pkce": { "function": "idpyoidc.client.oauth2.add_on.pkce.add_support", - "kwargs": {"code_challenge_length": 64, "code_challenge_method": "S256"}, + "kwargs": { + "code_challenge_length": 64, + "code_challenge_method": "S256" + }, } }, } - self.entity = Entity(keyjar=CLI_KEY, config=config, services=DEFAULT_OAUTH2_SERVICES, + self.entity = Entity(keyjar=CLI_KEY, + config=config, + services=DEFAULT_OAUTH2_SERVICES, client_type='oauth2') if "add_ons" in config: diff --git a/tests/test_client_28_rp_handler_oidc.py b/tests/test_client_28_rp_handler_oidc.py index 6de2e186..5950f015 100644 --- a/tests/test_client_28_rp_handler_oidc.py +++ b/tests/test_client_28_rp_handler_oidc.py @@ -23,7 +23,7 @@ PREF = { "application_type": "web", "contacts": ["ops@example.com"], - "response_types": [ + "response_types_supported": [ "code", "id_token", "id_token token", @@ -31,10 +31,9 @@ "code id_token token", "code token", ], - "token_endpoint_auth_method": "client_secret_basic", - "scope": ["openid", "profile", "email", "address", "phone"], + "token_endpoint_auth_methods_supported": ["client_secret_basic"], + "scopes_supported": ["openid", "profile", "email", "address", "phone"], "verify_args": {"allow_sign_alg_none": True}, - "request_uri": True } CLIENT_CONFIG = { @@ -42,7 +41,7 @@ "preference": PREF, "redirect_uris": None, "base_url": BASE_URL, - "requests_dir": "requests", + "request_parameter": "request_uris", "services": { "web_finger": {"class": "idpyoidc.client.oidc.webfinger.WebFinger"}, "discovery": { @@ -63,9 +62,9 @@ "client_secret": "yyyyyyyyyyyyyyyyyyyy", "redirect_uris": ["{}/authz_cb/linkedin".format(BASE_URL)], "preference": { - "response_types": ["code"], - "scope": ["r_basicprofile", "r_emailaddress"], - "token_endpoint_auth_method": "client_secret_post", + "response_types_supported": ["code"], + "scopes_supported": ["r_basicprofile", "r_emailaddress"], + "token_endpoint_auth_methods_supported": ["client_secret_post"], }, "provider_info": { "authorization_endpoint": "https://www.linkedin.com/oauth/v2/authorization", @@ -84,9 +83,9 @@ "client_id": "ccccccccc", "client_secret": "dddddddddddddd", "preference": { - "response_types": ["code"], - "scope": ["email", "public_profile"], - "token_endpoint_auth_method": "", + "response_types_supported": ["code"], + "scopes_supported": ["email", "public_profile"], + "token_endpoint_auth_methods_supported": [], }, "redirect_uris": ["{}/authz_cb/facebook".format(BASE_URL)], "provider_info": { @@ -113,7 +112,7 @@ "redirect_uris": ["{}/authz_cb/github".format(BASE_URL)], "preference": { "response_types_supported": ["code"], - "scope": ["user", "public_repo"], + "scopes_supported": ["user", "public_repo", 'openid'], "token_endpoint_auth_methods_supported": [], "verify_args": {"allow_sign_alg_none": True}, }, @@ -139,9 +138,9 @@ "client_secret": "aaaaaaaaaaaaaaaaaaaa", "redirect_uris": ["{}/authz_cb/github".format(BASE_URL)], "preference": { - "response_types": ["code"], - "scope": ["user", "public_repo"], - "token_endpoint_auth_method": "", + "response_types_supported": ["code"], + "scopes_supported": ["user", "public_repo"], + "token_endpoint_auth_methods_supported": [], "verify_args": {"allow_sign_alg_none": True}, }, "provider_info": { @@ -251,8 +250,9 @@ def test_init_client(self): "userinfo_endpoint", } - _pref = [k for k,v in _context.prefers().items() if v] - assert _pref == ['jwks', 'client_id', 'client_secret', 'redirect_uris', 'scope'] + _pref = [k for k, v in _context.prefers().items() if v] + assert set(_pref) == {'jwks', 'client_id', 'client_secret', 'redirect_uris', + 'response_types_supported', 'callback_uris'} _github_id = iss_id("github") _context.keyjar.import_jwks(GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) @@ -288,7 +288,7 @@ def test_do_client_registration(self): assert self.rph.hash2issuer["github"] == issuer assert ( - client.client_get("service_context").work_condition.callback.get( + client.client_get("service_context").get_preference('callback_uris').get( "post_logout_redirect_uris") is None ) @@ -321,10 +321,10 @@ def test_create_callbacks(self): _srv = client.client_get("service", "registration") _context = _srv.client_get("service_context") - cb = _srv.client_get("service_context").get_usage('callback_uris') + cb = _srv.client_get("service_context").get_preference('callback_uris') assert set(cb.keys()) == {"request_uris", "redirect_uris"} - assert set(cb['redirect_uris'].keys()) == {'code', 'form_post'} + assert set(cb['redirect_uris'].keys()) == {'code'} _hash = _context.iss_hash assert cb['redirect_uris']["code"] == f"https://example.com/rp/authz_cb/{_hash}" @@ -361,7 +361,7 @@ def test_begin(self): assert query["client_id"] == ["eeeeeeeee"] assert query["redirect_uri"] == ["https://example.com/rp/authz_cb/github"] assert query["response_type"] == ["code"] - assert query["scope"] == ["user public_repo openid"] + assert query["scope"] == ["user public_repo"] def test_get_session_information(self): res = self.rph.begin(issuer_id="github") @@ -377,7 +377,7 @@ def test_get_client_from_session_key(self): # redo self.rph.do_provider_info(state=res["state"]) # get new redirect_uris - cli2.client_get("service_context").work_condition.metadata["redirect_uris"] = [] + cli2.client_get("service_context").set_preference("redirect_uris", []) self.rph.do_client_registration(state=res["state"]) def test_finalize_auth(self): From b1f5dcd87fc7bfe1fac14ae809ad2aeb686baf07 Mon Sep 17 00:00:00 2001 From: roland Date: Tue, 22 Nov 2022 17:24:56 +0100 Subject: [PATCH 19/76] Fixed tests up to client_28. Back and forth... --- .../client/work_condition/transform.py | 25 ++- tests/test_08_transform.py | 143 ++++++++++++++---- 2 files changed, 133 insertions(+), 35 deletions(-) diff --git a/src/idpyoidc/client/work_condition/transform.py b/src/idpyoidc/client/work_condition/transform.py index c35a5252..83f5e47e 100644 --- a/src/idpyoidc/client/work_condition/transform.py +++ b/src/idpyoidc/client/work_condition/transform.py @@ -1,6 +1,7 @@ import logging from typing import Optional +from idpyoidc.message.oidc import RegistrationRequest from idpyoidc.message.oidc import RegistrationResponse logger = logging.getLogger(__name__) @@ -95,9 +96,12 @@ def supported_to_preferred(supported: dict, return preference -def array_to_singleton(claim_spec, values): +def array_or_singleton(claim_spec, values): if isinstance(claim_spec[0], list): - return values + if isinstance(values, list): + return values + else: + return [values] else: if isinstance(values, list): return values[0] @@ -128,7 +132,7 @@ def preferred_to_registered(prefers: dict, registration_response: Optional[dict] _preferred_values = prefers.get(_pref_key) if not _preferred_values: continue - registered[key] = array_to_singleton(spec, _preferred_values) + registered[key] = array_or_singleton(spec, _preferred_values) # transfer those claims that are not part of the registration request _rr_keys = list(RegistrationResponse.c_param.keys()) @@ -142,5 +146,16 @@ def preferred_to_registered(prefers: dict, registration_response: Optional[dict] return registered -def register_to_request(prefers, registration_response): - pass \ No newline at end of file +def create_registration_request(prefers, supported): + _request = {} + for key, spec in RegistrationRequest.c_param.items(): + _pref_key = REGISTER2PREFERRED.get(key, key) + if _pref_key in prefers: + value = prefers[_pref_key] + elif _pref_key in supported: + value = supported[_pref_key] + else: + continue + + _request[key] = array_or_singleton(spec, value) + return _request \ No newline at end of file diff --git a/tests/test_08_transform.py b/tests/test_08_transform.py index 742557ba..25bd7814 100644 --- a/tests/test_08_transform.py +++ b/tests/test_08_transform.py @@ -1,12 +1,12 @@ from typing import Callable -from cryptojwt.utils import importer import pytest +from cryptojwt.utils import importer from idpyoidc.client.work_condition.oidc import WorkCondition as WorkConditionOIDC -from idpyoidc.client.work_condition.transform import REGISTER2PREFERRED -from idpyoidc.client.work_condition.transform import array_to_singleton +from idpyoidc.client.work_condition.transform import create_registration_request from idpyoidc.client.work_condition.transform import preferred_to_registered +from idpyoidc.client.work_condition.transform import REGISTER2PREFERRED from idpyoidc.client.work_condition.transform import supported_to_preferred from idpyoidc.message.oidc import ProviderConfigurationResponse from idpyoidc.message.oidc import RegistrationRequest @@ -37,6 +37,7 @@ def setup(self): for key, val in supported.items(): if isinstance(val, Callable): supported[key] = val() + # NOTE! Not checking rules self.supported = supported def test_supported(self): @@ -251,6 +252,48 @@ def test_provider_info(self): # The OP supports less than the RP assert pref["response_types_supported"] == ['code', 'id_token', 'code id_token'] + +class TestTransform2: + + @pytest.fixture(autouse=True) + def setup(self): + self.work_condition = WorkConditionOIDC() + supported = self.work_condition._supports.copy() + for service in [ + 'idpyoidc.client.oidc.access_token.AccessToken', + 'idpyoidc.client.oidc.authorization.Authorization', + 'idpyoidc.client.oidc.backchannel_authentication.BackChannelAuthentication', + 'idpyoidc.client.oidc.backchannel_authentication.ClientNotification', + 'idpyoidc.client.oidc.check_id.CheckID', + 'idpyoidc.client.oidc.check_session.CheckSession', + 'idpyoidc.client.oidc.end_session.EndSession', + 'idpyoidc.client.oidc.provider_info_discovery.ProviderInfoDiscovery', + 'idpyoidc.client.oidc.read_registration.RegistrationRead', + 'idpyoidc.client.oidc.refresh_access_token.RefreshAccessToken', + 'idpyoidc.client.oidc.registration.Registration', + 'idpyoidc.client.oidc.userinfo.UserInfo', + 'idpyoidc.client.oidc.webfinger.WebFinger' + ]: + cls = importer(service) + supported.update(cls._supports) + + for key, val in supported.items(): + if isinstance(val, Callable): + supported[key] = val() + + self.supported = supported + preference = { + "application_type": "web", + "redirect_uris": ["https://client.example.org/callback", + "https://client.example.org/callback2"], + "client_name": "My Example", + # "client_name#ja-Jpan-JP": "クライアント名", + "logo_uri": "https://client.example.org/logo.png", + 'contacts': ["ve7jtb@example.org", "mary@example.org"] + } + + self.work_condition.load_conf(preference, self.supported) + def test_registration_response(self): OP_BASEURL = 'https://example.com' provider_info_response = { @@ -276,34 +319,50 @@ def test_registration_response(self): "acr_values_supported": ['mfa'], } - preference = { - "application_type": "web", - "redirect_uris": - ["https://client.example.org/callback", - "https://client.example.org/callback2"], - "client_name": "My Example", - "client_name#ja-Jpan-JP": - "クライアント名", - "logo_uri": "https://client.example.org/logo.png", - 'contacts': ["ve7jtb@example.org", "mary@example.org"] - } - pref = supported_to_preferred(supported=self.supported, preference=preference, + pref = supported_to_preferred(supported=self.supported, + preference=self.work_condition.prefer, base_url='https://example.com', info=provider_info_response) - registration_request = {} - for key, spec in RegistrationRequest.c_param.items(): - _pref_key = REGISTER2PREFERRED.get(key, key) - if _pref_key in pref: - value = pref[_pref_key] - elif _pref_key in self.supported: - value = self.supported[_pref_key] - else: - value = None + registration_request = create_registration_request(pref, self.supported) + + assert set(registration_request.keys()) == {'application_type', + 'backchannel_logout_session_required', + 'backchannel_logout_uri', + 'client_name', + 'client_uri', + 'contacts', + 'default_acr_values', + 'default_max_age', + 'frontchannel_logout_session_required', + 'frontchannel_logout_uri', + 'grant_types', + 'id_token_encrypted_response_alg', + 'id_token_encrypted_response_enc', + 'id_token_signed_response_alg', + 'initiate_login_uri', + 'jwks', + 'jwks_uri', + 'logo_uri', + 'policy_uri', + 'post_logout_redirect_uri', + 'redirect_uris', + 'request_object_encryption_alg', + 'request_object_encryption_enc', + 'request_object_signing_alg', + 'request_uris', + 'require_auth_time', + 'response_types', + 'sector_identifier_uri', + 'subject_type', + 'token_endpoint_auth_method', + 'token_endpoint_auth_signing_alg', + 'tos_uri', + 'userinfo_encrypted_response_alg', + 'userinfo_encrypted_response_enc', + 'userinfo_signed_response_alg'} - if value: - #registration_request[key] = array_to_singleton(spec, value) - registration_request[key] = value + assert registration_request["subject_type"] == 'public' registration_response = { "application_type": "web", @@ -311,8 +370,6 @@ def test_registration_response(self): ["https://client.example.org/callback", "https://client.example.org/callback2"], "client_name": "My Example", - "client_name#ja-Jpan-JP": - "クライアント名", "logo_uri": "https://client.example.org/logo.png", "subject_type": "pairwise", "sector_identifier_uri": @@ -329,4 +386,30 @@ def test_registration_response(self): to_use = preferred_to_registered(prefers=pref, registration_response=registration_response) - assert to_use + assert set(to_use.keys()) == {'application_type', + 'client_name', + 'client_name#ja-Jpan-JP', + 'contacts', + 'default_max_age', + 'grant_types', + 'id_token_encrypted_response_alg', + 'id_token_encrypted_response_enc', + 'id_token_signed_response_alg', + 'jwks_uri', + 'logo_uri', + 'redirect_uris', + 'request_object_encryption_alg', + 'request_object_encryption_enc', + 'request_object_signing_alg', + 'request_uris', + 'response_modes_supported', + 'response_types', + 'sector_identifier_uri', + 'subject_type', + 'token_endpoint_auth_method', + 'token_endpoint_auth_signing_alg', + 'userinfo_encrypted_response_alg', + 'userinfo_encrypted_response_enc', + 'userinfo_signed_response_alg'} + + assert to_use["subject_type"] == 'pairwise' From 9864eff9c6f27ad0c5df1d159157004152207213 Mon Sep 17 00:00:00 2001 From: Roland Hedberg Date: Sat, 26 Nov 2022 08:26:23 +0100 Subject: [PATCH 20/76] All tests green --- setup.py | 1 + src/idpyoidc/client/defaults.py | 6 +- src/idpyoidc/client/oauth2/access_token.py | 2 +- src/idpyoidc/client/oauth2/authorization.py | 5 +- src/idpyoidc/client/oidc/access_token.py | 2 +- src/idpyoidc/client/oidc/authorization.py | 11 +- src/idpyoidc/client/oidc/end_session.py | 4 +- .../client/oidc/provider_info_discovery.py | 5 +- src/idpyoidc/client/oidc/registration.py | 27 +- src/idpyoidc/client/provider/github.py | 7 + src/idpyoidc/client/provider/linkedin.py | 7 + src/idpyoidc/client/rp_handler.py | 8 + src/idpyoidc/client/service_context.py | 8 +- src/idpyoidc/client/util.py | 3 + .../client/work_condition/__init__.py | 76 +++++- src/idpyoidc/client/work_condition/oidc.py | 19 +- .../client/work_condition/transform.py | 42 +++- tests/pub_client.jwks | 2 +- tests/request123456.jwt | 2 +- tests/test_08_transform.py | 45 +--- tests/test_09_work_condition.py | 234 ++++++++++++++++++ tests/test_client_02b_entity_metadata.py | 43 ++-- tests/test_client_04_service.py | 4 +- tests/test_client_06_client_authn.py | 2 +- tests/test_client_12_client_auth.py | 2 +- .../test_client_14_service_context_impexp.py | 2 +- tests/test_client_21_oidc_service.py | 59 +++-- tests/test_client_23_pkce.py | 2 +- tests/test_client_26_read_registration.py | 2 +- tests/test_client_27_conversation.py | 15 +- tests/test_client_28_rp_handler_oidc.py | 16 +- tests/test_client_30_rph_defaults.py | 54 ++-- tests/test_client_41_rp_handler_persistent.py | 8 +- 33 files changed, 558 insertions(+), 167 deletions(-) create mode 100644 tests/test_09_work_condition.py diff --git a/setup.py b/setup.py index 3417af4f..3c2be62a 100644 --- a/setup.py +++ b/setup.py @@ -67,6 +67,7 @@ def run_tests(self): "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", "Topic :: Software Development :: Libraries :: Python Modules"], install_requires=[ "cryptojwt>=1.8.1", diff --git a/src/idpyoidc/client/defaults.py b/src/idpyoidc/client/defaults.py index 3237cd61..b8d50659 100644 --- a/src/idpyoidc/client/defaults.py +++ b/src/idpyoidc/client/defaults.py @@ -26,7 +26,7 @@ }, } -DEFAULT_CLIENT_METADATA = { +DEFAULT_CLIENT_PREFERENCES = { "application_type": "web", "response_types": [ "code", @@ -37,6 +37,7 @@ "code token", ], "token_endpoint_auth_method": "client_secret_basic", + "scopes_supported": ["openid"], } DEFAULT_USAGE = { @@ -47,8 +48,7 @@ # Using PKCE is default DEFAULT_CLIENT_CONFIGS = { "": { - "metadata": DEFAULT_CLIENT_METADATA, - "usage": DEFAULT_USAGE, + "preference": DEFAULT_CLIENT_PREFERENCES, "add_ons": { "pkce": { "function": "idpyoidc.client.oauth2.add_on.pkce.add_support", diff --git a/src/idpyoidc/client/oauth2/access_token.py b/src/idpyoidc/client/oauth2/access_token.py index 67c51517..a04a864b 100644 --- a/src/idpyoidc/client/oauth2/access_token.py +++ b/src/idpyoidc/client/oauth2/access_token.py @@ -27,7 +27,7 @@ class AccessToken(Service): response_body_type = "json" _supports = { - "token_endpoint_auth_method": get_client_authn_methods, + "token_endpoint_auth_methods_supported": get_client_authn_methods, "token_endpoint_auth_signing_alg": get_signing_algs, } diff --git a/src/idpyoidc/client/oauth2/authorization.py b/src/idpyoidc/client/oauth2/authorization.py index 2fd49fdc..75809075 100644 --- a/src/idpyoidc/client/oauth2/authorization.py +++ b/src/idpyoidc/client/oauth2/authorization.py @@ -120,12 +120,13 @@ def _do_redirect_uris(self, base_url, hex, context, callback_uris, response_type callback_uris['redirect_uris'] = {} for flow_type, path in self._callback_path['redirect_uris'].items(): if self._do_flow(flow_type, response_types): - callback_uris['redirect_uris'][flow_type] = self.get_uri(base_url, path, hex) + callback_uris['redirect_uris'][flow_type] = [ + self.get_uri(base_url, path, hex)] else: callback_uris['redirect_uris'] = {} for flow_type, path in self._callback_path['redirect_uris'].items(): if self._do_flow(flow_type, response_types): - callback_uris['redirect_uris'][flow_type] = self.get_uri(base_url, path, hex) + callback_uris['redirect_uris'][flow_type] = [self.get_uri(base_url, path, hex)] return callback_uris def construct_uris(self, diff --git a/src/idpyoidc/client/oidc/access_token.py b/src/idpyoidc/client/oidc/access_token.py index d35698a5..16f17f3c 100644 --- a/src/idpyoidc/client/oidc/access_token.py +++ b/src/idpyoidc/client/oidc/access_token.py @@ -23,7 +23,7 @@ class AccessToken(access_token.AccessToken): error_msg = oidc.ResponseMessage _supports = { - "token_endpoint_auth_methods_supported": get_client_authn_methods, + "token_endpoint_auth_method": get_client_authn_methods, "token_endpoint_auth_signing_alg_values_supported": get_signing_algs } diff --git a/src/idpyoidc/client/oidc/authorization.py b/src/idpyoidc/client/oidc/authorization.py index 4b243746..271c679d 100644 --- a/src/idpyoidc/client/oidc/authorization.py +++ b/src/idpyoidc/client/oidc/authorization.py @@ -37,6 +37,8 @@ class Authorization(authorization.Authorization): "request_object_encryption_enc_values_supported": work_condition.get_encryption_encs, "response_types_supported": ["code", "token", "code token", 'id_token', 'id_token token', 'code id_token', 'code idtoken token'], + 'request_parameter_supported': None, + 'request_uri_parameter_supported': None, "request_uris": None, "request_parameter": None, "encrypt_request_object_supported": None, @@ -61,7 +63,8 @@ def __init__(self, client_get, conf=None): self.oidc_pre_construct, ] self.post_construct = [self.oidc_post_construct] - self.default_request_args = {'scope': ['openid']} + if 'scope' not in self.default_request_args: + self.default_request_args['scope'] = ['openid'] def set_state(self, request_args, **kwargs): try: @@ -127,7 +130,11 @@ def oidc_pre_construct(self, request_args=None, post_args=None, **kwargs): if _scope: request_args["scope"] = _scope else: - request_args["scope"] = "openid" + _scope = _context.get_preference("scopes_supported") + if _scope: + request_args['scope'] = _scope + else: + request_args["scope"] = "openid" elif "openid" not in request_args["scope"]: request_args["scope"].append("openid") diff --git a/src/idpyoidc/client/oidc/end_session.py b/src/idpyoidc/client/oidc/end_session.py index 1f2e6577..eacb1861 100644 --- a/src/idpyoidc/client/oidc/end_session.py +++ b/src/idpyoidc/client/oidc/end_session.py @@ -21,7 +21,7 @@ class EndSession(Service): response_body_type = "html" _supports = { - "post_logout_redirect_uri": None, + "post_logout_redirect_uris": None, 'frontchannel_logout_supported': None, "frontchannel_logout_uri": None, "frontchannel_logout_session_required": None, @@ -33,7 +33,7 @@ class EndSession(Service): _callback_path = { "frontchannel_logout_uri": "fc_logout", "backchannel_logout_uri": "bc_logout", - "post_logout_redirect_uri": "session_logout" + "post_logout_redirect_uris": "session_logout" } def __init__(self, client_get, conf=None): diff --git a/src/idpyoidc/client/oidc/provider_info_discovery.py b/src/idpyoidc/client/oidc/provider_info_discovery.py index 3b52766a..ac21d3fd 100644 --- a/src/idpyoidc/client/oidc/provider_info_discovery.py +++ b/src/idpyoidc/client/oidc/provider_info_discovery.py @@ -2,6 +2,7 @@ from idpyoidc.client.exception import ConfigurationError from idpyoidc.client.oauth2 import server_metadata +from idpyoidc.client.work_condition.transform import supported_to_preferred from idpyoidc.message import oidc from idpyoidc.message.oauth2 import ResponseMessage @@ -53,8 +54,8 @@ def __init__(self, client_get, conf=None): def update_service_context(self, resp, **kwargs): _context = self.client_get("service_context") - self._update_service_context(resp) # set endpoints and import keys - self.match_preferences(resp, _context.issuer) + self._update_service_context(resp) # set endpoints and import keys + _context.map_supported_to_preferred(resp) if "pre_load_keys" in self.conf and self.conf["pre_load_keys"]: _jwks = _context.keyjar.export_jwks_as_json(issuer=resp["issuer"]) logger.info("Preloaded keys for {}: {}".format(resp["issuer"], _jwks)) diff --git a/src/idpyoidc/client/oidc/registration.py b/src/idpyoidc/client/oidc/registration.py index fd708b12..0d2ead7e 100644 --- a/src/idpyoidc/client/oidc/registration.py +++ b/src/idpyoidc/client/oidc/registration.py @@ -2,6 +2,7 @@ from idpyoidc.client.entity import response_types_to_grant_types from idpyoidc.client.service import Service +from idpyoidc.client.work_condition.transform import create_registration_request from idpyoidc.message import oidc from idpyoidc.message.oauth2 import ResponseMessage @@ -60,22 +61,20 @@ def oidc_post_construct(self, request_args=None, **kwargs): return request_args def update_service_context(self, resp, key="", **kwargs): - if "token_endpoint_auth_method" not in resp: - resp["token_endpoint_auth_method"] = "client_secret_basic" + # if "token_endpoint_auth_method" not in resp: + # resp["token_endpoint_auth_method"] = "client_secret_basic" _context = self.client_get("service_context") + _context.map_preferred_to_registered(resp) _keyjar = _context.keyjar _context.registration_response = resp - _client_id = resp.get("client_id") + _client_id = _context.get_usage("client_id") if _client_id: - _context.work_condition.set_usage("client_id", _client_id) if _client_id not in _keyjar: _keyjar.import_jwks(_keyjar.export_jwks(True, ""), issuer_id=_client_id) - _client_secret = resp.get("client_secret") + _client_secret = _context.get_usage("client_secret") if _client_secret: - _context.set_usage("client_secret", _client_secret) - # _context.client_secret = _client_secret _keyjar.add_symmetric("", _client_secret) _keyjar.add_symmetric(_client_id, _client_secret) try: @@ -88,3 +87,17 @@ def update_service_context(self, resp, key="", **kwargs): _context.set_usage("registration_access_token", resp["registration_access_token"]) except KeyError: pass + + def gather_request_args(self, **kwargs): + """ + + @param kwargs: + @return: + """ + _context = self.client_get("service_context") + req_args = create_registration_request(_context.work_condition.prefer, _context.supports()) + if "request_args" in self.conf: + req_args.update(self.conf["request_args"]) + + req_args.update(kwargs) + return req_args \ No newline at end of file diff --git a/src/idpyoidc/client/provider/github.py b/src/idpyoidc/client/provider/github.py index cf841f38..0e0e2fa5 100644 --- a/src/idpyoidc/client/provider/github.py +++ b/src/idpyoidc/client/provider/github.py @@ -1,5 +1,7 @@ from idpyoidc.client.oauth2 import access_token from idpyoidc.client.oidc import userinfo +from idpyoidc.client.work_condition import get_client_authn_methods +from idpyoidc.client.work_condition import get_signing_algs from idpyoidc.message import SINGLE_OPTIONAL_STRING from idpyoidc.message import SINGLE_REQUIRED_STRING from idpyoidc.message import Message @@ -25,6 +27,11 @@ class AccessToken(access_token.AccessToken): error_msg = oauth2.TokenErrorResponse response_body_type = "urlencoded" + _supports = { + "token_endpoint_auth_method": get_client_authn_methods, + "token_endpoint_auth_signing_alg_values_supported": get_signing_algs + } + class UserInfo(userinfo.UserInfo): response_cls = Message diff --git a/src/idpyoidc/client/provider/linkedin.py b/src/idpyoidc/client/provider/linkedin.py index 24d55020..8ddede1d 100644 --- a/src/idpyoidc/client/provider/linkedin.py +++ b/src/idpyoidc/client/provider/linkedin.py @@ -1,5 +1,7 @@ from idpyoidc.client.oauth2 import access_token from idpyoidc.client.oidc import userinfo +from idpyoidc.client.work_condition import get_client_authn_methods +from idpyoidc.client.work_condition import get_signing_algs from idpyoidc.message import SINGLE_OPTIONAL_JSON from idpyoidc.message import SINGLE_OPTIONAL_STRING from idpyoidc.message import SINGLE_REQUIRED_INT @@ -31,6 +33,11 @@ class AccessToken(access_token.AccessToken): response_cls = AccessTokenResponse error_msg = oauth2.TokenErrorResponse + _supports = { + "token_endpoint_auth_method": get_client_authn_methods, + "token_endpoint_auth_signing_alg_values_supported": get_signing_algs + } + class UserInfo(userinfo.UserInfo): response_cls = UserSchema diff --git a/src/idpyoidc/client/rp_handler.py b/src/idpyoidc/client/rp_handler.py index 2302a6d9..1ee5f4a6 100644 --- a/src/idpyoidc/client/rp_handler.py +++ b/src/idpyoidc/client/rp_handler.py @@ -183,6 +183,14 @@ def init_client(self, issuer): except KeyError: _services = self.services + if not 'base_url' in _cnf: + _cnf['base_url'] = self.base_url + + if self.jwks_uri: + _cnf['jwks_uri'] = self.jwks_uri + elif self.jwks: + _cnf['jwks'] = self.jwks + try: client = self.client_cls( services=_services, diff --git a/src/idpyoidc/client/service_context.py b/src/idpyoidc/client/service_context.py index 7edcac64..8d94bc11 100644 --- a/src/idpyoidc/client/service_context.py +++ b/src/idpyoidc/client/service_context.py @@ -348,7 +348,9 @@ def map_supported_to_preferred(self, info: Optional[dict] = None): info=info) return self.work_condition.prefer - def map_preferred_to_registered(self): - self.work_condition.use = preferred_to_registered(self.work_condition.prefer, - self.work_condition.use) + def map_preferred_to_registered(self, registration_response: Optional[dict] = None): + self.work_condition.use = preferred_to_registered( + self.work_condition.prefer, + supported=self.supports(), + registration_response=registration_response) return self.work_condition.use diff --git a/src/idpyoidc/client/util.py b/src/idpyoidc/client/util.py index 13d16965..a5d78c73 100755 --- a/src/idpyoidc/client/util.py +++ b/src/idpyoidc/client/util.py @@ -321,3 +321,6 @@ def implicit_response_types(a): if set(typ.split(' ')) in IMPLICIT_RESPONSE_TYPES: res.append(typ) return res + +def get_uri(base_url, path, hex): + return f"{base_url}/{path}/{hex}" diff --git a/src/idpyoidc/client/work_condition/__init__.py b/src/idpyoidc/client/work_condition/__init__.py index 298c952c..df7452ef 100644 --- a/src/idpyoidc/client/work_condition/__init__.py +++ b/src/idpyoidc/client/work_condition/__init__.py @@ -2,12 +2,16 @@ from typing import Callable from typing import Optional +from cryptojwt import KeyJar +from cryptojwt.exception import IssuerNotFound from cryptojwt.jwe import SUPPORTED +from cryptojwt.jwk.hmac import SYMKey from cryptojwt.jws.jws import SIGNER_ALGS +from cryptojwt.key_jar import init_key_jar from cryptojwt.utils import importer from idpyoidc.client.client_auth import CLIENT_AUTHN_METHOD -from idpyoidc.client.service import Service +from idpyoidc.client.util import get_uri from idpyoidc.impexp import ImpExp from idpyoidc.util import qualified_name @@ -79,7 +83,7 @@ def _callback_uris(self, base_url, hex): callback_uri = {} for key in _uri: - callback_uri[key] = Service.get_uri(base_url, self.callback_path[key], hex) + callback_uri[key] = get_uri(base_url, self.callback_path[key], hex) return callback_uri def construct_redirect_uris(self, @@ -101,8 +105,65 @@ def verify_rules(self): def locals(self, info): pass - def load_conf(self, info, supports): - for attr, val in info.items(): + def _keyjar(self, keyjar=None, conf=None, entity_id=""): + _uri_path = '' + if keyjar is None: + if "keys" in conf: + keys_args = {k: v for k, v in conf["keys"].items() if k != "uri_path"} + _keyjar = init_key_jar(**keys_args) + _uri_path = conf['keys'].get('uri_path') + elif "key_conf" in conf and conf["key_conf"]: + keys_args = {k: v for k, v in conf["key_conf"].items() if k != "uri_path"} + _keyjar = init_key_jar(**keys_args) + _uri_path = conf['key_conf'].get('uri_path') + else: + _keyjar = KeyJar() + if "jwks" in conf: + _keyjar.import_jwks(conf["jwks"], "") + + if "" in _keyjar and entity_id: + # make sure I have the keys under my own name too (if I know it) + _keyjar.import_jwks_as_json(_keyjar.export_jwks_as_json(True, ""), entity_id) + + _httpc_params = conf.get("httpc_params") + if _httpc_params: + _keyjar.httpc_params = _httpc_params + + return _keyjar, _uri_path + else: + return keyjar, _uri_path + + def handle_keys(self, configuration: dict, keyjar: Optional[KeyJar] = None): + _jwks = _jwks_uri = None + _id = self.get_preference('client_id') + keyjar, uri_path = self._keyjar(keyjar, configuration, entity_id=_id) + + _secret = self.get_preference('client_secret') + if _secret: + keyjar.add_symmetric(issuer_id=_id, key=_secret) + keyjar.add_symmetric(issuer_id='', key=_secret) + + # now that keys are in the Key Jar, now for how to publish it + if 'jwks_uri' in configuration: # simple + _jwks_uri = configuration.get('jwks_uri') + elif uri_path: + _jwks_uri = f"{configuration.get('base_url')}{uri_path}" + else: # jwks or nothing + # if only the client secret, no need to publish as a JWKS + try: + _own_keys = keyjar.get_issuer_keys('') + except IssuerNotFound: + pass + else: + if len(_own_keys) == 1 and isinstance(_own_keys[0], SYMKey): + pass + else: + _jwks = keyjar.export_jwks() + + return {'keyjar': keyjar, 'jwks': _jwks, 'jwks_uri': _jwks_uri} + + def load_conf(self, configuration, supports): + for attr, val in configuration.items(): if attr == "preference": for k, v in val.items(): if k in supports: @@ -110,7 +171,12 @@ def load_conf(self, info, supports): elif attr in supports: self.set_preference(attr, val) - self.locals(info) + self.locals(configuration) + + for key, val in self.handle_keys(configuration).items(): + if val: + self.set_preference(key, val) + self.verify_rules() return self diff --git a/src/idpyoidc/client/work_condition/oidc.py b/src/idpyoidc/client/work_condition/oidc.py index 2e609fb0..e949c4ea 100644 --- a/src/idpyoidc/client/work_condition/oidc.py +++ b/src/idpyoidc/client/work_condition/oidc.py @@ -1,6 +1,9 @@ import os from typing import Optional +from cryptojwt import KeyJar +from cryptojwt.key_jar import init_key_jar + from idpyoidc.client import work_condition @@ -46,17 +49,11 @@ def __init__(self, work_condition.WorkCondition.__init__(self, prefer=prefer, callback_path=callback_path) def verify_rules(self): - if self.get_preference("request_parameter") and self.get_preference("request_uri"): - raise ValueError("You have to chose one of 'request_parameter' and 'request_uri'." - " you can't have both.") - - _cb_uris = self.get_preference('callback_uris') - if _cb_uris: - self.set_preference('redirect_uris', list(_cb_uris.values())) # just overwrite - else: - _uris = self.get_preference('redirect_uris') - if _uris: - self.set_preference('callback_uris', {'redirect_uris': {'code': _uris}}) + if self.get_preference("request_parameter_supported") and self.get_preference( + "request_uri_parameter_supported"): + raise ValueError( + "You have to chose one of 'request_parameter_supported' and " + "'request_uri_parameter_supported'. You can't have both.") if not self.get_preference('encrypt_userinfo_supported'): self.set_preference('userinfo_encryption_alg_values_supported', []) diff --git a/src/idpyoidc/client/work_condition/transform.py b/src/idpyoidc/client/work_condition/transform.py index 83f5e47e..c2fcfe69 100644 --- a/src/idpyoidc/client/work_condition/transform.py +++ b/src/idpyoidc/client/work_condition/transform.py @@ -20,10 +20,11 @@ "default_acr_values": "acr_values_supported", "subject_type": "subject_types_supported", "token_endpoint_auth_method": "token_endpoint_auth_methods_supported", - "token_endpoint_auth_signing_alg": "token_endpoint_auth_signing_alg_values_supported", "response_types": "response_types_supported", "grant_types": "grant_types_supported", + # In OAuth2 but not in OIDC "scope": "scopes_supported", + "token_endpoint_auth_signing_alg": "token_endpoint_auth_signing_alg_values_supported", # "display": "display_values_supported", # "claims": "claims_supported", # "request": "request_parameter_supported", @@ -44,6 +45,7 @@ 'request_uri': "request_uris", 'grant_type': "grant_types", "scope": 'scopes_supported', + 'post_logout_redirect_uri': "post_logout_redirect_uris" } @@ -109,7 +111,18 @@ def array_or_singleton(claim_spec, values): return values -def preferred_to_registered(prefers: dict, registration_response: Optional[dict] = None): +def _is_subset(a, b): + if isinstance(a, list): + if isinstance(b, list): + return set(b).issubset(set(a)) + elif isinstance(b, list): + return a in b + else: + return a == b + + +def preferred_to_registered(prefers: dict, supported: dict, + registration_response: Optional[dict] = None): """ The claims with values that are returned from the OP is what goes unless (!!) the values returned are not within the supported values. @@ -122,25 +135,33 @@ def preferred_to_registered(prefers: dict, registration_response: Optional[dict] if registration_response: for key, val in registration_response.items(): - registered[key] = val # Should I just accept with the OP says ?? + if key in REGISTER2PREFERRED: + if _is_subset(val, supported.get(REGISTER2PREFERRED[key])): + registered[key] = val + else: + logger.warning(f'OP tells me to do something I do not support: {key} = {val}') + else: + registered[key] = val # Should I just accept with the OP says ?? for key, spec in RegistrationResponse.c_param.items(): if key in registered: continue _pref_key = REGISTER2PREFERRED.get(key, key) - _preferred_values = prefers.get(_pref_key) + _preferred_values = prefers.get(_pref_key, prefers.get(key)) if not _preferred_values: continue + registered[key] = array_or_singleton(spec, _preferred_values) # transfer those claims that are not part of the registration request _rr_keys = list(RegistrationResponse.c_param.keys()) for key, val in prefers.items(): - if PREFERRED2REGISTER.get(key): - continue - if key not in _rr_keys: - registered[key] = val + _reg_key = PREFERRED2REGISTER.get(key, key) + if _reg_key not in _rr_keys: + # If they are not part of the registration request I do not knoe if it is supposed to + # be a singleton or an array. So just add it as is. + registered[_reg_key] = val logger.debug(f"Entity registered: {registered}") return registered @@ -157,5 +178,8 @@ def create_registration_request(prefers, supported): else: continue + if not value: + continue + _request[key] = array_or_singleton(spec, value) - return _request \ No newline at end of file + return _request diff --git a/tests/pub_client.jwks b/tests/pub_client.jwks index 84a27042..d5ce25ed 100644 --- a/tests/pub_client.jwks +++ b/tests/pub_client.jwks @@ -1 +1 @@ -{"keys": [{"kty": "EC", "use": "sig", "kid": "azZQQ2FEQjh3QnVZWVdrbHJkMEZSaWR6aVJ0LTBjeUFfeWRlbTRrRFZ5VQ", "crv": "P-256", "x": "2ADe18caWWGp6hpRbfa9HqQHDFNpid9xUmR56Wzm_wc", "y": "HnD_8QBanz4Y-UF8mKQFZXfqkGkXUSm34mLsdDKtSyk"}, {"kty": "RSA", "use": "sig", "kid": "SHEyYWcwNVk0LTdROTZzZ2FUWndIVXdack0xWUM5SEpwcS03dVUxWU4zRQ", "e": "AQAB", "n": "rRz52ddyP9Y2ezSlRsnkt-sjXfV_Ii7vOFX-cStLE3IUlVeSJGEe_kAASLr2r3BE2unjntaxj67NP8D95h_rzG1SpCklTEn-aTe3FOwNyTzUH_oiDVeRoEcf04Y43ciRGYRB5PhI6ii-2lYuig6hyUr776Qxiu6-0zw-M_ay2MgGSy5CEj55dDSvcUyxStUObxGpPWnEvybO1vnE7iJEWGNe0L5uPe5nLidOiR-JwjxSWEx1xZYtIjxaf2Ulu-qu4hwgwBUQdx4bNZyBfljKj55skWuHqPMG3xMjnedQC6Ms5bR3rIkbBpvmgI3kJK-4CZikM6ruyLo94-Lk19aYQw"}]} \ No newline at end of file +{"keys": [{"kty": "EC", "use": "sig", "kid": "azZQQ2FEQjh3QnVZWVdrbHJkMEZSaWR6aVJ0LTBjeUFfeWRlbTRrRFZ5VQ", "crv": "P-256", "x": "2ADe18caWWGp6hpRbfa9HqQHDFNpid9xUmR56Wzm_wc", "y": "HnD_8QBanz4Y-UF8mKQFZXfqkGkXUSm34mLsdDKtSyk"}, {"kty": "RSA", "use": "sig", "kid": "SHEyYWcwNVk0LTdROTZzZ2FUWndIVXdack0xWUM5SEpwcS03dVUxWU4zRQ", "n": "rRz52ddyP9Y2ezSlRsnkt-sjXfV_Ii7vOFX-cStLE3IUlVeSJGEe_kAASLr2r3BE2unjntaxj67NP8D95h_rzG1SpCklTEn-aTe3FOwNyTzUH_oiDVeRoEcf04Y43ciRGYRB5PhI6ii-2lYuig6hyUr776Qxiu6-0zw-M_ay2MgGSy5CEj55dDSvcUyxStUObxGpPWnEvybO1vnE7iJEWGNe0L5uPe5nLidOiR-JwjxSWEx1xZYtIjxaf2Ulu-qu4hwgwBUQdx4bNZyBfljKj55skWuHqPMG3xMjnedQC6Ms5bR3rIkbBpvmgI3kJK-4CZikM6ruyLo94-Lk19aYQw", "e": "AQAB"}]} \ No newline at end of file diff --git a/tests/request123456.jwt b/tests/request123456.jwt index ab46a129..826a4289 100644 --- a/tests/request123456.jwt +++ b/tests/request123456.jwt @@ -1 +1 @@ -eyJhbGciOiJSUzI1NiIsImtpZCI6IlNIRXlZV2N3TlZrMExUZFJPVFp6WjJGVVduZElWWGRhY2sweFdVTTVTRXB3Y1MwM2RWVXhXVTR6UlEifQ.eyJyZXNwb25zZV90eXBlIjogImNvZGUiLCAic3RhdGUiOiAic3RhdGUiLCAicmVkaXJlY3RfdXJpIjogImh0dHBzOi8vZXhhbXBsZS5jb20vY2IiLCAic2NvcGUiOiAib3BlbmlkIiwgIm5vbmNlIjogIkZ3eG5aTk9Hc3hHUVRpblhRTmpwaUtqclA2WVQ1YUJBeGFmNV9yV1l2Y3ciLCAiY2xpZW50X2lkIjogImNsaWVudF9pZCIsICJpc3MiOiAiY2xpZW50X2lkIiwgImlhdCI6IDE2Njg5MzQzNDYsICJhdWQiOiBbImh0dHBzOi8vZXhhbXBsZS5jb20iXX0.pzfhYGlDcii7YEtbz9PxER3_ILD-3FfqyoVWOIAMuZ8Ryd0Nx4j7qbG43zjLIm0bTTm1VGXbeH7DC8mnVQfEEX-b_QYNjcFbH-1O9xTKgIq_B3RGfHu0TpAbTkAtJZO92pFKFDmrUKGDg0B00J0XKH4rG_gVY8u-X4x63veVus9peKtzOgyN0WiUixwIYgwpfQx8-CR4MozJuJ9Q28QIDft53Vwl2puwvTy2n29IHIO6qIvO-Aho02FY9guC9QVhMHcSlUxilg_iqCNQQ6GE1yB2pDP2sTWpQI97BZ4XNeQ_JC86gQnjVsaWhGAhCFuJDLfIsTPVdLhWGCEIIZkkLQ \ No newline at end of file +eyJhbGciOiJSUzI1NiIsImtpZCI6IlNIRXlZV2N3TlZrMExUZFJPVFp6WjJGVVduZElWWGRhY2sweFdVTTVTRXB3Y1MwM2RWVXhXVTR6UlEifQ.eyJyZXNwb25zZV90eXBlIjogImNvZGUiLCAic3RhdGUiOiAic3RhdGUiLCAicmVkaXJlY3RfdXJpIjogImh0dHBzOi8vZXhhbXBsZS5jb20vY2xpL2F1dGh6X2NiIiwgInNjb3BlIjogIm9wZW5pZCIsICJub25jZSI6ICJjNENQb1N0Q1BnY0hhOHVySG1kYk43d2V1bnZteUVKTk0wZ2oyUmxvY01jIiwgImNsaWVudF9pZCI6ICJjbGllbnRfaWQiLCAiaXNzIjogImNsaWVudF9pZCIsICJpYXQiOiAxNjY5MjgxODA3LCAiYXVkIjogWyJodHRwczovL2V4YW1wbGUuY29tIl19.BGBFSfVc5TNAIU2_4Z2QE4tfRPt-IVZwV_4zUGQWawmsDJZbo1dc7NZZjsXyI_1gHsG7mn7utylqtS8Q-jStg3ikyn886eON6UUIbf9YX--rkl3D8_es9CLnFL0jbAOX2sl573ujMBG-IlEgTLyIgzOgibLfJau0x-JOlQLP9l-dFqFxaE9LeeiPfm2sa7Y6HLk9BkjW7rC2UsrTLuf5PTYwriF7n9IXcCquZbNcbXdk2xA9Oy3ozvxrpMlxqA45peNTK9d97hE5cfJ6sB0vrouDPJ3x_o36AXbGx82PVL8Ce0ZCRXytLZqEtS21nwJiu_nadvvg8fodt81Gm8IOmQ \ No newline at end of file diff --git a/tests/test_08_transform.py b/tests/test_08_transform.py index 25bd7814..2e19ebf6 100644 --- a/tests/test_08_transform.py +++ b/tests/test_08_transform.py @@ -1,12 +1,12 @@ from typing import Callable -import pytest from cryptojwt.utils import importer +import pytest from idpyoidc.client.work_condition.oidc import WorkCondition as WorkConditionOIDC +from idpyoidc.client.work_condition.transform import REGISTER2PREFERRED from idpyoidc.client.work_condition.transform import create_registration_request from idpyoidc.client.work_condition.transform import preferred_to_registered -from idpyoidc.client.work_condition.transform import REGISTER2PREFERRED from idpyoidc.client.work_condition.transform import supported_to_preferred from idpyoidc.message.oidc import ProviderConfigurationResponse from idpyoidc.message.oidc import RegistrationRequest @@ -70,12 +70,14 @@ def test_supported(self): 'jwks_uri', 'logo_uri', 'policy_uri', - 'post_logout_redirect_uri', + 'post_logout_redirect_uris', 'redirect_uris', 'request_object_encryption_alg_values_supported', 'request_object_encryption_enc_values_supported', 'request_object_signing_alg_values_supported', 'request_parameter', + 'request_parameter_supported', + 'request_uri_parameter_supported', 'request_uris', 'requests_dir', 'require_auth_time', @@ -110,8 +112,8 @@ def test_oidc_setup(self): 'op_policy_uri', 'op_tos_uri', 'registration_endpoint', - 'request_parameter_supported', - 'request_uri_parameter_supported', + # 'request_parameter_supported', + # 'request_uri_parameter_supported', 'require_request_uri_registration', 'service_documentation', 'token_endpoint', @@ -138,7 +140,7 @@ def test_oidc_setup(self): 'jwks', 'logo_uri', 'policy_uri', - 'post_logout_redirect_uri', + 'post_logout_redirect_uris', 'redirect_uris', 'request_parameter', 'request_uris', @@ -179,7 +181,8 @@ def test_oidc_setup(self): if _pref_key in self.supported: reg_claim.append(key) - assert set(RegistrationRequest.c_param.keys()).difference(set(reg_claim)) == set() + assert set(RegistrationRequest.c_param.keys()).difference(set(reg_claim)) == { + 'post_logout_redirect_uri'} # Which ones are list -> singletons @@ -327,39 +330,18 @@ def test_registration_response(self): registration_request = create_registration_request(pref, self.supported) assert set(registration_request.keys()) == {'application_type', - 'backchannel_logout_session_required', - 'backchannel_logout_uri', 'client_name', - 'client_uri', 'contacts', - 'default_acr_values', 'default_max_age', - 'frontchannel_logout_session_required', - 'frontchannel_logout_uri', 'grant_types', - 'id_token_encrypted_response_alg', - 'id_token_encrypted_response_enc', 'id_token_signed_response_alg', - 'initiate_login_uri', - 'jwks', - 'jwks_uri', 'logo_uri', - 'policy_uri', - 'post_logout_redirect_uri', 'redirect_uris', - 'request_object_encryption_alg', - 'request_object_encryption_enc', 'request_object_signing_alg', - 'request_uris', - 'require_auth_time', 'response_types', - 'sector_identifier_uri', 'subject_type', 'token_endpoint_auth_method', 'token_endpoint_auth_signing_alg', - 'tos_uri', - 'userinfo_encrypted_response_alg', - 'userinfo_encrypted_response_enc', 'userinfo_signed_response_alg'} assert registration_request["subject_type"] == 'public' @@ -384,26 +366,23 @@ def test_registration_response(self): } to_use = preferred_to_registered(prefers=pref, + supported=self.supported, registration_response=registration_response) assert set(to_use.keys()) == {'application_type', 'client_name', - 'client_name#ja-Jpan-JP', 'contacts', 'default_max_age', 'grant_types', - 'id_token_encrypted_response_alg', - 'id_token_encrypted_response_enc', 'id_token_signed_response_alg', 'jwks_uri', 'logo_uri', 'redirect_uris', - 'request_object_encryption_alg', - 'request_object_encryption_enc', 'request_object_signing_alg', 'request_uris', 'response_modes_supported', 'response_types', + 'scope', 'sector_identifier_uri', 'subject_type', 'token_endpoint_auth_method', diff --git a/tests/test_09_work_condition.py b/tests/test_09_work_condition.py new file mode 100644 index 00000000..96b5f2ae --- /dev/null +++ b/tests/test_09_work_condition.py @@ -0,0 +1,234 @@ +from typing import Callable + +from cryptojwt.utils import importer +import pytest as pytest + +from idpyoidc.client.work_condition.oidc import WorkCondition as WorkConditionOIDC +from idpyoidc.client.work_condition.transform import create_registration_request +from idpyoidc.client.work_condition.transform import preferred_to_registered +from idpyoidc.client.work_condition.transform import supported_to_preferred + +KEYSPEC = [ + {"type": "RSA", "use": ["sig"]}, + {"type": "EC", "crv": "P-256", "use": ["sig"]}, +] + + +class TestWorkCondition: + + @pytest.fixture(autouse=True) + def setup(self): + self.work_condition = WorkConditionOIDC() + supported = self.work_condition._supports.copy() + for service in [ + 'idpyoidc.client.oidc.access_token.AccessToken', + 'idpyoidc.client.oidc.authorization.Authorization', + 'idpyoidc.client.oidc.backchannel_authentication.BackChannelAuthentication', + 'idpyoidc.client.oidc.backchannel_authentication.ClientNotification', + 'idpyoidc.client.oidc.check_id.CheckID', + 'idpyoidc.client.oidc.check_session.CheckSession', + 'idpyoidc.client.oidc.end_session.EndSession', + 'idpyoidc.client.oidc.provider_info_discovery.ProviderInfoDiscovery', + 'idpyoidc.client.oidc.read_registration.RegistrationRead', + 'idpyoidc.client.oidc.refresh_access_token.RefreshAccessToken', + 'idpyoidc.client.oidc.registration.Registration', + 'idpyoidc.client.oidc.userinfo.UserInfo', + 'idpyoidc.client.oidc.webfinger.WebFinger' + ]: + cls = importer(service) + supported.update(cls._supports) + + for key, val in supported.items(): + if isinstance(val, Callable): + supported[key] = val() + + self.supported = supported + + def test_load_conf(self): + # Only symmetric key + client_conf = { + "application_type": "web", + "redirect_uris": ["https://client.example.org/callback", + "https://client.example.org/callback2"], + "client_name": "My Example", + "client_id": "client_id", + "client_secret": "a longesh password", + "logo_uri": "https://client.example.org/logo.png", + 'contacts': ["ve7jtb@example.org", "mary@example.org"] + } + + self.work_condition.load_conf(client_conf, self.supported) + assert self.work_condition.get_preference('jwks') is None + assert self.work_condition.get_preference('jwks_uri') is None + + def test_load_jwks(self): + # Symmetric and asymmetric keys published as JWKS + client_conf = { + "application_type": "web", + 'base_url': "https://client.example.org/", + "redirect_uris": ["https://client.example.org/callback", + "https://client.example.org/callback2"], + "client_name": "My Example", + "client_id": "client_id", + "keys": {"key_defs": KEYSPEC, "read_only": True}, + "client_secret": "a longesh password", + "logo_uri": "https://client.example.org/logo.png", + 'contacts': ["ve7jtb@example.org", "mary@example.org"] + } + + self.work_condition.load_conf(client_conf, self.supported) + assert self.work_condition.get_preference('jwks') is not None + assert self.work_condition.get_preference('jwks_uri') is None + + def test_load_jwks_uri1(self): + # Symmetric and asymmetric keys published through a jwks_uri + client_conf = { + "application_type": "web", + 'base_url': "https://client.example.org/", + "redirect_uris": ["https://client.example.org/callback", + "https://client.example.org/callback2"], + "client_name": "My Example", + "keys": {"uri_path": "static/jwks.json", "key_defs": KEYSPEC, "read_only": True}, + "logo_uri": "https://client.example.org/logo.png", + 'contacts': ["ve7jtb@example.org", "mary@example.org"] + } + + self.work_condition.load_conf(client_conf, self.supported) + assert self.work_condition.get_preference('jwks') is None + assert self.work_condition.get_preference( + 'jwks_uri') == f"{client_conf['base_url']}{client_conf['keys']['uri_path']}" + + def test_load_jwks_uri2(self): + # Symmetric and asymmetric keys published through a jwks_uri + client_conf = { + "application_type": "web", + 'base_url': "https://client.example.org/", + "redirect_uris": ["https://client.example.org/callback", + "https://client.example.org/callback2"], + "client_name": "My Example", + "keys": {"key_defs": KEYSPEC, "read_only": True}, + "jwks_uri": 'https://client.example.org/keys/jwks.json', + "logo_uri": "https://client.example.org/logo.png", + 'contacts': ["ve7jtb@example.org", "mary@example.org"] + } + + self.work_condition.load_conf(client_conf, self.supported) + assert self.work_condition.get_preference('jwks') is None + assert self.work_condition.get_preference('jwks_uri') == client_conf['jwks_uri'] + + def test_registration_response(self): + client_conf = { + "application_type": "web", + 'base_url': "https://client.example.org/", + "redirect_uris": ["https://client.example.org/callback", + "https://client.example.org/callback2"], + "client_name": "My Example", + "client_id": "client_id", + "keys": {"key_defs": KEYSPEC, "read_only": True}, + "client_secret": "a longesh password", + "logo_uri": "https://client.example.org/logo.png", + 'contacts': ["ve7jtb@example.org", "mary@example.org"] + } + + self.work_condition.load_conf(client_conf, self.supported) + + OP_BASEURL = 'https://example.com' + provider_info_response = { + "version": "3.0", + "token_endpoint_auth_methods_supported": [ + "client_secret_post", + "client_secret_basic", + "client_secret_jwt", + "private_key_jwt", + ], + "issuer": OP_BASEURL, + "jwks_uri": f"{OP_BASEURL}/static/jwks_tE2iLbOAqXhe8bqh.json", + "authorization_endpoint": f"{OP_BASEURL}/authorization", + "token_endpoint": f"{OP_BASEURL}/token", + "userinfo_endpoint": f"{OP_BASEURL}/userinfo", + "registration_endpoint": f"{OP_BASEURL}/registration", + "end_session_endpoint": f"{OP_BASEURL}/end_session", + # below are a set which the RP has default values but the OP overwrites + "scopes_supported": ['openid', 'fee', 'faa', 'foo', 'fum'], + "response_types_supported": ['code', 'id_token', 'code id_token'], + "response_modes_supported": ['query', 'form_post', 'new_fangled'], + # this does not have a default value + "acr_values_supported": ['mfa'], + } + + pref = supported_to_preferred(supported=self.supported, + preference=self.work_condition.prefer, + base_url='https://example.com', + info=provider_info_response) + + registration_request = create_registration_request(pref, self.supported) + + assert set(registration_request.keys()) == {'application_type', + 'client_name', + 'contacts', + 'default_max_age', + 'grant_types', + 'id_token_signed_response_alg', + 'jwks', + 'logo_uri', + 'redirect_uris', + 'request_object_signing_alg', + 'response_types', + 'subject_type', + 'token_endpoint_auth_method', + 'token_endpoint_auth_signing_alg', + 'userinfo_signed_response_alg'} + + assert registration_request["subject_type"] == 'public' + + registration_response = { + "application_type": "web", + "redirect_uris": + ["https://client.example.org/callback", + "https://client.example.org/callback2"], + "client_name": "My Example", + "logo_uri": "https://client.example.org/logo.png", + "subject_type": "pairwise", + "sector_identifier_uri": + "https://other.example.net/file_of_redirect_uris.json", + "token_endpoint_auth_method": "client_secret_basic", + "jwks_uri": "https://client.example.org/my_public_keys.jwks", + "userinfo_encrypted_response_alg": "RSA1_5", + "userinfo_encrypted_response_enc": "A128CBC-HS256", + "contacts": ["ve7jtb@example.org", "mary@example.org"], + "request_uris": [ + "https://client.example.org/rf.txt#qpXaRLh_n93TTR9F252ValdatUQvQiJi5BDub2BeznA"] + } + + to_use = preferred_to_registered(prefers=pref, + supported=self.supported, + registration_response=registration_response) + + assert set(to_use.keys()) == {'application_type', + 'client_id', + 'client_name', + 'client_secret', + 'contacts', + 'default_max_age', + 'grant_types', + 'id_token_signed_response_alg', + 'jwks', + 'jwks_uri', + 'keyjar', + 'logo_uri', + 'redirect_uris', + 'request_object_signing_alg', + 'request_uris', + 'response_modes_supported', + 'response_types', + 'scope', + 'sector_identifier_uri', + 'subject_type', + 'token_endpoint_auth_method', + 'token_endpoint_auth_signing_alg', + 'userinfo_encrypted_response_alg', + 'userinfo_encrypted_response_enc', + 'userinfo_signed_response_alg'} + + # Not what I asked for but something I can handle + assert to_use["subject_type"] == 'pairwise' diff --git a/tests/test_client_02b_entity_metadata.py b/tests/test_client_02b_entity_metadata.py index e4b5d437..9de8d198 100644 --- a/tests/test_client_02b_entity_metadata.py +++ b/tests/test_client_02b_entity_metadata.py @@ -79,6 +79,7 @@ def test_create_client(): 'id_token_encryption_enc_values_supported', 'id_token_signing_alg_values_supported', 'jwks', + 'keyjar', 'post_logout_redirect_uris', 'redirect_uris', 'request_object_encryption_alg_values_supported', @@ -89,7 +90,7 @@ def test_create_client(): 'response_types_supported', 'scopes_supported', 'subject_types_supported', - 'token_endpoint_auth_methods_supported', + 'token_endpoint_auth_method', 'token_endpoint_auth_signing_alg_values_supported', 'userinfo_encryption_alg_values_supported', 'userinfo_encryption_enc_values_supported', @@ -101,33 +102,47 @@ def test_create_client(): # assert _context.get_preference("userinfo_signing_alg_values_supported") == ['ES256'] # How to act - _context.map_preferred_to_register() + _context.map_preferred_to_registered() assert _context.get_usage("request_uris") is None _conf_args = list(_context.collect_usage().keys()) assert _conf_args - assert len(_conf_args) == 21 + assert len(_conf_args) == 23 rr = set(RegistrationRequest.c_param.keys()) # The ones that are not defined d = rr.difference(set(_conf_args)) - assert d == {'initiate_login_uri', 'client_name', 'post_logout_redirect_uri', 'tos_uri', - 'logo_uri', 'jwks_uri', 'federation_type', 'frontchannel_logout_session_required', - 'require_auth_time', 'client_uri', 'frontchannel_logout_uri', 'request_uris', - 'sector_identifier_uri', 'organization_name', 'policy_uri', - 'default_acr_values', 'userinfo_encrypted_response_alg', - 'id_token_encrypted_response_alg', 'request_object_encryption_alg', - 'userinfo_encrypted_response_enc', 'request_object_encryption_enc', - 'id_token_encrypted_response_enc'} + assert d == {'client_name', + 'client_uri', + 'default_acr_values', + 'frontchannel_logout_session_required', + 'frontchannel_logout_uri', + 'id_token_encrypted_response_alg', + 'id_token_encrypted_response_enc', + 'initiate_login_uri', + 'jwks_uri', + 'logo_uri', + 'policy_uri', + 'post_logout_redirect_uri', + 'request_object_encryption_alg', + 'request_object_encryption_enc', + 'request_uris', + 'require_auth_time', + 'sector_identifier_uri', + 'tos_uri', + 'userinfo_encrypted_response_alg', + 'userinfo_encrypted_response_enc'} def test_create_client_key_conf(): client_config = CLIENT_CONFIG.copy() - client_config.update({"key_conf": KEY_CONF}) + client_config.update({ + "key_conf": KEY_CONF, + "jwks_uri": "https://example.com/keys/jwks.json" + }) client = Entity(config=client_config, client_type='oidc') - _jwks = client.get_service_context().get_preference("jwks") - assert _jwks + assert client.get_service_context().get_preference("jwks_uri") def test_create_client_keyjar(): diff --git a/tests/test_client_04_service.py b/tests/test_client_04_service.py index 6d5191f9..5a2e1767 100644 --- a/tests/test_client_04_service.py +++ b/tests/test_client_04_service.py @@ -49,12 +49,12 @@ def test_1(self): assert self.service def test_use(self): - use = self.service_context.map_preferred_to_register() + use = self.service_context.map_preferred_to_registered() assert set(use.keys()) == {'client_id', 'redirect_uris', 'response_types', 'grant_types', 'application_type', 'jwks', 'subject_type', 'id_token_signed_response_alg', 'default_max_age', - 'request_object_signing_alg', 'callback_uris', + 'request_object_signing_alg', 'callback_uris', 'scope', 'response_modes_supported'} def test_gather_request_args(self): diff --git a/tests/test_client_06_client_authn.py b/tests/test_client_06_client_authn.py index e5628d08..67804c0b 100644 --- a/tests/test_client_06_client_authn.py +++ b/tests/test_client_06_client_authn.py @@ -71,7 +71,7 @@ def entity(): # The following two lines is necessary since they replace provider info collection and # client registration. _entity.get_service_context().map_supported_to_preferred() - _entity.get_service_context().map_preferred_to_register() + _entity.get_service_context().map_preferred_to_registered() return _entity diff --git a/tests/test_client_12_client_auth.py b/tests/test_client_12_client_auth.py index 484bf304..42e91286 100755 --- a/tests/test_client_12_client_auth.py +++ b/tests/test_client_12_client_auth.py @@ -51,7 +51,7 @@ def entity(): # The following two lines is necessary since they replace provider info collection and # client registration. entity.get_service_context().map_supported_to_preferred() - entity.get_service_context().map_preferred_to_register() + entity.get_service_context().map_preferred_to_registered() return entity diff --git a/tests/test_client_14_service_context_impexp.py b/tests/test_client_14_service_context_impexp.py index ae3526df..d0f8edbc 100644 --- a/tests/test_client_14_service_context_impexp.py +++ b/tests/test_client_14_service_context_impexp.py @@ -21,7 +21,7 @@ def test_client_info_init(): ci = ServiceContext(config=config, client_type='oidc') ci.work_condition.load_conf(config, supports=ci.supports()) ci.map_supported_to_preferred() - ci.map_preferred_to_register() + ci.map_preferred_to_registered() srvcnx = ServiceContext(base_url=BASE_URL).load(ci.dump()) diff --git a/tests/test_client_21_oidc_service.py b/tests/test_client_21_oidc_service.py index f5dab9af..6e0df3d3 100644 --- a/tests/test_client_21_oidc_service.py +++ b/tests/test_client_21_oidc_service.py @@ -91,7 +91,7 @@ def create_request(self): _context = entity.client_get("service_context") _context.issuer = "https://example.com" _context.map_supported_to_preferred() - _context.map_preferred_to_register() + _context.map_preferred_to_registered() self.context = _context self.service = entity.client_get("service", "authorization") @@ -329,7 +329,7 @@ def create_request(self): _context = entity.client_get("service_context") _context.issuer = "https://example.com" _context.map_supported_to_preferred() - _context.map_preferred_to_register() + _context.map_preferred_to_registered() self.service = entity.client_get("service", "authorization") @@ -544,7 +544,7 @@ def create_service(self): "response_types_supported": ["code"], "request_object_signing_alg_values_supported": ["ES256"], "encrypt_id_token_supported": False, # default - "token_endpoint_auth_methods_supported": ["private_key_jwt"], + "token_endpoint_auth_method": ["private_key_jwt"], "token_endpoint_auth_signing_alg_values_supported": ["ES256"], "userinfo_signing_alg_values_supported": ["ES256"], "post_logout_redirect_uris": ["https://rp.example.com/post"], @@ -793,18 +793,18 @@ def test_post_parse(self): self.service.update_service_context(resp) # static client registration - _context.map_preferred_to_register() + _context.map_preferred_to_registered() use_copy = self.service.client_get("service_context").work_condition.use.copy() # jwks content will change dynamically between runs assert 'jwks' in use_copy del use_copy['jwks'] + del use_copy['keyjar'] + del use_copy['callback_uris'] assert use_copy == {'application_type': 'web', 'backchannel_logout_session_required': True, 'backchannel_logout_uri': 'https://rp.example.com/back', - 'callback_uris': { - 'redirect_uris': {'code': ['https://example.com/cli/authz_cb']}}, 'client_id': 'client_id', 'client_secret': 'a longesh password', 'contacts': ['ops@example.org'], @@ -817,7 +817,7 @@ def test_post_parse(self): 'request_object_signing_alg': 'ES256', 'response_modes_supported': ['query', 'fragment', 'form_post'], 'response_types': ['code'], - 'scope': ['openid', 'profile', 'email', 'address', 'phone'], + 'scope': ['openid'], 'subject_type': 'public', 'token_endpoint_auth_method': 'private_key_jwt', 'token_endpoint_auth_signing_alg': 'ES256', @@ -854,19 +854,19 @@ def test_post_parse_2(self): self.service.update_service_context(resp) # static client registration - _context.map_preferred_to_register() + _context.map_preferred_to_registered() use_copy = self.service.client_get("service_context").work_condition.use.copy() # jwks content will change dynamically between runs assert 'jwks' in use_copy del use_copy['jwks'] + del use_copy['keyjar'] + del use_copy['callback_uris'] assert use_copy == { 'application_type': 'web', 'backchannel_logout_session_required': True, 'backchannel_logout_uri': 'https://rp.example.com/back', - 'callback_uris': { - 'redirect_uris': {'code': ['https://example.com/cli/authz_cb']}}, 'client_id': 'client_id', 'client_secret': 'a longesh password', 'contacts': ['ops@example.org'], @@ -879,7 +879,7 @@ def test_post_parse_2(self): 'request_object_signing_alg': 'ES256', 'response_modes_supported': ['query', 'fragment', 'form_post'], 'response_types': ['code'], - 'scope': ['openid', 'profile', 'email', 'address', 'phone'], + 'scope': ['openid'], 'subject_type': 'public', 'token_endpoint_auth_method': 'private_key_jwt', 'token_endpoint_auth_signing_alg': 'ES256', @@ -924,16 +924,37 @@ def create_request(self): def test_construct(self): _req = self.service.construct() assert isinstance(_req, RegistrationRequest) - assert len(_req) == 5 + assert set(_req.keys()) == {'application_type', + 'default_max_age', + 'grant_types', + 'id_token_signed_response_alg', + 'jwks', + 'redirect_uris', + 'request_object_signing_alg', + 'response_types', + 'subject_type', + 'token_endpoint_auth_signing_alg', + 'userinfo_signed_response_alg'} def test_config_with_post_logout(self): - self.service.client_get("service_context").work_condition.set_usage( + self.service.client_get("service_context").work_condition.set_preference( "post_logout_redirect_uri", "https://example.com/post_logout") _req = self.service.construct() assert isinstance(_req, RegistrationRequest) - assert len(_req) == 6 - assert "post_logout_redirect_uri" in _req + assert set(_req.keys()) == {'application_type', + 'default_max_age', + 'grant_types', + 'id_token_signed_response_alg', + 'jwks', + 'post_logout_redirect_uri', + 'redirect_uris', + 'request_object_signing_alg', + 'response_types', + 'subject_type', + 'token_endpoint_auth_signing_alg', + 'userinfo_signed_response_alg'} + assert "post_logout_redirect_uri" in _req.keys() def test_config_with_required_request_uri(): @@ -1228,7 +1249,7 @@ def create_request(self): _context = entity.client_get("service_context") _context.issuer = "https://example.com" _context.map_supported_to_preferred() - _context.map_preferred_to_register() + _context.map_preferred_to_registered() self.service = entity.client_get("service", "end_session") def test_construct(self): @@ -1269,7 +1290,7 @@ def test_authz_service_conf(): _context = entity.client_get("service_context") _context.issuer = "https://example.com" _context.map_supported_to_preferred() - _context.map_preferred_to_register() + _context.map_preferred_to_registered() service = entity.client_get("service", "authorization") req = service.construct() @@ -1292,7 +1313,7 @@ def test_jwks_uri_conf(): _context = entity.client_get("service_context") _context.issuer = "https://example.com" _context.map_supported_to_preferred() - _context.map_preferred_to_register() + _context.map_preferred_to_registered() assert _context.get_usage("jwks_uri") @@ -1318,6 +1339,6 @@ def test_jwks_uri_arg(): _context = entity.client_get("service_context") _context.issuer = "https://example.com" _context.map_supported_to_preferred() - _context.map_preferred_to_register() + _context.map_preferred_to_registered() assert _context.get_usage("jwks_uri") diff --git a/tests/test_client_23_pkce.py b/tests/test_client_23_pkce.py index a77d9bda..0cd4f195 100644 --- a/tests/test_client_23_pkce.py +++ b/tests/test_client_23_pkce.py @@ -70,7 +70,7 @@ def create_client(self): do_add_ons(config["add_ons"], self.entity.client_get("services")) _context = self.entity.get_service_context() _context.map_supported_to_preferred() - _context.map_preferred_to_register() + _context.map_preferred_to_registered() def test_add_code_challenge_default_values(self): auth_serv = self.entity.client_get("service", "authorization") diff --git a/tests/test_client_26_read_registration.py b/tests/test_client_26_read_registration.py index ba4fed7d..b295c500 100644 --- a/tests/test_client_26_read_registration.py +++ b/tests/test_client_26_read_registration.py @@ -39,7 +39,7 @@ def create_request(self): self.entity = Entity(config=client_config, services=services) _context = self.entity.get_service_context() _context.map_supported_to_preferred() - _context.map_preferred_to_register() + _context.map_preferred_to_registered() self.reg_service = self.entity.client_get("service", "registration") self.read_service = self.entity.client_get("service", "registration_read") diff --git a/tests/test_client_27_conversation.py b/tests/test_client_27_conversation.py index da698e21..342e11fd 100644 --- a/tests/test_client_27_conversation.py +++ b/tests/test_client_27_conversation.py @@ -140,7 +140,7 @@ def test_conversation(): "contacts": ["ops@example.org"], "redirect_uris": [f"{RP_BASEURL}/authz_cb"], "response_types": ["code"], - "scope": ["openid", "profile", "email", "address", "phone"], + "scopes_supported": ["openid", "profile", "email", "address", "phone"], "request_object_signing_alg": "ES256", "request_uris": [f"{RP_BASEURL}/requests"], "token_endpoint_auth_methods_supported": ["private_key_jwt"], @@ -523,17 +523,8 @@ def test_conversation(): assert info["url"] == "https://example.org/op/token" _qp = parse_qs(info["body"]) - assert set(_qp.keys()) == { - "grant_type", - "redirect_uri", - "state", - "code", - "client_assertion", - "client_assertion_type" - } - assert info["headers"] == { - "Content-Type": "application/x-www-form-urlencoded", - } + assert set(_qp.keys()) == {'state', 'code', 'client_id', 'grant_type', 'redirect_uri'} + assert info["headers"]["Content-Type"] == "application/x-www-form-urlencoded" # create the IdToken _jwt = JWT(OP_KEYJAR, OP_BASEURL, lifetime=3600, sign=True, sign_alg="RS256") diff --git a/tests/test_client_28_rp_handler_oidc.py b/tests/test_client_28_rp_handler_oidc.py index 5950f015..a6bfa41f 100644 --- a/tests/test_client_28_rp_handler_oidc.py +++ b/tests/test_client_28_rp_handler_oidc.py @@ -64,7 +64,7 @@ "preference": { "response_types_supported": ["code"], "scopes_supported": ["r_basicprofile", "r_emailaddress"], - "token_endpoint_auth_methods_supported": ["client_secret_post"], + "token_endpoint_auth_method": ["client_secret_post"], }, "provider_info": { "authorization_endpoint": "https://www.linkedin.com/oauth/v2/authorization", @@ -85,7 +85,7 @@ "preference": { "response_types_supported": ["code"], "scopes_supported": ["email", "public_profile"], - "token_endpoint_auth_methods_supported": [], + "token_endpoint_auth_method": [], }, "redirect_uris": ["{}/authz_cb/facebook".format(BASE_URL)], "provider_info": { @@ -113,7 +113,7 @@ "preference": { "response_types_supported": ["code"], "scopes_supported": ["user", "public_repo", 'openid'], - "token_endpoint_auth_methods_supported": [], + "token_endpoint_auth_method": [], "verify_args": {"allow_sign_alg_none": True}, }, "provider_info": { @@ -252,7 +252,8 @@ def test_init_client(self): _pref = [k for k, v in _context.prefers().items() if v] assert set(_pref) == {'jwks', 'client_id', 'client_secret', 'redirect_uris', - 'response_types_supported', 'callback_uris'} + 'response_types_supported', 'callback_uris', 'scopes_supported', + 'keyjar'} _github_id = iss_id("github") _context.keyjar.import_jwks(GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) @@ -327,7 +328,7 @@ def test_create_callbacks(self): assert set(cb['redirect_uris'].keys()) == {'code'} _hash = _context.iss_hash - assert cb['redirect_uris']["code"] == f"https://example.com/rp/authz_cb/{_hash}" + assert cb['redirect_uris']["code"] == [f"https://example.com/rp/authz_cb/{_hash}"] assert list(self.rph.hash2issuer.keys()) == [_hash] @@ -358,10 +359,11 @@ def test_begin(self): } # nonce and state are created on the fly so can't check for those + # that all values are lists is a parse_qs artifact. assert query["client_id"] == ["eeeeeeeee"] assert query["redirect_uri"] == ["https://example.com/rp/authz_cb/github"] assert query["response_type"] == ["code"] - assert query["scope"] == ["user public_repo"] + assert set(query["scope"][0].split(' ')) == {"openid", "user", "public_repo"} def test_get_session_information(self): res = self.rph.begin(issuer_id="github") @@ -398,7 +400,7 @@ def test_get_client_authn_method(self): _session = self.rph.get_session_information(res["state"]) client = self.rph.issuer2rp[_session["iss"]] authn_method = self.rph.get_client_authn_method(client, "token_endpoint") - assert authn_method == "" + assert authn_method == '' res = self.rph.begin(issuer_id="linkedin") _session = self.rph.get_session_information(res["state"]) diff --git a/tests/test_client_30_rph_defaults.py b/tests/test_client_30_rph_defaults.py index 2fdaa79a..b2406862 100644 --- a/tests/test_client_30_rph_defaults.py +++ b/tests/test_client_30_rph_defaults.py @@ -1,9 +1,9 @@ from urllib.parse import parse_qs from urllib.parse import urlparse +from cryptojwt.key_jar import build_keyjar import pytest import responses -from cryptojwt.key_jar import build_keyjar from idpyoidc.client.defaults import DEFAULT_KEY_DEFS from idpyoidc.client.rp_handler import RPHandler @@ -35,15 +35,19 @@ def test_init_client(self): _context = client.client_get("service_context") - assert set(_context.config.conf["metadata"].keys()) == { - "application_type", - "response_types", - "token_endpoint_auth_method" - } - assert _context.config.conf["usage"] == { - "scope": ["openid"], - "jwks_uri": True - } + assert set(_context.work_condition.prefer.keys()) == { + 'application_type', + 'callback_uris', + 'id_token_encryption_alg_values_supported', + 'id_token_encryption_enc_values_supported', + 'jwks_uri', + 'redirect_uris', + 'request_object_encryption_alg_values_supported', + 'request_object_encryption_enc_values_supported', + 'scopes_supported', + 'token_endpoint_auth_method', + 'userinfo_encryption_alg_values_supported', + 'userinfo_encryption_enc_values_supported'} assert list(_context.keyjar.owners()) == ["", BASE_URL] keys = _context.keyjar.get_issuer_keys("") @@ -91,17 +95,25 @@ def test_begin(self): self.rph.issuer2rp[issuer] = client - assert set(_context.work_condition.use.keys()) == { - "token_endpoint_auth_method", - "response_types", - "scope", - "application_type", - 'redirect_uris', - 'id_token_signed_response_alg', - 'grant_types' - } + assert set(_context.work_condition.use.keys()) == {'application_type', + 'callback_uris', + 'client_id', + 'client_secret', + 'default_max_age', + 'grant_types', + 'id_token_signed_response_alg', + 'jwks_uri', + 'redirect_uris', + 'request_object_signing_alg', + 'response_modes_supported', + 'response_types', + 'scope', + 'subject_type', + 'token_endpoint_auth_method', + 'token_endpoint_auth_signing_alg', + 'userinfo_signed_response_alg'} assert _context.get_client_id() == "client uno" - assert _context.get("client_secret") == "VerySecretAndLongEnough" + assert _context.get_usage("client_secret") == "VerySecretAndLongEnough" assert _context.get("issuer") == ISS_ID res = self.rph.init_authorization(client) @@ -156,4 +168,4 @@ def test_begin_2(self): rsps.add("POST", request_uri, body=_jws, status=200) self.rph.do_client_registration(client, ISS_ID) - assert "jwks" in _context.get("registration_response") + assert "jwks_uri" in _context.get("registration_response") diff --git a/tests/test_client_41_rp_handler_persistent.py b/tests/test_client_41_rp_handler_persistent.py index 3c0b9804..dcf6b22c 100644 --- a/tests/test_client_41_rp_handler_persistent.py +++ b/tests/test_client_41_rp_handler_persistent.py @@ -106,7 +106,7 @@ "redirect_uris": ["{}/authz_cb/github".format(BASE_URL)], "preference": { "response_types": ["code"], - "scope": ["user", "public_repo"], + "scopes_supported": ["user", "public_repo"], "token_endpoint_auth_method": "", "verify_args": {"allow_sign_alg_none": True}, }, @@ -229,7 +229,7 @@ def test_do_client_registration(self): # only 2 things should have happened assert rph_1.hash2issuer["github"] == issuer - assert not client.client_get("service_context").callback.get("post_logout_redirect_uris") + assert not client.client_get("service_context").get_usage("post_logout_redirect_uris") def test_do_client_setup(self): rph_1 = RPHandler( @@ -241,7 +241,7 @@ def test_do_client_setup(self): _context = client.client_get("service_context") assert _context.get_client_id() == "eeeeeeeee" - assert _context.get("client_secret") == "aaaaaaaaaaaaaaaaaaaa" + assert _context.get_usage("client_secret") == "aaaaaaaaaaaaaaaaaaaa" assert _context.get("issuer") == _github_id _context.keyjar.import_jwks(GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) @@ -313,7 +313,7 @@ def test_get_client_from_session_key(self): # redo rph_1.do_provider_info(state=res["state"]) # get new redirect_uris - cli2.client_get("service_context").work_condition.metadata["redirect_uris"] = [] + cli2.client_get("service_context").set_usage("redirect_uris", []) rph_1.do_client_registration(state=res["state"]) def test_finalize_auth(self): From 5062006100338fe39795f589fa31c0a64d0e7c58 Mon Sep 17 00:00:00 2001 From: Roland Hedberg Date: Sun, 27 Nov 2022 09:28:53 +0100 Subject: [PATCH 21/76] Cleaned up code and removed keyjar from work_condition. --- src/idpyoidc/client/entity.py | 29 ++----- src/idpyoidc/client/oauth2/__init__.py | 3 - src/idpyoidc/client/rp_handler.py | 2 - src/idpyoidc/client/service_context.py | 78 +++++++++---------- .../client/work_condition/__init__.py | 10 ++- src/idpyoidc/client/work_condition/oidc.py | 3 - src/idpyoidc/context.py | 36 +-------- tests/request123456.jwt | 2 +- tests/test_client_01_service_context.py | 14 +--- tests/test_client_02b_entity_metadata.py | 12 +-- tests/test_client_04_service.py | 18 +++-- tests/test_client_20_oauth2.py | 2 +- tests/test_client_21_oidc_service.py | 76 ++++-------------- tests/test_client_28_rp_handler_oidc.py | 9 +-- tests/test_client_41_rp_handler_persistent.py | 10 +-- 15 files changed, 101 insertions(+), 203 deletions(-) diff --git a/src/idpyoidc/client/entity.py b/src/idpyoidc/client/entity.py index 977308b8..f9b3a7f6 100644 --- a/src/idpyoidc/client/entity.py +++ b/src/idpyoidc/client/entity.py @@ -77,7 +77,6 @@ def __init__( keyjar: Optional[KeyJar] = None, config: Optional[Union[dict, Configuration]] = None, services: Optional[dict] = None, - jwks_uri: Optional[str] = "", httpc_params: Optional[dict] = None, client_type: Optional[str] = "oauth2" ): @@ -89,16 +88,6 @@ def __init__( config = get_configuration(config) - if keyjar: - _kj = keyjar.copy() - else: - _kj = None - - self._service_context = ServiceContext( - keyjar=keyjar, config=config, jwks_uri=jwks_uri, httpc_params=self.httpc_params, - client_type=client_type, client_get=self.client_get - ) - if config: _srvs = config.conf.get("services") else: @@ -114,22 +103,18 @@ def __init__( self._service = init_services(service_definitions=_srvs, client_get=self.client_get) - self.setup_client_authn_methods(config) + self._service_context = ServiceContext( + keyjar=keyjar, config=config, httpc_params=self.httpc_params, + client_type=client_type, client_get=self.client_get + ) - jwks_uri = jwks_uri or self._service_context.get("jwks_uri") - set_jwks_uri_or_jwks(self._service_context, config, jwks_uri, self._service_context.keyjar) + self.keyjar = self._service_context.get_preference('keyjar') + + self.setup_client_authn_methods(config) # Deal with backward compatibility self.backward_compatibility(config) - self._service_context.work_condition.load_conf(config.conf, - supports=self._service_context.supports()) - - _response_types = self._service_context.get_preference( - 'response_types_supported', - self._service_context.supports().get('response_types_supported', [])) - - self._service_context.construct_uris(response_types=_response_types) def client_get(self, what, *arg): _func = getattr(self, "get_{}".format(what), None) diff --git a/src/idpyoidc/client/oauth2/__init__.py b/src/idpyoidc/client/oauth2/__init__.py index 2a2a7125..4eb60c0f 100755 --- a/src/idpyoidc/client/oauth2/__init__.py +++ b/src/idpyoidc/client/oauth2/__init__.py @@ -39,7 +39,6 @@ def __init__( config=None, httplib=None, services=None, - jwks_uri="", httpc_params=None, client_type: Optional[str] = "" ): @@ -53,7 +52,6 @@ def __init__( initialization :param httplib: A HTTP client to use :param services: A list of service definitions - :param jwks_uri: A jwks_uri :param httpc_params: HTTP request arguments :return: Client instance """ @@ -66,7 +64,6 @@ def __init__( keyjar=keyjar, config=config, services=services, - jwks_uri=jwks_uri, httpc_params=httpc_params, client_type=client_type ) diff --git a/src/idpyoidc/client/rp_handler.py b/src/idpyoidc/client/rp_handler.py index 1ee5f4a6..1e821e8a 100644 --- a/src/idpyoidc/client/rp_handler.py +++ b/src/idpyoidc/client/rp_handler.py @@ -188,8 +188,6 @@ def init_client(self, issuer): if self.jwks_uri: _cnf['jwks_uri'] = self.jwks_uri - elif self.jwks: - _cnf['jwks'] = self.jwks try: client = self.client_cls( diff --git a/src/idpyoidc/client/service_context.py b/src/idpyoidc/client/service_context.py index 8d94bc11..0142cf5f 100644 --- a/src/idpyoidc/client/service_context.py +++ b/src/idpyoidc/client/service_context.py @@ -18,7 +18,6 @@ from idpyoidc.client.configure import Configuration from idpyoidc.client.work_condition.oauth2 import WorkCondition as OAUTH2_Specs from idpyoidc.client.work_condition.oidc import WorkCondition as OIDC_Specs -from idpyoidc.context import OidcContext from idpyoidc.util import rndstr from .configure import get_configuration from .state_interface import StateInterface @@ -27,6 +26,7 @@ from .work_condition import work_condition_load from .work_condition.transform import preferred_to_registered from .work_condition.transform import supported_to_preferred +from ..impexp import ImpExp logger = logging.getLogger(__name__) @@ -77,7 +77,7 @@ } -class ServiceContext(OidcContext): +class ServiceContext(ImpExp): """ This class keeps information that a client needs to be able to talk to a server. Some of this information comes from configuration and some @@ -85,30 +85,28 @@ class ServiceContext(OidcContext): But information is also picked up during the conversation with a server. """ - parameter = OidcContext.parameter.copy() - parameter.update( - { - "add_on": None, - "allow": None, - "args": None, - "base_url": None, - "behaviour": None, - "client_secret_expires_at": 0, - "clock_skew": None, - "config": None, - "hash_seed": b"", - "httpc_params": None, - "iss_hash": None, - "issuer": None, - "work_condition": WorkCondition, - "provider_info": None, - "requests_dir": None, - "registration_response": None, - "state": StateInterface, - 'usage': None, - "verify_args": None, - } - ) + parameter = { + "add_on": None, + "allow": None, + "args": None, + "base_url": None, + # "behaviour": None, + # "client_secret_expires_at": 0, + "clock_skew": None, + "config": None, + "hash_seed": b"", + "httpc_params": None, + "iss_hash": None, + "issuer": None, + 'keyjar': KeyJar, + "work_condition": WorkCondition, + "provider_info": None, + "requests_dir": None, + "registration_response": None, + "state": StateInterface, + # 'usage': None, + "verify_args": None, + } special_load_dump = { "specs": {"load": work_condition_load, "dump": work_condition_dump}, @@ -122,6 +120,7 @@ def __init__(self, state: Optional[StateInterface] = None, client_type: Optional[str] = 'oauth2', **kwargs): + ImpExp.__init__(self) config = get_configuration(config) self.config = config self.client_get = client_get @@ -133,34 +132,26 @@ def __init__(self, else: raise ValueError(f"Unknown client type: {client_type}") - OidcContext.__init__(self, config, keyjar, entity_id=config.conf.get("client_id", "")) + self.entity_id = config.conf.get("client_id", "") self.state = state or StateInterface() self.kid = {"sig": {}, "enc": {}} - self.base_url = base_url or config.get("base_url") or config.conf.get('base_url', '') - # Below so my IDE won't complain self.allow = config.conf.get('allow', {}) + self.base_url = base_url or config.get("base_url", "") + self.provider_info = config.conf.get("provider_info", {}) + + # Below so my IDE won't complain self.args = {} self.add_on = {} self.iss_hash = "" self.issuer = "" self.httpc_params = {} self.client_secret_expires_at = 0 - self.provider_info = {} - # self.post_logout_redirect_uri = "" - # self.redirect_uris = [] self.registration_response = {} - # self.requests_dir = "" _def_value = copy.deepcopy(DEFAULT_VALUE) - _val = config.conf.get("client_secret") - if _val: - self.keyjar.add_symmetric("", _val) - - self.provider_info = config.conf.get("provider_info", {}) - _issuer = config.get("issuer") if _issuer: self.issuer = _issuer @@ -175,6 +166,15 @@ def __init__(self, for key, val in kwargs.items(): setattr(self, key, val) + self.keyjar = self.work_condition.load_conf(config.conf, supports=self.supports(), + keyjar=keyjar) + + _response_types = self.get_preference( + 'response_types_supported', + self.supports().get('response_types_supported', [])) + + self.construct_uris(response_types=_response_types) + def __setitem__(self, key, value): setattr(self, key, value) diff --git a/src/idpyoidc/client/work_condition/__init__.py b/src/idpyoidc/client/work_condition/__init__.py index df7452ef..e88450bf 100644 --- a/src/idpyoidc/client/work_condition/__init__.py +++ b/src/idpyoidc/client/work_condition/__init__.py @@ -162,7 +162,7 @@ def handle_keys(self, configuration: dict, keyjar: Optional[KeyJar] = None): return {'keyjar': keyjar, 'jwks': _jwks, 'jwks_uri': _jwks_uri} - def load_conf(self, configuration, supports): + def load_conf(self, configuration, supports, keyjar: Optional[KeyJar] = None): for attr, val in configuration.items(): if attr == "preference": for k, v in val.items(): @@ -173,12 +173,14 @@ def load_conf(self, configuration, supports): self.locals(configuration) - for key, val in self.handle_keys(configuration).items(): - if val: + for key, val in self.handle_keys(configuration, keyjar=keyjar).items(): + if key == 'keyjar': + keyjar = val + elif val: self.set_preference(key, val) self.verify_rules() - return self + return keyjar def get(self, key, default=None): if key in self._local: diff --git a/src/idpyoidc/client/work_condition/oidc.py b/src/idpyoidc/client/work_condition/oidc.py index e949c4ea..c0acafc0 100644 --- a/src/idpyoidc/client/work_condition/oidc.py +++ b/src/idpyoidc/client/work_condition/oidc.py @@ -1,9 +1,6 @@ import os from typing import Optional -from cryptojwt import KeyJar -from cryptojwt.key_jar import init_key_jar - from idpyoidc.client import work_condition diff --git a/src/idpyoidc/context.py b/src/idpyoidc/context.py index d7af05ec..e0fa61b5 100644 --- a/src/idpyoidc/context.py +++ b/src/idpyoidc/context.py @@ -1,9 +1,6 @@ import copy from urllib.parse import quote_plus -from cryptojwt import KeyJar -from cryptojwt.key_jar import init_key_jar - from idpyoidc.impexp import ImpExp @@ -20,35 +17,8 @@ def add_issuer(conf, issuer): class OidcContext(ImpExp): - parameter = {"keyjar": KeyJar, "issuer": None} + parameter = {"issuer": None} - def __init__(self, config=None, keyjar=None, entity_id=""): + def __init__(self, config=None, entity_id=""): ImpExp.__init__(self) - if config is None: - config = {} - self.keyjar = self._keyjar(keyjar, conf=config, entity_id=entity_id) - - def _keyjar(self, keyjar=None, conf=None, entity_id=""): - if keyjar is None: - if "keys" in conf: - keys_args = {k: v for k, v in conf["keys"].items() if k != "uri_path"} - _keyjar = init_key_jar(**keys_args) - elif "key_conf" in conf and conf["key_conf"]: - keys_args = {k: v for k, v in conf["key_conf"].items() if k != "uri_path"} - _keyjar = init_key_jar(**keys_args) - else: - _keyjar = KeyJar() - if "jwks" in conf: - _keyjar.import_jwks(conf["jwks"], "") - - if "" in _keyjar and entity_id: - # make sure I have the keys under my own name too (if I know it) - _keyjar.import_jwks_as_json(_keyjar.export_jwks_as_json(True, ""), entity_id) - - _httpc_params = conf.get("httpc_params") - if _httpc_params: - _keyjar.httpc_params = _httpc_params - - return _keyjar - else: - return keyjar + self.entity_id = entity_id or config.get('client_id') diff --git a/tests/request123456.jwt b/tests/request123456.jwt index 826a4289..2d059fc3 100644 --- a/tests/request123456.jwt +++ b/tests/request123456.jwt @@ -1 +1 @@ -eyJhbGciOiJSUzI1NiIsImtpZCI6IlNIRXlZV2N3TlZrMExUZFJPVFp6WjJGVVduZElWWGRhY2sweFdVTTVTRXB3Y1MwM2RWVXhXVTR6UlEifQ.eyJyZXNwb25zZV90eXBlIjogImNvZGUiLCAic3RhdGUiOiAic3RhdGUiLCAicmVkaXJlY3RfdXJpIjogImh0dHBzOi8vZXhhbXBsZS5jb20vY2xpL2F1dGh6X2NiIiwgInNjb3BlIjogIm9wZW5pZCIsICJub25jZSI6ICJjNENQb1N0Q1BnY0hhOHVySG1kYk43d2V1bnZteUVKTk0wZ2oyUmxvY01jIiwgImNsaWVudF9pZCI6ICJjbGllbnRfaWQiLCAiaXNzIjogImNsaWVudF9pZCIsICJpYXQiOiAxNjY5MjgxODA3LCAiYXVkIjogWyJodHRwczovL2V4YW1wbGUuY29tIl19.BGBFSfVc5TNAIU2_4Z2QE4tfRPt-IVZwV_4zUGQWawmsDJZbo1dc7NZZjsXyI_1gHsG7mn7utylqtS8Q-jStg3ikyn886eON6UUIbf9YX--rkl3D8_es9CLnFL0jbAOX2sl573ujMBG-IlEgTLyIgzOgibLfJau0x-JOlQLP9l-dFqFxaE9LeeiPfm2sa7Y6HLk9BkjW7rC2UsrTLuf5PTYwriF7n9IXcCquZbNcbXdk2xA9Oy3ozvxrpMlxqA45peNTK9d97hE5cfJ6sB0vrouDPJ3x_o36AXbGx82PVL8Ce0ZCRXytLZqEtS21nwJiu_nadvvg8fodt81Gm8IOmQ \ No newline at end of file +eyJhbGciOiJSUzI1NiIsImtpZCI6IlNIRXlZV2N3TlZrMExUZFJPVFp6WjJGVVduZElWWGRhY2sweFdVTTVTRXB3Y1MwM2RWVXhXVTR6UlEifQ.eyJyZXNwb25zZV90eXBlIjogImNvZGUiLCAic3RhdGUiOiAic3RhdGUiLCAicmVkaXJlY3RfdXJpIjogImh0dHBzOi8vZXhhbXBsZS5jb20vY2xpL2F1dGh6X2NiIiwgInNjb3BlIjogIm9wZW5pZCIsICJub25jZSI6ICJhbmZGdkFnZ0h5Z3JPSkhCdjctUDllNE5xaFF0MmpIcDlaRXlvU1V6S3VJIiwgImNsaWVudF9pZCI6ICJjbGllbnRfaWQiLCAiaXNzIjogImNsaWVudF9pZCIsICJpYXQiOiAxNjY5NTM2NTEwLCAiYXVkIjogWyJodHRwczovL2V4YW1wbGUuY29tIl19.d07N--8b-wxfCWIdGtZaVhRGCTTUpsNhu4OAiQrNHx3PSGbbicyoEzJLWgEjH1oAdD-d63iu8ak-C_47Ve1kMewBZ1MdiN4GsOqxvL2fX0WfuHHR0A1ui5Ag5ciWDdlSE7l7G4d1G7FVFtqRAlEt4Hwe1MsPoFKHgDgYuOrPWi1As2SDsOYnmuySFXdQqSh-wLMsPoXUMGQAzEKLsTC2ZnOtNWawZIOnYO74f8LSYLBxnlopI027AihLIsqOR4rxbVv3fX_okRz9iB9IxTCCvAc3UsSrVXeCdWhEFGK6SdznOCSHR4JftVRV7CGqDezn-U9Uwk71p7ggNltEOfKUEQ \ No newline at end of file diff --git a/tests/test_client_01_service_context.py b/tests/test_client_01_service_context.py index 7d6a8baa..97c860b6 100644 --- a/tests/test_client_01_service_context.py +++ b/tests/test_client_01_service_context.py @@ -14,7 +14,7 @@ "base_url": "https://example.com/cli", "key_conf": {"key_defs": KEYDEFS}, "issuer": "https://op.example.com", - "metadata": { + "preference": { "response_types": ["code"] } } @@ -31,15 +31,7 @@ def test_init(self): def test_filename_from_webname(self): _filename = self.service_context.filename_from_webname("https://example.com/cli/jwks.json") - assert _filename == "jwks.json" - - # def test_create_callback_uris(self): - # base_url = "https://example.com/cli" - # hex = "0123456789" - # self.service_context.work_condition.construct_redirect_uris(base_url, hex, []) - # _uris = self.service_context.work_condition.get_metadata_claim("redirect_uris") - # assert len(_uris) == 1 - # assert _uris == [f"https://example.com/cli/authz_cb/{hex}"] + assert _filename == 'jwks.json' def test_get_sign_alg(self): _alg = self.service_context.get_sign_alg("id_token") @@ -83,7 +75,7 @@ def test_get_enc_alg_enc(self): assert _alg_enc == {"alg": ["RSA1_5", "A128KW"], "enc": ["A128CBC+HS256", "A128GCM"]} def test_get(self): - assert self.service_context.get("base_url") == MINI_CONFIG["base_url"] + assert self.service_context.base_url == MINI_CONFIG["base_url"] def test_set(self): self.service_context.set_preference("client_id", "number5") diff --git a/tests/test_client_02b_entity_metadata.py b/tests/test_client_02b_entity_metadata.py index 9de8d198..8c9bb5e7 100644 --- a/tests/test_client_02b_entity_metadata.py +++ b/tests/test_client_02b_entity_metadata.py @@ -7,14 +7,14 @@ CLIENT_CONFIG = { "base_url": "https://example.com/cli", "client_secret": "a longesh password", + "client_id": "client_id", + "redirect_uris": ["https://example.com/cli/authz_cb"], "issuer": ISS, "application_name": "rphandler", "preference": { "application_type": "web", "contacts": "support@example.com", "response_types": ["code"], - "client_id": "client_id", - "redirect_uris": ["https://example.com/cli/authz_cb"], 'request_parameter': "request_uri", "request_object_signing_alg_values_supported": ["ES256"], "scope": ["openid", "profile", "email", "address", "phone"], @@ -78,8 +78,6 @@ def test_create_client(): 'id_token_encryption_alg_values_supported', 'id_token_encryption_enc_values_supported', 'id_token_signing_alg_values_supported', - 'jwks', - 'keyjar', 'post_logout_redirect_uris', 'redirect_uris', 'request_object_encryption_alg_values_supported', @@ -108,7 +106,7 @@ def test_create_client(): _conf_args = list(_context.collect_usage().keys()) assert _conf_args - assert len(_conf_args) == 23 + assert len(_conf_args) == 21 rr = set(RegistrationRequest.c_param.keys()) # The ones that are not defined d = rr.difference(set(_conf_args)) @@ -120,6 +118,7 @@ def test_create_client(): 'id_token_encrypted_response_alg', 'id_token_encrypted_response_enc', 'initiate_login_uri', + 'jwks', 'jwks_uri', 'logo_uri', 'policy_uri', @@ -156,5 +155,6 @@ def test_create_client_keyjar(): def test_create_client_jwks_uri(): client_config = CLIENT_CONFIG.copy() - client = Entity(config=client_config, jwks_uri="https://rp.example.com/jwks_uri.json") + client_config['jwks_uri'] = "https://rp.example.com/jwks_uri.json" + client = Entity(config=client_config) assert client.get_service_context().get_preference("jwks_uri") diff --git a/tests/test_client_04_service.py b/tests/test_client_04_service.py index 5a2e1767..f4112df0 100644 --- a/tests/test_client_04_service.py +++ b/tests/test_client_04_service.py @@ -51,11 +51,19 @@ def test_1(self): def test_use(self): use = self.service_context.map_preferred_to_registered() - assert set(use.keys()) == {'client_id', 'redirect_uris', 'response_types', - 'grant_types', 'application_type', 'jwks', 'subject_type', - 'id_token_signed_response_alg', 'default_max_age', - 'request_object_signing_alg', 'callback_uris', 'scope', - 'response_modes_supported'} + assert set(use.keys()) == {'application_type', + 'callback_uris', + 'client_id', + 'default_max_age', + 'grant_types', + 'id_token_signed_response_alg', + 'jwks', + 'redirect_uris', + 'request_object_signing_alg', + 'response_modes_supported', + 'response_types', + 'scope', + 'subject_type'} def test_gather_request_args(self): self.service.conf["request_args"] = {"response_type": "code"} diff --git a/tests/test_client_20_oauth2.py b/tests/test_client_20_oauth2.py index f0dea234..a6c3fae0 100644 --- a/tests/test_client_20_oauth2.py +++ b/tests/test_client_20_oauth2.py @@ -202,6 +202,6 @@ def test_keyjar(self): } _context = self.client.client_get("service_context") - assert len(_context.keyjar) == 1 # one issuer + assert len(_context.keyjar) == 2 # one issuer assert len(_context.keyjar[""]) == 2 assert len(_context.keyjar.get("sig")) == 2 diff --git a/tests/test_client_21_oidc_service.py b/tests/test_client_21_oidc_service.py index 6e0df3d3..3c2ff767 100644 --- a/tests/test_client_21_oidc_service.py +++ b/tests/test_client_21_oidc_service.py @@ -473,59 +473,6 @@ def test_id_token_nonce_match(self): self.service.update_service_context(resp, key="state2") -SERVICES = { - "discovery": { - "class": "idpyoidc.client.oidc.provider_info_discovery.ProviderInfoDiscovery", - "kwargs": {} - }, - "registration": { - "class": "idpyoidc.client.oidc.registration.Registration", - "kwargs": {} - }, - "authorization": { - "class": "idpyoidc.client.oidc.authorization.Authorization", - "kwargs": { - "metadata": { - "request_object_signing_alg": "ES256" - }, - "usage": { - "request_uri": True - } - } - }, - "accesstoken": { - "class": "idpyoidc.client.oidc.access_token.AccessToken", - "kwargs": { - "conf": { - "token_endpoint_auth_method": "private_key_jwt", - "token_endpoint_auth_signing_alg": "ES256" - } - } - }, - "userinfo": { - "class": "idpyoidc.client.oidc.userinfo.UserInfo", - "kwargs": { - "conf": { - "userinfo_signed_response_alg": "ES256" - }, - } - }, - "end_session": { - "class": "idpyoidc.client.oidc.end_session.EndSession", - "kwargs": { - "conf": { - "post_logout_redirect_uri": "https://rp.example.com/post", - "backchannel_logout_uri": "https://rp.example.com/back", - "backchannel_logout_session_required": True - }, - "usage": { - "backchannel_logout": True - } - } - } -} - - class TestProviderInfo(object): @pytest.fixture(autouse=True) @@ -799,7 +746,6 @@ def test_post_parse(self): # jwks content will change dynamically between runs assert 'jwks' in use_copy del use_copy['jwks'] - del use_copy['keyjar'] del use_copy['callback_uris'] assert use_copy == {'application_type': 'web', @@ -860,7 +806,6 @@ def test_post_parse_2(self): # jwks content will change dynamically between runs assert 'jwks' in use_copy del use_copy['jwks'] - del use_copy['keyjar'] del use_copy['callback_uris'] assert use_copy == { @@ -985,6 +930,7 @@ def test_config_with_required_request_uri(): 'token_endpoint_auth_signing_alg', 'userinfo_signed_response_alg'} + def test_config_logout_uri(): client_config = { "client_id": "client_id", @@ -995,10 +941,18 @@ def test_config_logout_uri(): "request_uris": ["https://example.com/cli/requests"], "base_url": "https://example.com/cli/", "preference": { - "request_parameter": "request_uri" + "request_parameter": "request_uri", + "request_object_signing_alg": "ES256", + "token_endpoint_auth_method": "private_key_jwt", + "token_endpoint_auth_signing_alg": "ES256", + "userinfo_signed_response_alg": "ES256", + "post_logout_redirect_uri": "https://rp.example.com/post", + "backchannel_logout_uri": "https://rp.example.com/back", + "backchannel_logout_session_required": True, + 'backchannel_logout_supported': True } } - entity = Entity(keyjar=make_keyjar(), config=client_config, services=SERVICES, + entity = Entity(keyjar=make_keyjar(), config=client_config, services=DEFAULT_OIDC_SERVICES, client_type='oidc') _context = entity.client_get("service_context") _context.issuer = "https://example.com" @@ -1322,9 +1276,10 @@ def test_jwks_uri_arg(): client_config = { "client_secret": "a longesh password", "issuer": ISS, - "metadata": { - "client_id": "client_id", - "redirect_uris": ["https://example.com/cli/authz_cb"], + "client_id": "client_id", + "redirect_uris": ["https://example.com/cli/authz_cb"], + "jwks_uri": "https://example.com/jwks/jwks.json", + "preference": { "id_token_signed_response_alg": "RS384", "userinfo_signed_response_alg": "RS384", }, @@ -1332,7 +1287,6 @@ def test_jwks_uri_arg(): entity = Entity( keyjar=make_keyjar(), config=client_config, - jwks_uri="https://example.com/jwks/jwks.json", services=DEFAULT_OIDC_SERVICES, client_type='oidc' ) diff --git a/tests/test_client_28_rp_handler_oidc.py b/tests/test_client_28_rp_handler_oidc.py index a6bfa41f..a2569858 100644 --- a/tests/test_client_28_rp_handler_oidc.py +++ b/tests/test_client_28_rp_handler_oidc.py @@ -251,9 +251,8 @@ def test_init_client(self): } _pref = [k for k, v in _context.prefers().items() if v] - assert set(_pref) == {'jwks', 'client_id', 'client_secret', 'redirect_uris', - 'response_types_supported', 'callback_uris', 'scopes_supported', - 'keyjar'} + assert set(_pref) == {'client_id', 'client_secret', 'redirect_uris', + 'response_types_supported', 'callback_uris', 'scopes_supported'} _github_id = iss_id("github") _context.keyjar.import_jwks(GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) @@ -262,7 +261,7 @@ def test_init_client(self): # secret. 2 because one is marked for encryption and the other signing # usage. - assert list(_context.keyjar.owners()) == ["", _github_id] + assert set(_context.keyjar.owners()) == {"", 'eeeeeeeee', _github_id} keys = _context.keyjar.get_issuer_keys("") assert len(keys) == 2 @@ -306,7 +305,7 @@ def test_do_client_setup(self): _context.keyjar.import_jwks(GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) - assert list(_context.keyjar.owners()) == ["", _github_id] + assert set(_context.keyjar.owners()) == {"", "eeeeeeeee", _github_id} keys = _context.keyjar.get_issuer_keys("") assert len(keys) == 2 diff --git a/tests/test_client_41_rp_handler_persistent.py b/tests/test_client_41_rp_handler_persistent.py index dcf6b22c..b858e6f1 100644 --- a/tests/test_client_41_rp_handler_persistent.py +++ b/tests/test_client_41_rp_handler_persistent.py @@ -12,7 +12,7 @@ BASE_URL = "https://example.com/rp" -METADATA = { +PREFERENCE = { "application_type": "web", "contacts": ["ops@example.com"], "response_types": [ @@ -24,17 +24,13 @@ "code token", ], "token_endpoint_auth_method": "client_secret_basic", -} - -USAGE = { "scope": ["openid", "profile", "email", "address", "phone"], "verify_args": {"allow_sign_alg_none": True}, } CLIENT_CONFIG = { "": { - "metadata": METADATA, - "usage": USAGE, + "preference": PREFERENCE, "redirect_uris": None, "services": { "web_finger": {"class": "idpyoidc.client.oidc.webfinger.WebFinger"}, @@ -246,7 +242,7 @@ def test_do_client_setup(self): _context.keyjar.import_jwks(GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) - assert list(_context.keyjar.owners()) == ["", _github_id] + assert set(_context.keyjar.owners()) == {"", 'eeeeeeeee', _github_id} keys = _context.keyjar.get_issuer_keys("") assert len(keys) == 2 From 2a60d20d87429d73effd297f5fe601ec247fc70b Mon Sep 17 00:00:00 2001 From: Roland Hedberg Date: Tue, 29 Nov 2022 20:18:57 +0100 Subject: [PATCH 22/76] Replaced StateInterface with Current a much simpler state manager. --- .../actor/client/oidc/registration.py | 4 +- src/idpyoidc/client/client_auth.py | 13 +- src/idpyoidc/client/entity.py | 18 +- src/idpyoidc/client/oauth2/__init__.py | 7 +- src/idpyoidc/client/oauth2/access_token.py | 18 +- .../oauth2/add_on/identity_assurance.py | 5 +- src/idpyoidc/client/oauth2/add_on/pkce.py | 6 +- src/idpyoidc/client/oauth2/authorization.py | 9 +- .../client_credentials/cc_access_token.py | 6 +- .../cc_refresh_access_token.py | 16 +- .../client/oauth2/refresh_access_token.py | 15 +- src/idpyoidc/client/oauth2/token_exchange.py | 17 +- src/idpyoidc/client/oidc/access_token.py | 16 +- src/idpyoidc/client/oidc/authorization.py | 31 ++- src/idpyoidc/client/oidc/check_id.py | 14 +- src/idpyoidc/client/oidc/check_session.py | 13 +- src/idpyoidc/client/oidc/end_session.py | 18 +- .../client/oidc/provider_info_discovery.py | 12 +- .../client/oidc/refresh_access_token.py | 4 +- src/idpyoidc/client/oidc/registration.py | 4 +- src/idpyoidc/client/oidc/userinfo.py | 38 +--- src/idpyoidc/client/provider/github.py | 4 +- src/idpyoidc/client/provider/linkedin.py | 4 +- src/idpyoidc/client/rp_handler.py | 207 ++++++++---------- src/idpyoidc/client/service.py | 2 +- src/idpyoidc/client/service_context.py | 68 +++--- src/idpyoidc/client/state_interface.py | 10 +- .../__init__.py | 6 +- .../oauth2.py | 6 +- .../oidc.py | 14 +- .../transform.py | 0 tests/request123456.jwt | 2 +- tests/test_08_transform.py | 20 +- tests/test_09_work_condition.py | 42 ++-- tests/test_client_00_current.py | 92 ++++++++ tests/test_client_01_service_context.py | 10 +- tests/test_client_04_service.py | 4 +- tests/test_client_06_client_authn.py | 41 ++-- tests/test_client_12_client_auth.py | 12 +- .../test_client_14_service_context_impexp.py | 8 +- tests/test_client_20_oauth2.py | 24 +- tests/test_client_21_oidc_service.py | 62 +++--- tests/test_client_22_oidc.py | 28 +-- tests/test_client_23_pkce.py | 8 +- tests/test_client_24_oic_utils.py | 2 +- tests/test_client_25_cc_oauth2_service.py | 23 +- tests/test_client_27_conversation.py | 38 ++-- tests/test_client_28_rp_handler_oidc.py | 119 +++++----- tests/test_client_30_rph_defaults.py | 4 +- tests/test_client_31_oauth2_persistent.py | 20 +- tests/test_client_32_oidc_persistent.py | 34 +-- tests/test_client_41_rp_handler_persistent.py | 35 +-- tests/test_client_51_identity_assurance.py | 12 +- tests/test_client_55_token_exchange.py | 13 +- tests/test_tandem_10_token_exchange.py | 4 +- ...t_00_state.py => xtest_client_00_state.py} | 0 56 files changed, 630 insertions(+), 632 deletions(-) rename src/idpyoidc/client/{work_condition => work_environment}/__init__.py (98%) rename src/idpyoidc/client/{work_condition => work_environment}/oauth2.py (78%) rename src/idpyoidc/client/{work_condition => work_environment}/oidc.py (82%) rename src/idpyoidc/client/{work_condition => work_environment}/transform.py (100%) create mode 100644 tests/test_client_00_current.py rename tests/{test_client_00_state.py => xtest_client_00_state.py} (100%) diff --git a/src/idpyoidc/actor/client/oidc/registration.py b/src/idpyoidc/actor/client/oidc/registration.py index 8e196559..2c98c411 100644 --- a/src/idpyoidc/actor/client/oidc/registration.py +++ b/src/idpyoidc/actor/client/oidc/registration.py @@ -162,10 +162,10 @@ def add_client_preference(self, request_args=None, **kwargs): continue try: - request_args[prop] = _context.work_condition.get_usage(prop) + request_args[prop] = _context.work_environment.get_usage(prop) except KeyError: try: - request_args[prop] = _context.work_condition.get_preference[prop] + request_args[prop] = _context.work_environment.get_preference[prop] except KeyError: pass return request_args, {} diff --git a/src/idpyoidc/client/client_auth.py b/src/idpyoidc/client/client_auth.py index 58f73ede..a1d8f672 100755 --- a/src/idpyoidc/client/client_auth.py +++ b/src/idpyoidc/client/client_auth.py @@ -270,15 +270,9 @@ def find_token(request, token_type, service, **kwargs): try: return kwargs["access_token"] except KeyError: - # I should pick the latest acquired token, this should be the right - # order for that. + # Get the latest acquired access token. _state = kwargs.get("state", kwargs.get("key")) - _arg = service.client_get("service_context").state.multiple_extend_request_args( - {}, - _state, - ["access_token"], - ["auth_response", "token_response", "refresh_token_response"], - ) + _arg = service.client_get("service_context").cstate.get_set(_state, claim=[token_type]) return _arg.get("access_token") @@ -410,7 +404,8 @@ def get_signing_key_from_keyjar(algorithm, service_context): Pick signing key based on signing algorithm to be used :param algorithm: Signing algorithm - :param service_context: A :py:class:`idpyoidc.client.service_context.ServiceContext` instance + :param service_context: A :py:class:`idpyoidc.client.service_context.ServiceContext` + instance :return: A key """ return service_context.keyjar.get_signing_key(alg2keytype(algorithm), alg=algorithm) diff --git a/src/idpyoidc/client/entity.py b/src/idpyoidc/client/entity.py index f9b3a7f6..90a95de3 100644 --- a/src/idpyoidc/client/entity.py +++ b/src/idpyoidc/client/entity.py @@ -145,11 +145,11 @@ def get_entity(self): return self def get_client_id(self): - _val = self._service_context.work_condition.get_usage('client_id') + _val = self._service_context.work_environment.get_usage('client_id') if _val: return _val else: - return self._service_context.work_condition.get_preference('client_id') + return self._service_context.work_environment.get_preference('client_id') def setup_client_authn_methods(self, config): self._service_context.client_authn_method = client_auth_setup( @@ -157,18 +157,18 @@ def setup_client_authn_methods(self, config): ) def backward_compatibility(self, config): - _work_condition = self._service_context.work_condition + _work_environment = self._service_context.work_environment _uris = config.get("redirect_uris") if _uris: - _work_condition.set_preference("redirect_uris", _uris) + _work_environment.set_preference("redirect_uris", _uris) _dir = config.conf.get("requests_dir") if _dir: - _work_condition.set_preference('requests_dir', _dir) + _work_environment.set_preference('requests_dir', _dir) _pref = config.get("client_preferences", {}) for key, val in _pref.items(): - _work_condition.set_preference(key, val) + _work_environment.set_preference(key, val) auth_request_args = config.conf.get("request_args", {}) if auth_request_args: @@ -182,12 +182,12 @@ def config_args(self): "preference": service.supports(), } res[""] = { - "preference": self._service_context.work_condition.supports, + "preference": self._service_context.work_environment.supports, } return res def prefers(self): - return self._service_context.work_condition.prefers() + return self._service_context.work_environment.prefers() def use(self): - return self._service_context.work_condition.get_use() + return self._service_context.work_environment.get_use() diff --git a/src/idpyoidc/client/oauth2/__init__.py b/src/idpyoidc/client/oauth2/__init__.py index 4eb60c0f..c464d4a2 100755 --- a/src/idpyoidc/client/oauth2/__init__.py +++ b/src/idpyoidc/client/oauth2/__init__.py @@ -194,12 +194,7 @@ def service_request( if "error" in response: pass else: - try: - kwargs["key"] = kwargs["state"] - except KeyError: - pass - - service.update_service_context(response, **kwargs) + service.update_service_context(response, key=kwargs.get('state'), **kwargs) return response def parse_request_response(self, service, reqresp, response_body_type="", state="", **kwargs): diff --git a/src/idpyoidc/client/oauth2/access_token.py b/src/idpyoidc/client/oauth2/access_token.py index a04a864b..e8f6a076 100644 --- a/src/idpyoidc/client/oauth2/access_token.py +++ b/src/idpyoidc/client/oauth2/access_token.py @@ -1,10 +1,11 @@ """Implements the service that talks to the Access Token endpoint.""" import logging +from typing import Optional from idpyoidc.client.oauth2.utils import get_state_parameter from idpyoidc.client.service import Service -from idpyoidc.client.work_condition import get_client_authn_methods -from idpyoidc.client.work_condition import get_signing_algs +from idpyoidc.client.work_environment import get_client_authn_methods +from idpyoidc.client.work_environment import get_signing_algs from idpyoidc.message import oauth2 from idpyoidc.message.oauth2 import ResponseMessage from idpyoidc.time_util import time_sans_frac @@ -35,10 +36,11 @@ def __init__(self, client_get, conf=None): Service.__init__(self, client_get, conf=conf) self.pre_construct.append(self.oauth_pre_construct) - def update_service_context(self, resp, key="", **kwargs): + def update_service_context(self, resp, key: Optional[str] = '', **kwargs): if "expires_in" in resp: resp["__expires_at"] = time_sans_frac() + int(resp["expires_in"]) - self.client_get("service_context").state.store_item(resp, "token_response", key) + if key: + self.client_get("service_context").cstate.update(key, resp) def oauth_pre_construct(self, request_args=None, post_args=None, **kwargs): """ @@ -51,13 +53,7 @@ def oauth_pre_construct(self, request_args=None, post_args=None, **kwargs): parameters = list(self.msg_type.c_param.keys()) _context = self.client_get("service_context") - _args = _context.state.extend_request_args( - {}, oauth2.AuthorizationRequest, "auth_request", _state, parameters - ) - - _args = _context.state.extend_request_args( - _args, oauth2.AuthorizationResponse, "auth_response", _state, parameters - ) + _args = _context.cstate.get_set(_state, claim=parameters) if "grant_type" not in _args: _args["grant_type"] = "authorization_code" diff --git a/src/idpyoidc/client/oauth2/add_on/identity_assurance.py b/src/idpyoidc/client/oauth2/add_on/identity_assurance.py index 6b8e535a..9944ffb9 100644 --- a/src/idpyoidc/client/oauth2/add_on/identity_assurance.py +++ b/src/idpyoidc/client/oauth2/add_on/identity_assurance.py @@ -35,9 +35,10 @@ def format_response(format, response, verified_response): def identity_assurance_process(response, service_context, state): - auth_request = service_context.state.get_item(AuthorizationRequest, "auth_request", state) + auth_request = service_context.cstate.get_set(state, + message=AuthorizationRequest) claims_request = auth_request.get("claims") - if "userinfo" in claims_request: + if claims_request and "userinfo" in claims_request: vc = VerifiedClaims(**response["verified_claims"]) # find the claims request in the authorization request diff --git a/src/idpyoidc/client/oauth2/add_on/pkce.py b/src/idpyoidc/client/oauth2/add_on/pkce.py index 06877b27..d45f411a 100644 --- a/src/idpyoidc/client/oauth2/add_on/pkce.py +++ b/src/idpyoidc/client/oauth2/add_on/pkce.py @@ -50,7 +50,7 @@ def add_code_challenge(request_args, service, **kwargs): raise Unsupported("PKCE Transformation method:{}".format(_method)) _item = Message(code_verifier=code_verifier, code_challenge_method=_method) - _context.state.store_item(_item, "pkce", request_args["state"]) + _context.cstate.update(request_args["state"], _item) request_args.update({"code_challenge": code_challenge, "code_challenge_method": _method}) return request_args, {} @@ -69,8 +69,8 @@ def add_code_verifier(request_args, service, **kwargs): _state = request_args.get("state") if _state is None: _state = kwargs.get("state") - _item = service.client_get("service_context").state.get_item(Message, "pkce", _state) - request_args.update({"code_verifier": _item["code_verifier"]}) + _item = service.client_get("service_context").cstate.get_set(_state, claim=['code_verifier']) + request_args.update(_item) return request_args diff --git a/src/idpyoidc/client/oauth2/authorization.py b/src/idpyoidc/client/oauth2/authorization.py index 75809075..221a68d4 100644 --- a/src/idpyoidc/client/oauth2/authorization.py +++ b/src/idpyoidc/client/oauth2/authorization.py @@ -49,12 +49,12 @@ def __init__(self, client_get, conf=None): def update_service_context(self, resp, key="", **kwargs): if "expires_in" in resp: resp["__expires_at"] = time_sans_frac() + int(resp["expires_in"]) - self.client_get("service_context").state.store_item(resp, "auth_response", key) + self.client_get("service_context").cstate.update(key, resp) def store_auth_request(self, request_args=None, **kwargs): """Store the authorization request in the state DB.""" _key = get_state_parameter(request_args, kwargs) - self.client_get("service_context").state.store_item(request_args, "auth_request", _key) + self.client_get("service_context").cstate.update(_key, request_args) return request_args def gather_request_args(self, **kwargs): @@ -87,9 +87,8 @@ def post_parse_response(self, response, **kwargs): pass else: if _key: - item = self.client_get("service_context").state.get_item( - oauth2.AuthorizationRequest, "auth_request", _key - ) + item = self.client_get("service_context").cstate.get_set( + _key, message=oauth2.AuthorizationRequest) try: response["scope"] = item["scope"] except KeyError: diff --git a/src/idpyoidc/client/oauth2/client_credentials/cc_access_token.py b/src/idpyoidc/client/oauth2/client_credentials/cc_access_token.py index 896c0897..1837e180 100644 --- a/src/idpyoidc/client/oauth2/client_credentials/cc_access_token.py +++ b/src/idpyoidc/client/oauth2/client_credentials/cc_access_token.py @@ -1,3 +1,5 @@ +from typing import Optional + from idpyoidc.client.service import Service from idpyoidc.message import oauth2 from idpyoidc.message.oauth2 import ResponseMessage @@ -19,7 +21,7 @@ class CCAccessToken(Service): def __init__(self, client_get, conf=None): Service.__init__(self, client_get, conf=conf) - def update_service_context(self, resp, key="cc", **kwargs): + def update_service_context(self, resp, key: Optional[str] = '', **kwargs): if "expires_in" in resp: resp["__expires_at"] = time_sans_frac() + int(resp["expires_in"]) - self.client_get("service_context").state.store_item(resp, "token_response", key) + self.client_get("service_context").cstate.update(key, resp) diff --git a/src/idpyoidc/client/oauth2/client_credentials/cc_refresh_access_token.py b/src/idpyoidc/client/oauth2/client_credentials/cc_refresh_access_token.py index 50fa4931..838fdb9f 100644 --- a/src/idpyoidc/client/oauth2/client_credentials/cc_refresh_access_token.py +++ b/src/idpyoidc/client/oauth2/client_credentials/cc_refresh_access_token.py @@ -1,3 +1,5 @@ +from typing import Optional + from idpyoidc.client.service import Service from idpyoidc.message import oauth2 from idpyoidc.message.oauth2 import ResponseMessage @@ -22,14 +24,8 @@ def __init__(self, client_get, conf=None): def cc_pre_construct(self, request_args=None, **kwargs): _state_id = kwargs.get("state", "cc") parameters = ["refresh_token"] - _state_interface = self.client_get("service_context").state - _args = _state_interface.extend_request_args( - {}, oauth2.AccessTokenResponse, "token_response", _state_id, parameters - ) - - _args = _state_interface.extend_request_args( - _args, oauth2.AccessTokenResponse, "refresh_token_response", _state_id, parameters - ) + _current = self.client_get("service_context").cstate + _args = _current.get_set(_state_id, claim=parameters) if request_args is None: request_args = _args @@ -48,7 +44,7 @@ def cc_post_construct(self, request_args, **kwargs): return request_args - def update_service_context(self, resp, key="cc", **kwargs): + def update_service_context(self, resp, key: Optional[str] = "", **kwargs): if "expires_in" in resp: resp["__expires_at"] = time_sans_frac() + int(resp["expires_in"]) - self.client_get("service_context").state.store_item(resp, "token_response", key) + self.client_get("service_context").cstate.update(key, resp) diff --git a/src/idpyoidc/client/oauth2/refresh_access_token.py b/src/idpyoidc/client/oauth2/refresh_access_token.py index 5feb496d..6ba8f986 100644 --- a/src/idpyoidc/client/oauth2/refresh_access_token.py +++ b/src/idpyoidc/client/oauth2/refresh_access_token.py @@ -1,5 +1,6 @@ """The service that talks to the OAuth2 refresh access token endpoint.""" import logging +from typing import Optional from idpyoidc.client.oauth2.utils import get_state_parameter from idpyoidc.client.service import Service @@ -26,24 +27,18 @@ def __init__(self, client_get, conf=None): Service.__init__(self, client_get, conf=conf) self.pre_construct.append(self.oauth_pre_construct) - def update_service_context(self, resp, key="", **kwargs): + def update_service_context(self, resp, key: Optional[str] = "", **kwargs): if "expires_in" in resp: resp["__expires_at"] = time_sans_frac() + int(resp["expires_in"]) - self.client_get("service_context").state.store_item(resp, "token_response", key) + self.client_get("service_context").cstate.update(key, resp) def oauth_pre_construct(self, request_args=None, **kwargs): """Preconstructor of request arguments""" _state = get_state_parameter(request_args, kwargs) parameters = list(self.msg_type.c_param.keys()) - _si = self.client_get("service_context").state - _args = _si.extend_request_args( - {}, oauth2.AccessTokenResponse, "token_response", _state, parameters - ) - - _args = _si.extend_request_args( - _args, oauth2.AccessTokenResponse, "refresh_token_response", _state, parameters - ) + _current = self.client_get("service_context").cstate + _args = _current.get_set(_state, claim=parameters) if request_args is None: request_args = _args diff --git a/src/idpyoidc/client/oauth2/token_exchange.py b/src/idpyoidc/client/oauth2/token_exchange.py index ed6390af..0a31d743 100644 --- a/src/idpyoidc/client/oauth2/token_exchange.py +++ b/src/idpyoidc/client/oauth2/token_exchange.py @@ -1,5 +1,6 @@ """Implements the service that can exchange one token for another.""" import logging +from typing import Optional from idpyoidc.client.oauth2.utils import get_state_parameter from idpyoidc.client.service import Service @@ -30,10 +31,10 @@ def __init__(self, client_get, conf=None): Service.__init__(self, client_get, conf=conf) self.pre_construct.append(self.oauth_pre_construct) - def update_service_context(self, resp, key="", **kwargs): + def update_service_context(self, resp, key: Optional[str] = "", **kwargs): if "expires_in" in resp: resp["__expires_at"] = time_sans_frac() + int(resp["expires_in"]) - self.client_get("service_context").state.store_item(resp, "token_response", key) + self.client_get("service_context").cstate.update(key, resp) def oauth_pre_construct(self, request_args=None, post_args=None, **kwargs): """ @@ -53,17 +54,9 @@ def oauth_pre_construct(self, request_args=None, post_args=None, **kwargs): parameters = {'access_token', 'scope'} - _state = self.client_get("service_context").state + _current = self.client_get("service_context").cstate - _args = _state.extend_request_args( - {}, oauth2.AuthorizationResponse, "auth_response", _key, parameters - ) - _args = _state.extend_request_args( - _args, oauth2.AccessTokenResponse, "token_response", _key, parameters - ) - _args = _state.extend_request_args( - _args, oauth2.AccessTokenResponse, "refresh_token_response", _key, parameters - ) + _args = _current.get_set(_key, claim=parameters) request_args["subject_token"] = _args["access_token"] request_args["subject_token_type"] = 'urn:ietf:params:oauth:token-type:access_token' diff --git a/src/idpyoidc/client/oidc/access_token.py b/src/idpyoidc/client/oidc/access_token.py index 16f17f3c..342bb4db 100644 --- a/src/idpyoidc/client/oidc/access_token.py +++ b/src/idpyoidc/client/oidc/access_token.py @@ -5,8 +5,8 @@ from idpyoidc.client.exception import ParameterError from idpyoidc.client.oauth2 import access_token from idpyoidc.client.oidc import IDT2REG -from idpyoidc.client.work_condition import get_client_authn_methods -from idpyoidc.client.work_condition import get_signing_algs +from idpyoidc.client.work_environment import get_client_authn_methods +from idpyoidc.client.work_environment import get_signing_algs from idpyoidc.message import Message from idpyoidc.message import oidc from idpyoidc.message.oidc import verified_claim_name @@ -62,32 +62,32 @@ def gather_verify_arguments( except KeyError: pass - _verify_args = _context.work_condition.get_usage("verify_args") + _verify_args = _context.work_environment.get_usage("verify_args") if _verify_args: if _verify_args: kwargs.update(_verify_args) return kwargs - def update_service_context(self, resp, key="", **kwargs): - _state_interface = self.client_get("service_context").state + def update_service_context(self, resp, key: Optional[str] ="", **kwargs): + _cstate = self.client_get("service_context").cstate try: _idt = resp[verified_claim_name("id_token")] except KeyError: pass else: try: - if _state_interface.get_state_by_nonce(_idt["nonce"]) != key: + if _cstate.get_base_key(_idt["nonce"]) != key: raise ParameterError('Someone has messed with "nonce"') except KeyError: raise ValueError("Invalid nonce value") - _state_interface.store_sub2state(_idt["sub"], key) + _cstate.bind_key(_idt["sub"], key) if "expires_in" in resp: resp["__expires_at"] = time_sans_frac() + int(resp["expires_in"]) - _state_interface.store_item(resp, "token_response", key) + _cstate.update(key, resp) def get_authn_method(self): _context = self.client_get("service_context") diff --git a/src/idpyoidc/client/oidc/authorization.py b/src/idpyoidc/client/oidc/authorization.py index 271c679d..d14345e1 100644 --- a/src/idpyoidc/client/oidc/authorization.py +++ b/src/idpyoidc/client/oidc/authorization.py @@ -3,7 +3,7 @@ from typing import Optional from typing import Union -from idpyoidc.client import work_condition +from idpyoidc.client import work_environment from idpyoidc.client.oauth2 import authorization from idpyoidc.client.oauth2.utils import pre_construct_pick_redirect_uri from idpyoidc.client.oidc import IDT2REG @@ -32,9 +32,9 @@ class Authorization(authorization.Authorization): error_msg = oidc.ResponseMessage _supports = { - "request_object_signing_alg_values_supported": work_condition.get_signing_algs, - "request_object_encryption_alg_values_supported": work_condition.get_encryption_algs, - "request_object_encryption_enc_values_supported": work_condition.get_encryption_encs, + "request_object_signing_alg_values_supported": work_environment.get_signing_algs, + "request_object_encryption_alg_values_supported": work_environment.get_encryption_algs, + "request_object_encryption_enc_values_supported": work_environment.get_encryption_encs, "response_types_supported": ["code", "token", "code token", 'id_token', 'id_token token', 'code id_token', 'code idtoken token'], 'request_parameter_supported': None, @@ -67,16 +67,17 @@ def __init__(self, client_get, conf=None): self.default_request_args['scope'] = ['openid'] def set_state(self, request_args, **kwargs): + _context = self.client_get("service_context") try: _state = kwargs["state"] except KeyError: try: _state = request_args["state"] except KeyError: - _state = "" + _state = _context.cstate.create_key() - _context = self.client_get("service_context") - request_args["state"] = _context.state.create_state(_context.issuer, _state) + request_args["state"] = _state + _context.cstate.set(_state, {'iss': _context.issuer}) return request_args, {} def update_service_context(self, resp, key="", **kwargs): @@ -84,13 +85,11 @@ def update_service_context(self, resp, key="", **kwargs): if "expires_in" in resp: resp["__expires_at"] = time_sans_frac() + int(resp["expires_in"]) - _context.state.store_item(resp.to_json(), "auth_response", key) + _context.cstate.update(key, resp) def get_request_from_response(self, response): _context = self.client_get("service_context") - return _context.state.get_item( - oauth2.AuthorizationRequest, "auth_request", response["state"] - ) + return _context.cstate.get_set(response["state"], message=oauth2.AuthorizationRequest) def post_parse_response(self, response, **kwargs): response = authorization.Authorization.post_parse_response(self, response, **kwargs) @@ -99,8 +98,8 @@ def post_parse_response(self, response, **kwargs): if _idt: # If there is a verified ID Token then we have to do nonce # verification. - _request = self.get_request_from_response(response) - _req_nonce = _request.get("nonce") + _req_nonce = self.client_get("service_context").cstate.get_set( + response["state"], claim=['nonce']).get('nonce') if _req_nonce: _id_token_nonce = _idt.get("nonce") if not _id_token_nonce: @@ -182,7 +181,7 @@ def get_request_object_signing_alg(self, **kwargs): if not alg: _context = self.client_get("service_context") try: - alg = _context.work_condition.get_usage("request_object_signing_alg") + alg = _context.work_environment.get_usage("request_object_signing_alg") except KeyError: # Use default alg = "RS256" return alg @@ -272,13 +271,13 @@ def oidc_post_construct(self, req, **kwargs): if "openid" in req["scope"]: _response_type = req["response_type"][0] if "id_token" in _response_type or "code" in _response_type: - _context.state.store_nonce2state(req["nonce"], req["state"]) + _context.cstate.bind_key(req["nonce"], req["state"]) if "offline_access" in req["scope"]: if "prompt" not in req: req["prompt"] = "consent" - _context.state.store_item(req, "auth_request", req["state"]) + _context.cstate.update(req["state"], req) # Overrides what's in the configuration _request_param = kwargs.get("request_param") diff --git a/src/idpyoidc/client/oidc/check_id.py b/src/idpyoidc/client/oidc/check_id.py index 7cdab89d..6c3973cd 100644 --- a/src/idpyoidc/client/oidc/check_id.py +++ b/src/idpyoidc/client/oidc/check_id.py @@ -1,4 +1,5 @@ import logging +from typing import Optional from idpyoidc.client.service import Service from idpyoidc.message.oauth2 import Message @@ -22,11 +23,14 @@ def __init__(self, client_get, conf=None): Service.__init__(self, client_get, conf=conf) self.pre_construct = [self.oidc_pre_construct] - def oidc_pre_construct(self, request_args=None, **kwargs): - request_args = self.client_get("service_context").state.multiple_extend_request_args( - request_args, + def oidc_pre_construct(self, request_args: Optional[dict]=None, **kwargs): + _args = self.client_get("service_context").cstate.get_set( kwargs["state"], - ["id_token"], - ["auth_response", "token_response", "refresh_token_response"], + claim=["id_token"] ) + if request_args: + request_args.update() + else: + request_args = _args + return request_args, {} diff --git a/src/idpyoidc/client/oidc/check_session.py b/src/idpyoidc/client/oidc/check_session.py index 692e30dd..525744fb 100644 --- a/src/idpyoidc/client/oidc/check_session.py +++ b/src/idpyoidc/client/oidc/check_session.py @@ -23,10 +23,11 @@ def __init__(self, client_get, conf=None): self.pre_construct = [self.oidc_pre_construct] def oidc_pre_construct(self, request_args=None, **kwargs): - request_args = self.client_get("service_context").state.multiple_extend_request_args( - request_args, - kwargs["state"], - ["id_token"], - ["auth_response", "token_response", "refresh_token_response"], - ) + _args = self.client_get("service_context").cstate.get_set(kwargs["state"], + claim=["id_token"]) + if request_args: + request_args.update(_args) + else: + request_args = _args + return request_args, {} diff --git a/src/idpyoidc/client/oidc/end_session.py b/src/idpyoidc/client/oidc/end_session.py index eacb1861..47f967a7 100644 --- a/src/idpyoidc/client/oidc/end_session.py +++ b/src/idpyoidc/client/oidc/end_session.py @@ -52,20 +52,14 @@ def get_id_token_hint(self, request_args=None, **kwargs): :param kwargs: :return: """ - request_args = self.client_get("service_context").state.multiple_extend_request_args( - request_args, - kwargs["state"], - ["id_token"], - ["auth_response", "token_response", "refresh_token_response"], - orig=True, - ) + + _args = self.client_get('service_context').cstate.get_set(kwargs["state"], + claim=['id_token']) try: - request_args["id_token_hint"] = request_args["id_token"] + request_args["id_token_hint"] = _args["id_token"] except KeyError: pass - else: - del request_args["id_token"] return request_args, {} @@ -85,8 +79,6 @@ def add_state(self, request_args=None, **kwargs): request_args["state"] = rndstr(32) # As a side effect bind logout state to session state - self.client_get("service_context").state.store_logout_state2state( - request_args["state"], kwargs["state"] - ) + self.client_get("service_context").cstate.bind_key(request_args["state"], kwargs["state"]) return request_args, {} diff --git a/src/idpyoidc/client/oidc/provider_info_discovery.py b/src/idpyoidc/client/oidc/provider_info_discovery.py index ac21d3fd..50723c24 100644 --- a/src/idpyoidc/client/oidc/provider_info_discovery.py +++ b/src/idpyoidc/client/oidc/provider_info_discovery.py @@ -2,7 +2,7 @@ from idpyoidc.client.exception import ConfigurationError from idpyoidc.client.oauth2 import server_metadata -from idpyoidc.client.work_condition.transform import supported_to_preferred +from idpyoidc.client.work_environment.transform import supported_to_preferred from idpyoidc.message import oidc from idpyoidc.message.oauth2 import ResponseMessage @@ -25,18 +25,18 @@ def add_redirect_uris(request_args, service=None, **kwargs): :param kwargs: Possible extra keyword arguments :return: A possibly augmented set of request arguments. """ - _work_condition = service.client_get("service_context").work_condition + _work_environment = service.client_get("service_context").work_environment if "redirect_uris" not in request_args: # Callbacks is a dictionary with callback type 'code', 'implicit', # 'form_post' as keys. - _callback = _work_condition.get_preference('callback') + _callback = _work_environment.get_preference('callback') if _callback: # Filter out local additions. _uris = [v for k, v in _callback.items() if not k.startswith("__")] request_args["redirect_uris"] = _uris else: - request_args["redirect_uris"] = _work_condition.get_preference( - "redirect_uris", _work_condition.supports.get('redirect_uris')) + request_args["redirect_uris"] = _work_environment.get_preference( + "redirect_uris", _work_environment.supports.get('redirect_uris')) return request_args, {} @@ -52,7 +52,7 @@ class ProviderInfoDiscovery(server_metadata.ServerMetadata): def __init__(self, client_get, conf=None): server_metadata.ServerMetadata.__init__(self, client_get, conf=conf) - def update_service_context(self, resp, **kwargs): + def update_service_context(self, resp, key, **kwargs): _context = self.client_get("service_context") self._update_service_context(resp) # set endpoints and import keys _context.map_supported_to_preferred(resp) diff --git a/src/idpyoidc/client/oidc/refresh_access_token.py b/src/idpyoidc/client/oidc/refresh_access_token.py index b6e0ef71..85274d93 100644 --- a/src/idpyoidc/client/oidc/refresh_access_token.py +++ b/src/idpyoidc/client/oidc/refresh_access_token.py @@ -8,8 +8,8 @@ class RefreshAccessToken(refresh_access_token.RefreshAccessToken): error_msg = oidc.ResponseMessage def get_authn_method(self): - _work_condition = self.client_get("service_context").work_condition + _work_environment = self.client_get("service_context").work_environment try: - return _work_condition.get_usage("token_endpoint_auth_method") + return _work_environment.get_usage("token_endpoint_auth_method") except KeyError: return self.default_authn_method diff --git a/src/idpyoidc/client/oidc/registration.py b/src/idpyoidc/client/oidc/registration.py index 0d2ead7e..1ddeacd4 100644 --- a/src/idpyoidc/client/oidc/registration.py +++ b/src/idpyoidc/client/oidc/registration.py @@ -2,7 +2,7 @@ from idpyoidc.client.entity import response_types_to_grant_types from idpyoidc.client.service import Service -from idpyoidc.client.work_condition.transform import create_registration_request +from idpyoidc.client.work_environment.transform import create_registration_request from idpyoidc.message import oidc from idpyoidc.message.oauth2 import ResponseMessage @@ -95,7 +95,7 @@ def gather_request_args(self, **kwargs): @return: """ _context = self.client_get("service_context") - req_args = create_registration_request(_context.work_condition.prefer, _context.supports()) + req_args = create_registration_request(_context.work_environment.prefer, _context.supports()) if "request_args" in self.conf: req_args.update(self.conf["request_args"]) diff --git a/src/idpyoidc/client/oidc/userinfo.py b/src/idpyoidc/client/oidc/userinfo.py index 436bafd0..034a71af 100644 --- a/src/idpyoidc/client/oidc/userinfo.py +++ b/src/idpyoidc/client/oidc/userinfo.py @@ -2,11 +2,12 @@ from typing import Optional from typing import Union +from idpyoidc import verified_claim_name from idpyoidc.client.oauth2.utils import get_state_parameter from idpyoidc.client.service import Service -from idpyoidc.client.work_condition import get_encryption_algs -from idpyoidc.client.work_condition import get_encryption_encs -from idpyoidc.client.work_condition import get_signing_algs +from idpyoidc.client.work_environment import get_encryption_algs +from idpyoidc.client.work_environment import get_encryption_encs +from idpyoidc.client.work_environment import get_signing_algs from idpyoidc.exception import MissingSigningKey from idpyoidc.message import Message from idpyoidc.message import oidc @@ -59,27 +60,20 @@ def oidc_pre_construct(self, request_args=None, **kwargs): if "access_token" in request_args: pass else: - request_args = self.client_get("service_context").state.multiple_extend_request_args( - request_args, + request_args = self.client_get("service_context").cstate.get_set( kwargs["state"], - ["access_token"], - ["auth_response", "token_response", "refresh_token_response"], + claim=["access_token"] ) return request_args, {} def post_parse_response(self, response, **kwargs): _context = self.client_get("service_context") - _state_interface = _context.state - _args = _state_interface.multiple_extend_request_args( - {}, - kwargs["state"], - ["id_token"], - ["auth_response", "token_response", "refresh_token_response"], - ) + _current = _context.cstate + _args = _current.get_set(kwargs["state"], claim=[verified_claim_name("id_token")]) try: - _sub = _args["id_token"]["sub"] + _sub = _args[verified_claim_name("id_token")]["sub"] except KeyError: logger.warning("Can not verify value on sub") else: @@ -108,22 +102,12 @@ def post_parse_response(self, response, **kwargs): for key in claims: response[key] = aggregated_claims[key] - # elif "endpoint" in spec: - # _info = { - # "headers": self.get_authn_header( - # {}, - # self.default_authn_method, - # authn_endpoint=self.endpoint_name, - # key=kwargs["state"], - # ), - # "url": spec["endpoint"], - # } # Extension point for meth in self.post_parse_process: - response = meth(response, _state_interface, kwargs["state"]) + response = meth(response, _current, kwargs["state"]) - _state_interface.store_item(response, "user_info", kwargs["state"]) + _current.update(kwargs["state"], response) return response def gather_verify_arguments( diff --git a/src/idpyoidc/client/provider/github.py b/src/idpyoidc/client/provider/github.py index 0e0e2fa5..ed4ef970 100644 --- a/src/idpyoidc/client/provider/github.py +++ b/src/idpyoidc/client/provider/github.py @@ -1,7 +1,7 @@ from idpyoidc.client.oauth2 import access_token from idpyoidc.client.oidc import userinfo -from idpyoidc.client.work_condition import get_client_authn_methods -from idpyoidc.client.work_condition import get_signing_algs +from idpyoidc.client.work_environment import get_client_authn_methods +from idpyoidc.client.work_environment import get_signing_algs from idpyoidc.message import SINGLE_OPTIONAL_STRING from idpyoidc.message import SINGLE_REQUIRED_STRING from idpyoidc.message import Message diff --git a/src/idpyoidc/client/provider/linkedin.py b/src/idpyoidc/client/provider/linkedin.py index 8ddede1d..a9ad7931 100644 --- a/src/idpyoidc/client/provider/linkedin.py +++ b/src/idpyoidc/client/provider/linkedin.py @@ -1,7 +1,7 @@ from idpyoidc.client.oauth2 import access_token from idpyoidc.client.oidc import userinfo -from idpyoidc.client.work_condition import get_client_authn_methods -from idpyoidc.client.work_condition import get_signing_algs +from idpyoidc.client.work_environment import get_client_authn_methods +from idpyoidc.client.work_environment import get_signing_algs from idpyoidc.message import SINGLE_OPTIONAL_JSON from idpyoidc.message import SINGLE_OPTIONAL_STRING from idpyoidc.message import SINGLE_REQUIRED_INT diff --git a/src/idpyoidc/client/rp_handler.py b/src/idpyoidc/client/rp_handler.py index 1e821e8a..0398261b 100644 --- a/src/idpyoidc/client/rp_handler.py +++ b/src/idpyoidc/client/rp_handler.py @@ -19,7 +19,6 @@ from idpyoidc.exception import NotForMe from idpyoidc.message.oauth2 import ResponseMessage from idpyoidc.message.oauth2 import is_error_message -from idpyoidc.message.oidc import AccessTokenResponse from idpyoidc.message.oidc import AuthorizationRequest from idpyoidc.message.oidc import AuthorizationResponse from idpyoidc.message.oidc import Claims @@ -29,7 +28,6 @@ from idpyoidc.time_util import utc_time_sans_frac from idpyoidc.util import add_path from idpyoidc.util import rndstr - from . import oidc from .oauth2 import Client from .oauth2 import dynamic_provider_info_discovery @@ -40,19 +38,19 @@ class RPHandler(object): def __init__( - self, - base_url: Optional[str] = "", - client_configs=None, - services=None, - keyjar=None, - hash_seed="", - verify_ssl=True, - client_cls=None, - state_db=None, - http_lib=None, - httpc_params=None, - config=None, - **kwargs, + self, + base_url: Optional[str] = "", + client_configs=None, + services=None, + keyjar=None, + hash_seed="", + verify_ssl=True, + client_cls=None, + state_db=None, + http_lib=None, + httpc_params=None, + config=None, + **kwargs, ): self.base_url = base_url _jwks_path = kwargs.get("jwks_path") @@ -128,13 +126,10 @@ def state2issuer(self, state): :return: An Issuer ID """ for _rp in self.issuer2rp.values(): - try: - _iss = _rp.client_get("service_context").state.get_iss(state) - except KeyError: - continue - else: - if _iss: - return _iss + _iss = _rp.client_get("service_context").cstate.get_set( + state, claim=['iss']).get('iss') + if _iss: + return _iss return None def pick_config(self, issuer): @@ -151,7 +146,7 @@ def get_session_information(self, key, client=None): """ This is the second of the methods users of this class should know about. It will return the complete session information as an - :py:class:`idpyoidc.client.state_interface.State` instance. + :py:class:`idpyoidc.client.current.Current` instance. :param key: The session key (state) :return: A State instance @@ -159,7 +154,7 @@ def get_session_information(self, key, client=None): if not client: client = self.get_client_from_session_key(key) - return client.client_get("service_context").state.get_state(key) + return client.client_get("service_context").cstate.get(key) def init_client(self, issuer): """ @@ -214,10 +209,10 @@ def init_client(self, issuer): return client def do_provider_info( - self, - client: Optional[Client] = None, - state: Optional[str] = "", - behaviour_args: Optional[dict] = None, + self, + client: Optional[Client] = None, + state: Optional[str] = "", + behaviour_args: Optional[dict] = None, ) -> str: """ Either get the provider info from configuration or through dynamic @@ -279,12 +274,12 @@ def do_provider_info( return _context.get("issuer") def do_client_registration( - self, - client=None, - iss_id: Optional[str] = "", - state: Optional[str] = "", - request_args: Optional[dict] = None, - behaviour_args: Optional[dict] = None, + self, + client=None, + iss_id: Optional[str] = "", + state: Optional[str] = "", + request_args: Optional[dict] = None, + behaviour_args: Optional[dict] = None, ): """ Prepare for and do client registration if configured to do so @@ -340,10 +335,10 @@ def do_webfinger(self, user: str) -> Client: return temporary_client def client_setup( - self, - iss_id: Optional[str] = "", - user: Optional[str] = "", - behaviour_args: Optional[dict] = None, + self, + iss_id: Optional[str] = "", + user: Optional[str] = "", + behaviour_args: Optional[dict] = None, ) -> Client: """ First if no issuer ID is given then the identifier for the user is @@ -394,16 +389,17 @@ def client_setup( def _get_response_type(self, context, req_args: Optional[dict] = None): if req_args: - return req_args.get("response_type", context.work_condition.get_usage("response_types")[0]) + return req_args.get("response_type", + context.work_environment.get_usage("response_types")[0]) else: - return context.work_condition.get_usage("response_types")[0] + return context.work_environment.get_usage("response_types")[0] def init_authorization( - self, - client: Optional[Client] = None, - state: Optional[str] = "", - req_args: Optional[dict] = None, - behaviour_args: Optional[dict] = None, + self, + client: Optional[Client] = None, + state: Optional[str] = "", + req_args: Optional[dict] = None, + behaviour_args: Optional[dict] = None, ) -> dict: """ Constructs the URL that will redirect the user to the authorization @@ -433,7 +429,7 @@ def init_authorization( "redirect_uri": pick_redirect_uri( _context, _entity, request_args=req_args, response_type=_response_type ), - "scope": _context.work_condition.get_usage("scope"), + "scope": _context.work_environment.get_usage("scope"), "response_type": _response_type, "nonce": _nonce, } @@ -448,9 +444,11 @@ def init_authorization( request_args.update(req_args) # Need a new state for a new authorization request - _state = _context.state.create_state(_context.get("issuer")) + _current = _context.cstate + _state = _current.create_key() request_args["state"] = _state - _context.state.store_nonce2state(_nonce, _state) + _current.bind_key(_nonce, _state) + _current.set(_state, {'iss': _context.get("issuer")}) logger.debug("Authorization request args: {}".format(request_args)) @@ -507,7 +505,7 @@ def get_response_type(client): :param client: A Client instance :return: The response_type """ - return client.service_context.work_condition.get_usage("response_types")[0] + return client.service_context.work_environment.get_usage("response_types")[0] @staticmethod def get_client_authn_method(client, endpoint): @@ -545,15 +543,12 @@ def get_tokens(self, state, client: Optional[Client] = None): client = self.get_client_from_session_key(state) _context = client.client_get("service_context") - authorization_response = _context.state.get_item( - AuthorizationResponse, "auth_response", state - ) - authorization_request = _context.state.get_item(AuthorizationRequest, "auth_request", state) + _claims = _context.cstate.get_set(state, claim=['code', 'redirect_uri']) req_args = { - "code": authorization_response["code"], + "code": _claims["code"], "state": state, - "redirect_uri": authorization_request["redirect_uri"], + "redirect_uri": _claims["redirect_uri"], "grant_type": "authorization_code", "client_id": client.get_client_id(), "client_secret": _context.get("client_secret"), @@ -633,12 +628,8 @@ def get_user_info(self, state, client=None, access_token="", **kwargs): client = self.get_client_from_session_key(state) if not access_token: - _arg = client.client_get("service_context").state.multiple_extend_request_args( - {}, - state, - ["access_token"], - ["auth_response", "token_response", "refresh_token_response"], - ) + _arg = client.client_get("service_context").cstate.get_set(state, + claim=["access_token"]) access_token = _arg["access_token"] request_args = {"access_token": access_token} @@ -663,7 +654,7 @@ def userinfo_in_id_token(id_token): return res def finalize_auth( - self, client, issuer: str, response: dict, behaviour_args: Optional[dict] = None + self, client, issuer: str, response: dict, behaviour_args: Optional[dict] = None ): """ Given the response returned to the redirect_uri, parse and verify it. @@ -696,7 +687,8 @@ def finalize_auth( _context = client.client_get("service_context") try: - _iss = _context.state.get_iss(authorization_response["state"]) + _iss = _context.cstate.get_set( + authorization_response["state"], claim=['iss']).get('iss') except KeyError: raise KeyError("Unknown state value") @@ -706,17 +698,14 @@ def finalize_auth( raise ValueError("Impersonator {}".format(issuer)) _srv.update_service_context(authorization_response, key=authorization_response["state"]) - _context.state.store_item( - authorization_response, "auth_response", authorization_response["state"] - ) return authorization_response def get_access_and_id_token( - self, - authorization_response=None, - state: Optional[str] = "", - client: Optional[object] = None, - behaviour_args: Optional[dict] = None, + self, + authorization_response=None, + state: Optional[str] = "", + client: Optional[object] = None, + behaviour_args: Optional[dict] = None, ): """ There are a number of services where access tokens and ID tokens can @@ -739,19 +728,16 @@ def get_access_and_id_token( _context = client.client_get("service_context") - if authorization_response is None: - if state: - authorization_response = _context.state.get_item( - AuthorizationResponse, "auth_response", state - ) - else: - raise ValueError("One of authorization_response or state must be provided") + resp_attr = authorization_response or _context.cstate.get_set(state, + message=AuthorizationResponse) + if resp_attr is None: + raise ValueError("One of authorization_response or state must be provided") if not state: state = authorization_response["state"] - authreq = _context.state.get_item(AuthorizationRequest, "auth_request", state) - _resp_type = set(authreq["response_type"]) + _req_attr = _context.cstate.get_set(state, AuthorizationRequest) + _resp_type = set(_req_attr["response_type"].split(' ')) access_token = None id_token = None @@ -860,12 +846,12 @@ def finalize(self, issuer, response, behaviour_args: Optional[dict] = None): except KeyError: pass else: - _context.state.store_sid2state(sid, _state) + _context.cstate.bind_key(sid, _state) if _id_token: - _context.state.store_sub2state(_id_token["sub"], _state) + _context.cstate.bind_key(_id_token["sub"], _state) else: - _context.state.store_sub2state(inforesp["sub"], _state) + _context.cstate.bind_key(inforesp["sub"], _state) return { "userinfo": inforesp, @@ -885,13 +871,9 @@ def has_active_authentication(self, state): client = self.get_client_from_session_key(state) - # Look for Id Token in all the places where it can be - _arg = client.client_get("service_context").state.multiple_extend_request_args( - {}, - state, - ["__verified_id_token"], - ["auth_response", "token_response", "refresh_token_response"], - ) + # Look for an IdToken + _arg = client.client_get("service_context").cstate.get_set(state, + claim=["__verified_id_token"]) if _arg: _now = utc_time_sans_frac() @@ -909,33 +891,21 @@ def get_valid_access_token(self, state): expires. Otherwise raise exception. """ - exp = 0 token = None indefinite = [] now = utc_time_sans_frac() client = self.get_client_from_session_key(state) _context = client.client_get("service_context") - for cls, typ in [ - (AccessTokenResponse, "refresh_token_response"), - (AccessTokenResponse, "token_response"), - (AuthorizationResponse, "auth_response"), - ]: - try: - response = _context.state.get_item(cls, typ, state) - except KeyError: - pass + _args = _context.cstate.get_set(state, claim=["access_token", "__expires_at"]) + if "access_token" in _args: + access_token = _args["access_token"] + _exp = _args.get("__expires_at", 0) + if not _exp: # No expiry date, lives for ever + indefinite.append((access_token, 0)) else: - if "access_token" in response: - access_token = response["access_token"] - try: - _exp = response["__expires_at"] - except KeyError: # No expiry date, lives for ever - indefinite.append((access_token, 0)) - else: - if _exp > now and _exp > exp: # expires sometime in the future - exp = _exp - token = (access_token, _exp) + if _exp > now: # expires sometime in the future + token = (access_token, _exp) if indefinite: return indefinite[0] @@ -946,10 +916,10 @@ def get_valid_access_token(self, state): raise OidcServiceError("No valid access token") def logout( - self, - state: str, - client: Optional[Client] = None, - post_logout_redirect_uri: Optional[str] = "", + self, + state: str, + client: Optional[Client] = None, + post_logout_redirect_uri: Optional[str] = "", ) -> dict: """ Does a RP initiated logout from an OP. After logout the user will be @@ -983,7 +953,8 @@ def logout( return resp def close( - self, state: str, issuer: Optional[str] = "", post_logout_redirect_uri: Optional[str] = "" + self, state: str, issuer: Optional[str] = "", + post_logout_redirect_uri: Optional[str] = "" ) -> dict: logger.debug(20 * "*" + " close " + 20 * "*") @@ -999,7 +970,7 @@ def close( def clear_session(self, state): client = self.get_client_from_session_key(state) - client.client_get("service_context").state.remove_state(state) + client.client_get("service_context").cstate.remove_state(state) def backchannel_logout(client, request="", request_args=None): @@ -1042,9 +1013,9 @@ def backchannel_logout(client, request="", request_args=None): if not sub and not sid: raise MessageException('Neither "sid" nor "sub"') elif sub: - _state = _context.state.get_state_by_sub(sub) + _state = _context.cstate.get_base_key(sub) elif sid: - _state = _context.state.get_state_by_sid(sid) + _state = _context.cstate.get_base_key(sid) else: _state = None diff --git a/src/idpyoidc/client/service.py b/src/idpyoidc/client/service.py index 7e45f497..8357038f 100644 --- a/src/idpyoidc/client/service.py +++ b/src/idpyoidc/client/service.py @@ -203,7 +203,7 @@ def do_post_construct(self, request_args, **kwargs): return request_args - def update_service_context(self, resp, key="", **kwargs): + def update_service_context(self, resp: Message, key: Optional[str] = '', **kwargs): """ A method run after the response has been parsed and verified. diff --git a/src/idpyoidc/client/service_context.py b/src/idpyoidc/client/service_context.py index 0142cf5f..2fb1dbfd 100644 --- a/src/idpyoidc/client/service_context.py +++ b/src/idpyoidc/client/service_context.py @@ -16,16 +16,16 @@ from cryptojwt.utils import as_bytes from idpyoidc.client.configure import Configuration -from idpyoidc.client.work_condition.oauth2 import WorkCondition as OAUTH2_Specs -from idpyoidc.client.work_condition.oidc import WorkCondition as OIDC_Specs +from idpyoidc.client.work_environment.oauth2 import WorkEnvironment as OAUTH2_Specs +from idpyoidc.client.work_environment.oidc import WorkEnvironment as OIDC_Specs from idpyoidc.util import rndstr from .configure import get_configuration -from .state_interface import StateInterface -from .work_condition import WorkCondition -from .work_condition import work_condition_dump -from .work_condition import work_condition_load -from .work_condition.transform import preferred_to_registered -from .work_condition.transform import supported_to_preferred +from .current import Current +from .work_environment import WorkEnvironment +from .work_environment import work_environment_dump +from .work_environment import work_environment_load +from .work_environment.transform import preferred_to_registered +from .work_environment.transform import supported_to_preferred from ..impexp import ImpExp logger = logging.getLogger(__name__) @@ -99,17 +99,17 @@ class ServiceContext(ImpExp): "iss_hash": None, "issuer": None, 'keyjar': KeyJar, - "work_condition": WorkCondition, + "work_environment": WorkEnvironment, "provider_info": None, "requests_dir": None, "registration_response": None, - "state": StateInterface, + "cstate": Current, # 'usage': None, "verify_args": None, } special_load_dump = { - "specs": {"load": work_condition_load, "dump": work_condition_dump}, + "specs": {"load": work_environment_load, "dump": work_environment_dump}, } def __init__(self, @@ -117,7 +117,7 @@ def __init__(self, base_url: Optional[str] = "", keyjar: Optional[KeyJar] = None, config: Optional[Union[dict, Configuration]] = None, - state: Optional[StateInterface] = None, + cstate: Optional[Current] = None, client_type: Optional[str] = 'oauth2', **kwargs): ImpExp.__init__(self) @@ -126,14 +126,14 @@ def __init__(self, self.client_get = client_get if not client_type or client_type == "oidc": - self.work_condition = OIDC_Specs() + self.work_environment = OIDC_Specs() elif client_type == "oauth2": - self.work_condition = OAUTH2_Specs() + self.work_environment = OAUTH2_Specs() else: raise ValueError(f"Unknown client type: {client_type}") self.entity_id = config.conf.get("client_id", "") - self.state = state or StateInterface() + self.cstate = cstate or Current() self.kid = {"sig": {}, "enc": {}} @@ -166,7 +166,7 @@ def __init__(self, for key, val in kwargs.items(): setattr(self, key, val) - self.keyjar = self.work_condition.load_conf(config.conf, supports=self.supports(), + self.keyjar = self.work_environment.load_conf(config.conf, supports=self.supports(), keyjar=keyjar) _response_types = self.get_preference( @@ -227,9 +227,9 @@ def _get_crypt(self, typ, attr): _item_typ = CLI_REG_MAP.get(typ) _alg = '' if _item_typ: - _alg = self.work_condition.get_usage(_item_typ[attr]) + _alg = self.work_environment.get_usage(_item_typ[attr]) if not _alg: - _alg = self.work_condition.get_preference(_item_typ[attr]) + _alg = self.work_environment.get_preference(_item_typ[attr]) if not _alg: _item_typ = PROVIDER_INFO_MAP.get(typ) @@ -266,10 +266,10 @@ def set(self, key, value): setattr(self, key, value) def get_client_id(self): - return self.work_condition.get_usage("client_id") + return self.work_environment.get_usage("client_id") def collect_usage(self): - return self.work_condition.use + return self.work_environment.use def supports(self): res = {} @@ -277,23 +277,23 @@ def supports(self): services = self.client_get('services') for service in services.values(): res.update(service.supports()) - res.update(self.work_condition.supports()) + res.update(self.work_environment.supports()) return res def prefers(self): - return self.work_condition.prefer + return self.work_environment.prefer def get_preference(self, claim, default=None): - return self.work_condition.get_preference(claim, default=default) + return self.work_environment.get_preference(claim, default=default) def set_preference(self, key, value): - self.work_condition.set_preference(key, value) + self.work_environment.set_preference(key, value) def get_usage(self, claim, default: Optional[str] = None): - return self.work_condition.get_usage(claim, default) + return self.work_environment.get_usage(claim, default) def set_usage(self, claim, value): - return self.work_condition.set_usage(claim, value) + return self.work_environment.set_usage(claim, value) def _callback_per_service(self): _cb = {} @@ -329,7 +329,7 @@ def construct_uris(self, response_types: Optional[list] = None): self.set_preference('redirect_uris', list(_redirect_uris)) def prefer_or_support(self, claim): - if claim in self.work_condition.prefer: + if claim in self.work_environment.prefer: return 'prefer' else: for service in self.client_get('services').values(): @@ -337,20 +337,20 @@ def prefer_or_support(self, claim): if _res: return _res - if claim in self.work_condition.supported(claim): + if claim in self.work_environment.supported(claim): return 'support' return None def map_supported_to_preferred(self, info: Optional[dict] = None): - self.work_condition.prefer = supported_to_preferred(self.supports(), - self.work_condition.prefer, + self.work_environment.prefer = supported_to_preferred(self.supports(), + self.work_environment.prefer, base_url=self.base_url, info=info) - return self.work_condition.prefer + return self.work_environment.prefer def map_preferred_to_registered(self, registration_response: Optional[dict] = None): - self.work_condition.use = preferred_to_registered( - self.work_condition.prefer, + self.work_environment.use = preferred_to_registered( + self.work_environment.prefer, supported=self.supports(), registration_response=registration_response) - return self.work_condition.use + return self.work_environment.use diff --git a/src/idpyoidc/client/state_interface.py b/src/idpyoidc/client/state_interface.py index ab752456..94818d77 100644 --- a/src/idpyoidc/client/state_interface.py +++ b/src/idpyoidc/client/state_interface.py @@ -88,7 +88,7 @@ def get_state(self, key): Get the state connected to a given key. :param key: Key into the state database - :return: A :py:class:´idpyoidc.client.state_interface.State` instance + :return: A :py:class:´idpyoidc.client.current.Current` instance """ _data = self._db.get(key) if not _data: @@ -155,7 +155,7 @@ def extend_request_args(self, args, item_cls, item_type, key, parameters, orig=F :param item_cls: The :py:class:`idpyoidc.message.Message` subclass that describes the item :param item_type: The type of item, this is one of the parameter - names in the :py:class:`idpyoidc.client.state_interface.State` class. + names in the :py:class:`idpyoidc.client.current.Current` class. :param key: The key to the information in the database :param parameters: A list of parameters who's values this method will return. @@ -302,7 +302,7 @@ def get_state_by_logout_state(self, logout_state): """ Find the state value by providing the logout state value. Will raise an exception if the logout state value is absent from the - state data base. + state database. :param logout_state: The logout state value :return: The state value @@ -324,7 +324,7 @@ def get_state_by_sid(self, sid): """ Find the state value by providing the logout state value. Will raise an exception if the logout state value is absent from the - state data base. + state database. :param sid: The session ID value :return: The state value @@ -346,7 +346,7 @@ def get_state_by_sub(self, sub): """ Find the state value by providing the subject id value. Will raise an exception if the subject id value is absent from the - state data base. + state database. :param sub: The Subject ID value :return: The state value diff --git a/src/idpyoidc/client/work_condition/__init__.py b/src/idpyoidc/client/work_environment/__init__.py similarity index 98% rename from src/idpyoidc/client/work_condition/__init__.py rename to src/idpyoidc/client/work_environment/__init__.py index e88450bf..8c31f144 100644 --- a/src/idpyoidc/client/work_condition/__init__.py +++ b/src/idpyoidc/client/work_environment/__init__.py @@ -16,18 +16,18 @@ from idpyoidc.util import qualified_name -def work_condition_dump(info, exclude_attributes): +def work_environment_dump(info, exclude_attributes): return {qualified_name(info.__class__): info.dump(exclude_attributes=exclude_attributes)} -def work_condition_load(item: dict, **kwargs): +def work_environment_load(item: dict, **kwargs): _class_name = list(item.keys())[0] # there is only one _cls = importer(_class_name) _cls = _cls().load(item[_class_name]) return _cls -class WorkCondition(ImpExp): +class WorkEnvironment(ImpExp): parameter = { "prefer": None, "use": None, diff --git a/src/idpyoidc/client/work_condition/oauth2.py b/src/idpyoidc/client/work_environment/oauth2.py similarity index 78% rename from src/idpyoidc/client/work_condition/oauth2.py rename to src/idpyoidc/client/work_environment/oauth2.py index 8b0861b9..40293cf2 100644 --- a/src/idpyoidc/client/work_condition/oauth2.py +++ b/src/idpyoidc/client/work_environment/oauth2.py @@ -1,9 +1,9 @@ from typing import Optional -from idpyoidc.client import work_condition +from idpyoidc.client import work_environment -class WorkCondition(work_condition.WorkCondition): +class WorkEnvironment(work_environment.WorkEnvironment): _supports = { "redirect_uris": None, "grant_types": ["authorization_code", "implicit", "refresh_token"], @@ -31,4 +31,4 @@ class WorkCondition(work_condition.WorkCondition): def __init__(self, prefer: Optional[dict] = None, callback_path: Optional[dict] = None): - work_condition.WorkCondition.__init__(self, prefer=prefer, callback_path=callback_path) + work_environment.WorkEnvironment.__init__(self, prefer=prefer, callback_path=callback_path) diff --git a/src/idpyoidc/client/work_condition/oidc.py b/src/idpyoidc/client/work_environment/oidc.py similarity index 82% rename from src/idpyoidc/client/work_condition/oidc.py rename to src/idpyoidc/client/work_environment/oidc.py index c0acafc0..e5adbcb1 100644 --- a/src/idpyoidc/client/work_condition/oidc.py +++ b/src/idpyoidc/client/work_environment/oidc.py @@ -1,20 +1,20 @@ import os from typing import Optional -from idpyoidc.client import work_condition +from idpyoidc.client import work_environment -class WorkCondition(work_condition.WorkCondition): - parameter = work_condition.WorkCondition.parameter.copy() +class WorkEnvironment(work_environment.WorkEnvironment): + parameter = work_environment.WorkEnvironment.parameter.copy() parameter.update({ "requests_dir": None }) _supports = { "grant_types_supported": ["authorization_code", "implicit", "refresh_token"], - "id_token_signing_alg_values_supported": work_condition.get_signing_algs, - "id_token_encryption_alg_values_supported": work_condition.get_encryption_algs, - "id_token_encryption_enc_values_supported": work_condition.get_encryption_encs, + "id_token_signing_alg_values_supported": work_environment.get_signing_algs, + "id_token_encryption_alg_values_supported": work_environment.get_encryption_algs, + "id_token_encryption_enc_values_supported": work_environment.get_encryption_encs, "acr_values_supported": None, "subject_types_supported": ["public", "pairwise", "ephemeral"], "application_type": "web", @@ -43,7 +43,7 @@ def __init__(self, prefer: Optional[dict] = None, callback_path: Optional[dict] = None ): - work_condition.WorkCondition.__init__(self, prefer=prefer, callback_path=callback_path) + work_environment.WorkEnvironment.__init__(self, prefer=prefer, callback_path=callback_path) def verify_rules(self): if self.get_preference("request_parameter_supported") and self.get_preference( diff --git a/src/idpyoidc/client/work_condition/transform.py b/src/idpyoidc/client/work_environment/transform.py similarity index 100% rename from src/idpyoidc/client/work_condition/transform.py rename to src/idpyoidc/client/work_environment/transform.py diff --git a/tests/request123456.jwt b/tests/request123456.jwt index 2d059fc3..5fead1d0 100644 --- a/tests/request123456.jwt +++ b/tests/request123456.jwt @@ -1 +1 @@ -eyJhbGciOiJSUzI1NiIsImtpZCI6IlNIRXlZV2N3TlZrMExUZFJPVFp6WjJGVVduZElWWGRhY2sweFdVTTVTRXB3Y1MwM2RWVXhXVTR6UlEifQ.eyJyZXNwb25zZV90eXBlIjogImNvZGUiLCAic3RhdGUiOiAic3RhdGUiLCAicmVkaXJlY3RfdXJpIjogImh0dHBzOi8vZXhhbXBsZS5jb20vY2xpL2F1dGh6X2NiIiwgInNjb3BlIjogIm9wZW5pZCIsICJub25jZSI6ICJhbmZGdkFnZ0h5Z3JPSkhCdjctUDllNE5xaFF0MmpIcDlaRXlvU1V6S3VJIiwgImNsaWVudF9pZCI6ICJjbGllbnRfaWQiLCAiaXNzIjogImNsaWVudF9pZCIsICJpYXQiOiAxNjY5NTM2NTEwLCAiYXVkIjogWyJodHRwczovL2V4YW1wbGUuY29tIl19.d07N--8b-wxfCWIdGtZaVhRGCTTUpsNhu4OAiQrNHx3PSGbbicyoEzJLWgEjH1oAdD-d63iu8ak-C_47Ve1kMewBZ1MdiN4GsOqxvL2fX0WfuHHR0A1ui5Ag5ciWDdlSE7l7G4d1G7FVFtqRAlEt4Hwe1MsPoFKHgDgYuOrPWi1As2SDsOYnmuySFXdQqSh-wLMsPoXUMGQAzEKLsTC2ZnOtNWawZIOnYO74f8LSYLBxnlopI027AihLIsqOR4rxbVv3fX_okRz9iB9IxTCCvAc3UsSrVXeCdWhEFGK6SdznOCSHR4JftVRV7CGqDezn-U9Uwk71p7ggNltEOfKUEQ \ No newline at end of file +eyJhbGciOiJSUzI1NiIsImtpZCI6IlNIRXlZV2N3TlZrMExUZFJPVFp6WjJGVVduZElWWGRhY2sweFdVTTVTRXB3Y1MwM2RWVXhXVTR6UlEifQ.eyJyZXNwb25zZV90eXBlIjogImNvZGUiLCAic3RhdGUiOiAic3RhdGUiLCAicmVkaXJlY3RfdXJpIjogImh0dHBzOi8vZXhhbXBsZS5jb20vY2xpL2F1dGh6X2NiIiwgInNjb3BlIjogIm9wZW5pZCIsICJub25jZSI6ICJvQzlvLUIwZ2dJZzRVeFgxQ0ZEN0hOVFpOTnplZUlSWjh2azZzZTZMR213IiwgImNsaWVudF9pZCI6ICJjbGllbnRfaWQiLCAiaXNzIjogImNsaWVudF9pZCIsICJpYXQiOiAxNjY5NzM0MDAxLCAiYXVkIjogWyJodHRwczovL2V4YW1wbGUuY29tIl19.adpuPLhsRTs5__3vLjHMPn1nlFYXHq6imhQ6ZAyF5XAwp0TCTNd7ZP6gFtiR-iGOsLFJCbrDyCuGC8opB3c3ETHXVgVnMoE_KzwwwFw20PL4zkq_B0lYgp1PEi9nb7q9a0qVQujb-hkRdq3B8ntaRxdaGnofQ1sP_DtbiqZDyNDbWnT3Wv7H-rLdotStTcZ7KGslQarZJHxN_0m1Mr7Ucon0VL267RciGf0x5pNQ_wQjj9T5uzVOKZMV7gfis_Gr1PlqDPwBDm_tH_10K49mn3exmx1dUlCN-Taw67yTR9Puqo3w6rQxImWusq00LxfHPtl4POZS7RKCQWEU2lQ_wA \ No newline at end of file diff --git a/tests/test_08_transform.py b/tests/test_08_transform.py index 2e19ebf6..ac92245c 100644 --- a/tests/test_08_transform.py +++ b/tests/test_08_transform.py @@ -3,11 +3,11 @@ from cryptojwt.utils import importer import pytest -from idpyoidc.client.work_condition.oidc import WorkCondition as WorkConditionOIDC -from idpyoidc.client.work_condition.transform import REGISTER2PREFERRED -from idpyoidc.client.work_condition.transform import create_registration_request -from idpyoidc.client.work_condition.transform import preferred_to_registered -from idpyoidc.client.work_condition.transform import supported_to_preferred +from idpyoidc.client.work_environment.oidc import WorkEnvironment as WorkEnvironmentOIDC +from idpyoidc.client.work_environment.transform import REGISTER2PREFERRED +from idpyoidc.client.work_environment.transform import create_registration_request +from idpyoidc.client.work_environment.transform import preferred_to_registered +from idpyoidc.client.work_environment.transform import supported_to_preferred from idpyoidc.message.oidc import ProviderConfigurationResponse from idpyoidc.message.oidc import RegistrationRequest @@ -15,7 +15,7 @@ class TestTransform: @pytest.fixture(autouse=True) def setup(self): - supported = WorkConditionOIDC._supports.copy() + supported = WorkEnvironmentOIDC._supports.copy() for service in [ 'idpyoidc.client.oidc.access_token.AccessToken', 'idpyoidc.client.oidc.authorization.Authorization', @@ -260,8 +260,8 @@ class TestTransform2: @pytest.fixture(autouse=True) def setup(self): - self.work_condition = WorkConditionOIDC() - supported = self.work_condition._supports.copy() + self.work_environment = WorkEnvironmentOIDC() + supported = self.work_environment._supports.copy() for service in [ 'idpyoidc.client.oidc.access_token.AccessToken', 'idpyoidc.client.oidc.authorization.Authorization', @@ -295,7 +295,7 @@ def setup(self): 'contacts': ["ve7jtb@example.org", "mary@example.org"] } - self.work_condition.load_conf(preference, self.supported) + self.work_environment.load_conf(preference, self.supported) def test_registration_response(self): OP_BASEURL = 'https://example.com' @@ -323,7 +323,7 @@ def test_registration_response(self): } pref = supported_to_preferred(supported=self.supported, - preference=self.work_condition.prefer, + preference=self.work_environment.prefer, base_url='https://example.com', info=provider_info_response) diff --git a/tests/test_09_work_condition.py b/tests/test_09_work_condition.py index 96b5f2ae..82df52f6 100644 --- a/tests/test_09_work_condition.py +++ b/tests/test_09_work_condition.py @@ -3,10 +3,10 @@ from cryptojwt.utils import importer import pytest as pytest -from idpyoidc.client.work_condition.oidc import WorkCondition as WorkConditionOIDC -from idpyoidc.client.work_condition.transform import create_registration_request -from idpyoidc.client.work_condition.transform import preferred_to_registered -from idpyoidc.client.work_condition.transform import supported_to_preferred +from idpyoidc.client.work_environment.oidc import WorkEnvironment as WorkEnvironmentOIDC +from idpyoidc.client.work_environment.transform import create_registration_request +from idpyoidc.client.work_environment.transform import preferred_to_registered +from idpyoidc.client.work_environment.transform import supported_to_preferred KEYSPEC = [ {"type": "RSA", "use": ["sig"]}, @@ -14,12 +14,12 @@ ] -class TestWorkCondition: +class TestWorkEnvironment: @pytest.fixture(autouse=True) def setup(self): - self.work_condition = WorkConditionOIDC() - supported = self.work_condition._supports.copy() + self.work_environment = WorkEnvironmentOIDC() + supported = self.work_environment._supports.copy() for service in [ 'idpyoidc.client.oidc.access_token.AccessToken', 'idpyoidc.client.oidc.authorization.Authorization', @@ -57,9 +57,9 @@ def test_load_conf(self): 'contacts': ["ve7jtb@example.org", "mary@example.org"] } - self.work_condition.load_conf(client_conf, self.supported) - assert self.work_condition.get_preference('jwks') is None - assert self.work_condition.get_preference('jwks_uri') is None + self.work_environment.load_conf(client_conf, self.supported) + assert self.work_environment.get_preference('jwks') is None + assert self.work_environment.get_preference('jwks_uri') is None def test_load_jwks(self): # Symmetric and asymmetric keys published as JWKS @@ -76,9 +76,9 @@ def test_load_jwks(self): 'contacts': ["ve7jtb@example.org", "mary@example.org"] } - self.work_condition.load_conf(client_conf, self.supported) - assert self.work_condition.get_preference('jwks') is not None - assert self.work_condition.get_preference('jwks_uri') is None + self.work_environment.load_conf(client_conf, self.supported) + assert self.work_environment.get_preference('jwks') is not None + assert self.work_environment.get_preference('jwks_uri') is None def test_load_jwks_uri1(self): # Symmetric and asymmetric keys published through a jwks_uri @@ -93,9 +93,9 @@ def test_load_jwks_uri1(self): 'contacts': ["ve7jtb@example.org", "mary@example.org"] } - self.work_condition.load_conf(client_conf, self.supported) - assert self.work_condition.get_preference('jwks') is None - assert self.work_condition.get_preference( + self.work_environment.load_conf(client_conf, self.supported) + assert self.work_environment.get_preference('jwks') is None + assert self.work_environment.get_preference( 'jwks_uri') == f"{client_conf['base_url']}{client_conf['keys']['uri_path']}" def test_load_jwks_uri2(self): @@ -112,9 +112,9 @@ def test_load_jwks_uri2(self): 'contacts': ["ve7jtb@example.org", "mary@example.org"] } - self.work_condition.load_conf(client_conf, self.supported) - assert self.work_condition.get_preference('jwks') is None - assert self.work_condition.get_preference('jwks_uri') == client_conf['jwks_uri'] + self.work_environment.load_conf(client_conf, self.supported) + assert self.work_environment.get_preference('jwks') is None + assert self.work_environment.get_preference('jwks_uri') == client_conf['jwks_uri'] def test_registration_response(self): client_conf = { @@ -130,7 +130,7 @@ def test_registration_response(self): 'contacts': ["ve7jtb@example.org", "mary@example.org"] } - self.work_condition.load_conf(client_conf, self.supported) + self.work_environment.load_conf(client_conf, self.supported) OP_BASEURL = 'https://example.com' provider_info_response = { @@ -157,7 +157,7 @@ def test_registration_response(self): } pref = supported_to_preferred(supported=self.supported, - preference=self.work_condition.prefer, + preference=self.work_environment.prefer, base_url='https://example.com', info=provider_info_response) diff --git a/tests/test_client_00_current.py b/tests/test_client_00_current.py new file mode 100644 index 00000000..b701d6dc --- /dev/null +++ b/tests/test_client_00_current.py @@ -0,0 +1,92 @@ +import pytest + +from idpyoidc.client.current import Current +from idpyoidc.message import Message + +ISSUER = "https://example.com" + + +class TestCurrent: + @pytest.fixture(autouse=True) + def test_setup(self): + self.current = Current() + + def test_create_key_no_key(self): + state_key = self.current.create_key() + self.current.set(state_key, {'iss': ISSUER}) + _iss = self.current.get(state_key)['iss'] + assert _iss == ISSUER + _item = self.current.get_set(state_key, claim=['iss']) + assert _item['iss'] == ISSUER + + def test_store_and_retrieve_state_item(self): + state_key = self.current.create_key() + item = Message(foo="bar", issuer=ISSUER) + self.current.set(state_key, item) + _state = self.current.get(state_key) + assert set(_state.keys()) == {"issuer", "foo"} + _item = self.current.get_set(state_key, Message) + assert set(_item.keys()) == set() # since Message has no attribute definitions + + def test_nonce(self): + state_key = self.current.create_key() + self.current.bind_key("nonce", state_key) + _state_key = self.current.get_base_key("nonce") + assert _state_key == state_key + + def test_other_id(self): + state_key = self.current.create_key() + self.current.bind_key("subject_id", state_key) + self.current.bind_key("nonce", state_key) + self.current.bind_key("session_id", state_key) + self.current.bind_key("logout_id", state_key) + + _state_key = self.current.get_base_key("nonce") + assert _state_key == state_key + _state_key = self.current.get_base_key("subject_id") + assert _state_key == state_key + _state_key = self.current.get_base_key("session_id") + assert _state_key == state_key + _state_key = self.current.get_base_key("logout_id") + assert _state_key == state_key + + def test_remove(self): + state_key = self.current.create_state(iss='foo') + self.current.bind_key("subject_id", state_key) + self.current.bind_key("nonce", state_key) + self.current.bind_key("session_id", state_key) + self.current.bind_key("logout_id", state_key) + + _state_key = self.current.get_base_key("nonce") + assert _state_key == state_key + _state_key = self.current.get_base_key("subject_id") + assert _state_key == state_key + _state_key = self.current.get_base_key("session_id") + assert _state_key == state_key + _state_key = self.current.get_base_key("logout_id") + assert _state_key == state_key + + self.current.remove_state(state_key) + with pytest.raises(KeyError): + self.current.get_base_key(state_key) + with pytest.raises(KeyError): + self.current.get_base_key("subject_id") + with pytest.raises(KeyError): + self.current.get_base_key("nonce") + with pytest.raises(KeyError): + self.current.get_base_key("session_id") + with pytest.raises(KeyError): + self.current.get_base_key("logout_id") + + def test_extend_request_args(self): + state_key = self.current.create_key() + + item = Message(foo="bar") + self.current.set(state_key, item) + + args = self.current.get_set(state_key, claim=["foo"]) + assert args == {"foo": "bar"} + + # unknown attribute + args = self.current.get_set(state_key, claim=["fox"]) + assert args == {} diff --git a/tests/test_client_01_service_context.py b/tests/test_client_01_service_context.py index 97c860b6..a57e9c82 100644 --- a/tests/test_client_01_service_context.py +++ b/tests/test_client_01_service_context.py @@ -37,11 +37,11 @@ def test_get_sign_alg(self): _alg = self.service_context.get_sign_alg("id_token") assert _alg is None - self.service_context.work_condition.set_preference("id_token_signed_response_alg", "RS384") + self.service_context.work_environment.set_preference("id_token_signed_response_alg", "RS384") _alg = self.service_context.get_sign_alg("id_token") assert _alg == "RS384" - self.service_context.work_condition.prefer = {} + self.service_context.work_environment.prefer = {} self.service_context.provider_info["id_token_signing_alg_values_supported"] = [ "RS256", "ES256", @@ -53,15 +53,15 @@ def test_get_enc_alg_enc(self): _alg_enc = self.service_context.get_enc_alg_enc("userinfo") assert _alg_enc == {"alg": None, "enc": None} - self.service_context.work_condition.set_preference("userinfo_encrypted_response_alg", + self.service_context.work_environment.set_preference("userinfo_encrypted_response_alg", "RSA1_5") - self.service_context.work_condition.set_preference("userinfo_encrypted_response_enc", + self.service_context.work_environment.set_preference("userinfo_encrypted_response_enc", "A128CBC+HS256") _alg_enc = self.service_context.get_enc_alg_enc("userinfo") assert _alg_enc == {"alg": "RSA1_5", "enc": "A128CBC+HS256"} - self.service_context.work_condition.prefer = {} + self.service_context.work_environment.prefer = {} self.service_context.provider_info["userinfo_encryption_alg_values_supported"] = [ "RSA1_5", "A128KW", diff --git a/tests/test_client_04_service.py b/tests/test_client_04_service.py index f4112df0..951e94dd 100644 --- a/tests/test_client_04_service.py +++ b/tests/test_client_04_service.py @@ -184,9 +184,9 @@ def test_response(self): _info = self.service.get_request_parameters(request_args=req_args) assert set(_info.keys()) == {"url", "method", "request"} msg = Message().from_urlencoded(self.service.get_urlinfo(_info["url"])) - self.service.client_get("service_context").state.store_item(msg, "request", _state) + self.service.client_get("service_context").cstate.set(_state, msg) resp1 = AuthorizationResponse(code="auth_grant", state=_state) response = self.service.parse_response(resp1.to_urlencoded(), "urlencoded", state=_state) self.service.update_service_context(response, key=_state) - assert self.service.client_get("service_context").state.get_state(_state) + assert self.service.client_get("service_context").cstate.get(_state) diff --git a/tests/test_client_06_client_authn.py b/tests/test_client_06_client_authn.py index 67804c0b..1544d1ee 100644 --- a/tests/test_client_06_client_authn.py +++ b/tests/test_client_06_client_authn.py @@ -1,27 +1,27 @@ import base64 import os -import pytest from cryptojwt.exception import MissingKey -from cryptojwt.jws.jws import factory from cryptojwt.jws.jws import JWS +from cryptojwt.jws.jws import factory from cryptojwt.jwt import JWT from cryptojwt.key_bundle import KeyBundle -from cryptojwt.key_jar import init_key_jar from cryptojwt.key_jar import KeyJar +from cryptojwt.key_jar import init_key_jar +import pytest -from idpyoidc.client.client_auth import assertion_jwt from idpyoidc.client.client_auth import AuthnFailure -from idpyoidc.client.client_auth import bearer_auth from idpyoidc.client.client_auth import BearerBody from idpyoidc.client.client_auth import BearerHeader from idpyoidc.client.client_auth import ClientSecretBasic from idpyoidc.client.client_auth import ClientSecretJWT from idpyoidc.client.client_auth import ClientSecretPost from idpyoidc.client.client_auth import PrivateKeyJWT +from idpyoidc.client.client_auth import assertion_jwt +from idpyoidc.client.client_auth import bearer_auth from idpyoidc.client.client_auth import valid_service_context from idpyoidc.client.entity import Entity -from idpyoidc.client.work_condition import WorkCondition +from idpyoidc.client.work_environment import WorkEnvironment from idpyoidc.defaults import JWT_BEARER from idpyoidc.message import Message from idpyoidc.message.oauth2 import AccessTokenRequest @@ -168,11 +168,12 @@ def test_construct_with_resource_request(self, entity): def test_construct_with_token(self, entity): _service = entity.client_get("service", "") srv_cntx = _service.client_get("service_context") - _state = srv_cntx.state.create_state("Issuer") + _state = srv_cntx.cstate.create_key() + srv_cntx.cstate.set(_state, {'iss': "Issuer"}) req = AuthorizationRequest( state=_state, response_type="code", redirect_uri="https://example.com", scope=["openid"] ) - srv_cntx.state.store_item(req, "auth_request", _state) + srv_cntx.cstate.update(_state, req) # Add a state and bind a code to it resp1 = AuthorizationResponse(code="auth_grant", state=_state) @@ -185,9 +186,7 @@ def test_construct_with_token(self, entity): ) response = _service.parse_response(resp2.to_urlencoded(), "urlencoded") - _service.client_get("service_context").state.store_item( - response, "token_response", key=_state - ) + _service.client_get("service_context").cstate.update(_state, response) # and finally use the access token, bound to a state, to # construct the authorization header @@ -208,10 +207,11 @@ def test_construct(self, entity): def test_construct_with_state(self, entity): _auth_service = entity.client_get("service", "") _cntx = _auth_service.client_get("service_context") - _key = _cntx.state.create_state(iss="Issuer") + _key = _cntx.cstate.create_key() + _cntx.cstate.set(_key, {'iss': "Issuer"}) resp = AuthorizationResponse(code="code", state=_key) - _cntx.state.store_item(resp, "auth_response", _key) + _cntx.cstate.update(_key, resp) atr = AccessTokenResponse( access_token="2YotnFZFEjr1zCsicMWpAA", @@ -220,7 +220,7 @@ def test_construct_with_state(self, entity): example_parameter="example_value", scope=["inner", "outer"], ) - _cntx.state.store_item(atr, "token_response", _key) + _cntx.cstate.update(_key, atr) request = ResourceRequest() http_args = BearerBody().construct(request, service=_auth_service, key=_key) @@ -231,7 +231,8 @@ def test_construct_with_request(self, entity): authz_service = entity.client_get("service", "") _cntx = authz_service.client_get("service_context") - _key = _cntx.state.create_state(iss="Issuer") + _key = _cntx.cstate.create_key() + _cntx.cstate.set(_key, {'iss': "Issuer"}) resp1 = AuthorizationResponse(code="auth_grant", state=_key) response = authz_service.parse_response(resp1.to_urlencoded(), "urlencoded") authz_service.update_service_context(response, key=_key) @@ -241,9 +242,7 @@ def test_construct_with_request(self, entity): ) _service2 = entity.client_get("service", "") response = _service2.parse_response(resp2.to_urlencoded(), "urlencoded") - _service2.client_get("service_context").state.store_item( - response, "token_response", key=_key - ) + _service2.client_get("service_context").cstate.update(_key, response) request = ResourceRequest() BearerBody().construct(request, service=authz_service, key=_key) @@ -274,7 +273,7 @@ def test_construct(self, entity): def test_modify_1(self, entity): token_service = entity.client_get("service", "") request = token_service.construct(request_args={'redirect_uri': "http://example.com", - 'state': "ABCDE"}) + 'state': "ABCDE"}) csp = ClientSecretPost() http_args = csp.construct(request, service=token_service) assert "client_secret" in request @@ -282,7 +281,7 @@ def test_modify_1(self, entity): def test_modify_2(self, entity): _service = entity.client_get("service", "") request = _service.construct(request_args={'redirect_uri': "http://example.com", - 'state': "ABCDE"}) + 'state': "ABCDE"}) csp = ClientSecretPost() _service.client_get("service_context").set_usage('client_secret', "") # this will fail @@ -464,7 +463,7 @@ def test_get_audience_and_algorithm_default_alg(self, entity): # Use provider information is everything else fails request = AccessTokenRequest() - _service_context.work_condition = WorkCondition() + _service_context.work_environment = WorkEnvironment() _service_context.provider_info["token_endpoint_auth_signing_alg_values_supported"] = [ "ES256", "RS256", diff --git a/tests/test_client_12_client_auth.py b/tests/test_client_12_client_auth.py index 42e91286..6d0d1c6f 100755 --- a/tests/test_client_12_client_auth.py +++ b/tests/test_client_12_client_auth.py @@ -153,11 +153,11 @@ def test_construct_with_resource_request(self, entity): def test_construct_with_token(self, entity): authz_service = entity.client_get("service", "authorization") srv_cntx = authz_service.client_get("service_context") - _state = srv_cntx.state.create_state("Issuer") + _state = srv_cntx.cstate.create_state(iss="Issuer") req = AuthorizationRequest( state=_state, response_type="code", redirect_uri="https://example.com", scope=["openid"] ) - srv_cntx.state.store_item(req, "auth_request", _state) + srv_cntx.cstate.update(_state, req) # Add a state and bind a code to it resp1 = AuthorizationResponse(code="auth_grant", state=_state) @@ -192,10 +192,10 @@ def test_construct(self, entity): def test_construct_with_state(self, entity): _auth_service = entity.client_get("service", "authorization") _cntx = _auth_service.client_get("service_context") - _key = _cntx.state.create_state(iss="Issuer") + _key = _cntx.cstate.create_state(iss="Issuer") resp = AuthorizationResponse(code="code", state=_key) - _cntx.state.store_item(resp, "auth_response", _key) + _cntx.cstate.update(_key, resp) atr = AccessTokenResponse( access_token="2YotnFZFEjr1zCsicMWpAA", @@ -204,7 +204,7 @@ def test_construct_with_state(self, entity): example_parameter="example_value", scope=["inner", "outer"], ) - _cntx.state.store_item(atr, "token_response", _key) + _cntx.cstate.update(_key, atr) request = ResourceRequest() http_args = BearerBody().construct(request, service=_auth_service, key=_key) @@ -215,7 +215,7 @@ def test_construct_with_request(self, entity): authz_service = entity.client_get("service", "authorization") _cntx = authz_service.client_get("service_context") - _key = _cntx.state.create_state(iss="Issuer") + _key = _cntx.cstate.create_state(iss="Issuer") resp1 = AuthorizationResponse(code="auth_grant", state=_key) response = authz_service.parse_response(resp1.to_urlencoded(), "urlencoded") authz_service.update_service_context(response, key=_key) diff --git a/tests/test_client_14_service_context_impexp.py b/tests/test_client_14_service_context_impexp.py index d0f8edbc..106d829c 100644 --- a/tests/test_client_14_service_context_impexp.py +++ b/tests/test_client_14_service_context_impexp.py @@ -19,7 +19,7 @@ def test_client_info_init(): "requests_dir": "requests", } ci = ServiceContext(config=config, client_type='oidc') - ci.work_condition.load_conf(config, supports=ci.supports()) + ci.work_environment.load_conf(config, supports=ci.supports()) ci.map_supported_to_preferred() ci.map_preferred_to_registered() @@ -109,7 +109,7 @@ def create_client_info_instance(self): self.service_context = ServiceContext(config=config) def test_registration_userinfo_sign_enc_algs(self): - self.service_context.work_condition.use = { + self.service_context.work_environment.use = { "application_type": "web", "redirect_uris": [ "https://client.example.org/callback", @@ -128,7 +128,7 @@ def test_registration_userinfo_sign_enc_algs(self): assert srvcntx.get_enc_alg_enc("userinfo") == {"alg": "RSA1_5", "enc": "A128CBC-HS256"} def test_registration_request_object_sign_enc_algs(self): - self.service_context.work_condition.use = { + self.service_context.work_environment.use = { "application_type": "web", "redirect_uris": [ "https://client.example.org/callback", @@ -150,7 +150,7 @@ def test_registration_request_object_sign_enc_algs(self): assert srvcntx.get_sign_alg("request_object") == "RS384" def test_registration_id_token_sign_enc_algs(self): - self.service_context.work_condition.use = { + self.service_context.work_environment.use = { "application_type": "web", "redirect_uris": [ "https://client.example.org/callback", diff --git a/tests/test_client_20_oauth2.py b/tests/test_client_20_oauth2.py index a6c3fae0..fca7ba43 100644 --- a/tests/test_client_20_oauth2.py +++ b/tests/test_client_20_oauth2.py @@ -2,9 +2,9 @@ import sys import time -import pytest from cryptojwt.jwk.rsa import import_private_rsa_key_from_file from cryptojwt.key_bundle import KeyBundle +import pytest from idpyoidc.client.configure import RPHConfiguration from idpyoidc.client.exception import OidcServiceError @@ -65,7 +65,7 @@ def test_construct_authorization_request(self): "response_type": ["code"], } - self.client.client_get("service_context").state.create_state("issuer", key="ABCDE") + self.client.client_get("service_context").cstate.set("ABCDE", {"iss": 'issuer'}) msg = self.client.client_get("service", "authorization").construct(request_args=req_args) assert isinstance(msg, AuthorizationRequest) assert msg["client_id"] == "client_1" @@ -75,19 +75,17 @@ def test_construct_accesstoken_request(self): # Bind access code to state req_args = {} _context = self.client.client_get("service_context") - _context.state.create_state("issuer", "ABCDE") + _context.cstate.set("ABCDE", {"issuer": "issuer"}) auth_request = AuthorizationRequest( redirect_uri="https://example.com/cli/authz_cb", state="ABCDE" ) - _context.state.store_item(auth_request, "auth_request", "ABCDE") + _context.cstate.update("ABCDE", auth_request) auth_response = AuthorizationResponse(code="access_code") - self.client.client_get("service_context").state.store_item( - auth_response, "auth_response", "ABCDE" - ) + self.client.client_get("service_context").cstate.update("ABCDE", auth_response) msg = self.client.client_get("service", "accesstoken").construct( request_args=req_args, state="ABCDE" @@ -105,21 +103,22 @@ def test_construct_accesstoken_request(self): def test_construct_refresh_token_request(self): _context = self.client.client_get("service_context") - _context.state.create_state("issuer", "ABCDE") + _state = "ABCDE" + _context.cstate.set(_state, {'iss': "issuer"}) auth_request = AuthorizationRequest( redirect_uri="https://example.com/cli/authz_cb", state="state" ) - _context.state.store_item(auth_request, "auth_request", "ABCDE") + _context.cstate.update(_state, auth_request) auth_response = AuthorizationResponse(code="access_code") - _context.state.store_item(auth_response, "auth_response", "ABCDE") + _context.cstate.update(_state, auth_response) token_response = AccessTokenResponse(refresh_token="refresh_with_me", access_token="access") - _context.state.store_item(token_response, "token_response", "ABCDE") + _context.cstate.update(_state, token_response) req_args = {} msg = self.client.client_get("service", "refresh_token").construct( @@ -165,6 +164,7 @@ def test_error_response_2(self): BASE_URL = "https://example.com" + class TestClient2(object): @pytest.fixture(autouse=True) def create_client(self): @@ -190,7 +190,7 @@ def create_client(self): }, } rp_conf = RPHConfiguration(conf) - rp_handler = RPHandler(base_url=BASE_URL,config=rp_conf) + rp_handler = RPHandler(base_url=BASE_URL, config=rp_conf) self.client = rp_handler.init_client(issuer="client_1") assert self.client diff --git a/tests/test_client_21_oidc_service.py b/tests/test_client_21_oidc_service.py index 3c2ff767..53a4010c 100644 --- a/tests/test_client_21_oidc_service.py +++ b/tests/test_client_21_oidc_service.py @@ -298,7 +298,7 @@ def test_allow_unsigned_idtoken(self, allow_sign_alg_none): idt = JWT(ISS_KEY, iss=ISS, lifetime=3600, sign_alg="none") payload = {"sub": "123456789", "aud": ["client_id"], "nonce": req_args["nonce"]} _idt = idt.pack(payload) - self.service.client_get("service_context").work_condition.set_usage("verify_args", { + self.service.client_get("service_context").work_environment.set_usage("verify_args", { "allow_sign_alg_none": allow_sign_alg_none }) resp = AuthorizationResponse(state="state", code="code", id_token=_idt) @@ -407,13 +407,13 @@ def create_request(self): # add some history auth_request = AuthorizationRequest( redirect_uri="https://example.com/cli/authz_cb", state="state", response_type="code" - ).to_json() + ) - _stat_interface = entity.client_get("service_context").state - _stat_interface.store_item(auth_request, "auth_request", "state") + _current = entity.client_get("service_context").cstate + _current.update("state", auth_request) - auth_response = AuthorizationResponse(code="access_code").to_json() - _stat_interface.store_item(auth_response, "auth_response", "state") + auth_response = AuthorizationResponse(code="access_code") + _current.update("state", auth_response) def test_construct(self): req_args = {"foo": "bar"} @@ -464,11 +464,11 @@ def test_request_init(self): } def test_id_token_nonce_match(self): - _state_interface = self.service.client_get("service_context").state - _state_interface.store_nonce2state("nonce", "state") + _cstate = self.service.client_get("service_context").cstate + _cstate.bind_key("nonce", "state") resp = AccessTokenResponse() resp[verified_claim_name("id_token")] = {"nonce": "nonce"} - _state_interface.store_nonce2state("nonce2", "state2") + _cstate.bind_key("nonce2", "state2") with pytest.raises(ParameterError): self.service.update_service_context(resp, key="state2") @@ -730,19 +730,19 @@ def test_post_parse(self): "end_session_endpoint": "{}/end_session".format(OP_BASEURL), } _context = self.service.client_get("service_context") - assert _context.work_condition.use == {} + assert _context.work_environment.use == {} resp = self.service.post_parse_response(provider_info_response) iss_jwks = ISS_KEY.export_jwks_as_json(issuer_id=ISS) with responses.RequestsMock() as rsps: rsps.add("GET", resp["jwks_uri"], body=iss_jwks, status=200) - self.service.update_service_context(resp) + self.service.update_service_context(resp, '') # static client registration _context.map_preferred_to_registered() - use_copy = self.service.client_get("service_context").work_condition.use.copy() + use_copy = self.service.client_get("service_context").work_environment.use.copy() # jwks content will change dynamically between runs assert 'jwks' in use_copy del use_copy['jwks'] @@ -790,19 +790,19 @@ def test_post_parse_2(self): "end_session_endpoint": "{}/end_session".format(OP_BASEURL), } _context = self.service.client_get("service_context") - assert _context.work_condition.use == {} + assert _context.work_environment.use == {} resp = self.service.post_parse_response(provider_info_response) iss_jwks = ISS_KEY.export_jwks_as_json(issuer_id=ISS) with responses.RequestsMock() as rsps: rsps.add("GET", resp["jwks_uri"], body=iss_jwks, status=200) - self.service.update_service_context(resp) + self.service.update_service_context(resp, '') # static client registration _context.map_preferred_to_registered() - use_copy = self.service.client_get("service_context").work_condition.use.copy() + use_copy = self.service.client_get("service_context").work_environment.use.copy() # jwks content will change dynamically between runs assert 'jwks' in use_copy del use_copy['jwks'] @@ -882,7 +882,7 @@ def test_construct(self): 'userinfo_signed_response_alg'} def test_config_with_post_logout(self): - self.service.client_get("service_context").work_condition.set_preference( + self.service.client_get("service_context").work_environment.set_preference( "post_logout_redirect_uri", "https://example.com/post_logout") _req = self.service.construct() @@ -930,7 +930,6 @@ def test_config_with_required_request_uri(): 'token_endpoint_auth_signing_alg', 'userinfo_signed_response_alg'} - def test_config_logout_uri(): client_config = { "client_id": "client_id", @@ -997,16 +996,16 @@ def create_request(self): entity.client_get("service_context").issuer = "https://example.com" self.service = entity.client_get("service", "userinfo") - entity.client_get("service_context").work_condition.use = { + entity.client_get("service_context").work_environment.use = { "userinfo_signed_response_alg": "RS256", "userinfo_encrypted_response_alg": "RSA-OAEP", "userinfo_encrypted_response_enc": "A256GCM", } - _state_interface = self.service.client_get("service_context").state + _cstate = self.service.client_get("service_context").cstate # Add history - auth_response = AuthorizationResponse(code="access_code").to_json() - _state_interface.store_item(auth_response, "auth_response", "abcde") + auth_response = AuthorizationResponse(code="access_code") + _cstate.update("abcde", auth_response) idtval = {"nonce": "KUEYfRM2VzKDaaKD", "sub": "diana", "iss": ISS, "aud": "client_id"} idt = create_jws(idtval) @@ -1015,8 +1014,8 @@ def create_request(self): token_response = AccessTokenResponse( access_token="access_token", id_token=idt, __verified_id_token=ver_idt - ).to_json() - _state_interface.store_item(token_response, "token_response", "abcde") + ) + _cstate.update("abcde", token_response) def test_construct(self): _req = self.service.construct(state="abcde") @@ -1143,10 +1142,8 @@ def create_request(self): self.service = entity.client_get("service", "check_session") def test_construct(self): - _state_interface = self.service.client_get("service_context").state - _state_interface.store_item( - json.dumps({"id_token": "a.signed.jwt"}), "token_response", "abcde" - ) + _cstate = self.service.client_get("service_context").cstate + _cstate.update("abcde", {"id_token": "a.signed.jwt"}) _req = self.service.construct(state="abcde") assert isinstance(_req, CheckSessionRequest) assert len(_req) == 1 @@ -1173,10 +1170,8 @@ def create_request(self): self.service = entity.client_get("service", "check_id") def test_construct(self): - _state_interface = self.service.client_get("service_context").state - _state_interface.store_item( - json.dumps({"id_token": "a.signed.jwt"}), "token_response", "abcde" - ) + _cstate = self.service.client_get("service_context").cstate + _cstate.set("abcde", {"id_token": "a.signed.jwt"}) _req = self.service.construct(state="abcde") assert isinstance(_req, CheckIDRequest) assert len(_req) == 1 @@ -1207,9 +1202,8 @@ def create_request(self): self.service = entity.client_get("service", "end_session") def test_construct(self): - self.service.client_get("service_context").state.store_item( - json.dumps({"id_token": "a.signed.jwt"}), "token_response", "abcde" - ) + self.service.client_get("service_context").cstate.update( + "abcde", {"id_token": "a.signed.jwt"}) _req = self.service.construct(state="abcde") assert isinstance(_req, EndSessionRequest) assert len(_req) == 3 diff --git a/tests/test_client_22_oidc.py b/tests/test_client_22_oidc.py index 5e4ca7d7..f6ac3f9b 100755 --- a/tests/test_client_22_oidc.py +++ b/tests/test_client_22_oidc.py @@ -61,7 +61,7 @@ def test_construct_authorization_request(self): "nonce": "nonce", } - self.client.client_get("service_context").state.create_state("issuer", "ABCDE") + self.client.client_get("service_context").cstate.set("ABCDE", {'iss': "issuer"}) msg = self.client.client_get("service", "authorization").construct(request_args=req_args) assert isinstance(msg, AuthorizationRequest) @@ -71,14 +71,15 @@ def test_construct_accesstoken_request(self): _context = self.client.client_get("service_context") auth_request = AuthorizationRequest(redirect_uri="https://example.com/cli/authz_cb") - _state = _context.state.create_state("issuer") + _state = _context.cstate.create_key() + _context.cstate.set(_state, {'iss': "issuer"}) auth_request["state"] = _state - _context.state.store_item(auth_request, "auth_request", _state) + _context.cstate.update(_state, auth_request) auth_response = AuthorizationResponse(code="access_code") - _context.state.store_item(auth_response, "auth_response", _state) + _context.cstate.update(_state, auth_response) # Bind access code to state req_args = {} @@ -97,19 +98,19 @@ def test_construct_accesstoken_request(self): def test_construct_refresh_token_request(self): _context = self.client.client_get("service_context") - _context.state.create_state("issuer", "ABCDE") + _context.cstate.set("ABCDE", {'iss':"issuer"}) auth_request = AuthorizationRequest( redirect_uri="https://example.com/cli/authz_cb", state="state" ) - _context.state.store_item(auth_request, "auth_request", "ABCDE") + _context.cstate.update("ABCDE", auth_request) auth_response = AuthorizationResponse(code="access_code") - _context.state.store_item(auth_response, "auth_response", "ABCDE") + _context.cstate.set("ABCDE", auth_response) token_response = AccessTokenResponse(refresh_token="refresh_with_me", access_token="access") - _context.state.store_item(token_response, "token_response", "ABCDE") + _context.cstate.update("ABCDE", token_response) req_args = {} msg = self.client.client_get("service", "refresh_token").construct( @@ -125,23 +126,24 @@ def test_construct_refresh_token_request(self): def test_do_userinfo_request_init(self): _context = self.client.client_get("service_context") - _context.state.create_state("issuer", "ABCDE") + _state = _context.cstate.create_key() + _context.cstate.set(_state, {'iss': "issuer"}) auth_request = AuthorizationRequest( redirect_uri="https://example.com/cli/authz_cb", state="state" ) - _context.state.store_item(auth_request, "auth_request", "ABCDE") + _context.cstate.update(_state, auth_request) auth_response = AuthorizationResponse(code="access_code") - _context.state.store_item(auth_response, "auth_response", "ABCDE") + _context.cstate.update(_state, auth_response) token_response = AccessTokenResponse(refresh_token="refresh_with_me", access_token="access") - _context.state.store_item(token_response, "token_response", "ABCDE") + _context.cstate.update(_state, token_response) _srv = self.client.client_get("service", "userinfo") _srv.endpoint = "https://example.com/userinfo" - _info = _srv.get_request_parameters(state="ABCDE") + _info = _srv.get_request_parameters(state=_state) assert _info assert _info["headers"] == {"Authorization": "Bearer access"} assert _info["url"] == "https://example.com/userinfo" diff --git a/tests/test_client_23_pkce.py b/tests/test_client_23_pkce.py index 0cd4f195..22ca27c9 100644 --- a/tests/test_client_23_pkce.py +++ b/tests/test_client_23_pkce.py @@ -74,7 +74,7 @@ def create_client(self): def test_add_code_challenge_default_values(self): auth_serv = self.entity.client_get("service", "authorization") - _state_key = self.entity.client_get("service_context").state.create_state(iss="Issuer") + _state_key = self.entity.client_get("service_context").cstate.create_state(iss="Issuer") request_args, _ = add_code_challenge({"state": _state_key}, auth_serv) # default values are length:64 method:S256 @@ -86,7 +86,7 @@ def test_add_code_challenge_default_values(self): def test_authorization_and_pkce(self): auth_serv = self.entity.client_get("service", "authorization") - _state = self.entity.client_get("service_context").state.create_state(iss="Issuer") + _state = self.entity.client_get("service_context").cstate.create_state(iss="Issuer") request = auth_serv.construct_request({"state": _state, "response_type": "code"}) assert set(request.keys()) == { @@ -103,9 +103,7 @@ def test_access_token_and_pkce(self): request = authz_service.construct_request({"state": "state", "response_type": "code"}) _state = request["state"] auth_response = AuthorizationResponse(code="access code") - self.entity.client_get("service_context").state.store_item( - auth_response, "auth_response", _state - ) + self.entity.client_get("service_context").cstate.update(_state, auth_response) token_service = self.entity.client_get("service", "accesstoken") request = token_service.construct_request(state=_state) diff --git a/tests/test_client_24_oic_utils.py b/tests/test_client_24_oic_utils.py index 4be09df0..f903d992 100644 --- a/tests/test_client_24_oic_utils.py +++ b/tests/test_client_24_oic_utils.py @@ -27,7 +27,7 @@ def test_request_object_encryption(): "client_secret": "abcdefghijklmnop", } service_context = ServiceContext(keyjar=KEYJAR, config=conf) - _condition = service_context.work_condition + _condition = service_context.work_environment _condition.set_usage("request_object_encryption_alg", "RSA1_5") _condition.set_usage("request_object_encryption_enc", "A128CBC-HS256") diff --git a/tests/test_client_25_cc_oauth2_service.py b/tests/test_client_25_cc_oauth2_service.py index dfc4251f..282b066e 100644 --- a/tests/test_client_25_cc_oauth2_service.py +++ b/tests/test_client_25_cc_oauth2_service.py @@ -63,9 +63,7 @@ def test_token_parse_response(self): # since no state attribute is involved, a key is minted _key = rndstr(16) _srv.update_service_context(_response, key=_key) - info = _srv.client_get("service_context").state.get_item( - AccessTokenResponse, "token_response", _key - ) + info = _srv.client_get("service_context").cstate.get(_key) assert "__expires_at" in info def test_refresh_token_get_request(self): @@ -80,8 +78,7 @@ def test_refresh_token_get_request(self): } ) _srv = self.entity.client_get("service", "refresh_token") - _id = rndstr(16) - _info = _srv.get_request_parameters(state_id=_id) + _info = _srv.get_request_parameters(state='') assert _info["method"] == "POST" assert _info["url"] == "https://example.com/token" assert _info["body"] == "grant_type=refresh_token" @@ -109,9 +106,7 @@ def test_refresh_token_parse_response(self): # since no state attribute is involved, a key is minted _key = rndstr(16) _srv.update_service_context(_response, key=_key) - info = _srv.client_get("service_context").state.get_item( - AccessTokenResponse, "token_response", _key - ) + info = _srv.client_get("service_context").cstate.get(_key) assert "__expires_at" in info # Move from token to refresh token service @@ -130,9 +125,7 @@ def test_refresh_token_parse_response(self): _response = _srv.parse_response(refresh_response.to_json(), sformat="json") _srv.update_service_context(_response, key=_key) - info = _srv.client_get("service_context").state.get_item( - AccessTokenResponse, "token_response", _key - ) + info = _srv.client_get("service_context").cstate.get(_key) assert "__expires_at" in info def test_2nd_refresh_token_parse_response(self): @@ -154,9 +147,7 @@ def test_2nd_refresh_token_parse_response(self): # since no state attribute is involved, a key is minted _key = rndstr(16) _srv.update_service_context(_response, key=_key) - info = _srv.client_get("service_context").state.get_item( - AccessTokenResponse, "token_response", _key - ) + info = _srv.client_get("service_context").cstate.get(_key) assert "__expires_at" in info # Move from token to refresh token service @@ -175,9 +166,7 @@ def test_2nd_refresh_token_parse_response(self): _response = _srv.parse_response(refresh_response.to_json(), sformat="json") _srv.update_service_context(_response, key=_key) - info = _srv.client_get("service_context").state.get_item( - AccessTokenResponse, "token_response", _key - ) + info = _srv.client_get("service_context").cstate.get(_key) assert "__expires_at" in info _request_info = _srv.get_request_parameters(request_args=request_args, state=_key) diff --git a/tests/test_client_27_conversation.py b/tests/test_client_27_conversation.py index 342e11fd..7c99920b 100644 --- a/tests/test_client_27_conversation.py +++ b/tests/test_client_27_conversation.py @@ -404,7 +404,7 @@ def test_conversation(): resp = provider_info_service.parse_response(provider_info_response) assert isinstance(resp, ProviderConfigurationResponse) - provider_info_service.update_service_context(resp) + provider_info_service.update_service_context(resp, '') _pi = entity.client_get("service_context").provider_info assert _pi["issuer"] == OP_BASEURL @@ -481,7 +481,7 @@ def test_conversation(): NONCE = "UvudLKz287YByZdsY3AJoPAlEXQkJ0dK" auth_service = entity.client_get("service", "authorization") - _state_interface = service_context.state + _cstate = service_context.cstate info = auth_service.get_request_parameters(request_args={"state": STATE, "nonce": NONCE}) @@ -511,7 +511,7 @@ def test_conversation(): _resp = auth_service.parse_response(_authz_rep.to_urlencoded()) auth_service.update_service_context(_resp, key=STATE) - _item = _state_interface.get_item(AuthorizationResponse, "auth_response", STATE) + _item = _cstate.get(STATE) assert _item["code"] == "Z0FBQUFBQmFkdFFjUVpFWE81SHU5N1N4N01" # =================== Access token ==================== @@ -562,18 +562,22 @@ def test_conversation(): token_service.update_service_context(_resp, key=STATE) - _item = _state_interface.get_item(AccessTokenResponse, "token_response", STATE) - - assert set(_item.keys()) == { - "state", - "scope", - "access_token", - "token_type", - "id_token", - "__verified_id_token", - "expires_in", - "__expires_at", - } + _item = _cstate.get(STATE) + + assert set(_item.keys()) == {'__expires_at', + '__verified_id_token', + 'access_token', + 'client_id', + 'code', + 'expires_in', + 'id_token', + 'iss', + 'nonce', + 'redirect_uri', + 'response_type', + 'scope', + 'state', + 'token_type'} assert _item["token_type"] == "Bearer" assert _item["access_token"] == "Z0FBQUFBQmFkdFF" @@ -594,5 +598,5 @@ def test_conversation(): assert isinstance(_resp, OpenIDSchema) assert _resp.to_dict() == {"sub": "1b2fc9341a16ae4e30082965d537"} - _item = _state_interface.get_item(OpenIDSchema, "user_info", STATE) - assert _item.to_dict() == {"sub": "1b2fc9341a16ae4e30082965d537"} + _item = _cstate.get_set(STATE, message=OpenIDSchema) + assert _item == {"sub": "1b2fc9341a16ae4e30082965d537"} diff --git a/tests/test_client_28_rp_handler_oidc.py b/tests/test_client_28_rp_handler_oidc.py index a2569858..566430a8 100644 --- a/tests/test_client_28_rp_handler_oidc.py +++ b/tests/test_client_28_rp_handler_oidc.py @@ -373,7 +373,7 @@ def test_get_client_from_session_key(self): res = self.rph.begin(issuer_id="linkedin") cli1 = self.rph.get_client_from_session_key(state=res["state"]) _session = self.rph.get_session_information(res["state"]) - cli2 = self.rph.issuer2rp[_session["iss"]] + cli2 = self.rph.issuer2rp[_session['iss']] assert cli1 == cli2 # redo self.rph.do_provider_info(state=res["state"]) @@ -384,42 +384,47 @@ def test_get_client_from_session_key(self): def test_finalize_auth(self): res = self.rph.begin(issuer_id="linkedin") _session = self.rph.get_session_information(res["state"]) - client = self.rph.issuer2rp[_session["iss"]] + client = self.rph.issuer2rp[_session['iss']] auth_response = AuthorizationResponse(code="access_code", state=res["state"]) - resp = self.rph.finalize_auth(client, _session["iss"], auth_response.to_dict()) + resp = self.rph.finalize_auth(client, _session['iss'], auth_response.to_dict()) assert set(resp.keys()) == {"state", "code"} - aresp = client.client_get("service_context").state.get_item( - AuthorizationResponse, "auth_response", res["state"] - ) - assert set(aresp.keys()) == {"state", "code"} + _state = client.client_get("service_context").cstate.get(res["state"]) + assert set(_state.keys()) == {'client_id', + 'code', + 'iss', + 'nonce', + 'redirect_uri', + 'response_type', + 'scope', + 'state'} def test_get_client_authn_method(self): res = self.rph.begin(issuer_id="github") _session = self.rph.get_session_information(res["state"]) - client = self.rph.issuer2rp[_session["iss"]] + client = self.rph.issuer2rp[_session['iss']] authn_method = self.rph.get_client_authn_method(client, "token_endpoint") assert authn_method == '' res = self.rph.begin(issuer_id="linkedin") _session = self.rph.get_session_information(res["state"]) - client = self.rph.issuer2rp[_session["iss"]] + client = self.rph.issuer2rp[_session['iss']] authn_method = self.rph.get_client_authn_method(client, "token_endpoint") assert authn_method == "client_secret_post" def test_get_tokens(self): res = self.rph.begin(issuer_id="github") _session = self.rph.get_session_information(res["state"]) - client = self.rph.issuer2rp[_session["iss"]] + client = self.rph.issuer2rp[_session['iss']] _github_id = iss_id("github") _context = client.client_get("service_context") _context.keyjar.import_jwks(GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) - _nonce = _session["auth_request"]["nonce"] - _iss = _session["iss"] + _nonce = _session["nonce"] + _iss = _session['iss'] _aud = _context.get_client_id() - idval = {"nonce": _nonce, "sub": "EndUserSubject", "iss": _iss, "aud": _aud} + idval = {"nonce": _nonce, "sub": "EndUserSubject", 'iss': _iss, "aud": _aud} idts = IdToken(**idval) _signed_jwt = idts.to_jwt( @@ -446,7 +451,7 @@ def test_get_tokens(self): client.client_get("service", "accesstoken").endpoint = _url auth_response = AuthorizationResponse(code="access_code", state=res["state"]) - resp = self.rph.finalize_auth(client, _session["iss"], auth_response.to_dict()) + resp = self.rph.finalize_auth(client, _session['iss'], auth_response.to_dict()) resp = self.rph.get_tokens(res["state"], client) assert set(resp.keys()) == { @@ -458,27 +463,31 @@ def test_get_tokens(self): "__expires_at", } - atresp = client.client_get("service_context").state.get_item( - AccessTokenResponse, "token_response", res["state"] - ) - assert set(atresp.keys()) == { - "access_token", - "expires_in", - "id_token", - "token_type", - "__verified_id_token", - "__expires_at", - } + _curr = client.client_get("service_context").cstate.get(res["state"]) + assert set(_curr.keys()) == {'__expires_at', + '__verified_id_token', + 'access_token', + 'client_id', + 'code', + 'expires_in', + 'id_token', + 'iss', + 'nonce', + 'redirect_uri', + 'response_type', + 'scope', + 'state', + 'token_type'} def test_access_and_id_token(self): res = self.rph.begin(issuer_id="github") _session = self.rph.get_session_information(res["state"]) - client = self.rph.issuer2rp[_session["iss"]] + client = self.rph.issuer2rp[_session['iss']] _context = client.client_get("service_context") - _nonce = _session["auth_request"]["nonce"] - _iss = _session["iss"] + _nonce = _session["nonce"] + _iss = _session['iss'] _aud = _context.get_client_id() - idval = {"nonce": _nonce, "sub": "EndUserSubject", "iss": _iss, "aud": _aud} + idval = {"nonce": _nonce, "sub": "EndUserSubject", 'iss': _iss, "aud": _aud} _github_id = iss_id("github") _context.keyjar.import_jwks(GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) @@ -510,7 +519,7 @@ def test_access_and_id_token(self): client.client_get("service", "accesstoken").endpoint = _url _response = AuthorizationResponse(code="access_code", state=res["state"]) - auth_response = self.rph.finalize_auth(client, _session["iss"], _response.to_dict()) + auth_response = self.rph.finalize_auth(client, _session['iss'], _response.to_dict()) resp = self.rph.get_access_and_id_token(auth_response, client=client) assert resp["access_token"] == "accessTok" assert isinstance(resp["id_token"], IdToken) @@ -518,12 +527,12 @@ def test_access_and_id_token(self): def test_access_and_id_token_by_reference(self): res = self.rph.begin(issuer_id="github") _session = self.rph.get_session_information(res["state"]) - client = self.rph.issuer2rp[_session["iss"]] + client = self.rph.issuer2rp[_session['iss']] _context = client.client_get("service_context") - _nonce = _session["auth_request"]["nonce"] - _iss = _session["iss"] + _nonce = _session["nonce"] + _iss = _session['iss'] _aud = _context.get_client_id() - idval = {"nonce": _nonce, "sub": "EndUserSubject", "iss": _iss, "aud": _aud} + idval = {"nonce": _nonce, "sub": "EndUserSubject", 'iss': _iss, "aud": _aud} _github_id = iss_id("github") _context.keyjar.import_jwks(GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) @@ -555,7 +564,7 @@ def test_access_and_id_token_by_reference(self): client.client_get("service", "accesstoken").endpoint = _url _response = AuthorizationResponse(code="access_code", state=res["state"]) - _ = self.rph.finalize_auth(client, _session["iss"], _response.to_dict()) + _ = self.rph.finalize_auth(client, _session['iss'], _response.to_dict()) resp = self.rph.get_access_and_id_token(state=res["state"]) assert resp["access_token"] == "accessTok" assert isinstance(resp["id_token"], IdToken) @@ -563,12 +572,12 @@ def test_access_and_id_token_by_reference(self): def test_get_user_info(self): res = self.rph.begin(issuer_id="github") _session = self.rph.get_session_information(res["state"]) - client = self.rph.issuer2rp[_session["iss"]] + client = self.rph.issuer2rp[_session['iss']] _context = client.client_get("service_context") - _nonce = _session["auth_request"]["nonce"] - _iss = _session["iss"] + _nonce = _session["nonce"] + _iss = _session['iss'] _aud = _context.get_client_id() - idval = {"nonce": _nonce, "sub": "EndUserSubject", "iss": _iss, "aud": _aud} + idval = {"nonce": _nonce, "sub": "EndUserSubject", 'iss': _iss, "aud": _aud} _github_id = iss_id("github") _context.keyjar.import_jwks(GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) @@ -600,7 +609,7 @@ def test_get_user_info(self): client.client_get("service", "accesstoken").endpoint = _url _response = AuthorizationResponse(code="access_code", state=res["state"]) - auth_response = self.rph.finalize_auth(client, _session["iss"], _response.to_dict()) + auth_response = self.rph.finalize_auth(client, _session['iss'], _response.to_dict()) token_resp = self.rph.get_access_and_id_token(auth_response, client=client) @@ -621,15 +630,15 @@ def test_get_user_info(self): def test_userinfo_in_id_token(self): res = self.rph.begin(issuer_id="github") _session = self.rph.get_session_information(res["state"]) - client = self.rph.issuer2rp[_session["iss"]] + client = self.rph.issuer2rp[_session['iss']] _context = client.client_get("service_context") - _nonce = _session["auth_request"]["nonce"] - _iss = _session["iss"] + _nonce = _session["nonce"] + _iss = _session['iss'] _aud = _context.get_client_id() idval = { "nonce": _nonce, "sub": "EndUserSubject", - "iss": _iss, + 'iss': _iss, "aud": _aud, "given_name": "Diana", "family_name": "Krall", @@ -654,12 +663,12 @@ def rphandler_setup(self): self.rph = RPHandler(BASE_URL, CLIENT_CONFIG, keyjar=CLI_KEY) res = self.rph.begin(issuer_id="github") _session = self.rph.get_session_information(res["state"]) - client = self.rph.issuer2rp[_session["iss"]] + client = self.rph.issuer2rp[_session['iss']] _context = client.client_get("service_context") - _nonce = _session["auth_request"]["nonce"] - _iss = _session["iss"] + _nonce = _session["nonce"] + _iss = _session['iss'] _aud = _context.get_client_id() - idval = {"nonce": _nonce, "sub": "EndUserSubject", "iss": _iss, "aud": _aud} + idval = {"nonce": _nonce, "sub": "EndUserSubject", 'iss': _iss, "aud": _aud} _github_id = iss_id("github") _context.keyjar.import_jwks(GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) @@ -693,7 +702,7 @@ def rphandler_setup(self): client.client_get("service", "accesstoken").endpoint = _url _response = AuthorizationResponse(code="access_code", state=res["state"]) - auth_response = self.rph.finalize_auth(client, _session["iss"], _response.to_dict()) + auth_response = self.rph.finalize_auth(client, _session['iss'], _response.to_dict()) token_resp = self.rph.get_access_and_id_token(auth_response, client=client) @@ -713,7 +722,7 @@ def rphandler_setup(self): def test_init_authorization(self): _session = self.rph.get_session_information(self.state) - client = self.rph.issuer2rp[_session["iss"]] + client = self.rph.issuer2rp[_session['iss']] res = self.rph.init_authorization(client, req_args={"scope": ["openid", "email"]}) part = urlsplit(res["url"]) _qp = parse_qs(part.query) @@ -721,7 +730,7 @@ def test_init_authorization(self): def test_refresh_access_token(self): _session = self.rph.get_session_information(self.state) - client = self.rph.issuer2rp[_session["iss"]] + client = self.rph.issuer2rp[_session['iss']] _info = {"access_token": "2nd_accessTok", "token_type": "Bearer", "expires_in": 3600} at = AccessTokenResponse(**_info) @@ -741,7 +750,7 @@ def test_refresh_access_token(self): def test_get_user_info(self): _session = self.rph.get_session_information(self.state) - client = self.rph.issuer2rp[_session["iss"]] + client = self.rph.issuer2rp[_session['iss']] _url = "https://github.com/userinfo" with responses.RequestsMock() as rsps: @@ -823,7 +832,7 @@ def __call__(self, url, method="GET", data=None, headers=None, **kwargs): def construct_access_token_response(nonce, issuer, client_id, key_jar): _aud = client_id - idval = {"nonce": nonce, "sub": "EndUserSubject", "iss": issuer, "aud": _aud} + idval = {"nonce": nonce, "sub": "EndUserSubject", 'iss': issuer, "aud": _aud} idts = IdToken(**idval) _signed_jwt = idts.to_jwt( @@ -894,7 +903,7 @@ def test_finalize(self): # Faking resp = construct_access_token_response( - _session["auth_request"]["nonce"], + _session["nonce"], issuer=self.issuer, client_id=CLIENT_CONFIG["github"]["client_id"], key_jar=GITHUB_KEY, @@ -920,7 +929,7 @@ def test_finalize(self): # do the rest (= get access token and user info) # assume code flow - resp = self.rph.finalize(_session["iss"], auth_response.to_dict()) + resp = self.rph.finalize(_session['iss'], auth_response.to_dict()) assert set(resp.keys()) == {"userinfo", "state", "token", "id_token", "session_state"} diff --git a/tests/test_client_30_rph_defaults.py b/tests/test_client_30_rph_defaults.py index b2406862..bff069ee 100644 --- a/tests/test_client_30_rph_defaults.py +++ b/tests/test_client_30_rph_defaults.py @@ -35,7 +35,7 @@ def test_init_client(self): _context = client.client_get("service_context") - assert set(_context.work_condition.prefer.keys()) == { + assert set(_context.work_environment.prefer.keys()) == { 'application_type', 'callback_uris', 'id_token_encryption_alg_values_supported', @@ -95,7 +95,7 @@ def test_begin(self): self.rph.issuer2rp[issuer] = client - assert set(_context.work_condition.use.keys()) == {'application_type', + assert set(_context.work_environment.use.keys()) == {'application_type', 'callback_uris', 'client_id', 'client_secret', diff --git a/tests/test_client_31_oauth2_persistent.py b/tests/test_client_31_oauth2_persistent.py index 8af0a63f..6c6e3dd1 100644 --- a/tests/test_client_31_oauth2_persistent.py +++ b/tests/test_client_31_oauth2_persistent.py @@ -53,13 +53,13 @@ def test_construct_accesstoken_request(self): # Client 1 starts the chain of event client_1 = Client(config=CONF) _context_1 = client_1.client_get("service_context") - _state = _context_1.state.create_state("issuer") + _state = _context_1.cstate.create_state(iss="issuer") auth_request = AuthorizationRequest( redirect_uri="https://example.com/cli/authz_cb", state=_state ) - _context_1.state.store_item(auth_request, "auth_request", _state) + _context_1.cstate.update(_state, auth_request) # Client 2 carries on client_2 = Client(config=CONF) @@ -69,7 +69,7 @@ def test_construct_accesstoken_request(self): _context2.load(_state_dump) auth_response = AuthorizationResponse(code="access_code") - _context2.state.store_item(auth_response, "auth_response", _state) + _context2.cstate.update(_state, auth_response) msg = client_2.client_get("service", "accesstoken").construct(request_args={}, state=_state) @@ -86,15 +86,13 @@ def test_construct_accesstoken_request(self): def test_construct_refresh_token_request(self): # Client 1 starts the chain event client_1 = Client(config=CONF) - _state = client_1.client_get("service_context").state.create_state("issuer") + _state = client_1.client_get("service_context").cstate.create_state(iss="issuer") auth_request = AuthorizationRequest( redirect_uri="https://example.com/cli/authz_cb", state=_state ) - client_1.client_get("service_context").state.store_item( - auth_request, "auth_request", _state - ) + client_1.client_get("service_context").cstate.update(_state, auth_request) # Client 2 carries on client_2 = Client(config=CONF) @@ -102,15 +100,11 @@ def test_construct_refresh_token_request(self): client_2.client_get("service_context").load(_state_dump) auth_response = AuthorizationResponse(code="access_code") - client_2.client_get("service_context").state.store_item( - auth_response, "auth_response", _state - ) + client_2.client_get("service_context").cstate.update(_state, auth_response) token_response = AccessTokenResponse(refresh_token="refresh_with_me", access_token="access") - client_2.client_get("service_context").state.store_item( - token_response, "token_response", _state - ) + client_2.client_get("service_context").cstate.update(_state, token_response) # Next up is Client 1 _state_dump = client_2.client_get("service_context").dump() diff --git a/tests/test_client_32_oidc_persistent.py b/tests/test_client_32_oidc_persistent.py index 3a639b16..cd8d75fc 100755 --- a/tests/test_client_32_oidc_persistent.py +++ b/tests/test_client_32_oidc_persistent.py @@ -51,13 +51,11 @@ class TestClient(object): def test_construct_accesstoken_request(self): # Client 1 starts client_1 = RP(config=CONF) - _state = client_1.client_get("service_context").state.create_state(ISSUER) + _state = client_1.client_get("service_context").cstate.create_state(iss=ISSUER) auth_request = AuthorizationRequest( redirect_uri="https://example.com/cli/authz_cb", state=_state ) - client_1.client_get("service_context").state.store_item( - auth_request, "auth_request", _state - ) + client_1.client_get("service_context").cstate.update(_state, auth_request) # Client 2 carries on client_2 = RP(config=CONF) @@ -65,9 +63,7 @@ def test_construct_accesstoken_request(self): client_2.client_get("service_context").load(_state_dump) auth_response = AuthorizationResponse(code="access_code") - client_2.client_get("service_context").state.store_item( - auth_response, "auth_response", _state - ) + client_2.client_get("service_context").cstate.update(_state, auth_response) # Bind access code to state req_args = {} @@ -87,15 +83,13 @@ def test_construct_accesstoken_request(self): def test_construct_refresh_token_request(self): # Client 1 starts client_1 = RP(config=CONF) - _state = client_1.client_get("service_context").state.create_state(ISSUER) + _state = client_1.client_get("service_context").cstate.create_state(iss=ISSUER) auth_request = AuthorizationRequest( redirect_uri="https://example.com/cli/authz_cb", state=_state ) - client_1.client_get("service_context").state.store_item( - auth_request, "auth_request", _state - ) + client_1.client_get("service_context").cstate.update(_state,auth_request) # Client 2 carries on client_2 = RP(config=CONF) @@ -103,14 +97,10 @@ def test_construct_refresh_token_request(self): client_2.client_get("service_context").load(_state_dump) auth_response = AuthorizationResponse(code="access_code") - client_2.client_get("service_context").state.store_item( - auth_response, "auth_response", _state - ) + client_2.client_get("service_context").cstate.update(_state, auth_response) token_response = AccessTokenResponse(refresh_token="refresh_with_me", access_token="access") - client_2.client_get("service_context").state.store_item( - token_response, "token_response", _state - ) + client_2.client_get("service_context").cstate.update(_state,token_response ) # Back to Client 1 _state_dump = client_2.client_get("service_context").dump() @@ -131,7 +121,7 @@ def test_construct_refresh_token_request(self): def test_do_userinfo_request_init(self): # Client 1 starts client_1 = RP(config=CONF) - _state = client_1.client_get("service_context").state.create_state(ISSUER) + _state = client_1.client_get("service_context").cstate.create_state(iss=ISSUER) auth_request = AuthorizationRequest( redirect_uri="https://example.com/cli/authz_cb", state="state" @@ -143,14 +133,10 @@ def test_do_userinfo_request_init(self): client_2.client_get("service_context").load(_state_dump) auth_response = AuthorizationResponse(code="access_code") - client_2.client_get("service_context").state.store_item( - auth_response, "auth_response", _state - ) + client_2.client_get("service_context").cstate.update(_state,auth_response) token_response = AccessTokenResponse(refresh_token="refresh_with_me", access_token="access") - client_2.client_get("service_context").state.store_item( - token_response, "token_response", _state - ) + client_2.client_get("service_context").cstate.update(_state,token_response) # Back to Client 1 _state_dump = client_2.client_get("service_context").dump() diff --git a/tests/test_client_41_rp_handler_persistent.py b/tests/test_client_41_rp_handler_persistent.py index b858e6f1..2558938c 100644 --- a/tests/test_client_41_rp_handler_persistent.py +++ b/tests/test_client_41_rp_handler_persistent.py @@ -2,8 +2,8 @@ from urllib.parse import parse_qs from urllib.parse import urlsplit -import responses from cryptojwt.key_jar import init_key_jar +import responses from idpyoidc.client.rp_handler import RPHandler from idpyoidc.message.oidc import AccessTokenResponse @@ -326,10 +326,11 @@ def test_finalize_auth(self): assert set(resp.keys()) == {"state", "code"} aresp = ( client.client_get("service", "authorization") - .client_get("service_context") - .state.get_item(AuthorizationResponse, "auth_response", res["state"]) + .client_get("service_context").cstate.get(res["state"]) ) - assert set(aresp.keys()) == {"state", "code"} + assert set(aresp.keys()) == { + "state", "code", 'iss', 'client_id', + 'scope', 'nonce', 'response_type', 'redirect_uri'} def test_get_client_authn_method(self): rph_1 = RPHandler( @@ -361,7 +362,7 @@ def test_get_tokens(self): _context = client.client_get("service_context") _context.keyjar.import_jwks(GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) - _nonce = _session["auth_request"]["nonce"] + _nonce = _session["nonce"] _iss = _session["iss"] _aud = _context.get_client_id() idval = {"nonce": _nonce, "sub": "EndUserSubject", "iss": _iss, "aud": _aud} @@ -406,15 +407,23 @@ def test_get_tokens(self): atresp = ( client.client_get("service", "accesstoken") .client_get("service_context") - .state.get_item(AccessTokenResponse, "token_response", res["state"]) + .cstate.get(res["state"]) ) assert set(atresp.keys()) == { + "__expires_at", + "__verified_id_token", "access_token", + 'client_id', + 'code', "expires_in", "id_token", - "token_type", - "__verified_id_token", - "__expires_at", + 'iss', + 'nonce', + 'redirect_uri', + 'response_type', + 'scope', + 'state', + "token_type" } def test_access_and_id_token(self): @@ -426,7 +435,7 @@ def test_access_and_id_token(self): _session = rph_1.get_session_information(res["state"]) client = rph_1.issuer2rp[_session["iss"]] _context = client.client_get("service_context") - _nonce = _session["auth_request"]["nonce"] + _nonce = _session["nonce"] _iss = _session["iss"] _aud = _context.get_client_id() idval = {"nonce": _nonce, "sub": "EndUserSubject", "iss": _iss, "aud": _aud} @@ -475,7 +484,7 @@ def test_access_and_id_token_by_reference(self): _session = rph_1.get_session_information(res["state"]) client = rph_1.issuer2rp[_session["iss"]] _context = client.client_get("service_context") - _nonce = _session["auth_request"]["nonce"] + _nonce = _session["nonce"] _iss = _session["iss"] _aud = _context.get_client_id() idval = {"nonce": _nonce, "sub": "EndUserSubject", "iss": _iss, "aud": _aud} @@ -524,7 +533,7 @@ def test_get_user_info(self): _session = rph_1.get_session_information(res["state"]) client = rph_1.issuer2rp[_session["iss"]] _context = client.client_get("service_context") - _nonce = _session["auth_request"]["nonce"] + _nonce = _session["nonce"] _iss = _session["iss"] _aud = _context.get_client_id() idval = {"nonce": _nonce, "sub": "EndUserSubject", "iss": _iss, "aud": _aud} @@ -586,7 +595,7 @@ def test_userinfo_in_id_token(self): _session = rph_1.get_session_information(res["state"]) client = rph_1.issuer2rp[_session["iss"]] # _context = client.client_get("service_context") - _nonce = _session["auth_request"]["nonce"] + _nonce = _session["nonce"] _iss = _session["iss"] _aud = client.get_client_id() idval = { diff --git a/tests/test_client_51_identity_assurance.py b/tests/test_client_51_identity_assurance.py index 454707ed..d64e906a 100644 --- a/tests/test_client_51_identity_assurance.py +++ b/tests/test_client_51_identity_assurance.py @@ -36,14 +36,14 @@ def create_request(self): entity.client_get("service_context").issuer = "https://server.otherop.com" self.service = entity.client_get("service", "userinfo") - entity.client_get("service_context").work_condition.use = { + entity.client_get("service_context").work_environment.use = { "userinfo_signed_response_alg": "RS256", "userinfo_encrypted_response_alg": "RSA-OAEP", "userinfo_encrypted_response_enc": "A256GCM", } def test_unpack_aggregated_response(self): - _state_interface = self.service.client_get("service_context").state + _cstate = self.service.client_get("service_context").cstate # Add history auth_request = AuthorizationRequest( redirect_uri="https://example.com/cli/authz_cb", @@ -56,12 +56,12 @@ def test_unpack_aggregated_response(self): } }, ) - _state = _state_interface.create_state("issuer") + _state = _cstate.create_state(iss="issuer") auth_request["state"] = _state - _state_interface.store_item(auth_request, "auth_request", _state) + _cstate.update(_state, auth_request) - auth_response = AuthorizationResponse(code="access_code").to_json() - _state_interface.store_item(auth_response, "auth_response", "abcde") + auth_response = AuthorizationResponse(code="access_code") + _cstate.update("abcde", auth_response) _distributed_respone = { "iss": "https://server.otherop.com", diff --git a/tests/test_client_55_token_exchange.py b/tests/test_client_55_token_exchange.py index 707dd14d..197c8b7d 100644 --- a/tests/test_client_55_token_exchange.py +++ b/tests/test_client_55_token_exchange.py @@ -70,20 +70,19 @@ def create_request(self): entity.client_get("service_context").issuer = "https://example.com" self.service = entity.client_get("service", "token_exchange") - _state_interface = self.service.client_get("service_context").state + _cstate = self.service.client_get("service_context").cstate # Add history - auth_response = AuthorizationResponse(code="access_code").to_json() - _state_interface.store_item(auth_response, "auth_response", "abcde") + auth_response = AuthorizationResponse(code="access_code") + _cstate.update("abcde", auth_response) idtval = {"nonce": "KUEYfRM2VzKDaaKD", "sub": "diana", "iss": ISS, "aud": "client_id"} idt = create_jws(idtval) ver_idt = IdToken().from_jwt(idt, make_keyjar()) - token_response = AccessTokenResponse( - access_token="access_token", id_token=idt, __verified_id_token=ver_idt - ).to_json() - _state_interface.store_item(token_response, "token_response", "abcde") + token_response = AccessTokenResponse(access_token="access_token", id_token=idt, + __verified_id_token=ver_idt) + _cstate.update("abcde", token_response) def test_construct(self): _req = self.service.construct(state="abcde") diff --git a/tests/test_tandem_10_token_exchange.py b/tests/test_tandem_10_token_exchange.py index b462b6c4..d3e50ea8 100644 --- a/tests/test_tandem_10_token_exchange.py +++ b/tests/test_tandem_10_token_exchange.py @@ -269,8 +269,8 @@ def process_setup(self, token=None, scope=None): _nonce = rndstr(24), _context = self.client_1.get_service_context() # Need a new state for a new authorization request - _state = _context.state.create_state(_context.get("issuer")) - _context.state.store_nonce2state(_nonce, _state) + _state = _context.cstate.create_state(iss=_context.get("issuer")) + _context.cstate.bind_key(_nonce, _state) req_args = { "response_type": ["code"], diff --git a/tests/test_client_00_state.py b/tests/xtest_client_00_state.py similarity index 100% rename from tests/test_client_00_state.py rename to tests/xtest_client_00_state.py From e11d01439342fe869a26b30003cb0558ce6de1b6 Mon Sep 17 00:00:00 2001 From: Roland Hedberg Date: Thu, 1 Dec 2022 10:30:42 +0100 Subject: [PATCH 23/76] Make server side also use WorkEnvironment --- src/idpyoidc/client/oauth2/access_token.py | 2 +- src/idpyoidc/client/oidc/access_token.py | 2 +- src/idpyoidc/client/oidc/authorization.py | 2 +- src/idpyoidc/client/oidc/userinfo.py | 6 +- src/idpyoidc/client/provider/github.py | 4 +- src/idpyoidc/client/provider/linkedin.py | 4 +- src/idpyoidc/client/service_context.py | 14 +- src/idpyoidc/client/state_interface.py | 383 ------------------ .../client/work_environment/__init__.py | 245 ----------- .../client/work_environment/oauth2.py | 2 +- src/idpyoidc/client/work_environment/oidc.py | 2 +- src/idpyoidc/server/__init__.py | 10 +- src/idpyoidc/server/client_authn.py | 31 +- src/idpyoidc/server/endpoint.py | 12 +- src/idpyoidc/server/endpoint_context.py | 58 ++- src/idpyoidc/server/oidc/authorization.py | 34 +- .../server/oidc/backchannel_authentication.py | 3 +- src/idpyoidc/server/oidc/provider_config.py | 2 +- src/idpyoidc/server/oidc/session.py | 3 +- src/idpyoidc/server/oidc/token.py | 7 +- src/idpyoidc/server/oidc/userinfo.py | 13 +- .../server/work_environment/__init__.py | 8 + .../server/work_environment/oauth2.py | 34 ++ src/idpyoidc/server/work_environment/oidc.py | 67 +++ src/idpyoidc/work_environment.py | 248 ++++++++++++ tests/static/jwks.json | 2 +- tests/test_08_transform.py | 107 +++-- tests/test_09_work_condition.py | 2 - tests/test_12_context.py | 88 ---- tests/test_client_06_client_authn.py | 2 +- tests/test_client_19_webfinger.py | 3 +- tests/test_server_00a_client_configure.py | 4 +- 32 files changed, 543 insertions(+), 861 deletions(-) delete mode 100644 src/idpyoidc/client/state_interface.py create mode 100644 src/idpyoidc/server/work_environment/__init__.py create mode 100644 src/idpyoidc/server/work_environment/oauth2.py create mode 100644 src/idpyoidc/server/work_environment/oidc.py create mode 100644 src/idpyoidc/work_environment.py delete mode 100644 tests/test_12_context.py diff --git a/src/idpyoidc/client/oauth2/access_token.py b/src/idpyoidc/client/oauth2/access_token.py index e8f6a076..df0804d0 100644 --- a/src/idpyoidc/client/oauth2/access_token.py +++ b/src/idpyoidc/client/oauth2/access_token.py @@ -5,10 +5,10 @@ from idpyoidc.client.oauth2.utils import get_state_parameter from idpyoidc.client.service import Service from idpyoidc.client.work_environment import get_client_authn_methods -from idpyoidc.client.work_environment import get_signing_algs from idpyoidc.message import oauth2 from idpyoidc.message.oauth2 import ResponseMessage from idpyoidc.time_util import time_sans_frac +from idpyoidc.work_environment import get_signing_algs LOGGER = logging.getLogger(__name__) diff --git a/src/idpyoidc/client/oidc/access_token.py b/src/idpyoidc/client/oidc/access_token.py index 342bb4db..3c392615 100644 --- a/src/idpyoidc/client/oidc/access_token.py +++ b/src/idpyoidc/client/oidc/access_token.py @@ -6,7 +6,7 @@ from idpyoidc.client.oauth2 import access_token from idpyoidc.client.oidc import IDT2REG from idpyoidc.client.work_environment import get_client_authn_methods -from idpyoidc.client.work_environment import get_signing_algs +from idpyoidc.work_environment import get_signing_algs from idpyoidc.message import Message from idpyoidc.message import oidc from idpyoidc.message.oidc import verified_claim_name diff --git a/src/idpyoidc/client/oidc/authorization.py b/src/idpyoidc/client/oidc/authorization.py index d14345e1..e1484fcc 100644 --- a/src/idpyoidc/client/oidc/authorization.py +++ b/src/idpyoidc/client/oidc/authorization.py @@ -3,7 +3,7 @@ from typing import Optional from typing import Union -from idpyoidc.client import work_environment +from idpyoidc import work_environment from idpyoidc.client.oauth2 import authorization from idpyoidc.client.oauth2.utils import pre_construct_pick_redirect_uri from idpyoidc.client.oidc import IDT2REG diff --git a/src/idpyoidc/client/oidc/userinfo.py b/src/idpyoidc/client/oidc/userinfo.py index 034a71af..5c54b806 100644 --- a/src/idpyoidc/client/oidc/userinfo.py +++ b/src/idpyoidc/client/oidc/userinfo.py @@ -5,9 +5,9 @@ from idpyoidc import verified_claim_name from idpyoidc.client.oauth2.utils import get_state_parameter from idpyoidc.client.service import Service -from idpyoidc.client.work_environment import get_encryption_algs -from idpyoidc.client.work_environment import get_encryption_encs -from idpyoidc.client.work_environment import get_signing_algs +from idpyoidc.work_environment import get_encryption_algs +from idpyoidc.work_environment import get_encryption_encs +from idpyoidc.work_environment import get_signing_algs from idpyoidc.exception import MissingSigningKey from idpyoidc.message import Message from idpyoidc.message import oidc diff --git a/src/idpyoidc/client/provider/github.py b/src/idpyoidc/client/provider/github.py index ed4ef970..dc50c906 100644 --- a/src/idpyoidc/client/provider/github.py +++ b/src/idpyoidc/client/provider/github.py @@ -1,12 +1,12 @@ from idpyoidc.client.oauth2 import access_token from idpyoidc.client.oidc import userinfo from idpyoidc.client.work_environment import get_client_authn_methods -from idpyoidc.client.work_environment import get_signing_algs +from idpyoidc.message import Message from idpyoidc.message import SINGLE_OPTIONAL_STRING from idpyoidc.message import SINGLE_REQUIRED_STRING -from idpyoidc.message import Message from idpyoidc.message import oauth2 from idpyoidc.message.oauth2 import ResponseMessage +from idpyoidc.work_environment import get_signing_algs class AccessTokenResponse(Message): diff --git a/src/idpyoidc/client/provider/linkedin.py b/src/idpyoidc/client/provider/linkedin.py index a9ad7931..916d6f35 100644 --- a/src/idpyoidc/client/provider/linkedin.py +++ b/src/idpyoidc/client/provider/linkedin.py @@ -1,13 +1,13 @@ from idpyoidc.client.oauth2 import access_token from idpyoidc.client.oidc import userinfo from idpyoidc.client.work_environment import get_client_authn_methods -from idpyoidc.client.work_environment import get_signing_algs +from idpyoidc.message import Message from idpyoidc.message import SINGLE_OPTIONAL_JSON from idpyoidc.message import SINGLE_OPTIONAL_STRING from idpyoidc.message import SINGLE_REQUIRED_INT from idpyoidc.message import SINGLE_REQUIRED_STRING -from idpyoidc.message import Message from idpyoidc.message import oauth2 +from idpyoidc.work_environment import get_signing_algs class AccessTokenResponse(Message): diff --git a/src/idpyoidc/client/service_context.py b/src/idpyoidc/client/service_context.py index 2fb1dbfd..ae8c9573 100644 --- a/src/idpyoidc/client/service_context.py +++ b/src/idpyoidc/client/service_context.py @@ -19,11 +19,11 @@ from idpyoidc.client.work_environment.oauth2 import WorkEnvironment as OAUTH2_Specs from idpyoidc.client.work_environment.oidc import WorkEnvironment as OIDC_Specs from idpyoidc.util import rndstr +from idpyoidc.work_environment import WorkEnvironment +from idpyoidc.work_environment import work_environment_dump +from idpyoidc.work_environment import work_environment_load from .configure import get_configuration from .current import Current -from .work_environment import WorkEnvironment -from .work_environment import work_environment_dump -from .work_environment import work_environment_load from .work_environment.transform import preferred_to_registered from .work_environment.transform import supported_to_preferred from ..impexp import ImpExp @@ -167,7 +167,7 @@ def __init__(self, setattr(self, key, val) self.keyjar = self.work_environment.load_conf(config.conf, supports=self.supports(), - keyjar=keyjar) + keyjar=keyjar) _response_types = self.get_preference( 'response_types_supported', @@ -343,9 +343,9 @@ def prefer_or_support(self, claim): def map_supported_to_preferred(self, info: Optional[dict] = None): self.work_environment.prefer = supported_to_preferred(self.supports(), - self.work_environment.prefer, - base_url=self.base_url, - info=info) + self.work_environment.prefer, + base_url=self.base_url, + info=info) return self.work_environment.prefer def map_preferred_to_registered(self, registration_response: Optional[dict] = None): diff --git a/src/idpyoidc/client/state_interface.py b/src/idpyoidc/client/state_interface.py deleted file mode 100644 index 94818d77..00000000 --- a/src/idpyoidc/client/state_interface.py +++ /dev/null @@ -1,383 +0,0 @@ -"""A database interface for storing state information.""" -import json - -from idpyoidc.impexp import ImpExp -from idpyoidc.message import SINGLE_OPTIONAL_JSON -from idpyoidc.message import SINGLE_REQUIRED_STRING -from idpyoidc.message import Message -from idpyoidc.message.oidc import verified_claim_name -from idpyoidc.util import rndstr - - -class State(Message): - """A structure to keep information about previous events.""" - - c_param = { - "iss": SINGLE_REQUIRED_STRING, - "auth_request": SINGLE_OPTIONAL_JSON, - "auth_response": SINGLE_OPTIONAL_JSON, - "token_response": SINGLE_OPTIONAL_JSON, - "refresh_token_request": SINGLE_OPTIONAL_JSON, - "refresh_token_response": SINGLE_OPTIONAL_JSON, - "user_info": SINGLE_OPTIONAL_JSON, - } - - -KEY_PATTERN = { - "nonce": "__{}__", - "logout state": "::{}::", - "session id": "..{}..", - "subject id": "=={}==", -} - - -class InMemoryStateDataBase: - """The simplest possible implementation of the state database.""" - - def __init__(self): - self._db = {} - - def set(self, key, value): - """Assign a value to a key.""" - self._db[key] = value - - def get(self, key): - """Return the value bound to a key.""" - try: - return self._db[key] - except KeyError: - return None - - def delete(self, key): - """Delete a key and its value.""" - try: - del self._db[key] - except KeyError: - pass - - def __setitem__(self, key, value): - """Assign a value to a key.""" - self._db[key] = value - - def __getitem__(self, key): - """Return the value bound to a key.""" - try: - return self._db[key] - except KeyError: - return None - - def __delitem__(self, key): - """Delete a key and its value.""" - try: - del self._db[key] - except KeyError: - pass - - -class StateInterface(ImpExp): - """A more powerful interface to a state DB.""" - - parameter = {"_db": None} - - def __init__(self): - ImpExp.__init__(self) - self._db = {} - - def get_state(self, key): - """ - Get the state connected to a given key. - - :param key: Key into the state database - :return: A :py:class:´idpyoidc.client.current.Current` instance - """ - _data = self._db.get(key) - if not _data: - raise KeyError(key) - - return State().from_json(_data) - - def store_item(self, item, item_type, key): - """ - Store a service response. - - :param item: The item as a :py:class:`idpyoidc.message.Message` - subclass instance or a JSON document. - :param item_type: The type of request or response - :param key: The key under which the information should be stored in - the state database - """ - try: - _state = self.get_state(key) - except KeyError: - _state = State() - - try: - _state[item_type] = item.to_json() - except AttributeError: - _state[item_type] = item - - self._db[key] = _state.to_json() - - def get_iss(self, key): - """ - Get the Issuer ID - - :param key: Key to the information in the state database - :return: The issuer ID - """ - _state = self.get_state(key) - if not _state: - raise KeyError(key) - return _state["iss"] - - def get_item(self, item_cls, item_type, key): - """ - Get a piece of information (a request or a response) from the state - database. - - :param item_cls: The :py:class:`idpyoidc.message.Message` subclass - that described the item. - :param item_type: Which request/response that is wanted - :param key: The key to the information in the state database - :return: A :py:class:`idpyoidc.message.Message` instance - """ - _state = self.get_state(key) - try: - return item_cls(**_state[item_type]) - except TypeError: - return item_cls().from_json(_state[item_type]) - - def extend_request_args(self, args, item_cls, item_type, key, parameters, orig=False): - """ - Add a set of parameters and their value to a set of request arguments. - - :param args: A dictionary - :param item_cls: The :py:class:`idpyoidc.message.Message` subclass - that describes the item - :param item_type: The type of item, this is one of the parameter - names in the :py:class:`idpyoidc.client.current.Current` class. - :param key: The key to the information in the database - :param parameters: A list of parameters who's values this method - will return. - :param orig: Where the value of a claim is a signed JWT return - that. - :return: A dictionary with keys from the list of parameters and - values being the values of those parameters in the item. - If the parameter does not a appear in the item it will not appear - in the returned dictionary. - """ - try: - item = self.get_item(item_cls, item_type, key) - except KeyError: - pass - else: - for parameter in parameters: - if orig: - try: - args[parameter] = item[parameter] - except KeyError: - pass - else: - try: - args[parameter] = item[verified_claim_name(parameter)] - except KeyError: - try: - args[parameter] = item[parameter] - except KeyError: - pass - - return args - - def multiple_extend_request_args(self, args, key, parameters, item_types, orig=False): - """ - Go through a set of items (by their type) and add the attribute-value - that match the list of parameters to the arguments - If the same parameter occurs in 2 different items then the value in - the later one will be the one used. - - :param args: Initial set of arguments - :param key: Key to the State information in the state database - :param parameters: A list of parameters that we're looking for - :param item_types: A list of item_type specifying which items we - are interested in. - :param orig: Where the value of a claim is a signed JWT return - that. - :return: A possibly augmented set of arguments. - """ - _state = self.get_state(key) - - for typ in item_types: - try: - _item = Message(**_state[typ]) - except KeyError: - continue - - for parameter in parameters: - if orig: - try: - args[parameter] = _item[parameter] - except KeyError: - pass - else: - try: - args[parameter] = _item[verified_claim_name(parameter)] - except KeyError: - try: - args[parameter] = _item[parameter] - except KeyError: - pass - - return args - - def store_x2state(self, value, state, xtyp): - """ - Store the connection between some value (x) and a state value. - This allows us later in the game to find the state if we have x. - - :param value: The value - :param state: The state value - :param xtyp: The type of value x is (e.g. nonce, ...) - """ - self._db[KEY_PATTERN[xtyp].format(value)] = state - try: - _val = self._db.get("ref{}ref".format(state)) - except KeyError: - _val = None - - if _val is None: - refs = {xtyp: value} - else: - refs = json.loads(_val) - refs[xtyp] = value - self._db["ref{}ref".format(state)] = json.dumps(refs) - - def get_state_by_x(self, value, xtyp): - """ - Find the state value by providing the x value. - Will raise an exception if the x value is absent from the state - data base. - - :param value: The value - :return: The state value - """ - _state = self._db.get(KEY_PATTERN[xtyp].format(value)) - if _state: - return _state - - raise KeyError('Unknown {}: "{}"'.format(xtyp, value)) - - def store_nonce2state(self, nonce, state): - """ - Store the connection between a nonce value and a state value. - This allows us later in the game to find the state if we have the nonce. - - :param nonce: The nonce value - :param state: The state value - """ - self.store_x2state(nonce, state, "nonce") - - def get_state_by_nonce(self, nonce): - """ - Find the state value by providing the nonce value. - Will raise an exception if the nonce value is absent from the state - data base. - - :param nonce: The nonce value - :return: The state value - """ - return self.get_state_by_x(nonce, "nonce") - - def store_logout_state2state(self, logout_state, state): - """ - Store the connection between a logout state value and a state value. - This allows us later in the game to find the state if we have the - logout state value. - - :param logout_state: The logout state value - :param state: The state value - """ - self.store_x2state(logout_state, state, "logout state") - - def get_state_by_logout_state(self, logout_state): - """ - Find the state value by providing the logout state value. - Will raise an exception if the logout state value is absent from the - state database. - - :param logout_state: The logout state value - :return: The state value - """ - return self.get_state_by_x(logout_state, "logout state") - - def store_sid2state(self, sid, state): - """ - Store the connection between a session id (sid) value and a state value. - This allows us later in the game to find the state if we have the - sid value. - - :param sid: The session ID value - :param state: The state value - """ - self.store_x2state(sid, state, "session id") - - def get_state_by_sid(self, sid): - """ - Find the state value by providing the logout state value. - Will raise an exception if the logout state value is absent from the - state database. - - :param sid: The session ID value - :return: The state value - """ - return self.get_state_by_x(sid, "session id") - - def store_sub2state(self, sub, state): - """ - Store the connection between a subject id (sub) value and a state value. - This allows us later in the game to find the state if we have the - sub value. - - :param sub: The Subject ID value - :param state: The state value - """ - self.store_x2state(sub, state, "subject id") - - def get_state_by_sub(self, sub): - """ - Find the state value by providing the subject id value. - Will raise an exception if the subject id value is absent from the - state database. - - :param sub: The Subject ID value - :return: The state value - """ - return self.get_state_by_x(sub, "subject id") - - def create_state(self, iss, key=""): - """ - Create a State and assign some value to it. - - :param iss: The issuer - :param key: A key to use to access the state - """ - if not key: - key = rndstr(32) - else: - if key.startswith("__") and key.endswith("__"): - raise ValueError('Invalid format. Leading and trailing "__" not allowed') - - _state = State(iss=iss) - self._db[key] = _state.to_json() - return key - - def remove_state(self, state): - """ - Remove a state. - - :param state: Key to the state - """ - del self._db[state] - refs = json.loads(self._db.get("ref{}ref".format(state))) - if refs: - for xtyp, _val in refs.items(): - del self._db[KEY_PATTERN[xtyp].format(_val)] diff --git a/src/idpyoidc/client/work_environment/__init__.py b/src/idpyoidc/client/work_environment/__init__.py index 8c31f144..02d022ae 100644 --- a/src/idpyoidc/client/work_environment/__init__.py +++ b/src/idpyoidc/client/work_environment/__init__.py @@ -1,249 +1,4 @@ -from functools import cmp_to_key -from typing import Callable -from typing import Optional - -from cryptojwt import KeyJar -from cryptojwt.exception import IssuerNotFound -from cryptojwt.jwe import SUPPORTED -from cryptojwt.jwk.hmac import SYMKey -from cryptojwt.jws.jws import SIGNER_ALGS -from cryptojwt.key_jar import init_key_jar -from cryptojwt.utils import importer - from idpyoidc.client.client_auth import CLIENT_AUTHN_METHOD -from idpyoidc.client.util import get_uri -from idpyoidc.impexp import ImpExp -from idpyoidc.util import qualified_name - - -def work_environment_dump(info, exclude_attributes): - return {qualified_name(info.__class__): info.dump(exclude_attributes=exclude_attributes)} - - -def work_environment_load(item: dict, **kwargs): - _class_name = list(item.keys())[0] # there is only one - _cls = importer(_class_name) - _cls = _cls().load(item[_class_name]) - return _cls - - -class WorkEnvironment(ImpExp): - parameter = { - "prefer": None, - "use": None, - "callback_path": None, - "_local": None - } - - _supports = {} - - def __init__(self, - prefer: Optional[dict] = None, - callback_path: Optional[dict] = None): - - ImpExp.__init__(self) - if isinstance(prefer, dict): - self.prefer = {k: v for k, v in prefer.items() if k in self.supports} - else: - self.prefer = {} - - self.callback_path = callback_path or {} - self.use = {} - self._local = {} - - def get_use(self): - return self.use - - def set_usage(self, key, value): - self.use[key] = value - - def get_usage(self, key, default=None): - return self.use.get(key, default) - - def get_preference(self, key, default=None): - return self.prefer.get(key, default) - - def set_preference(self, key, value): - self.prefer[key] = value - - def remove_preference(self, key): - if key in self.prefer: - del self.prefer[key] - - def _callback_uris(self, base_url, hex): - _uri = [] - for type in self.get_usage("response_types", self._supports['response_types']): - if "code" in type: - _uri.append('code') - elif type in ["id_token", "id_token token"]: - _uri.append('implicit') - - if "form_post" in self.supports: - _uri.append("form_post") - - callback_uri = {} - for key in _uri: - callback_uri[key] = get_uri(base_url, self.callback_path[key], hex) - return callback_uri - - def construct_redirect_uris(self, - base_url: str, - hex: str, - callbacks: Optional[dict] = None): - if not callbacks: - callbacks = self._callback_uris(base_url, hex) - - if callbacks: - self.set_preference('callbacks', callbacks) - self.set_preference("redirect_uris", [v for k, v in callbacks.items()]) - - self.callback = callbacks - - def verify_rules(self): - return True - - def locals(self, info): - pass - - def _keyjar(self, keyjar=None, conf=None, entity_id=""): - _uri_path = '' - if keyjar is None: - if "keys" in conf: - keys_args = {k: v for k, v in conf["keys"].items() if k != "uri_path"} - _keyjar = init_key_jar(**keys_args) - _uri_path = conf['keys'].get('uri_path') - elif "key_conf" in conf and conf["key_conf"]: - keys_args = {k: v for k, v in conf["key_conf"].items() if k != "uri_path"} - _keyjar = init_key_jar(**keys_args) - _uri_path = conf['key_conf'].get('uri_path') - else: - _keyjar = KeyJar() - if "jwks" in conf: - _keyjar.import_jwks(conf["jwks"], "") - - if "" in _keyjar and entity_id: - # make sure I have the keys under my own name too (if I know it) - _keyjar.import_jwks_as_json(_keyjar.export_jwks_as_json(True, ""), entity_id) - - _httpc_params = conf.get("httpc_params") - if _httpc_params: - _keyjar.httpc_params = _httpc_params - - return _keyjar, _uri_path - else: - return keyjar, _uri_path - - def handle_keys(self, configuration: dict, keyjar: Optional[KeyJar] = None): - _jwks = _jwks_uri = None - _id = self.get_preference('client_id') - keyjar, uri_path = self._keyjar(keyjar, configuration, entity_id=_id) - - _secret = self.get_preference('client_secret') - if _secret: - keyjar.add_symmetric(issuer_id=_id, key=_secret) - keyjar.add_symmetric(issuer_id='', key=_secret) - - # now that keys are in the Key Jar, now for how to publish it - if 'jwks_uri' in configuration: # simple - _jwks_uri = configuration.get('jwks_uri') - elif uri_path: - _jwks_uri = f"{configuration.get('base_url')}{uri_path}" - else: # jwks or nothing - # if only the client secret, no need to publish as a JWKS - try: - _own_keys = keyjar.get_issuer_keys('') - except IssuerNotFound: - pass - else: - if len(_own_keys) == 1 and isinstance(_own_keys[0], SYMKey): - pass - else: - _jwks = keyjar.export_jwks() - - return {'keyjar': keyjar, 'jwks': _jwks, 'jwks_uri': _jwks_uri} - - def load_conf(self, configuration, supports, keyjar: Optional[KeyJar] = None): - for attr, val in configuration.items(): - if attr == "preference": - for k, v in val.items(): - if k in supports: - self.set_preference(k, v) - elif attr in supports: - self.set_preference(attr, val) - - self.locals(configuration) - - for key, val in self.handle_keys(configuration, keyjar=keyjar).items(): - if key == 'keyjar': - keyjar = val - elif val: - self.set_preference(key, val) - - self.verify_rules() - return keyjar - - def get(self, key, default=None): - if key in self._local: - return self._local[key] - else: - return default - - def set(self, key, val): - self._local[key] = val - - def construct_uris(self, *args): - pass - - def supports(self): - res = {} - for key, val in self._supports.items(): - if isinstance(val, Callable): - res[key] = val() - else: - res[key] = val - return res - - def supported(self, claim): - return claim in self._supports - - def prefers(self): - return self.prefer - - -SIGNING_ALGORITHM_SORT_ORDER = ['RS', 'ES', 'PS', 'HS'] - - -def cmp(a, b): - return (a > b) - (a < b) - - -def alg_cmp(a, b): - if a == 'none': - return 1 - elif b == 'none': - return -1 - - _pos1 = SIGNING_ALGORITHM_SORT_ORDER.index(a[0:2]) - _pos2 = SIGNING_ALGORITHM_SORT_ORDER.index(b[0:2]) - if _pos1 == _pos2: - return (a > b) - (a < b) - elif _pos1 > _pos2: - return 1 - else: - return -1 - - -def get_signing_algs(): - # Assumes Cryptojwt - return sorted(list(SIGNER_ALGS.keys()), key=cmp_to_key(alg_cmp)) - - -def get_encryption_algs(): - return SUPPORTED['alg'] - - -def get_encryption_encs(): - return SUPPORTED['enc'] def get_client_authn_methods(): diff --git a/src/idpyoidc/client/work_environment/oauth2.py b/src/idpyoidc/client/work_environment/oauth2.py index 40293cf2..4b1bb605 100644 --- a/src/idpyoidc/client/work_environment/oauth2.py +++ b/src/idpyoidc/client/work_environment/oauth2.py @@ -1,6 +1,6 @@ from typing import Optional -from idpyoidc.client import work_environment +from idpyoidc import work_environment class WorkEnvironment(work_environment.WorkEnvironment): diff --git a/src/idpyoidc/client/work_environment/oidc.py b/src/idpyoidc/client/work_environment/oidc.py index e5adbcb1..3843523f 100644 --- a/src/idpyoidc/client/work_environment/oidc.py +++ b/src/idpyoidc/client/work_environment/oidc.py @@ -1,7 +1,7 @@ import os from typing import Optional -from idpyoidc.client import work_environment +from idpyoidc import work_environment class WorkEnvironment(work_environment.WorkEnvironment): diff --git a/src/idpyoidc/server/__init__.py b/src/idpyoidc/server/__init__.py index 1e52fe47..ccf09ac5 100644 --- a/src/idpyoidc/server/__init__.py +++ b/src/idpyoidc/server/__init__.py @@ -45,6 +45,10 @@ def __init__( ): ImpExp.__init__(self) self.conf = conf + + self.endpoint = do_endpoints(conf, self.server_get) + + # endpoint context MUST be done after do_endpoints !! self.endpoint_context = EndpointContext( conf=conf, server_get=self.server_get, @@ -54,13 +58,11 @@ def __init__( httpc=httpc, ) self.endpoint_context.authz = self.setup_authz() + # _cap = get_provider_capabilities(conf, self.endpoint) + # self.endpoint_context.provider_info = self.endpoint_context.create_providerinfo(_cap) self.setup_authentication(self.endpoint_context) - self.endpoint = do_endpoints(conf, self.server_get) - _cap = get_provider_capabilities(conf, self.endpoint) - - self.endpoint_context.provider_info = self.endpoint_context.create_providerinfo(_cap) self.endpoint_context.do_add_on(endpoints=self.endpoint) self.endpoint_context.session_manager = create_session_manager( diff --git a/src/idpyoidc/server/client_authn.py b/src/idpyoidc/server/client_authn.py index 65caac50..c11e1a8e 100755 --- a/src/idpyoidc/server/client_authn.py +++ b/src/idpyoidc/server/client_authn.py @@ -5,6 +5,12 @@ from typing import Optional from typing import Union +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from idpyoidc.server.endpoint_context import EndpointContext + + from cryptojwt.exception import BadSignature from cryptojwt.exception import Invalid from cryptojwt.exception import MissingKey @@ -17,7 +23,6 @@ from idpyoidc.message.oidc import JsonWebToken from idpyoidc.message.oidc import verified_claim_name from idpyoidc.server.constant import JWT_BEARER -from idpyoidc.server.endpoint_context import EndpointContext from idpyoidc.server.exception import BearerTokenAuthenticationError from idpyoidc.server.exception import ClientAuthenticationError from idpyoidc.server.exception import InvalidClient @@ -44,7 +49,7 @@ def __init__(self, server_get): def _verify( self, - endpoint_context: EndpointContext, + endpoint_context: "EndpointContext", request: Optional[Union[dict, Message]] = None, authorization_token: Optional[str] = None, endpoint=None, # Optional[Endpoint] @@ -122,7 +127,7 @@ def is_usable(self, request=None, authorization_token=None): def _verify( self, - endpoint_context: EndpointContext, + endpoint_context: "EndpointContext", request: Optional[Union[dict, Message]] = None, authorization_token: Optional[str] = None, endpoint=None, # Optional[Endpoint] @@ -144,7 +149,7 @@ def is_usable(self, request=None, authorization_token=None): def _verify( self, - endpoint_context: EndpointContext, + endpoint_context: "EndpointContext", request: Optional[Union[dict, Message]] = None, authorization_token: Optional[str] = None, endpoint=None, # Optional[Endpoint] @@ -169,7 +174,7 @@ def is_usable(self, request=None, authorization_token=None): def _verify( self, - endpoint_context: EndpointContext, + endpoint_context: "EndpointContext", request: Optional[Union[dict, Message]] = None, authorization_token: Optional[str] = None, endpoint=None, # Optional[Endpoint] @@ -202,7 +207,7 @@ def is_usable(self, request=None, authorization_token=None): def _verify( self, - endpoint_context: EndpointContext, + endpoint_context: "EndpointContext", request: Optional[Union[dict, Message]] = None, authorization_token: Optional[str] = None, endpoint=None, # Optional[Endpoint] @@ -226,7 +231,7 @@ def is_usable(self, request=None, authorization_token=None): def _verify( self, - endpoint_context: EndpointContext, + endpoint_context: "EndpointContext", request: Optional[Union[dict, Message]] = None, authorization_token: Optional[str] = None, endpoint=None, # Optional[Endpoint] @@ -257,7 +262,7 @@ def is_usable(self, request=None, authorization_token=None): def _verify( self, - endpoint_context: EndpointContext, + endpoint_context: "EndpointContext", request: Optional[Union[dict, Message]] = None, authorization_token: Optional[str] = None, endpoint=None, # Optional[Endpoint] @@ -285,7 +290,7 @@ def is_usable(self, request=None, authorization_token=None): def _verify( self, - endpoint_context: EndpointContext, + endpoint_context: "EndpointContext", request: Optional[Union[dict, Message]] = None, authorization_token: Optional[str] = None, endpoint=None, # Optional[Endpoint] @@ -354,7 +359,7 @@ class ClientSecretJWT(JWSAuthnMethod): def _verify( self, - endpoint_context: EndpointContext, + endpoint_context: "EndpointContext", request: Optional[Union[dict, Message]] = None, authorization_token: Optional[str] = None, endpoint=None, # Optional[Endpoint] @@ -376,7 +381,7 @@ class PrivateKeyJWT(JWSAuthnMethod): def _verify( self, - endpoint_context: EndpointContext, + endpoint_context: "EndpointContext", request: Optional[Union[dict, Message]] = None, authorization_token: Optional[str] = None, endpoint=None, # Optional[Endpoint] @@ -403,7 +408,7 @@ def is_usable(self, request=None, authorization_token=None): def _verify( self, - endpoint_context: EndpointContext, + endpoint_context: "EndpointContext", request: Optional[Union[dict, Message]] = None, authorization_token: Optional[str] = None, endpoint=None, # Optional[Endpoint] @@ -454,7 +459,7 @@ def valid_client_info(cinfo): def verify_client( - endpoint_context: EndpointContext, + endpoint_context: "EndpointContext", request: Union[dict, Message], http_info: Optional[dict] = None, get_client_id_from_token: Optional[Callable] = None, diff --git a/src/idpyoidc/server/endpoint.py b/src/idpyoidc/server/endpoint.py index 2167285f..4e05f52e 100755 --- a/src/idpyoidc/server/endpoint.py +++ b/src/idpyoidc/server/endpoint.py @@ -89,9 +89,10 @@ class Endpoint(object): response_placement = "body" client_authn_method = "" default_capabilities = None - provider_info_attributes = None auth_method_attribute = "" + _supports = {} + def __init__(self, server_get: Callable, **kwargs): self.server_get = server_get self.pre_construct = [] @@ -447,3 +448,12 @@ def allowed_target_uris(self): else: res.append(self.server_get("endpoint", t).full_path) return set(res) + + def supports(self): + res = {} + for key, val in self._supports.items(): + if isinstance(val, Callable): + res[key] = val() + else: + res[key] = val + return res diff --git a/src/idpyoidc/server/endpoint_context.py b/src/idpyoidc/server/endpoint_context.py index 2316baae..7755ad9b 100755 --- a/src/idpyoidc/server/endpoint_context.py +++ b/src/idpyoidc/server/endpoint_context.py @@ -11,12 +11,15 @@ import requests from idpyoidc.context import OidcContext +from idpyoidc.message.oidc import ProviderConfigurationResponse from idpyoidc.server.configure import OPConfiguration from idpyoidc.server.scopes import SCOPE2CLAIMS from idpyoidc.server.scopes import Scopes from idpyoidc.server.session.manager import SessionManager from idpyoidc.server.template_handler import Jinja2TemplateHandler from idpyoidc.server.util import get_http_params +from idpyoidc.server.work_environment.oauth2 import WorkEnvironment as OAUTH2_Env +from idpyoidc.server.work_environment.oidc import WorkEnvironment as OIDC_Env from idpyoidc.util import importer from idpyoidc.util import rndstr @@ -119,11 +122,19 @@ def __init__( cwd: Optional[str] = "", cookie_handler: Optional[Any] = None, httpc: Optional[Any] = None, + server_type: Optional[str] = '' ): - OidcContext.__init__(self, conf, keyjar, entity_id=conf.get("issuer", "")) + OidcContext.__init__(self, conf, entity_id=conf.get("issuer", "")) self.conf = conf self.server_get = server_get + if not server_type or server_type == "oidc": + self.work_environment = OIDC_Env() + elif server_type == "oauth2": + self.work_environment = OAUTH2_Env() + else: + raise ValueError(f"Unknown server type: {server_type}") + _client_db = conf.get("client_db") if _client_db: logger.debug(f"Loading client db using: {_client_db}") @@ -151,7 +162,7 @@ def __init__( self.httpc = httpc or requests self.idtoken = None self.issuer = "" - self.jwks_uri = None + # self.jwks_uri = None self.login_hint_lookup = None self.login_hint2acrs = None self.par_db = {} @@ -198,15 +209,15 @@ def __init__( if _loader: self.template_handler = Jinja2TemplateHandler(_loader) - # self.setup = {} - _keys_conf = conf.get("key_conf") - if _keys_conf: - jwks_uri_path = _keys_conf["uri_path"] - - if self.issuer.endswith("/"): - self.jwks_uri = "{}{}".format(self.issuer, jwks_uri_path) - else: - self.jwks_uri = "{}/{}".format(self.issuer, jwks_uri_path) + # # self.setup = {} + # _keys_conf = conf.get("key_conf") + # if _keys_conf: + # jwks_uri_path = _keys_conf["uri_path"] + # + # if self.issuer.endswith("/"): + # self.jwks_uri = "{}{}".format(self.issuer, jwks_uri_path) + # else: + # self.jwks_uri = "{}/{}".format(self.issuer, jwks_uri_path) for item in [ "cookie_handler", @@ -236,6 +247,9 @@ def __init__( self.dev_auth_db = None self.claims_interface = init_service(conf["claims_interface"], self.server_get) + self.keyjar = self.work_environment.load_conf(conf.conf, supports=self.supports(), + keyjar=keyjar) + def new_cookie(self, name: str, max_age: Optional[int] = 0, **kwargs): cookie_cont = self.cookie_handler.make_cookie_content( name=name, value=json.dumps(kwargs), max_age=max_age @@ -365,3 +379,25 @@ def do_login_hint_lookup(self): self.login_hint_lookup = init_service(_conf) self.login_hint_lookup.userinfo = _userinfo + + def supports(self): + res = {} + if self.server_get: + for endpoint in self.server_get('endpoints').values(): + res.update(endpoint.supports()) + res.update(self.work_environment.supports()) + return res + + def set_provider_info(self): + prefers = self.work_environment.prefer + supported = self.supports() + _info = {} + for key, spec in ProviderConfigurationResponse.c_param.items(): + _val = prefers.get(key, None) + if _val is None: + _val = supported.get(key, None) + if _val is None: + continue + _info[key] = _val + + self.provider_info = _info diff --git a/src/idpyoidc/server/oidc/authorization.py b/src/idpyoidc/server/oidc/authorization.py index ef77ace9..ae302bb9 100755 --- a/src/idpyoidc/server/oidc/authorization.py +++ b/src/idpyoidc/server/oidc/authorization.py @@ -2,6 +2,7 @@ from typing import Callable from urllib.parse import urlsplit +from idpyoidc import work_environment from idpyoidc.message import oidc from idpyoidc.message.oidc import Claims from idpyoidc.message.oidc import verified_claim_name @@ -74,32 +75,19 @@ class Authorization(authorization.Authorization): response_placement = "url" endpoint_name = "authorization_endpoint" name = "authorization" - provider_info_attributes = { + _supports = { "claims_parameter_supported": True, - "client_authn_method": ["request_param", "public"], + "encrypt_request_object_supported": None, + "request_object_signing_alg_values_supported": work_environment.get_signing_algs, + "request_object_encryption_alg_values_supported": work_environment.get_encryption_algs, + "request_object_encryption_enc_values_supported": work_environment.get_encryption_encs, "request_parameter_supported": True, "request_uri_parameter_supported": True, - "response_types_supported": [ - "code", - "token", - "id_token", - "code token", - "code id_token", - "id_token token", - "code id_token token", - ], - "response_modes_supported": ["query", "fragment", "form_post"], - "request_object_signing_alg_values_supported": None, - "request_object_encryption_alg_values_supported": None, - "request_object_encryption_enc_values_supported": None, - "grant_types_supported": ["authorization_code", "implicit"], - "claim_types_supported": ["normal", "aggregated", "distributed"], - } - metadata_claims = { - - } - default_capabilities = { - "client_authn_method": ["request_param", "public"], + "require_request_uri_registration": False, + "response_types_supported": ["code", "token", "code token", 'id_token', 'id_token token', + 'code id_token', 'code idtoken token'], + "response_modes_supported": ['query', 'fragment', 'form_post'], + "subject_types_supported": ["public", "pairwise", "ephemeral"], } def __init__(self, server_get: Callable, **kwargs): diff --git a/src/idpyoidc/server/oidc/backchannel_authentication.py b/src/idpyoidc/server/oidc/backchannel_authentication.py index eca81a6f..aaf44ce1 100644 --- a/src/idpyoidc/server/oidc/backchannel_authentication.py +++ b/src/idpyoidc/server/oidc/backchannel_authentication.py @@ -36,7 +36,8 @@ class BackChannelAuthentication(Endpoint): response_placement = "url" endpoint_name = "backchannel_authentication_endpoint" name = "backchannel_authentication" - provider_info_attributes = { + + _supports = { "backchannel_token_delivery_modes_supported": ["poll", "ping", "push"], "backchannel_authentication_request_signing_alg_values_supported": None, "backchannel_user_code_parameter_supported": True, diff --git a/src/idpyoidc/server/oidc/provider_config.py b/src/idpyoidc/server/oidc/provider_config.py index 507fbab9..38f61584 100755 --- a/src/idpyoidc/server/oidc/provider_config.py +++ b/src/idpyoidc/server/oidc/provider_config.py @@ -12,7 +12,7 @@ class ProviderConfiguration(Endpoint): request_format = "" response_format = "json" name = "provider_config" - provider_info_attributes = {"require_request_uri_registration": None} + # _supports = {"require_request_uri_registration": None} def __init__(self, server_get, **kwargs): Endpoint.__init__(self, server_get=server_get, **kwargs) diff --git a/src/idpyoidc/server/oidc/session.py b/src/idpyoidc/server/oidc/session.py index 716b8277..841419a8 100644 --- a/src/idpyoidc/server/oidc/session.py +++ b/src/idpyoidc/server/oidc/session.py @@ -80,7 +80,8 @@ class Session(Endpoint): response_placement = "url" endpoint_name = "end_session_endpoint" name = "session" - provider_info_attributes = { + + _supports = { "frontchannel_logout_supported": True, "frontchannel_logout_session_required": True, "backchannel_logout_supported": True, diff --git a/src/idpyoidc/server/oidc/token.py b/src/idpyoidc/server/oidc/token.py index 523c6b75..19491baa 100755 --- a/src/idpyoidc/server/oidc/token.py +++ b/src/idpyoidc/server/oidc/token.py @@ -23,7 +23,8 @@ class Token(token.Token): endpoint_name = "token_endpoint" name = "token" default_capabilities = None - provider_info_attributes = { + + _supports = { "token_endpoint_auth_methods_supported": [ "client_secret_post", "client_secret_basic", @@ -32,7 +33,9 @@ class Token(token.Token): ], "token_endpoint_auth_signing_alg_values_supported": None, } - auth_method_attribute = "token_endpoint_auth_methods_supported" + + # auth_method_attribute = "token_endpoint_auth_methods_supported" + helper_by_grant_type = { "authorization_code": AccessTokenHelper, "refresh_token": RefreshTokenHelper, diff --git a/src/idpyoidc/server/oidc/userinfo.py b/src/idpyoidc/server/oidc/userinfo.py index ae6e87b5..c965e3e4 100755 --- a/src/idpyoidc/server/oidc/userinfo.py +++ b/src/idpyoidc/server/oidc/userinfo.py @@ -9,6 +9,7 @@ from cryptojwt.jwt import JWT from cryptojwt.jwt import utc_time_sans_frac +from idpyoidc import work_environment from idpyoidc.message import Message from idpyoidc.message import oidc from idpyoidc.message.oauth2 import ResponseMessage @@ -27,14 +28,12 @@ class UserInfo(Endpoint): response_placement = "body" endpoint_name = "userinfo_endpoint" name = "userinfo" - provider_info_attributes = { + _supports = { "claim_types_supported": ["normal", "aggregated", "distributed"], - "userinfo_signing_alg_values_supported": None, - "userinfo_encryption_alg_values_supported": None, - "userinfo_encryption_enc_values_supported": None, - } - default_capabilities = { - "client_authn_method": ["bearer_header", "bearer_body"], + "encrypt_userinfo_supported": False, + "userinfo_signing_alg_values_supported": work_environment.get_signing_algs, + "userinfo_encryption_alg_values_supported": work_environment.get_encryption_algs, + "userinfo_encryption_enc_values_supported": work_environment.get_encryption_encs, } def __init__(self, server_get: Callable, add_claims_by_scope: Optional[bool] = True, **kwargs): diff --git a/src/idpyoidc/server/work_environment/__init__.py b/src/idpyoidc/server/work_environment/__init__.py new file mode 100644 index 00000000..249d369b --- /dev/null +++ b/src/idpyoidc/server/work_environment/__init__.py @@ -0,0 +1,8 @@ +from idpyoidc.message.oidc import ProviderConfigurationResponse +from idpyoidc.server.client_authn import CLIENT_AUTHN_METHOD + + +def get_client_authn_methods(): + return list(CLIENT_AUTHN_METHOD.keys()) + + diff --git a/src/idpyoidc/server/work_environment/oauth2.py b/src/idpyoidc/server/work_environment/oauth2.py new file mode 100644 index 00000000..4b1bb605 --- /dev/null +++ b/src/idpyoidc/server/work_environment/oauth2.py @@ -0,0 +1,34 @@ +from typing import Optional + +from idpyoidc import work_environment + + +class WorkEnvironment(work_environment.WorkEnvironment): + _supports = { + "redirect_uris": None, + "grant_types": ["authorization_code", "implicit", "refresh_token"], + "response_types": ["code"], + "client_id": None, + 'client_secret': None, + "client_name": None, + "client_uri": None, + "logo_uri": None, + "contacts": None, + "scopes_supported": [], + "tos_uri": None, + "policy_uri": None, + "jwks_uri": None, + "jwks": None, + "software_id": None, + "software_version": None + } + + + callback_path = {} + + callback_uris = ["redirect_uris"] + + def __init__(self, + prefer: Optional[dict] = None, + callback_path: Optional[dict] = None): + work_environment.WorkEnvironment.__init__(self, prefer=prefer, callback_path=callback_path) diff --git a/src/idpyoidc/server/work_environment/oidc.py b/src/idpyoidc/server/work_environment/oidc.py new file mode 100644 index 00000000..1107affb --- /dev/null +++ b/src/idpyoidc/server/work_environment/oidc.py @@ -0,0 +1,67 @@ +import os +from typing import Optional + +from idpyoidc import work_environment + + +class WorkEnvironment(work_environment.WorkEnvironment): + parameter = work_environment.WorkEnvironment.parameter.copy() + + _supports = { + "acr_values_supported": None, + "claim_types_supported": None, + "claims_locales_supported": None, + "claims_supported": None, + "contacts": None, + "default_max_age": 86400, + "display_values_supported": None, + "encrypt_id_token_supported": None, + "grant_types_supported": ["authorization_code", "implicit", "refresh_token"], + "id_token_signing_alg_values_supported": work_environment.get_signing_algs, + "id_token_encryption_alg_values_supported": work_environment.get_encryption_algs, + "id_token_encryption_enc_values_supported": work_environment.get_encryption_encs, + "initiate_login_uri": None, + "jwks": None, + "jwks_uri": None, + "op_policy_uri": None, + "require_auth_time": None, + "scopes_supported": ["openid"], + "service_documentation": None, + "op_tos_uri": None, + "ui_locales_supported": None + # "verify_args": None, + } + + def __init__(self, + prefer: Optional[dict] = None, + callback_path: Optional[dict] = None + ): + work_environment.WorkEnvironment.__init__(self, prefer=prefer, callback_path=callback_path) + + def verify_rules(self): + if self.get_preference("request_parameter_supported") and self.get_preference( + "request_uri_parameter_supported"): + raise ValueError( + "You have to chose one of 'request_parameter_supported' and " + "'request_uri_parameter_supported'. You can't have both.") + + if not self.get_preference('encrypt_userinfo_supported'): + self.set_preference('userinfo_encryption_alg_values_supported', []) + self.set_preference('userinfo_encryption_enc_values_supported', []) + + if not self.get_preference('encrypt_request_object_supported'): + self.set_preference('request_object_encryption_alg_values_supported', []) + self.set_preference('request_object_encryption_enc_values_supported', []) + + if not self.get_preference('encrypt_id_token_supported'): + self.set_preference('id_token_encryption_alg_values_supported', []) + self.set_preference('id_token_encryption_enc_values_supported', []) + + def locals(self, info): + requests_dir = info.get("requests_dir") + if requests_dir: + # make sure the path exists. If not, then create it. + if not os.path.isdir(requests_dir): + os.makedirs(requests_dir) + + self.set("requests_dir", requests_dir) diff --git a/src/idpyoidc/work_environment.py b/src/idpyoidc/work_environment.py new file mode 100644 index 00000000..79e8aee6 --- /dev/null +++ b/src/idpyoidc/work_environment.py @@ -0,0 +1,248 @@ +from functools import cmp_to_key +from typing import Callable +from typing import Optional + +from cryptojwt import KeyJar +from cryptojwt.exception import IssuerNotFound +from cryptojwt.jwe import SUPPORTED +from cryptojwt.jwk.hmac import SYMKey +from cryptojwt.jws.jws import SIGNER_ALGS +from cryptojwt.key_jar import init_key_jar +from cryptojwt.utils import importer + +from idpyoidc.client.client_auth import CLIENT_AUTHN_METHOD +from idpyoidc.client.util import get_uri +from idpyoidc.impexp import ImpExp +from idpyoidc.util import qualified_name + + +def work_environment_dump(info, exclude_attributes): + return {qualified_name(info.__class__): info.dump(exclude_attributes=exclude_attributes)} + + +def work_environment_load(item: dict, **kwargs): + _class_name = list(item.keys())[0] # there is only one + _cls = importer(_class_name) + _cls = _cls().load(item[_class_name]) + return _cls + + +class WorkEnvironment(ImpExp): + parameter = { + "prefer": None, + "use": None, + "callback_path": None, + "_local": None + } + + _supports = {} + + def __init__(self, + prefer: Optional[dict] = None, + callback_path: Optional[dict] = None): + + ImpExp.__init__(self) + if isinstance(prefer, dict): + self.prefer = {k: v for k, v in prefer.items() if k in self.supports} + else: + self.prefer = {} + + self.callback_path = callback_path or {} + self.use = {} + self._local = {} + + def get_use(self): + return self.use + + def set_usage(self, key, value): + self.use[key] = value + + def get_usage(self, key, default=None): + return self.use.get(key, default) + + def get_preference(self, key, default=None): + return self.prefer.get(key, default) + + def set_preference(self, key, value): + self.prefer[key] = value + + def remove_preference(self, key): + if key in self.prefer: + del self.prefer[key] + + def _callback_uris(self, base_url, hex): + _uri = [] + for type in self.get_usage("response_types", self._supports['response_types']): + if "code" in type: + _uri.append('code') + elif type in ["id_token", "id_token token"]: + _uri.append('implicit') + + if "form_post" in self.supports: + _uri.append("form_post") + + callback_uri = {} + for key in _uri: + callback_uri[key] = get_uri(base_url, self.callback_path[key], hex) + return callback_uri + + def construct_redirect_uris(self, + base_url: str, + hex: str, + callbacks: Optional[dict] = None): + if not callbacks: + callbacks = self._callback_uris(base_url, hex) + + if callbacks: + self.set_preference('callbacks', callbacks) + self.set_preference("redirect_uris", [v for k, v in callbacks.items()]) + + self.callback = callbacks + + def verify_rules(self): + return True + + def locals(self, info): + pass + + def _keyjar(self, keyjar=None, conf=None, entity_id=""): + _uri_path = '' + if keyjar is None: + if "keys" in conf: + keys_args = {k: v for k, v in conf["keys"].items() if k != "uri_path"} + _keyjar = init_key_jar(**keys_args) + _uri_path = conf['keys'].get('uri_path') + elif "key_conf" in conf and conf["key_conf"]: + keys_args = {k: v for k, v in conf["key_conf"].items() if k != "uri_path"} + _keyjar = init_key_jar(**keys_args) + _uri_path = conf['key_conf'].get('uri_path') + else: + _keyjar = KeyJar() + if "jwks" in conf: + _keyjar.import_jwks(conf["jwks"], "") + + if "" in _keyjar and entity_id: + # make sure I have the keys under my own name too (if I know it) + _keyjar.import_jwks_as_json(_keyjar.export_jwks_as_json(True, ""), entity_id) + + _httpc_params = conf.get("httpc_params") + if _httpc_params: + _keyjar.httpc_params = _httpc_params + + return _keyjar, _uri_path + else: + return keyjar, _uri_path + + def handle_keys(self, configuration: dict, keyjar: Optional[KeyJar] = None): + _jwks = _jwks_uri = None + _id = self.get_preference('client_id') + keyjar, uri_path = self._keyjar(keyjar, configuration, entity_id=_id) + + _secret = self.get_preference('client_secret') + if _secret: + keyjar.add_symmetric(issuer_id=_id, key=_secret) + keyjar.add_symmetric(issuer_id='', key=_secret) + + # now that keys are in the Key Jar, now for how to publish it + if 'jwks_uri' in configuration: # simple + _jwks_uri = configuration.get('jwks_uri') + elif uri_path: + _jwks_uri = f"{configuration.get('base_url')}{uri_path}" + else: # jwks or nothing + # if only the client secret, no need to publish as a JWKS + try: + _own_keys = keyjar.get_issuer_keys('') + except IssuerNotFound: + pass + else: + if len(_own_keys) == 1 and isinstance(_own_keys[0], SYMKey): + pass + else: + _jwks = keyjar.export_jwks() + + return {'keyjar': keyjar, 'jwks': _jwks, 'jwks_uri': _jwks_uri} + + def load_conf(self, configuration, supports, keyjar: Optional[KeyJar] = None): + for attr, val in configuration.items(): + if attr == "preference": + for k, v in val.items(): + if k in supports: + self.set_preference(k, v) + elif attr in supports: + self.set_preference(attr, val) + + self.locals(configuration) + + for key, val in self.handle_keys(configuration, keyjar=keyjar).items(): + if key == 'keyjar': + keyjar = val + elif val: + self.set_preference(key, val) + + self.verify_rules() + return keyjar + + def get(self, key, default=None): + if key in self._local: + return self._local[key] + else: + return default + + def set(self, key, val): + self._local[key] = val + + def construct_uris(self, *args): + pass + + def supports(self): + res = {} + for key, val in self._supports.items(): + if isinstance(val, Callable): + res[key] = val() + else: + res[key] = val + return res + + def supported(self, claim): + return claim in self._supports + + def prefers(self): + return self.prefer + + +SIGNING_ALGORITHM_SORT_ORDER = ['RS', 'ES', 'PS', 'HS'] + + +def cmp(a, b): + return (a > b) - (a < b) + + +def alg_cmp(a, b): + if a == 'none': + return 1 + elif b == 'none': + return -1 + + _pos1 = SIGNING_ALGORITHM_SORT_ORDER.index(a[0:2]) + _pos2 = SIGNING_ALGORITHM_SORT_ORDER.index(b[0:2]) + if _pos1 == _pos2: + return (a > b) - (a < b) + elif _pos1 > _pos2: + return 1 + else: + return -1 + + +def get_signing_algs(): + # Assumes Cryptojwt + return sorted(list(SIGNER_ALGS.keys()), key=cmp_to_key(alg_cmp)) + + +def get_encryption_algs(): + return SUPPORTED['alg'] + + +def get_encryption_encs(): + return SUPPORTED['enc'] + + diff --git a/tests/static/jwks.json b/tests/static/jwks.json index 8322d976..161a407b 100644 --- a/tests/static/jwks.json +++ b/tests/static/jwks.json @@ -1 +1 @@ -{"keys": [{"kty": "RSA", "use": "sig", "kid": "YnNESFhyQjloMnYzV2VqRGR2a3VCblFLX2h4VGl3TDVlY3FUNkViUE90bw", "n": "2iMaDALTQolz4UaT--GhjriLMyNbrDGlIXxSmgRh17Cm3cuHiyPOIQv1pjZVg4ATU1aafxmFyTfrmtf56tPuJ8yqcNNZC8XadYPAw7PTW9g8GJgLtC8GURJ9GQZD6FYIE6YCou8fYo6yd4b99y2y_vsl06cm9xQnstfp6eyMkcgQyrmdmlbyeuXwvcxsxtGX61MTJtCp4VELmDctJiYP_bD7HNRPV7uqXDMNmWSY0TYL-tg0As4y8-w3wSwmtcfWhnQEraFT0-m4hBpEWHlouuFNXRQIrXbamKxeh6kJNO0wJN8fZ4Ovygf8sE4kEwBPfWO59wxDF7camTpDUqg29Q", "e": "AQAB"}, {"kty": "EC", "use": "sig", "kid": "aWhtalRSTDZmNmRTd1ZDNWZmY3ZGMTNqM1dnLVA2RjQyMi1CNGdOSUNKVQ", "crv": "P-256", "x": "Ww5XVT3CxYN88BpJDZGodRiar0qr8UvPFaRoqzyD1Io", "y": "w23EDFAvwe03NjL5NKtUXwxuVMFmEn3ecJOPbljiDkg"}]} \ No newline at end of file +{"keys": [{"kty": "RSA", "use": "sig", "kid": "YnNESFhyQjloMnYzV2VqRGR2a3VCblFLX2h4VGl3TDVlY3FUNkViUE90bw", "e": "AQAB", "n": "2iMaDALTQolz4UaT--GhjriLMyNbrDGlIXxSmgRh17Cm3cuHiyPOIQv1pjZVg4ATU1aafxmFyTfrmtf56tPuJ8yqcNNZC8XadYPAw7PTW9g8GJgLtC8GURJ9GQZD6FYIE6YCou8fYo6yd4b99y2y_vsl06cm9xQnstfp6eyMkcgQyrmdmlbyeuXwvcxsxtGX61MTJtCp4VELmDctJiYP_bD7HNRPV7uqXDMNmWSY0TYL-tg0As4y8-w3wSwmtcfWhnQEraFT0-m4hBpEWHlouuFNXRQIrXbamKxeh6kJNO0wJN8fZ4Ovygf8sE4kEwBPfWO59wxDF7camTpDUqg29Q"}, {"kty": "EC", "use": "sig", "kid": "aWhtalRSTDZmNmRTd1ZDNWZmY3ZGMTNqM1dnLVA2RjQyMi1CNGdOSUNKVQ", "crv": "P-256", "x": "Ww5XVT3CxYN88BpJDZGodRiar0qr8UvPFaRoqzyD1Io", "y": "w23EDFAvwe03NjL5NKtUXwxuVMFmEn3ecJOPbljiDkg"}]} \ No newline at end of file diff --git a/tests/test_08_transform.py b/tests/test_08_transform.py index ac92245c..606e87b6 100644 --- a/tests/test_08_transform.py +++ b/tests/test_08_transform.py @@ -41,6 +41,7 @@ def setup(self): self.supported = supported def test_supported(self): + assert 'token_endpoint_auth_methods_supported' not in self.supported # These are all the available configuration parameters assert set(self.supported.keys()) == { 'acr_values_supported', @@ -86,7 +87,8 @@ def test_supported(self): 'scopes_supported', 'sector_identifier_uri', 'subject_types_supported', - 'token_endpoint_auth_methods_supported', + 'token_endpoint_auth_method', + # 'token_endpoint_auth_methods_supported', 'token_endpoint_auth_signing_alg_values_supported', 'tos_uri', 'userinfo_encryption_alg_values_supported', @@ -96,58 +98,56 @@ def test_supported(self): def test_oidc_setup(self): # This is OP specified stuff assert set(ProviderConfigurationResponse.c_param.keys()).difference( - set(self.supported)) == { - 'authorization_endpoint', - 'check_session_iframe', - 'claim_types_supported', - 'claims_locales_supported', - 'claims_parameter_supported', - 'claims_supported', - 'display_values_supported', - 'end_session_endpoint', - 'error', - 'error_description', - 'error_uri', - 'issuer', - 'op_policy_uri', - 'op_tos_uri', - 'registration_endpoint', - # 'request_parameter_supported', - # 'request_uri_parameter_supported', - 'require_request_uri_registration', - 'service_documentation', - 'token_endpoint', - 'ui_locales_supported', - 'userinfo_endpoint'} + set(self.supported)) == {'authorization_endpoint', + 'check_session_iframe', + 'claim_types_supported', + 'claims_locales_supported', + 'claims_parameter_supported', + 'claims_supported', + 'display_values_supported', + 'end_session_endpoint', + 'error', + 'error_description', + 'error_uri', + 'issuer', + 'op_policy_uri', + 'op_tos_uri', + 'registration_endpoint', + 'require_request_uri_registration', + 'service_documentation', + 'token_endpoint', + 'token_endpoint_auth_methods_supported', + 'ui_locales_supported', + 'userinfo_endpoint'} # parameters that are not mapped against what the OP's provider info says assert set(self.supported).difference( - set(ProviderConfigurationResponse.c_param.keys())) == { - 'application_type', - 'backchannel_logout_uri', - 'callback_uris', - 'client_id', - 'client_name', - 'client_secret', - 'client_uri', - 'contacts', - 'default_max_age', - 'encrypt_id_token_supported', - 'encrypt_request_object_supported', - 'encrypt_userinfo_supported', - 'frontchannel_logout_uri', - 'initiate_login_uri', - 'jwks', - 'logo_uri', - 'policy_uri', - 'post_logout_redirect_uris', - 'redirect_uris', - 'request_parameter', - 'request_uris', - 'requests_dir', - 'require_auth_time', - 'sector_identifier_uri', - 'tos_uri'} + set(ProviderConfigurationResponse.c_param.keys())) == {'application_type', + 'backchannel_logout_uri', + 'callback_uris', + 'client_id', + 'client_name', + 'client_secret', + 'client_uri', + 'contacts', + 'default_max_age', + 'encrypt_id_token_supported', + 'encrypt_request_object_supported', + 'encrypt_userinfo_supported', + 'frontchannel_logout_uri', + 'initiate_login_uri', + 'jwks', + 'logo_uri', + 'policy_uri', + 'post_logout_redirect_uris', + 'redirect_uris', + 'request_parameter', + 'request_uris', + 'requests_dir', + 'require_auth_time', + 'sector_identifier_uri', + 'token_endpoint_auth_method', + 'tos_uri'} preference = {} pref = supported_to_preferred(supported=self.supported, preference=preference, @@ -168,7 +168,7 @@ def test_oidc_setup(self): 'response_types_supported', 'scopes_supported', 'subject_types_supported', - 'token_endpoint_auth_methods_supported', + 'token_endpoint_auth_method', 'token_endpoint_auth_signing_alg_values_supported', 'userinfo_encryption_alg_values_supported', 'userinfo_encryption_enc_values_supported', @@ -182,7 +182,7 @@ def test_oidc_setup(self): reg_claim.append(key) assert set(RegistrationRequest.c_param.keys()).difference(set(reg_claim)) == { - 'post_logout_redirect_uri'} + 'post_logout_redirect_uri', 'token_endpoint_auth_method'} # Which ones are list -> singletons @@ -242,7 +242,7 @@ def test_provider_info(self): 'response_types_supported', 'scopes_supported', 'subject_types_supported', - 'token_endpoint_auth_methods_supported', + 'token_endpoint_auth_method', 'token_endpoint_auth_signing_alg_values_supported', 'userinfo_encryption_alg_values_supported', 'userinfo_encryption_enc_values_supported', @@ -340,7 +340,6 @@ def test_registration_response(self): 'request_object_signing_alg', 'response_types', 'subject_type', - 'token_endpoint_auth_method', 'token_endpoint_auth_signing_alg', 'userinfo_signed_response_alg'} diff --git a/tests/test_09_work_condition.py b/tests/test_09_work_condition.py index 82df52f6..8353d34a 100644 --- a/tests/test_09_work_condition.py +++ b/tests/test_09_work_condition.py @@ -175,7 +175,6 @@ def test_registration_response(self): 'request_object_signing_alg', 'response_types', 'subject_type', - 'token_endpoint_auth_method', 'token_endpoint_auth_signing_alg', 'userinfo_signed_response_alg'} @@ -214,7 +213,6 @@ def test_registration_response(self): 'id_token_signed_response_alg', 'jwks', 'jwks_uri', - 'keyjar', 'logo_uri', 'redirect_uris', 'request_object_signing_alg', diff --git a/tests/test_12_context.py b/tests/test_12_context.py deleted file mode 100644 index 0f5919d2..00000000 --- a/tests/test_12_context.py +++ /dev/null @@ -1,88 +0,0 @@ -import copy -import shutil - -import pytest - -from idpyoidc.context import OidcContext - -KEYDEF = [ - {"type": "EC", "crv": "P-256", "use": ["sig"]}, - {"type": "EC", "crv": "P-256", "use": ["enc"]}, -] - -JWKS = { - "keys": [ - { - "n": "zkpUgEgXICI54blf6iWiD2RbMDCOO1jV0VSff1MFFnujM4othfMsad7H1kRo50YM5S" - "_X9TdvrpdOfpz5aBaKFhT6Ziv0nhtcekq1eRl8mjBlvGKCE5XGk-0LFSDwvqgkJoFY" - "Inq7bu0a4JEzKs5AyJY75YlGh879k1Uu2Sv3ZZOunfV1O1Orta-NvS-aG_jN5cstVb" - "CGWE20H0vFVrJKNx0Zf-u-aA-syM4uX7wdWgQ-owoEMHge0GmGgzso2lwOYf_4znan" - "LwEuO3p5aabEaFoKNR4K6GjQcjBcYmDEE4CtfRU9AEmhcD1kleiTB9TjPWkgDmT9MX" - "sGxBHf3AKT5w", - "e": "AQAB", - "kty": "RSA", - "kid": "rsa1", - }, - { - "k": "YTEyZjBlMDgxMGI4YWU4Y2JjZDFiYTFlZTBjYzljNDU3YWM0ZWNiNzhmNmFlYTNkNTY0NzMzYjE", - "kty": "oct", - }, - ] -} - - -def test_dump_load(): - c = OidcContext({}) - assert c.keyjar is not None - mem = c.dump() - c2 = OidcContext().load(mem) - assert c2.keyjar is not None - - -class TestDumpLoad(object): - @pytest.fixture(autouse=True) - def setup(self): - self.conf = {"issuer": "https://example.com"} - - def test_context_with_entity_id_no_keys(self): - c = OidcContext(self.conf, entity_id="https://example.com") - mem = c.dump() - c2 = OidcContext().load(mem) - assert c2.keyjar.owners() == [] - - def test_context_with_entity_id_and_keys(self): - conf = copy.deepcopy(self.conf) - conf["keys"] = {"key_defs": KEYDEF} - c = OidcContext(conf, entity_id="https://example.com") - - mem = c.dump() - c2 = OidcContext().load(mem) - assert set(c2.keyjar.owners()) == {"", "https://example.com"} - - def test_context_with_entity_id_and_jwks(self): - conf = copy.deepcopy(self.conf) - conf["jwks"] = JWKS - c = OidcContext(conf, entity_id="https://example.com") - - mem = c.dump() - c2 = OidcContext().load(mem) - - assert set(c2.keyjar.owners()) == {"", "https://example.com"} - assert len(c2.keyjar.get("sig", "RSA")) == 1 - assert len(c2.keyjar.get("sig", "RSA", issuer_id="https://example.com")) == 1 - assert len(c2.keyjar.get("sig", "oct")) == 1 - assert len(c2.keyjar.get("sig", "oct", issuer_id="https://example.com")) == 1 - - def test_context_restore(self): - conf = copy.deepcopy(self.conf) - conf["keys"] = {"key_defs": KEYDEF} - - c = OidcContext(conf, entity_id="https://example.com") - mem = c.dump() - c2 = OidcContext().load(mem) - - assert set(c2.keyjar.owners()) == {"", "https://example.com"} - assert len(c2.keyjar.get("sig", "EC")) == 1 - assert len(c2.keyjar.get("enc", "EC")) == 1 - assert len(c.keyjar.get("sig", "RSA")) == 0 - assert len(c.keyjar.get("sig", "oct")) == 0 diff --git a/tests/test_client_06_client_authn.py b/tests/test_client_06_client_authn.py index 1544d1ee..060ef48b 100644 --- a/tests/test_client_06_client_authn.py +++ b/tests/test_client_06_client_authn.py @@ -21,7 +21,7 @@ from idpyoidc.client.client_auth import bearer_auth from idpyoidc.client.client_auth import valid_service_context from idpyoidc.client.entity import Entity -from idpyoidc.client.work_environment import WorkEnvironment +from idpyoidc.work_environment import WorkEnvironment from idpyoidc.defaults import JWT_BEARER from idpyoidc.message import Message from idpyoidc.message.oauth2 import AccessTokenRequest diff --git a/tests/test_client_19_webfinger.py b/tests/test_client_19_webfinger.py index a1a289c4..0edc919b 100644 --- a/tests/test_client_19_webfinger.py +++ b/tests/test_client_19_webfinger.py @@ -8,14 +8,13 @@ from idpyoidc.client.entity import Entity from idpyoidc.client.oidc import OIC_ISSUER from idpyoidc.client.oidc.webfinger import WebFinger -from idpyoidc.client.service_context import ServiceContext from idpyoidc.exception import MissingRequiredAttribute from idpyoidc.message.oidc import JRD from idpyoidc.message.oidc import Link __author__ = "Roland Hedberg" -ENTITY = Entity(config={"base_url":"https://example.com"}) +ENTITY = Entity(config={"base_url": "https://example.com"}) def test_query(): diff --git a/tests/test_server_00a_client_configure.py b/tests/test_server_00a_client_configure.py index a6f18bbf..665d936f 100644 --- a/tests/test_server_00a_client_configure.py +++ b/tests/test_server_00a_client_configure.py @@ -9,7 +9,7 @@ BASEDIR = os.path.abspath(os.path.dirname(__file__)) -extra = { +EXTRA = { "token_usage_rules": { "authorization_code": { "expires_in": 600, @@ -106,7 +106,7 @@ def test_verify_oidc_client_information_complext(): } } - client_conf["client1"].update(extra) + client_conf["client1"].update(EXTRA) res = verify_oidc_client_information(client_conf, server_get=server.server_get) assert res From afdb39ddc3667b2d412ca0f225bf4ebf0762153d Mon Sep 17 00:00:00 2001 From: roland Date: Thu, 1 Dec 2022 19:10:43 +0100 Subject: [PATCH 24/76] Partly done with harmonizing work environment usage. --- .../client/work_environment/__init__.py | 37 ++++ .../client/work_environment/oauth2.py | 1 - src/idpyoidc/server/__init__.py | 1 + src/idpyoidc/server/client_authn.py | 163 +++++++++--------- src/idpyoidc/server/endpoint.py | 6 - src/idpyoidc/server/endpoint_context.py | 26 ++- src/idpyoidc/server/oidc/authorization.py | 2 +- src/idpyoidc/server/oidc/provider_config.py | 16 +- src/idpyoidc/server/oidc/registration.py | 39 +++-- src/idpyoidc/server/oidc/token.py | 4 +- src/idpyoidc/server/util.py | 12 +- .../server/work_environment/__init__.py | 14 +- .../server/work_environment/oauth2.py | 30 ++-- src/idpyoidc/server/work_environment/oidc.py | 22 +-- src/idpyoidc/work_environment.py | 34 ++-- tests/static/jwks.json | 2 +- tests/test_server_16_endpoint_context.py | 162 +++++++---------- tests/test_server_20a_server.py | 9 +- ...server_22_oidc_provider_config_endpoint.py | 44 +---- 19 files changed, 304 insertions(+), 320 deletions(-) diff --git a/src/idpyoidc/client/work_environment/__init__.py b/src/idpyoidc/client/work_environment/__init__.py index 02d022ae..7eba9d42 100644 --- a/src/idpyoidc/client/work_environment/__init__.py +++ b/src/idpyoidc/client/work_environment/__init__.py @@ -1,5 +1,42 @@ +from cryptojwt.exception import IssuerNotFound +from cryptojwt.jwk.hmac import SYMKey + +from idpyoidc import work_environment from idpyoidc.client.client_auth import CLIENT_AUTHN_METHOD def get_client_authn_methods(): return list(CLIENT_AUTHN_METHOD.keys()) + + +class WorkEnvironment(work_environment.WorkEnvironment): + + def get_base_url(self, configuration: dict): + _base = configuration.get('base_url') + if not _base: + _base = configuration.get('client_id') + + return _base + + def get_id(self, configuration: dict): + return self.get_preference('client_id') + + def add_extra_keys(self, keyjar, id): + _secret = self.get_preference('client_secret') + if _secret: + keyjar.add_symmetric(issuer_id=id, key=_secret) + keyjar.add_symmetric(issuer_id='', key=_secret) + + def get_jwks(self, keyjar): + _jwks = None + try: + _own_keys = keyjar.get_issuer_keys('') + except IssuerNotFound: + pass + else: + if len(_own_keys) == 1 and isinstance(_own_keys[0], SYMKey): + pass + else: + _jwks = keyjar.export_jwks() + + return _jwks diff --git a/src/idpyoidc/client/work_environment/oauth2.py b/src/idpyoidc/client/work_environment/oauth2.py index 4b1bb605..69b5a20c 100644 --- a/src/idpyoidc/client/work_environment/oauth2.py +++ b/src/idpyoidc/client/work_environment/oauth2.py @@ -23,7 +23,6 @@ class WorkEnvironment(work_environment.WorkEnvironment): "software_version": None } - callback_path = {} callback_uris = ["redirect_uris"] diff --git a/src/idpyoidc/server/__init__.py b/src/idpyoidc/server/__init__.py index ccf09ac5..df93fd21 100644 --- a/src/idpyoidc/server/__init__.py +++ b/src/idpyoidc/server/__init__.py @@ -57,6 +57,7 @@ def __init__( cookie_handler=cookie_handler, httpc=httpc, ) + self.endpoint_context.set_provider_info() self.endpoint_context.authz = self.setup_authz() # _cap = get_provider_capabilities(conf, self.endpoint) # self.endpoint_context.provider_info = self.endpoint_context.create_providerinfo(_cap) diff --git a/src/idpyoidc/server/client_authn.py b/src/idpyoidc/server/client_authn.py index c11e1a8e..98233f5d 100755 --- a/src/idpyoidc/server/client_authn.py +++ b/src/idpyoidc/server/client_authn.py @@ -3,14 +3,12 @@ from typing import Callable from typing import Dict from typing import Optional -from typing import Union - from typing import TYPE_CHECKING +from typing import Union if TYPE_CHECKING: from idpyoidc.server.endpoint_context import EndpointContext - from cryptojwt.exception import BadSignature from cryptojwt.exception import Invalid from cryptojwt.exception import MissingKey @@ -48,12 +46,12 @@ def __init__(self, server_get): self.server_get = server_get def _verify( - self, - endpoint_context: "EndpointContext", - request: Optional[Union[dict, Message]] = None, - authorization_token: Optional[str] = None, - endpoint=None, # Optional[Endpoint] - **kwargs, + self, + endpoint_context: "EndpointContext", + request: Optional[Union[dict, Message]] = None, + authorization_token: Optional[str] = None, + endpoint=None, # Optional[Endpoint] + **kwargs, ): """ Verify authentication information in a request @@ -63,12 +61,12 @@ def _verify( raise NotImplementedError() def verify( - self, - request: Optional[Union[dict, Message]] = None, - authorization_token: Optional[str] = None, - endpoint=None, # Optional[Endpoint] - get_client_id_from_token: Optional[Callable] = None, - **kwargs, + self, + request: Optional[Union[dict, Message]] = None, + authorization_token: Optional[str] = None, + endpoint=None, # Optional[Endpoint] + get_client_id_from_token: Optional[Callable] = None, + **kwargs, ): """ Verify authentication information in a request @@ -87,9 +85,9 @@ def verify( return res def is_usable( - self, - request: Optional[Union[dict, Message]] = None, - authorization_token: Optional[str] = None, + self, + request: Optional[Union[dict, Message]] = None, + authorization_token: Optional[str] = None, ): """ Verify that this authentication method is applicable. @@ -126,12 +124,12 @@ def is_usable(self, request=None, authorization_token=None): return request is not None def _verify( - self, - endpoint_context: "EndpointContext", - request: Optional[Union[dict, Message]] = None, - authorization_token: Optional[str] = None, - endpoint=None, # Optional[Endpoint] - **kwargs, + self, + endpoint_context: "EndpointContext", + request: Optional[Union[dict, Message]] = None, + authorization_token: Optional[str] = None, + endpoint=None, # Optional[Endpoint] + **kwargs, ): return {"client_id": request.get("client_id")} @@ -148,12 +146,12 @@ def is_usable(self, request=None, authorization_token=None): return request and "client_id" in request def _verify( - self, - endpoint_context: "EndpointContext", - request: Optional[Union[dict, Message]] = None, - authorization_token: Optional[str] = None, - endpoint=None, # Optional[Endpoint] - **kwargs, + self, + endpoint_context: "EndpointContext", + request: Optional[Union[dict, Message]] = None, + authorization_token: Optional[str] = None, + endpoint=None, # Optional[Endpoint] + **kwargs, ): return {"client_id": request["client_id"]} @@ -173,12 +171,12 @@ def is_usable(self, request=None, authorization_token=None): return False def _verify( - self, - endpoint_context: "EndpointContext", - request: Optional[Union[dict, Message]] = None, - authorization_token: Optional[str] = None, - endpoint=None, # Optional[Endpoint] - **kwargs, + self, + endpoint_context: "EndpointContext", + request: Optional[Union[dict, Message]] = None, + authorization_token: Optional[str] = None, + endpoint=None, # Optional[Endpoint] + **kwargs, ): client_info = basic_authn(authorization_token) @@ -206,12 +204,12 @@ def is_usable(self, request=None, authorization_token=None): return False def _verify( - self, - endpoint_context: "EndpointContext", - request: Optional[Union[dict, Message]] = None, - authorization_token: Optional[str] = None, - endpoint=None, # Optional[Endpoint] - **kwargs, + self, + endpoint_context: "EndpointContext", + request: Optional[Union[dict, Message]] = None, + authorization_token: Optional[str] = None, + endpoint=None, # Optional[Endpoint] + **kwargs, ): if endpoint_context.cdb[request["client_id"]]["client_secret"] == request["client_secret"]: return {"client_id": request["client_id"]} @@ -230,13 +228,13 @@ def is_usable(self, request=None, authorization_token=None): return False def _verify( - self, - endpoint_context: "EndpointContext", - request: Optional[Union[dict, Message]] = None, - authorization_token: Optional[str] = None, - endpoint=None, # Optional[Endpoint] - get_client_id_from_token: Optional[Callable] = None, - **kwargs, + self, + endpoint_context: "EndpointContext", + request: Optional[Union[dict, Message]] = None, + authorization_token: Optional[str] = None, + endpoint=None, # Optional[Endpoint] + get_client_id_from_token: Optional[Callable] = None, + **kwargs, ): token = authorization_token.split(" ", 1)[1] try: @@ -281,6 +279,7 @@ def _verify( class JWSAuthnMethod(ClientAuthnMethod): + def is_usable(self, request=None, authorization_token=None): if request is None: return False @@ -289,13 +288,13 @@ def is_usable(self, request=None, authorization_token=None): return False def _verify( - self, - endpoint_context: "EndpointContext", - request: Optional[Union[dict, Message]] = None, - authorization_token: Optional[str] = None, - endpoint=None, # Optional[Endpoint] - key_type: Optional[str] = None, - **kwargs, + self, + endpoint_context: "EndpointContext", + request: Optional[Union[dict, Message]] = None, + authorization_token: Optional[str] = None, + endpoint=None, # Optional[Endpoint] + key_type: Optional[str] = None, + **kwargs, ): _jwt = JWT(endpoint_context.keyjar, msg_cls=JsonWebToken) try: @@ -358,12 +357,12 @@ class ClientSecretJWT(JWSAuthnMethod): tag = "client_secret_jwt" def _verify( - self, - endpoint_context: "EndpointContext", - request: Optional[Union[dict, Message]] = None, - authorization_token: Optional[str] = None, - endpoint=None, # Optional[Endpoint] - **kwargs, + self, + endpoint_context: "EndpointContext", + request: Optional[Union[dict, Message]] = None, + authorization_token: Optional[str] = None, + endpoint=None, # Optional[Endpoint] + **kwargs, ): res = super()._verify( endpoint_context, request=request, key_type="client_secret", endpoint=endpoint, **kwargs @@ -380,12 +379,12 @@ class PrivateKeyJWT(JWSAuthnMethod): tag = "private_key_jwt" def _verify( - self, - endpoint_context: "EndpointContext", - request: Optional[Union[dict, Message]] = None, - authorization_token: Optional[str] = None, - endpoint=None, # Optional[Endpoint] - **kwargs, + self, + endpoint_context: "EndpointContext", + request: Optional[Union[dict, Message]] = None, + authorization_token: Optional[str] = None, + endpoint=None, # Optional[Endpoint] + **kwargs, ): res = super()._verify( endpoint_context, @@ -407,12 +406,12 @@ def is_usable(self, request=None, authorization_token=None): return True def _verify( - self, - endpoint_context: "EndpointContext", - request: Optional[Union[dict, Message]] = None, - authorization_token: Optional[str] = None, - endpoint=None, # Optional[Endpoint] - **kwargs, + self, + endpoint_context: "EndpointContext", + request: Optional[Union[dict, Message]] = None, + authorization_token: Optional[str] = None, + endpoint=None, # Optional[Endpoint] + **kwargs, ): _jwt = JWT(endpoint_context.keyjar, msg_cls=JsonWebToken) try: @@ -459,12 +458,12 @@ def valid_client_info(cinfo): def verify_client( - endpoint_context: "EndpointContext", - request: Union[dict, Message], - http_info: Optional[dict] = None, - get_client_id_from_token: Optional[Callable] = None, - endpoint=None, # Optional[Endpoint] - also_known_as: Optional[Dict[str, str]] = None, + endpoint_context: "EndpointContext", + request: Union[dict, Message], + http_info: Optional[dict] = None, + get_client_id_from_token: Optional[Callable] = None, + endpoint=None, # Optional[Endpoint] + also_known_as: Optional[Dict[str, str]] = None, ): """ Initiated Guessing ! @@ -568,3 +567,7 @@ def client_auth_setup(server_get, auth_set=None): cls = importer(cls) res[name] = cls(server_get) return res + + +def get_client_authn_methods(): + return list(CLIENT_AUTHN_METHOD.keys()) diff --git a/src/idpyoidc/server/endpoint.py b/src/idpyoidc/server/endpoint.py index 4e05f52e..083940bb 100755 --- a/src/idpyoidc/server/endpoint.py +++ b/src/idpyoidc/server/endpoint.py @@ -138,12 +138,6 @@ def set_client_authn_methods(self, **kwargs): self.endpoint_info = construct_provider_info(self.default_capabilities, **kwargs) return kwargs - def get_provider_info_attributes(self): - _pia = construct_provider_info(self.provider_info_attributes, **self.kwargs) - if self.endpoint_name: - _pia[self.endpoint_name] = self.full_path - return _pia - def process_verify_error(self, exception): _error = "invalid_request" return self.error_cls(error=_error, error_description="%s" % exception) diff --git a/src/idpyoidc/server/endpoint_context.py b/src/idpyoidc/server/endpoint_context.py index 7755ad9b..e6fb11d1 100755 --- a/src/idpyoidc/server/endpoint_context.py +++ b/src/idpyoidc/server/endpoint_context.py @@ -247,8 +247,12 @@ def __init__( self.dev_auth_db = None self.claims_interface = init_service(conf["claims_interface"], self.server_get) - self.keyjar = self.work_environment.load_conf(conf.conf, supports=self.supports(), - keyjar=keyjar) + if isinstance(conf, OPConfiguration): + self.keyjar = self.work_environment.load_conf(conf.conf, supports=self.supports(), + keyjar=keyjar) + else: # OidcConfig + self.keyjar = self.work_environment.load_conf(conf, supports=self.supports(), + keyjar=keyjar) def new_cookie(self, name: str, max_age: Optional[int] = 0, **kwargs): cookie_cont = self.cookie_handler.make_cookie_content( @@ -391,13 +395,25 @@ def supports(self): def set_provider_info(self): prefers = self.work_environment.prefer supported = self.supports() - _info = {} + _info = {'issuer': self.issuer} for key, spec in ProviderConfigurationResponse.c_param.items(): _val = prefers.get(key, None) - if _val is None: + if not _val and _val != False: _val = supported.get(key, None) - if _val is None: + if not _val and _val != False: continue _info[key] = _val self.provider_info = _info + + def get_preference(self, claim, default=None): + return self.work_environment.get_preference(claim, default=default) + + def set_preference(self, key, value): + self.work_environment.set_preference(key, value) + + def get_usage(self, claim, default: Optional[str] = None): + return self.work_environment.get_usage(claim, default) + + def set_usage(self, claim, value): + return self.work_environment.set_usage(claim, value) diff --git a/src/idpyoidc/server/oidc/authorization.py b/src/idpyoidc/server/oidc/authorization.py index ae302bb9..6a7d5eef 100755 --- a/src/idpyoidc/server/oidc/authorization.py +++ b/src/idpyoidc/server/oidc/authorization.py @@ -85,7 +85,7 @@ class Authorization(authorization.Authorization): "request_uri_parameter_supported": True, "require_request_uri_registration": False, "response_types_supported": ["code", "token", "code token", 'id_token', 'id_token token', - 'code id_token', 'code idtoken token'], + 'code id_token', 'code id_token token'], "response_modes_supported": ['query', 'fragment', 'form_post'], "subject_types_supported": ["public", "pairwise", "ephemeral"], } diff --git a/src/idpyoidc/server/oidc/provider_config.py b/src/idpyoidc/server/oidc/provider_config.py index 38f61584..0c0de15a 100755 --- a/src/idpyoidc/server/oidc/provider_config.py +++ b/src/idpyoidc/server/oidc/provider_config.py @@ -18,19 +18,19 @@ def __init__(self, server_get, **kwargs): Endpoint.__init__(self, server_get=server_get, **kwargs) self.pre_construct.append(self.add_endpoints) - def add_endpoints(self, request, client_id, endpoint_context, **kwargs): + def add_endpoints(self, info, client_id, endpoint_context, **kwargs): for endpoint in [ - "authorization_endpoint", - "registration_endpoint", - "token_endpoint", - "userinfo_endpoint", - "end_session_endpoint", + "authorization", + "provider_config", + "token", + "userinfo", + "session", ]: endp_instance = self.server_get("endpoint", endpoint) if endp_instance: - request[endpoint] = endp_instance.endpoint_path + info[endp_instance.endpoint_name] = endp_instance.full_path - return request + return info def process_request(self, request=None, **kwargs): return {"response_args": self.server_get("endpoint_context").provider_info} diff --git a/src/idpyoidc/server/oidc/registration.py b/src/idpyoidc/server/oidc/registration.py index 8076f618..283b5cc5 100755 --- a/src/idpyoidc/server/oidc/registration.py +++ b/src/idpyoidc/server/oidc/registration.py @@ -10,6 +10,8 @@ from cryptojwt.utils import as_bytes # from idpyoidc.defaults import PREFERENCE2SUPPORTED +from idpyoidc.client.work_environment.transform import REGISTER2PREFERRED + from idpyoidc.exception import MessageException from idpyoidc.message.oauth2 import ResponseMessage from idpyoidc.message.oidc import ClientRegistrationErrorResponse @@ -136,20 +138,22 @@ def __init__(self, *args, **kwargs): _seed = kwargs.get("seed") or rndstr(32) self.seed = as_bytes(_seed) - def match_client_request(self, request): - _context = self.server_get("endpoint_context") - # for _pref, _prov in PREFERENCE2SUPPORTED.items(): - # if _pref in request: - # if _pref in ["response_types", "default_acr_values"]: - # if not match_sp_sep(request[_pref], _context.provider_info[_prov]): - # raise CapabilitiesMisMatch(_pref) - # else: - # if isinstance(request[_pref], str): - # if request[_pref] not in _context.provider_info[_prov]: - # raise CapabilitiesMisMatch(_pref) - # else: - # if not set(request[_pref]).issubset(set(_context.provider_info[_prov])): - # raise CapabilitiesMisMatch(_pref) + def match_client_request(self, request: dict) -> list: + err = [] + _provider_info = self.server_get("endpoint_context").provider_info + for key, val in request.items(): + if key not in REGISTER2PREFERRED: + continue + _pi_key = REGISTER2PREFERRED.get(key, key) + if isinstance(val, str): + if val not in _provider_info[_pi_key]: + logger.error(f"CapabilitiesMisMatch: {key}") + err.append(key) + else: + if not set(val).issubset(set(_provider_info[_pi_key])): + logger.error(f"CapabilitiesMisMatch: {key}") + err.append(key) + return err def do_client_registration(self, request, client_id, ignore=None): if ignore is None: @@ -377,12 +381,11 @@ def client_registration_setup(self, request, new_id=True, set_secret=True): return ResponseMessage(error=_error, error_description="%s" % err) request.rm_blanks() - try: - self.match_client_request(request) - except CapabilitiesMisMatch as err: + faulty_claims = self.match_client_request(request) + if faulty_claims: return ResponseMessage( error="invalid_request", - error_description="Don't support proposed %s" % err, + error_description=f"Don't support proposed {faulty_claims}" ) _context = self.server_get("endpoint_context") diff --git a/src/idpyoidc/server/oidc/token.py b/src/idpyoidc/server/oidc/token.py index 19491baa..cab8f66c 100755 --- a/src/idpyoidc/server/oidc/token.py +++ b/src/idpyoidc/server/oidc/token.py @@ -1,5 +1,7 @@ import logging +from idpyoidc import work_environment + from idpyoidc.message import Message from idpyoidc.message import oidc from idpyoidc.message.oidc import TokenErrorResponse @@ -31,7 +33,7 @@ class Token(token.Token): "client_secret_jwt", "private_key_jwt", ], - "token_endpoint_auth_signing_alg_values_supported": None, + "token_endpoint_auth_signing_alg_values_supported": work_environment.get_signing_algs, } # auth_method_attribute = "token_endpoint_auth_methods_supported" diff --git a/src/idpyoidc/server/util.py b/src/idpyoidc/server/util.py index 241bcebf..3e00d43f 100755 --- a/src/idpyoidc/server/util.py +++ b/src/idpyoidc/server/util.py @@ -51,12 +51,6 @@ def build_endpoints(conf, server_get, issuer): _instance.endpoint_path = _path _instance.full_path = "{}/{}".format(_url, _path) - # if _instance.endpoint_name: - # try: - # _instance.endpoint_info[_instance.endpoint_name] = _instance.full_path - # except TypeError: - # _instance.endpoint_info = {_instance.endpoint_name: _instance.full_path} - endpoint[_instance.name] = _instance return endpoint @@ -135,9 +129,9 @@ def allow_refresh_token(endpoint_context): # Is refresh_token grant type supported _token_supported = False - _cap = endpoint_context.conf.get("capabilities") - if _cap: - if "refresh_token" in _cap["grant_types_supported"]: + _supported = endpoint_context.get_preference("grant_types_supported") + if _supported: + if "refresh_token" in _supported: # self.allow_refresh = kwargs.get("allow_refresh", True) _token_supported = True diff --git a/src/idpyoidc/server/work_environment/__init__.py b/src/idpyoidc/server/work_environment/__init__.py index 249d369b..06bbda43 100644 --- a/src/idpyoidc/server/work_environment/__init__.py +++ b/src/idpyoidc/server/work_environment/__init__.py @@ -1,8 +1,14 @@ -from idpyoidc.message.oidc import ProviderConfigurationResponse -from idpyoidc.server.client_authn import CLIENT_AUTHN_METHOD +from idpyoidc import work_environment -def get_client_authn_methods(): - return list(CLIENT_AUTHN_METHOD.keys()) +class WorkEnvironment(work_environment.WorkEnvironment): + def get_base_url(self, configuration: dict): + _base = configuration.get('base_url') + if not _base: + _base = configuration.get('issuer') + return _base + + def get_id(self, configuration: dict): + return configuration.get('issuer') diff --git a/src/idpyoidc/server/work_environment/oauth2.py b/src/idpyoidc/server/work_environment/oauth2.py index 4b1bb605..42ab7579 100644 --- a/src/idpyoidc/server/work_environment/oauth2.py +++ b/src/idpyoidc/server/work_environment/oauth2.py @@ -1,26 +1,26 @@ from typing import Optional -from idpyoidc import work_environment +from idpyoidc.server import work_environment class WorkEnvironment(work_environment.WorkEnvironment): + # 'issuer', 'authorization_endpoint', 'token_endpoint', 'jwks_uri', 'registration_endpoint', + # 'scopes_supported', 'response_types_supported', 'response_modes_supported', + # 'grant_types_supported', 'token_endpoint_auth_methods_supported', + # 'token_endpoint_auth_signing_alg_values_supported', 'service_documentation', + # 'ui_locales_supported', 'op_policy_uri', 'op_tos_uri', 'revocation_endpoint', + # 'introspection_endpoint' _supports = { - "redirect_uris": None, - "grant_types": ["authorization_code", "implicit", "refresh_token"], - "response_types": ["code"], - "client_id": None, - 'client_secret': None, - "client_name": None, - "client_uri": None, - "logo_uri": None, - "contacts": None, - "scopes_supported": [], - "tos_uri": None, - "policy_uri": None, + "grant_types_supported": ["authorization_code", "implicit", "refresh_token"], + "response_types_supported": ["code"], + "response_modes_supported": ["code"], "jwks_uri": None, "jwks": None, - "software_id": None, - "software_version": None + "scopes_supported": [], + "service_documentation": None, + "ui_locales_supported": [], + "op_tos_uri": None, + "op_policy_uri": None, } diff --git a/src/idpyoidc/server/work_environment/oidc.py b/src/idpyoidc/server/work_environment/oidc.py index 1107affb..c776da38 100644 --- a/src/idpyoidc/server/work_environment/oidc.py +++ b/src/idpyoidc/server/work_environment/oidc.py @@ -1,7 +1,7 @@ -import os from typing import Optional -from idpyoidc import work_environment +from idpyoidc import work_environment as WE +from idpyoidc.server import work_environment class WorkEnvironment(work_environment.WorkEnvironment): @@ -17,9 +17,9 @@ class WorkEnvironment(work_environment.WorkEnvironment): "display_values_supported": None, "encrypt_id_token_supported": None, "grant_types_supported": ["authorization_code", "implicit", "refresh_token"], - "id_token_signing_alg_values_supported": work_environment.get_signing_algs, - "id_token_encryption_alg_values_supported": work_environment.get_encryption_algs, - "id_token_encryption_enc_values_supported": work_environment.get_encryption_encs, + "id_token_signing_alg_values_supported": WE.get_signing_algs, + "id_token_encryption_alg_values_supported": WE.get_encryption_algs, + "id_token_encryption_enc_values_supported": WE.get_encryption_encs, "initiate_login_uri": None, "jwks": None, "jwks_uri": None, @@ -28,7 +28,8 @@ class WorkEnvironment(work_environment.WorkEnvironment): "scopes_supported": ["openid"], "service_documentation": None, "op_tos_uri": None, - "ui_locales_supported": None + "ui_locales_supported": None, + # "version": '3.0' # "verify_args": None, } @@ -56,12 +57,3 @@ def verify_rules(self): if not self.get_preference('encrypt_id_token_supported'): self.set_preference('id_token_encryption_alg_values_supported', []) self.set_preference('id_token_encryption_enc_values_supported', []) - - def locals(self, info): - requests_dir = info.get("requests_dir") - if requests_dir: - # make sure the path exists. If not, then create it. - if not os.path.isdir(requests_dir): - os.makedirs(requests_dir) - - self.set("requests_dir", requests_dir) diff --git a/src/idpyoidc/work_environment.py b/src/idpyoidc/work_environment.py index 79e8aee6..5b1f56fc 100644 --- a/src/idpyoidc/work_environment.py +++ b/src/idpyoidc/work_environment.py @@ -13,6 +13,7 @@ from idpyoidc.client.client_auth import CLIENT_AUTHN_METHOD from idpyoidc.client.util import get_uri from idpyoidc.impexp import ImpExp +from idpyoidc.util import add_path from idpyoidc.util import qualified_name @@ -133,32 +134,33 @@ def _keyjar(self, keyjar=None, conf=None, entity_id=""): else: return keyjar, _uri_path + def get_base_url(self, configuration: dict): + raise NotImplemented() + + def get_id(self, configuration: dict): + raise NotImplemented() + + def add_extra_keys(self, keyjar, id): + return None + + def get_jwks(self, keyjar): + return None + def handle_keys(self, configuration: dict, keyjar: Optional[KeyJar] = None): _jwks = _jwks_uri = None - _id = self.get_preference('client_id') + _id = self.get_id(configuration) keyjar, uri_path = self._keyjar(keyjar, configuration, entity_id=_id) - _secret = self.get_preference('client_secret') - if _secret: - keyjar.add_symmetric(issuer_id=_id, key=_secret) - keyjar.add_symmetric(issuer_id='', key=_secret) + self.add_extra_keys(keyjar, _id) # now that keys are in the Key Jar, now for how to publish it if 'jwks_uri' in configuration: # simple _jwks_uri = configuration.get('jwks_uri') elif uri_path: - _jwks_uri = f"{configuration.get('base_url')}{uri_path}" + _base_url = self.get_base_url(configuration) + _jwks_uri = add_path(_base_url, uri_path) else: # jwks or nothing - # if only the client secret, no need to publish as a JWKS - try: - _own_keys = keyjar.get_issuer_keys('') - except IssuerNotFound: - pass - else: - if len(_own_keys) == 1 and isinstance(_own_keys[0], SYMKey): - pass - else: - _jwks = keyjar.export_jwks() + _jwks = self.get_jwks(keyjar) return {'keyjar': keyjar, 'jwks': _jwks, 'jwks_uri': _jwks_uri} diff --git a/tests/static/jwks.json b/tests/static/jwks.json index 161a407b..8322d976 100644 --- a/tests/static/jwks.json +++ b/tests/static/jwks.json @@ -1 +1 @@ -{"keys": [{"kty": "RSA", "use": "sig", "kid": "YnNESFhyQjloMnYzV2VqRGR2a3VCblFLX2h4VGl3TDVlY3FUNkViUE90bw", "e": "AQAB", "n": "2iMaDALTQolz4UaT--GhjriLMyNbrDGlIXxSmgRh17Cm3cuHiyPOIQv1pjZVg4ATU1aafxmFyTfrmtf56tPuJ8yqcNNZC8XadYPAw7PTW9g8GJgLtC8GURJ9GQZD6FYIE6YCou8fYo6yd4b99y2y_vsl06cm9xQnstfp6eyMkcgQyrmdmlbyeuXwvcxsxtGX61MTJtCp4VELmDctJiYP_bD7HNRPV7uqXDMNmWSY0TYL-tg0As4y8-w3wSwmtcfWhnQEraFT0-m4hBpEWHlouuFNXRQIrXbamKxeh6kJNO0wJN8fZ4Ovygf8sE4kEwBPfWO59wxDF7camTpDUqg29Q"}, {"kty": "EC", "use": "sig", "kid": "aWhtalRSTDZmNmRTd1ZDNWZmY3ZGMTNqM1dnLVA2RjQyMi1CNGdOSUNKVQ", "crv": "P-256", "x": "Ww5XVT3CxYN88BpJDZGodRiar0qr8UvPFaRoqzyD1Io", "y": "w23EDFAvwe03NjL5NKtUXwxuVMFmEn3ecJOPbljiDkg"}]} \ No newline at end of file +{"keys": [{"kty": "RSA", "use": "sig", "kid": "YnNESFhyQjloMnYzV2VqRGR2a3VCblFLX2h4VGl3TDVlY3FUNkViUE90bw", "n": "2iMaDALTQolz4UaT--GhjriLMyNbrDGlIXxSmgRh17Cm3cuHiyPOIQv1pjZVg4ATU1aafxmFyTfrmtf56tPuJ8yqcNNZC8XadYPAw7PTW9g8GJgLtC8GURJ9GQZD6FYIE6YCou8fYo6yd4b99y2y_vsl06cm9xQnstfp6eyMkcgQyrmdmlbyeuXwvcxsxtGX61MTJtCp4VELmDctJiYP_bD7HNRPV7uqXDMNmWSY0TYL-tg0As4y8-w3wSwmtcfWhnQEraFT0-m4hBpEWHlouuFNXRQIrXbamKxeh6kJNO0wJN8fZ4Ovygf8sE4kEwBPfWO59wxDF7camTpDUqg29Q", "e": "AQAB"}, {"kty": "EC", "use": "sig", "kid": "aWhtalRSTDZmNmRTd1ZDNWZmY3ZGMTNqM1dnLVA2RjQyMi1CNGdOSUNKVQ", "crv": "P-256", "x": "Ww5XVT3CxYN88BpJDZGodRiar0qr8UvPFaRoqzyD1Io", "y": "w23EDFAvwe03NjL5NKtUXwxuVMFmEn3ecJOPbljiDkg"}]} \ No newline at end of file diff --git a/tests/test_server_16_endpoint_context.py b/tests/test_server_16_endpoint_context.py index 4a18da16..1406c3be 100644 --- a/tests/test_server_16_endpoint_context.py +++ b/tests/test_server_16_endpoint_context.py @@ -4,20 +4,16 @@ import pytest from cryptojwt.key_jar import build_keyjar +from idpyoidc import work_environment from idpyoidc.server import OPConfiguration from idpyoidc.server import Server -from idpyoidc.server import do_endpoints from idpyoidc.server.endpoint import Endpoint -from idpyoidc.server.endpoint_context import EndpointContext -from idpyoidc.server.endpoint_context import get_provider_capabilities from idpyoidc.server.exception import OidcEndpointError -from idpyoidc.server.session.manager import create_session_manager from idpyoidc.server.user_authn.authn_context import INTERNETPROTOCOLPASSWORD from idpyoidc.server.util import allow_refresh_token - from . import CRYPT_CONFIG -from . import SESSION_PARAMS from . import full_path +from . import SESSION_PARAMS KEYDEFS = [ {"type": "RSA", "key": "", "use": ["sig"]}, @@ -29,9 +25,9 @@ class Endpoint_1(Endpoint): name = "userinfo" - default_capabilities = { + _supports = { "claim_types_supported": ["normal", "aggregated", "distributed"], - "userinfo_signing_alg_values_supported": None, + "userinfo_signing_alg_values_supported": work_environment.get_signing_algs, "userinfo_encryption_alg_values_supported": None, "userinfo_encryption_enc_values_supported": None, "client_authn_method": ["bearer_header", "bearer_body"], @@ -42,27 +38,24 @@ class Endpoint_1(Endpoint): "issuer": "https://example.com/", "template_dir": "template", "keys": {"uri_path": "static/jwks.json", "key_defs": KEYDEFS, "read_only": True}, - "capabilities": { - "subject_types_supported": ["public", "pairwise"], - "grant_types_supported": [ - "authorization_code", - "implicit", - "urn:ietf:params:oauth:grant-type:jwt-bearer", - "refresh_token", - ], - }, + "client_authn_method": [ + "private_key_jwt", + "client_secret_jwt", + "client_secret_post", + "client_secret_basic", + ], + "subject_types_supported": ["public", "pairwise"], + "grant_types_supported": [ + "authorization_code", + "implicit", + "urn:ietf:params:oauth:grant-type:jwt-bearer", + "refresh_token", + ], "endpoint": { "userinfo": { "path": "userinfo", "class": Endpoint_1, - "kwargs": { - "client_authn_method": [ - "private_key_jwt", - "client_secret_jwt", - "client_secret_post", - "client_secret_basic", - ] - }, + "kwargs": {} } }, "token_handler_args": { @@ -99,64 +92,48 @@ class Endpoint_1(Endpoint): class TestEndpointContext: + @pytest.fixture(autouse=True) def create_endpoint_context(self): - self.endpoint_context = EndpointContext( - conf=conf, - server_get=self.server_get, - keyjar=KEYJAR, - ) - - def server_get(self, *args): - if args[0] == "endpoint_context": - return self.endpoint_context + server = Server(conf) + self.endpoint_context = server.endpoint_context def test(self): - endpoint = do_endpoints(conf, self.server_get) - _cap = get_provider_capabilities(conf, endpoint) - pi = self.endpoint_context.create_providerinfo(_cap) - assert set(pi.keys()) == { - "claims_supported", - "issuer", - "version", - "scopes_supported", - "subject_types_supported", - "grant_types_supported", - } + assert set(self.endpoint_context.provider_info.keys()) == { + 'grant_types_supported', + 'id_token_encryption_alg_values_supported', + 'id_token_encryption_enc_values_supported', + 'id_token_signing_alg_values_supported', + 'issuer', + 'jwks_uri', + 'scopes_supported', + 'userinfo_signing_alg_values_supported'} def test_allow_refresh_token(self): - self.endpoint_context.session_manager = create_session_manager( - self.server_get, - self.endpoint_context.th_args, - sub_func=self.endpoint_context._sub_func, - conf=conf, - ) - assert allow_refresh_token(self.endpoint_context) # Have the software but is not expected to use it. - self.endpoint_context.conf["capabilities"]["grant_types_supported"] = [ + self.endpoint_context.set_preference("grant_types_supported", [ "authorization_code", "implicit", "urn:ietf:params:oauth:grant-type:jwt-bearer", - ] + ]) assert allow_refresh_token(self.endpoint_context) is False # Don't have the software but are expected to use it. - self.endpoint_context.conf["capabilities"]["grant_types_supported"] = [ + self.endpoint_context.set_preference("grant_types_supported", [ "authorization_code", "implicit", "urn:ietf:params:oauth:grant-type:jwt-bearer", "refresh_token", - ] + ]) del self.endpoint_context.session_manager.token_handler.handler["refresh_token"] with pytest.raises(OidcEndpointError): assert allow_refresh_token(self.endpoint_context) is False class Tokenish(Endpoint): - default_capabilities = None - provider_info_attributes = { + _supports = { "token_endpoint_auth_methods_supported": [ "client_secret_post", "client_secret_basic", @@ -165,26 +142,24 @@ class Tokenish(Endpoint): ], "token_endpoint_auth_signing_alg_values_supported": None, } - auth_method_attribute = "token_endpoint_auth_methods_supported" BASEDIR = os.path.abspath(os.path.dirname(__file__)) +# Note no endpoints !! CONF = { "issuer": "https://example.com/", "httpc_params": {"verify": False, "timeout": 1}, "token_expires_in": 600, "grant_expires_in": 300, "refresh_token_expires_in": 86400, - "capabilities": { - "subject_types_supported": ["public", "pairwise"], - "grant_types_supported": [ - "authorization_code", - "implicit", - "urn:ietf:params:oauth:grant-type:jwt-bearer", - "refresh_token", - ], - }, + "subject_types_supported": ["public", "pairwise"], + "grant_types_supported": [ + "authorization_code", + "implicit", + "urn:ietf:params:oauth:grant-type:jwt-bearer", + "refresh_token", + ], "keys": { "public_path": "jwks.json", "key_defs": KEYDEFS, @@ -212,38 +187,35 @@ class Tokenish(Endpoint): ) def test_provider_configuration(kwargs): conf = copy.deepcopy(CONF) + conf.update(kwargs) conf["endpoint"] = { - "endpoint": {"path": "endpoint", "class": Tokenish, "kwargs": kwargs}, + "endpoint": {"path": "endpoint", "class": Tokenish, "kwargs": {}}, } server = Server(OPConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) server.endpoint_context.cdb["client_id"] = {} - _endpoints = do_endpoints(conf, server.server_get) - - _cap = get_provider_capabilities(conf, _endpoints) - pi = server.endpoint_context.create_providerinfo(_cap) - assert set(pi.keys()) == { - "version", - "acr_values_supported", - "issuer", - "jwks_uri", - "scopes_supported", - "grant_types_supported", - "claims_supported", - "subject_types_supported", - "token_endpoint_auth_methods_supported", - "token_endpoint_auth_signing_alg_values_supported", - } + pi = server.endpoint_context.provider_info + assert set(pi.keys()) == {'grant_types_supported', + 'id_token_encryption_alg_values_supported', + 'id_token_encryption_enc_values_supported', + 'id_token_signing_alg_values_supported', + 'issuer', + 'jwks_uri', + 'scopes_supported', + 'token_endpoint_auth_methods_supported'} if kwargs: - assert pi["token_endpoint_auth_methods_supported"] == [ - "client_secret_jwt", - "private_key_jwt", - ] + if 'token_endpoint_auth_methods_supported' in kwargs: + assert pi["token_endpoint_auth_methods_supported"] == ['client_secret_jwt', + 'private_key_jwt'] + else: + assert pi["token_endpoint_auth_methods_supported"] == ['client_secret_post', + 'client_secret_basic', + 'client_secret_jwt', + 'private_key_jwt'] + else: - assert pi["token_endpoint_auth_methods_supported"] == [ - "client_secret_post", - "client_secret_basic", - "client_secret_jwt", - "private_key_jwt", - ] + assert pi["token_endpoint_auth_methods_supported"] == ['client_secret_post', + 'client_secret_basic', + 'client_secret_jwt', + 'private_key_jwt'] diff --git a/tests/test_server_20a_server.py b/tests/test_server_20a_server.py index 6d0f78ca..fbcf6008 100755 --- a/tests/test_server_20a_server.py +++ b/tests/test_server_20a_server.py @@ -127,19 +127,20 @@ def test_capabilities_default(): "code id_token token", } assert server.endpoint_context.provider_info["request_uri_parameter_supported"] is True - assert server.endpoint_context.jwks_uri == "https://127.0.0.1:443/static/jwks.json" + assert server.endpoint_context.get_preference('jwks_uri') == \ + "https://127.0.0.1:443/static/jwks.json" def test_capabilities_subset1(): _cnf = deepcopy(CONF) - _cnf["capabilities"] = {"response_types_supported": ["code"]} + _cnf["response_types_supported"] = ["code"] server = Server(_cnf) assert server.endpoint_context.provider_info["response_types_supported"] == ["code"] def test_capabilities_subset2(): _cnf = deepcopy(CONF) - _cnf["capabilities"] = {"response_types_supported": ["code", "id_token"]} + _cnf["response_types_supported"] = ["code", "id_token"] server = Server(_cnf) assert set(server.endpoint_context.provider_info["response_types_supported"]) == { "code", @@ -149,7 +150,7 @@ def test_capabilities_subset2(): def test_capabilities_bool(): _cnf = deepcopy(CONF) - _cnf["capabilities"] = {"request_uri_parameter_supported": False} + _cnf["request_uri_parameter_supported"] = False server = Server(_cnf) assert server.endpoint_context.provider_info["request_uri_parameter_supported"] is False diff --git a/tests/test_server_22_oidc_provider_config_endpoint.py b/tests/test_server_22_oidc_provider_config_endpoint.py index 08c862b5..3cbe05e2 100755 --- a/tests/test_server_22_oidc_provider_config_endpoint.py +++ b/tests/test_server_22_oidc_provider_config_endpoint.py @@ -51,7 +51,7 @@ } -class TestEndpoint(object): +class TestProviderConfigEndpoint(object): @pytest.fixture def conf(self): return { @@ -91,33 +91,12 @@ def test_do_response(self): assert _msg assert _msg["token_endpoint"] == "https://example.com/token" assert _msg["jwks_uri"] == "https://example.com/static/jwks.json" - assert set(_msg["claims_supported"]) == { - "gender", - "zoneinfo", - "website", - "phone_number_verified", - "middle_name", - "family_name", - "nickname", - "email", - "preferred_username", - "profile", - "name", - "phone_number", - "given_name", - "email_verified", - "sub", - "locale", - "picture", - "address", - "updated_at", - "birthdate", - } + assert "claims_supported" not in _msg # No default for this assert ("Content-type", "application/json; charset=utf-8") in msg["http_headers"] def test_scopes_supported(self, conf): scopes_supported = ["openid", "random", "profile"] - conf["capabilities"]["scopes_supported"] = scopes_supported + conf["scopes_supported"] = scopes_supported server = Server(OPConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) endpoint = server.server_get("endpoint", "provider_config") @@ -126,20 +105,3 @@ def test_scopes_supported(self, conf): assert isinstance(msg, dict) _msg = json.loads(msg["response"]) assert set(_msg["scopes_supported"]) == set(scopes_supported) - assert set(_msg["claims_supported"]) == { - "zoneinfo", - "gender", - "sub", - "middle_name", - "given_name", - "nickname", - "preferred_username", - "name", - "updated_at", - "birthdate", - "locale", - "profile", - "family_name", - "picture", - "website", - } From 17691c783889614f527a5c679033800228638fab Mon Sep 17 00:00:00 2001 From: Roland Hedberg Date: Sat, 3 Dec 2022 10:40:14 +0100 Subject: [PATCH 25/76] All tests green. --- src/idpyoidc/client/oauth2/authorization.py | 6 +- .../client/work_environment/oauth2.py | 2 +- src/idpyoidc/client/work_environment/oidc.py | 4 +- src/idpyoidc/message/oauth2/__init__.py | 62 ++++++++++++- src/idpyoidc/message/oidc/__init__.py | 2 - src/idpyoidc/server/client_authn.py | 2 +- src/idpyoidc/server/endpoint.py | 2 +- src/idpyoidc/server/endpoint_context.py | 44 +++------ src/idpyoidc/server/oauth2/authorization.py | 11 ++- src/idpyoidc/server/oauth2/token.py | 3 + src/idpyoidc/work_environment.py | 4 +- tests/pub_client.jwks | 2 +- tests/pub_iss.jwks | 2 +- tests/static/jwks.json | 2 +- tests/test_05_oauth2.py | 2 +- ...er_24_oauth2_authorization_endpoint_jar.py | 22 +++-- ...t_server_24_oidc_authorization_endpoint.py | 19 +++- .../test_server_26_oidc_userinfo_endpoint.py | 91 ++++++++++++------- tests/test_server_50_persistence.py | 30 +----- ...> test_tandem_10_oauth2_token_exchange.py} | 35 +++---- 20 files changed, 202 insertions(+), 145 deletions(-) rename tests/{test_tandem_10_token_exchange.py => test_tandem_10_oauth2_token_exchange.py} (97%) diff --git a/src/idpyoidc/client/oauth2/authorization.py b/src/idpyoidc/client/oauth2/authorization.py index 221a68d4..b557339e 100644 --- a/src/idpyoidc/client/oauth2/authorization.py +++ b/src/idpyoidc/client/oauth2/authorization.py @@ -3,6 +3,7 @@ from typing import List from typing import Optional +from idpyoidc import work_environment from idpyoidc.client.oauth2.utils import get_state_parameter from idpyoidc.client.oauth2.utils import pre_construct_pick_redirect_uri from idpyoidc.client.oauth2.utils import set_state_parameter @@ -30,7 +31,10 @@ class Authorization(Service): _supports = { "response_types_supported": ["code", 'token'], - "response_modes_supported": ['query', 'fragment'] + "response_modes_supported": ['query', 'fragment'], + "request_object_signing_alg_values_supported": work_environment.get_signing_algs, + "request_object_encryption_alg_values_supported": work_environment.get_encryption_algs, + "request_object_encryption_enc_values_supported": work_environment.get_encryption_encs, } _callback_path = { diff --git a/src/idpyoidc/client/work_environment/oauth2.py b/src/idpyoidc/client/work_environment/oauth2.py index 69b5a20c..71fedde9 100644 --- a/src/idpyoidc/client/work_environment/oauth2.py +++ b/src/idpyoidc/client/work_environment/oauth2.py @@ -1,6 +1,6 @@ from typing import Optional -from idpyoidc import work_environment +from idpyoidc.client import work_environment class WorkEnvironment(work_environment.WorkEnvironment): diff --git a/src/idpyoidc/client/work_environment/oidc.py b/src/idpyoidc/client/work_environment/oidc.py index 3843523f..f7996148 100644 --- a/src/idpyoidc/client/work_environment/oidc.py +++ b/src/idpyoidc/client/work_environment/oidc.py @@ -2,9 +2,9 @@ from typing import Optional from idpyoidc import work_environment +from idpyoidc.client import work_environment as client_work_environment - -class WorkEnvironment(work_environment.WorkEnvironment): +class WorkEnvironment(client_work_environment.WorkEnvironment): parameter = work_environment.WorkEnvironment.parameter.copy() parameter.update({ "requests_dir": None diff --git a/src/idpyoidc/message/oauth2/__init__.py b/src/idpyoidc/message/oauth2/__init__.py index ea2e5702..b6666f0e 100644 --- a/src/idpyoidc/message/oauth2/__init__.py +++ b/src/idpyoidc/message/oauth2/__init__.py @@ -6,6 +6,7 @@ from idpyoidc import verified_claim_name from idpyoidc.exception import MissingAttribute from idpyoidc.exception import VerificationError +from idpyoidc.message import Message from idpyoidc.message import OPTIONAL_LIST_OF_SP_SEP_STRINGS from idpyoidc.message import OPTIONAL_LIST_OF_STRINGS from idpyoidc.message import REQUIRED_LIST_OF_SP_SEP_STRINGS @@ -16,7 +17,6 @@ from idpyoidc.message import SINGLE_REQUIRED_BOOLEAN from idpyoidc.message import SINGLE_REQUIRED_INT from idpyoidc.message import SINGLE_REQUIRED_STRING -from idpyoidc.message import Message logger = logging.getLogger(__name__) @@ -108,6 +108,19 @@ class AccessTokenRequest(Message): c_default = {"grant_type": "authorization_code"} +CLAIMS_WITH_VERIFIED = ["request"] + + +def clear_verified_claims(msg): + for claim in CLAIMS_WITH_VERIFIED: + _vc_name = verified_claim_name(claim) + try: + del msg[_vc_name] + except KeyError: + pass + return msg + + class AuthorizationRequest(Message): """ An authorization request @@ -119,6 +132,7 @@ class AuthorizationRequest(Message): "scope": OPTIONAL_LIST_OF_SP_SEP_STRINGS, "redirect_uri": SINGLE_OPTIONAL_STRING, "state": SINGLE_OPTIONAL_STRING, + "request": SINGLE_OPTIONAL_STRING, } def merge(self, request_object, treatement="strict", whitelist=None): @@ -149,6 +163,52 @@ def merge(self, request_object, treatement="strict", whitelist=None): self.update(request_object) + def verify(self, **kwargs): + """Authorization Request parameters that are OPTIONAL in the OAuth 2.0 + specification MAY be included in the OpenID Request Object without also + passing them as OAuth 2.0 Authorization Request parameters, with one + exception: The scope parameter MUST always be present in OAuth 2.0 + Authorization Request parameters. + All parameter values that are present both in the OAuth 2.0 + Authorization Request and in the OpenID Request Object MUST exactly + match.""" + super(AuthorizationRequest, self).verify(**kwargs) + + clear_verified_claims(self) + + args = {} + for arg in ["keyjar", "opponent_id", "sender", "alg", "encalg", "encenc"]: + try: + args[arg] = kwargs[arg] + except KeyError: + pass + + if "opponent_id" not in kwargs: + args["opponent_id"] = self["client_id"] + + if "request" in self: + if isinstance(self["request"], str): + # Try to decode the JWT, checks the signature + oidr = AuthorizationRequest().from_jwt(str(self["request"]), **args) + + # check if something is change in the original message + for key, val in oidr.items(): + if key in self: + if self[key] != val: + # log but otherwise ignore + logger.warning("{} != {}".format(self[key], val)) + + # remove all claims + _keys = list(self.keys()) + for key in _keys: + if key not in oidr: + del self[key] + + self.update(oidr) + + # replace the JWT with the parsed and verified instance + self[verified_claim_name("request")] = oidr + class AuthorizationResponse(ResponseMessage): """ diff --git a/src/idpyoidc/message/oidc/__init__.py b/src/idpyoidc/message/oidc/__init__.py index cf53198e..9f8fd295 100644 --- a/src/idpyoidc/message/oidc/__init__.py +++ b/src/idpyoidc/message/oidc/__init__.py @@ -451,8 +451,6 @@ def verify(self, **kwargs): match.""" super(AuthorizationRequest, self).verify(**kwargs) - clear_verified_claims(self) - args = {} for arg in ["keyjar", "opponent_id", "sender", "alg", "encalg", "encenc"]: try: diff --git a/src/idpyoidc/server/client_authn.py b/src/idpyoidc/server/client_authn.py index 98233f5d..c36168a4 100755 --- a/src/idpyoidc/server/client_authn.py +++ b/src/idpyoidc/server/client_authn.py @@ -490,7 +490,7 @@ def verify_client( client_id = None allowed_methods = getattr(endpoint, "client_authn_method") if not allowed_methods: - allowed_methods = list(methods.keys()) + allowed_methods = list(methods.keys()) # If not specific for this endpoint then all _method = None for _method in (methods[meth] for meth in allowed_methods): diff --git a/src/idpyoidc/server/endpoint.py b/src/idpyoidc/server/endpoint.py index 083940bb..b84db51d 100755 --- a/src/idpyoidc/server/endpoint.py +++ b/src/idpyoidc/server/endpoint.py @@ -178,7 +178,7 @@ def parse_request( verify=_context.httpc_params["verify"], **kwargs ) - elif self.request_format == "url": + elif self.request_format == "url": # A whole URL not just the query part parts = urlparse(request) scheme, netloc, path, params, query, fragment = parts[:6] req = _cls_inst.deserialize(query, "urlencoded") diff --git a/src/idpyoidc/server/endpoint_context.py b/src/idpyoidc/server/endpoint_context.py index e6fb11d1..4ccf2f0c 100755 --- a/src/idpyoidc/server/endpoint_context.py +++ b/src/idpyoidc/server/endpoint_context.py @@ -323,36 +323,6 @@ def do_sub_func(self) -> None: else: self._sub_func[key] = args["function"] - def create_providerinfo(self, capabilities): - """ - Dynamically create the provider info response - - :param capabilities: - :return: - """ - - _provider_info = capabilities - _provider_info["issuer"] = self.issuer - _provider_info["version"] = "3.0" - - # acr_values - if self.authn_broker: - acr_values = self.authn_broker.get_acr_values() - if acr_values is not None: - _provider_info["acr_values_supported"] = acr_values - - if self.jwks_uri and self.keyjar: - _provider_info["jwks_uri"] = self.jwks_uri - - if "scopes_supported" not in _provider_info: - _provider_info["scopes_supported"] = self.scopes_handler.get_allowed_scopes() - if "claims_supported" not in _provider_info: - _provider_info["claims_supported"] = list( - self.scopes_handler.scopes_to_claims(_provider_info["scopes_supported"]).keys() - ) - - return _provider_info - def set_remember_token(self): ses_par = self.conf.get("session_params") or {} @@ -395,7 +365,12 @@ def supports(self): def set_provider_info(self): prefers = self.work_environment.prefer supported = self.supports() - _info = {'issuer': self.issuer} + _info = {'issuer': self.issuer, 'version': "3.0"} + + for endp in self.server_get('endpoints').values(): + if endp.endpoint_name: + _info[endp.endpoint_name] = endp.full_path + for key, spec in ProviderConfigurationResponse.c_param.items(): _val = prefers.get(key, None) if not _val and _val != False: @@ -404,6 +379,13 @@ def set_provider_info(self): continue _info[key] = _val + # acr_values + if 'acr_values_supported' not in _info: + if self.authn_broker: + acr_values = self.authn_broker.get_acr_values() + if acr_values is not None: + _info["acr_values_supported"] = acr_values + self.provider_info = _info def get_preference(self, claim, default=None): diff --git a/src/idpyoidc/server/oauth2/authorization.py b/src/idpyoidc/server/oauth2/authorization.py index 3b2cf50c..d7c6ddb6 100755 --- a/src/idpyoidc/server/oauth2/authorization.py +++ b/src/idpyoidc/server/oauth2/authorization.py @@ -14,6 +14,7 @@ from cryptojwt.utils import as_bytes from cryptojwt.utils import b64e +from idpyoidc import work_environment from idpyoidc.exception import ImproperlyConfigured from idpyoidc.exception import ParameterError from idpyoidc.exception import URIError @@ -42,6 +43,7 @@ from idpyoidc.util import split_uri from idpyoidc.util import importer + logger = logging.getLogger(__name__) # For the time being. This is JAR specific and should probably be configurable. @@ -335,15 +337,16 @@ class Authorization(Endpoint): response_placement = "url" endpoint_name = "authorization_endpoint" name = "authorization" - metadata_claims = { + + _supports = { "claims_parameter_supported": True, "request_parameter_supported": True, "request_uri_parameter_supported": True, "response_types_supported": ["code", "token", "code token"], "response_modes_supported": ["query", "fragment", "form_post"], - "request_object_signing_alg_values_supported": None, - "request_object_encryption_alg_values_supported": None, - "request_object_encryption_enc_values_supported": None, + "request_object_signing_alg_values_supported": work_environment.get_signing_algs, + "request_object_encryption_alg_values_supported": work_environment.get_encryption_algs, + "request_object_encryption_enc_values_supported": work_environment.get_encryption_encs, "grant_types_supported": ["authorization_code", "implicit"], "scopes_supported": [], } diff --git a/src/idpyoidc/server/oauth2/token.py b/src/idpyoidc/server/oauth2/token.py index 431ae7ad..e0a77196 100755 --- a/src/idpyoidc/server/oauth2/token.py +++ b/src/idpyoidc/server/oauth2/token.py @@ -7,13 +7,16 @@ from idpyoidc.message import Message from idpyoidc.message.oauth2 import AccessTokenResponse from idpyoidc.message.oauth2 import ResponseMessage +from idpyoidc.message.oauth2 import TokenExchangeRequest from idpyoidc.message.oidc import TokenErrorResponse +from idpyoidc.server.constant import DEFAULT_REQUESTED_TOKEN_TYPE from idpyoidc.server.endpoint import Endpoint from idpyoidc.server.exception import ProcessError from idpyoidc.server.oauth2.token_helper import AccessTokenHelper from idpyoidc.server.oauth2.token_helper import RefreshTokenHelper from idpyoidc.server.oauth2.token_helper import TokenExchangeHelper from idpyoidc.server.session import MintingNotAllowed +from idpyoidc.server.session.token import TOKEN_TYPES_MAPPING from idpyoidc.util import importer logger = logging.getLogger(__name__) diff --git a/src/idpyoidc/work_environment.py b/src/idpyoidc/work_environment.py index 5b1f56fc..4e5eabad 100644 --- a/src/idpyoidc/work_environment.py +++ b/src/idpyoidc/work_environment.py @@ -135,10 +135,10 @@ def _keyjar(self, keyjar=None, conf=None, entity_id=""): return keyjar, _uri_path def get_base_url(self, configuration: dict): - raise NotImplemented() + raise NotImplementedError() def get_id(self, configuration: dict): - raise NotImplemented() + raise NotImplementedError() def add_extra_keys(self, keyjar, id): return None diff --git a/tests/pub_client.jwks b/tests/pub_client.jwks index d5ce25ed..84a27042 100644 --- a/tests/pub_client.jwks +++ b/tests/pub_client.jwks @@ -1 +1 @@ -{"keys": [{"kty": "EC", "use": "sig", "kid": "azZQQ2FEQjh3QnVZWVdrbHJkMEZSaWR6aVJ0LTBjeUFfeWRlbTRrRFZ5VQ", "crv": "P-256", "x": "2ADe18caWWGp6hpRbfa9HqQHDFNpid9xUmR56Wzm_wc", "y": "HnD_8QBanz4Y-UF8mKQFZXfqkGkXUSm34mLsdDKtSyk"}, {"kty": "RSA", "use": "sig", "kid": "SHEyYWcwNVk0LTdROTZzZ2FUWndIVXdack0xWUM5SEpwcS03dVUxWU4zRQ", "n": "rRz52ddyP9Y2ezSlRsnkt-sjXfV_Ii7vOFX-cStLE3IUlVeSJGEe_kAASLr2r3BE2unjntaxj67NP8D95h_rzG1SpCklTEn-aTe3FOwNyTzUH_oiDVeRoEcf04Y43ciRGYRB5PhI6ii-2lYuig6hyUr776Qxiu6-0zw-M_ay2MgGSy5CEj55dDSvcUyxStUObxGpPWnEvybO1vnE7iJEWGNe0L5uPe5nLidOiR-JwjxSWEx1xZYtIjxaf2Ulu-qu4hwgwBUQdx4bNZyBfljKj55skWuHqPMG3xMjnedQC6Ms5bR3rIkbBpvmgI3kJK-4CZikM6ruyLo94-Lk19aYQw", "e": "AQAB"}]} \ No newline at end of file +{"keys": [{"kty": "EC", "use": "sig", "kid": "azZQQ2FEQjh3QnVZWVdrbHJkMEZSaWR6aVJ0LTBjeUFfeWRlbTRrRFZ5VQ", "crv": "P-256", "x": "2ADe18caWWGp6hpRbfa9HqQHDFNpid9xUmR56Wzm_wc", "y": "HnD_8QBanz4Y-UF8mKQFZXfqkGkXUSm34mLsdDKtSyk"}, {"kty": "RSA", "use": "sig", "kid": "SHEyYWcwNVk0LTdROTZzZ2FUWndIVXdack0xWUM5SEpwcS03dVUxWU4zRQ", "e": "AQAB", "n": "rRz52ddyP9Y2ezSlRsnkt-sjXfV_Ii7vOFX-cStLE3IUlVeSJGEe_kAASLr2r3BE2unjntaxj67NP8D95h_rzG1SpCklTEn-aTe3FOwNyTzUH_oiDVeRoEcf04Y43ciRGYRB5PhI6ii-2lYuig6hyUr776Qxiu6-0zw-M_ay2MgGSy5CEj55dDSvcUyxStUObxGpPWnEvybO1vnE7iJEWGNe0L5uPe5nLidOiR-JwjxSWEx1xZYtIjxaf2Ulu-qu4hwgwBUQdx4bNZyBfljKj55skWuHqPMG3xMjnedQC6Ms5bR3rIkbBpvmgI3kJK-4CZikM6ruyLo94-Lk19aYQw"}]} \ No newline at end of file diff --git a/tests/pub_iss.jwks b/tests/pub_iss.jwks index 77081f40..9b062907 100644 --- a/tests/pub_iss.jwks +++ b/tests/pub_iss.jwks @@ -1 +1 @@ -{"keys": [{"kty": "EC", "use": "sig", "kid": "SmdKMlVGcG1zMnprdDdXZGpGWEczdHhlZVpGbkx1THpPdUY4d0w4bnZkSQ", "crv": "P-256", "x": "tRHJYm0fsOi0icpGEb33qiDVgt68ltMoYSWdLGhDGz4", "y": "fRpX0i6p5Jigf5I0qwW34PyStosMShwWAWS8x_w5o7E"}, {"kty": "RSA", "use": "sig", "kid": "R0FsaFdqREFaUFp1c0MwbUpsbHVSZ200blBJZWJVMTUtNGsyVlBmdHk5UQ", "n": "2ilgsKVqF92KfhwmosSVeZOaDgb3RF1mbg-pqkmLO6YpOO06LF4V4angF-GhP-ysAm2E75aSIU4tnHVThFlcxTgKFqjYKJQXyVzTVK2r-L2IbvFPaDtvoU6WteybpMlIUVk2po3cFDGObCWYKCm7CUOLlwH0uOpui66P9VSCqdKVKbJRAQBvTSbP10KWPxulfqjWGJtHO5fY7-JVWwOBkG-eHSJIT_uaoPjyvKCZjknq04bLUV9qP78KRQpRyYijBN60w2v8F79baN9CN10TIEjjWKGz0uX0M_YYQzTUoSY5l5ka9RkL3wT4o2iQ1t5nHphX6aA-gqwgCQmi-nvjaw", "e": "AQAB"}]} \ No newline at end of file +{"keys": [{"kty": "EC", "use": "sig", "kid": "SmdKMlVGcG1zMnprdDdXZGpGWEczdHhlZVpGbkx1THpPdUY4d0w4bnZkSQ", "crv": "P-256", "x": "tRHJYm0fsOi0icpGEb33qiDVgt68ltMoYSWdLGhDGz4", "y": "fRpX0i6p5Jigf5I0qwW34PyStosMShwWAWS8x_w5o7E"}, {"kty": "RSA", "use": "sig", "kid": "R0FsaFdqREFaUFp1c0MwbUpsbHVSZ200blBJZWJVMTUtNGsyVlBmdHk5UQ", "e": "AQAB", "n": "2ilgsKVqF92KfhwmosSVeZOaDgb3RF1mbg-pqkmLO6YpOO06LF4V4angF-GhP-ysAm2E75aSIU4tnHVThFlcxTgKFqjYKJQXyVzTVK2r-L2IbvFPaDtvoU6WteybpMlIUVk2po3cFDGObCWYKCm7CUOLlwH0uOpui66P9VSCqdKVKbJRAQBvTSbP10KWPxulfqjWGJtHO5fY7-JVWwOBkG-eHSJIT_uaoPjyvKCZjknq04bLUV9qP78KRQpRyYijBN60w2v8F79baN9CN10TIEjjWKGz0uX0M_YYQzTUoSY5l5ka9RkL3wT4o2iQ1t5nHphX6aA-gqwgCQmi-nvjaw"}]} \ No newline at end of file diff --git a/tests/static/jwks.json b/tests/static/jwks.json index 8322d976..161a407b 100644 --- a/tests/static/jwks.json +++ b/tests/static/jwks.json @@ -1 +1 @@ -{"keys": [{"kty": "RSA", "use": "sig", "kid": "YnNESFhyQjloMnYzV2VqRGR2a3VCblFLX2h4VGl3TDVlY3FUNkViUE90bw", "n": "2iMaDALTQolz4UaT--GhjriLMyNbrDGlIXxSmgRh17Cm3cuHiyPOIQv1pjZVg4ATU1aafxmFyTfrmtf56tPuJ8yqcNNZC8XadYPAw7PTW9g8GJgLtC8GURJ9GQZD6FYIE6YCou8fYo6yd4b99y2y_vsl06cm9xQnstfp6eyMkcgQyrmdmlbyeuXwvcxsxtGX61MTJtCp4VELmDctJiYP_bD7HNRPV7uqXDMNmWSY0TYL-tg0As4y8-w3wSwmtcfWhnQEraFT0-m4hBpEWHlouuFNXRQIrXbamKxeh6kJNO0wJN8fZ4Ovygf8sE4kEwBPfWO59wxDF7camTpDUqg29Q", "e": "AQAB"}, {"kty": "EC", "use": "sig", "kid": "aWhtalRSTDZmNmRTd1ZDNWZmY3ZGMTNqM1dnLVA2RjQyMi1CNGdOSUNKVQ", "crv": "P-256", "x": "Ww5XVT3CxYN88BpJDZGodRiar0qr8UvPFaRoqzyD1Io", "y": "w23EDFAvwe03NjL5NKtUXwxuVMFmEn3ecJOPbljiDkg"}]} \ No newline at end of file +{"keys": [{"kty": "RSA", "use": "sig", "kid": "YnNESFhyQjloMnYzV2VqRGR2a3VCblFLX2h4VGl3TDVlY3FUNkViUE90bw", "e": "AQAB", "n": "2iMaDALTQolz4UaT--GhjriLMyNbrDGlIXxSmgRh17Cm3cuHiyPOIQv1pjZVg4ATU1aafxmFyTfrmtf56tPuJ8yqcNNZC8XadYPAw7PTW9g8GJgLtC8GURJ9GQZD6FYIE6YCou8fYo6yd4b99y2y_vsl06cm9xQnstfp6eyMkcgQyrmdmlbyeuXwvcxsxtGX61MTJtCp4VELmDctJiYP_bD7HNRPV7uqXDMNmWSY0TYL-tg0As4y8-w3wSwmtcfWhnQEraFT0-m4hBpEWHlouuFNXRQIrXbamKxeh6kJNO0wJN8fZ4Ovygf8sE4kEwBPfWO59wxDF7camTpDUqg29Q"}, {"kty": "EC", "use": "sig", "kid": "aWhtalRSTDZmNmRTd1ZDNWZmY3ZGMTNqM1dnLVA2RjQyMi1CNGdOSUNKVQ", "crv": "P-256", "x": "Ww5XVT3CxYN88BpJDZGodRiar0qr8UvPFaRoqzyD1Io", "y": "w23EDFAvwe03NjL5NKtUXwxuVMFmEn3ecJOPbljiDkg"}]} \ No newline at end of file diff --git a/tests/test_05_oauth2.py b/tests/test_05_oauth2.py index 8ca0d3bf..c0cb90cb 100644 --- a/tests/test_05_oauth2.py +++ b/tests/test_05_oauth2.py @@ -268,7 +268,7 @@ def test_verify(self): "&response_type=code&client_id=0123456789" ) ar = AuthorizationRequest().deserialize(query, "urlencoded") - assert ar.verify() + ar.verify() def test_load_dict(self): bib = { diff --git a/tests/test_server_24_oauth2_authorization_endpoint_jar.py b/tests/test_server_24_oauth2_authorization_endpoint_jar.py index 6a719758..5d1f6e30 100755 --- a/tests/test_server_24_oauth2_authorization_endpoint_jar.py +++ b/tests/test_server_24_oauth2_authorization_endpoint_jar.py @@ -139,20 +139,24 @@ def create_endpoint(self): "issuer": "https://example.com/", "password": "mycket hemligt zebra", "verify_ssl": False, - "capabilities": CAPABILITIES, + "grant_types_supported": [ + "authorization_code", + "implicit", + "urn:ietf:params:oauth:grant-type:jwt-bearer", + "refresh_token", + ], "keys": {"uri_path": "static/jwks.json", "key_defs": KEYDEFS}, + "response_types_supported": [" ".join(x) for x in RESPONSE_TYPES_SUPPORTED], + "response_modes_supported": ["query", "fragment", "form_post"], + "claims_parameter_supported": True, + "request_parameter_supported": True, + "request_uri_parameter_supported": True, + "request_cls": JWTSecuredAuthorizationRequest, "endpoint": { "authorization": { "path": "{}/authorization", "class": Authorization, - "kwargs": { - "response_types_supported": [" ".join(x) for x in RESPONSE_TYPES_SUPPORTED], - "response_modes_supported": ["query", "fragment", "form_post"], - "claims_parameter_supported": True, - "request_parameter_supported": True, - "request_uri_parameter_supported": True, - "request_cls": JWTSecuredAuthorizationRequest, - }, + "kwargs": {}, } }, "authentication": { diff --git a/tests/test_server_24_oidc_authorization_endpoint.py b/tests/test_server_24_oidc_authorization_endpoint.py index 019349b0..6a2d7912 100755 --- a/tests/test_server_24_oidc_authorization_endpoint.py +++ b/tests/test_server_24_oidc_authorization_endpoint.py @@ -4,14 +4,14 @@ from urllib.parse import parse_qs from urllib.parse import urlparse -import pytest -import responses -import yaml from cryptojwt import JWT from cryptojwt import KeyJar from cryptojwt.jws.jws import factory from cryptojwt.utils import as_bytes from cryptojwt.utils import b64e +import pytest +import responses +import yaml from idpyoidc.exception import ParameterError from idpyoidc.exception import URIError @@ -863,7 +863,15 @@ def test_parse_request(self): "scope": AUTH_REQ.get("scope"), } ) - assert "__verified_request" in _req + assert set(_req.keys()) == {'__verified_request', + 'aud', + 'client_id', + 'iat', + 'iss', + 'redirect_uri', + 'response_type', + 'scope', + 'state'} def test_parse_request_uri(self): _jwt = JWT(key_jar=self.rp_keyjar, iss="client_1", sign_alg="HS256") @@ -1464,7 +1472,8 @@ def test_authenticated_as_with_goobledigook(self): ) kakor = [{ - 'value': '{"sub": "adam", "sid": "Z0FBQUFBQmlhVl", "state": "state_identifier", "client_id": "client 12345"}', + 'value': '{"sub": "adam", "sid": "Z0FBQUFBQmlhVl", "state": "state_identifier", ' + '"client_id": "client 12345"}', 'type': '', 'timestamp': '1651070251'}] diff --git a/tests/test_server_26_oidc_userinfo_endpoint.py b/tests/test_server_26_oidc_userinfo_endpoint.py index 6bfa071e..16b50117 100755 --- a/tests/test_server_26_oidc_userinfo_endpoint.py +++ b/tests/test_server_26_oidc_userinfo_endpoint.py @@ -42,13 +42,6 @@ ] CAPABILITIES = { - "subject_types_supported": ["public", "pairwise", "ephemeral"], - "grant_types_supported": [ - "authorization_code", - "implicit", - "urn:ietf:params:oauth:grant-type:jwt-bearer", - "refresh_token", - ], } AUTH_REQ = AuthorizationRequest( @@ -85,7 +78,40 @@ def create_endpoint(self): conf = { "issuer": "https://example.com/", "httpc_params": {"verify": False, "timeout": 1}, - "capabilities": CAPABILITIES, + "subject_types_supported": ["public", "pairwise", "ephemeral"], + 'claims_supported': [ + "address", + "birthdate", + "email", + "email_verified", + "eduperson_scoped_affiliation", + "family_name", + "gender", + "given_name", + "locale", + "middle_name", + "name", + "nickname", + "phone_number", + "phone_number_verified", + "picture", + "preferred_username", + "profile", + "sub", + "updated_at", + "website", + "zoneinfo"], + "grant_types_supported": [ + "authorization_code", + "implicit", + "urn:ietf:params:oauth:grant-type:jwt-bearer", + "refresh_token", + ], + "claim_types_supported": [ + "normal", + "aggregated", + "distributed", + ], "cookie_handler": { "class": CookieHandler, "kwargs": { @@ -130,11 +156,6 @@ def create_endpoint(self): "path": "userinfo", "class": userinfo.UserInfo, "kwargs": { - "claim_types_supported": [ - "normal", - "aggregated", - "distributed", - ], "client_authn_method": ["bearer_header", "bearer_body"], }, }, @@ -236,28 +257,28 @@ def test_init(self): assert set( self.endpoint.server_get("endpoint_context").provider_info["claims_supported"] ) == { - "address", - "birthdate", - "email", - "email_verified", - "eduperson_scoped_affiliation", - "family_name", - "gender", - "given_name", - "locale", - "middle_name", - "name", - "nickname", - "phone_number", - "phone_number_verified", - "picture", - "preferred_username", - "profile", - "sub", - "updated_at", - "website", - "zoneinfo", - } + "address", + "birthdate", + "email", + "email_verified", + "eduperson_scoped_affiliation", + "family_name", + "gender", + "given_name", + "locale", + "middle_name", + "name", + "nickname", + "phone_number", + "phone_number_verified", + "picture", + "preferred_username", + "profile", + "sub", + "updated_at", + "website", + "zoneinfo", + } def test_parse(self): session_id = self._create_session(AUTH_REQ) diff --git a/tests/test_server_50_persistence.py b/tests/test_server_50_persistence.py index 22a5bb51..52570e68 100644 --- a/tests/test_server_50_persistence.py +++ b/tests/test_server_50_persistence.py @@ -291,33 +291,11 @@ def _dump_restore(self, fro, to): def test_init(self): assert self.endpoint[1] assert set( - self.endpoint[1].server_get("endpoint_context").provider_info["claims_supported"] - ) == { - "address", - "birthdate", - "email", - "email_verified", - "eduperson_scoped_affiliation", - "family_name", - "gender", - "given_name", - "locale", - "middle_name", - "name", - "nickname", - "phone_number", - "phone_number_verified", - "picture", - "preferred_username", - "profile", - "sub", - "updated_at", - "website", - "zoneinfo", - } + self.endpoint[1].server_get("endpoint_context").provider_info["scopes_supported"] + ) == {"openid"} assert set( - self.endpoint[1].server_get("endpoint_context").provider_info["claims_supported"] - ) == set(self.endpoint[2].server_get("endpoint_context").provider_info["claims_supported"]) + self.endpoint[1].server_get("endpoint_context").provider_info["scopes_supported"] + ) == set(self.endpoint[2].server_get("endpoint_context").provider_info["scopes_supported"]) def test_parse(self): session_id = self._create_session(AUTH_REQ, index=1) diff --git a/tests/test_tandem_10_token_exchange.py b/tests/test_tandem_10_oauth2_token_exchange.py similarity index 97% rename from tests/test_tandem_10_token_exchange.py rename to tests/test_tandem_10_oauth2_token_exchange.py index d3e50ea8..b86b6da0 100644 --- a/tests/test_tandem_10_token_exchange.py +++ b/tests/test_tandem_10_oauth2_token_exchange.py @@ -43,16 +43,6 @@ ["none"], ] -CAPABILITIES = { - "subject_types_supported": ["public", "pairwise", "ephemeral"], - "grant_types_supported": [ - "authorization_code", - "implicit", - "urn:ietf:params:oauth:grant-type:jwt-bearer", - "refresh_token", - ], -} - AUTH_REQ = AuthorizationRequest( client_id="client_1", redirect_uri="https://example.com/cb", @@ -103,7 +93,19 @@ def create_endpoint(self): server_conf = { "issuer": "https://example.com/", "httpc_params": {"verify": False, "timeout": 1}, - "capabilities": CAPABILITIES, + "subject_types_supported": ["public", "pairwise", "ephemeral"], + "grant_types_supported": [ + "authorization_code", + "implicit", + "urn:ietf:params:oauth:grant-type:jwt-bearer", + "refresh_token", + ], + "client_authn_method": [ + "client_secret_basic", + "client_secret_post", + "client_secret_jwt", + "private_key_jwt", + ], "cookie_handler": { "class": CookieHandler, "kwargs": {"keys": {"key_defs": COOKIE_KEYDEFS}}, @@ -123,14 +125,7 @@ def create_endpoint(self): "token": { "path": "token", "class": "idpyoidc.server.oidc.token.Token", - "kwargs": { - "client_authn_method": [ - "client_secret_basic", - "client_secret_post", - "client_secret_jwt", - "private_key_jwt", - ], - }, + "kwargs": {}, }, }, "authentication": { @@ -298,7 +293,7 @@ def process_setup(self, token=None, scope=None): "redirect_uri": areq["redirect_uri"], "grant_type": "authorization_code", "client_id": self.client_1.get_client_id(), - "client_secret": _context.get("client_secret"), + "client_secret": _context.get_usage("client_secret"), } _token_request, resp = self.do_query("accesstoken", 'token', req_args, _state) From abc35c59d407e56ddec48be45f9901142246c8c4 Mon Sep 17 00:00:00 2001 From: roland Date: Fri, 11 Nov 2022 09:09:39 +0100 Subject: [PATCH 26/76] Spring(?)/Autumn cleaning. --- src/idpyoidc/client/specification/__init__.py | 0 src/idpyoidc/client/specification/oidc.py | 0 tests/static/jwks.json | 2 +- 3 files changed, 1 insertion(+), 1 deletion(-) create mode 100644 src/idpyoidc/client/specification/__init__.py create mode 100644 src/idpyoidc/client/specification/oidc.py diff --git a/src/idpyoidc/client/specification/__init__.py b/src/idpyoidc/client/specification/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/idpyoidc/client/specification/oidc.py b/src/idpyoidc/client/specification/oidc.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/static/jwks.json b/tests/static/jwks.json index 161a407b..8322d976 100644 --- a/tests/static/jwks.json +++ b/tests/static/jwks.json @@ -1 +1 @@ -{"keys": [{"kty": "RSA", "use": "sig", "kid": "YnNESFhyQjloMnYzV2VqRGR2a3VCblFLX2h4VGl3TDVlY3FUNkViUE90bw", "e": "AQAB", "n": "2iMaDALTQolz4UaT--GhjriLMyNbrDGlIXxSmgRh17Cm3cuHiyPOIQv1pjZVg4ATU1aafxmFyTfrmtf56tPuJ8yqcNNZC8XadYPAw7PTW9g8GJgLtC8GURJ9GQZD6FYIE6YCou8fYo6yd4b99y2y_vsl06cm9xQnstfp6eyMkcgQyrmdmlbyeuXwvcxsxtGX61MTJtCp4VELmDctJiYP_bD7HNRPV7uqXDMNmWSY0TYL-tg0As4y8-w3wSwmtcfWhnQEraFT0-m4hBpEWHlouuFNXRQIrXbamKxeh6kJNO0wJN8fZ4Ovygf8sE4kEwBPfWO59wxDF7camTpDUqg29Q"}, {"kty": "EC", "use": "sig", "kid": "aWhtalRSTDZmNmRTd1ZDNWZmY3ZGMTNqM1dnLVA2RjQyMi1CNGdOSUNKVQ", "crv": "P-256", "x": "Ww5XVT3CxYN88BpJDZGodRiar0qr8UvPFaRoqzyD1Io", "y": "w23EDFAvwe03NjL5NKtUXwxuVMFmEn3ecJOPbljiDkg"}]} \ No newline at end of file +{"keys": [{"kty": "RSA", "use": "sig", "kid": "YnNESFhyQjloMnYzV2VqRGR2a3VCblFLX2h4VGl3TDVlY3FUNkViUE90bw", "n": "2iMaDALTQolz4UaT--GhjriLMyNbrDGlIXxSmgRh17Cm3cuHiyPOIQv1pjZVg4ATU1aafxmFyTfrmtf56tPuJ8yqcNNZC8XadYPAw7PTW9g8GJgLtC8GURJ9GQZD6FYIE6YCou8fYo6yd4b99y2y_vsl06cm9xQnstfp6eyMkcgQyrmdmlbyeuXwvcxsxtGX61MTJtCp4VELmDctJiYP_bD7HNRPV7uqXDMNmWSY0TYL-tg0As4y8-w3wSwmtcfWhnQEraFT0-m4hBpEWHlouuFNXRQIrXbamKxeh6kJNO0wJN8fZ4Ovygf8sE4kEwBPfWO59wxDF7camTpDUqg29Q", "e": "AQAB"}, {"kty": "EC", "use": "sig", "kid": "aWhtalRSTDZmNmRTd1ZDNWZmY3ZGMTNqM1dnLVA2RjQyMi1CNGdOSUNKVQ", "crv": "P-256", "x": "Ww5XVT3CxYN88BpJDZGodRiar0qr8UvPFaRoqzyD1Io", "y": "w23EDFAvwe03NjL5NKtUXwxuVMFmEn3ecJOPbljiDkg"}]} \ No newline at end of file From 290fb46553b2f07f88fec1c4ac0eadf135a0c324 Mon Sep 17 00:00:00 2001 From: roland Date: Sat, 12 Nov 2022 08:28:03 +0100 Subject: [PATCH 27/76] Merged --- .../client/{specification => work_condition}/__init__.py | 0 .../client/{specification/oidc.py => work_condition/oauth2.py} | 0 src/idpyoidc/client/work_condition/oidc.py | 0 src/idpyoidc/server/oauth2/token.py | 3 --- 4 files changed, 3 deletions(-) rename src/idpyoidc/client/{specification => work_condition}/__init__.py (100%) rename src/idpyoidc/client/{specification/oidc.py => work_condition/oauth2.py} (100%) create mode 100644 src/idpyoidc/client/work_condition/oidc.py diff --git a/src/idpyoidc/client/specification/__init__.py b/src/idpyoidc/client/work_condition/__init__.py similarity index 100% rename from src/idpyoidc/client/specification/__init__.py rename to src/idpyoidc/client/work_condition/__init__.py diff --git a/src/idpyoidc/client/specification/oidc.py b/src/idpyoidc/client/work_condition/oauth2.py similarity index 100% rename from src/idpyoidc/client/specification/oidc.py rename to src/idpyoidc/client/work_condition/oauth2.py diff --git a/src/idpyoidc/client/work_condition/oidc.py b/src/idpyoidc/client/work_condition/oidc.py new file mode 100644 index 00000000..e69de29b diff --git a/src/idpyoidc/server/oauth2/token.py b/src/idpyoidc/server/oauth2/token.py index e0a77196..431ae7ad 100755 --- a/src/idpyoidc/server/oauth2/token.py +++ b/src/idpyoidc/server/oauth2/token.py @@ -7,16 +7,13 @@ from idpyoidc.message import Message from idpyoidc.message.oauth2 import AccessTokenResponse from idpyoidc.message.oauth2 import ResponseMessage -from idpyoidc.message.oauth2 import TokenExchangeRequest from idpyoidc.message.oidc import TokenErrorResponse -from idpyoidc.server.constant import DEFAULT_REQUESTED_TOKEN_TYPE from idpyoidc.server.endpoint import Endpoint from idpyoidc.server.exception import ProcessError from idpyoidc.server.oauth2.token_helper import AccessTokenHelper from idpyoidc.server.oauth2.token_helper import RefreshTokenHelper from idpyoidc.server.oauth2.token_helper import TokenExchangeHelper from idpyoidc.server.session import MintingNotAllowed -from idpyoidc.server.session.token import TOKEN_TYPES_MAPPING from idpyoidc.util import importer logger = logging.getLogger(__name__) From 04cd0f10f5a02e338a9ea268f0720763ae759bc3 Mon Sep 17 00:00:00 2001 From: Roland Hedberg Date: Mon, 14 Nov 2022 08:40:12 +0100 Subject: [PATCH 28/76] Reworking the work condition system. This is about going from what the software can do and what the admin wants it to do to what is actually used. --- tests/xtest_x_ciba_01_backchannel_auth.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 tests/xtest_x_ciba_01_backchannel_auth.py diff --git a/tests/xtest_x_ciba_01_backchannel_auth.py b/tests/xtest_x_ciba_01_backchannel_auth.py new file mode 100644 index 00000000..e69de29b From d8fb841058c5b6e57fc7e52129ea734b80276d4d Mon Sep 17 00:00:00 2001 From: roland Date: Thu, 17 Nov 2022 08:42:46 +0100 Subject: [PATCH 29/76] Refactoring and putting better names on things. --- .../client/work_condition/transform.py | 115 ++++++++++++++++++ 1 file changed, 115 insertions(+) create mode 100644 src/idpyoidc/client/work_condition/transform.py diff --git a/src/idpyoidc/client/work_condition/transform.py b/src/idpyoidc/client/work_condition/transform.py new file mode 100644 index 00000000..6d0c8220 --- /dev/null +++ b/src/idpyoidc/client/work_condition/transform.py @@ -0,0 +1,115 @@ +import logging +from typing import Optional + +from idpyoidc.message.oidc import RegistrationResponse + +logger = logging.getLogger(__name__) + +REGISTER2PREFERRED = { + # "require_signed_request_object": "request_object_algs_supported", + "request_object_signing_alg": "request_object_signing_alg_values_supported", + "request_object_encryption_alg": "request_object_encryption_alg_values_supported", + "request_object_encryption_enc": "request_object_encryption_enc_values_supported", + "userinfo_signed_response_alg": "userinfo_signing_alg_values_supported", + "userinfo_encrypted_response_alg": "userinfo_encryption_alg_values_supported", + "userinfo_encrypted_response_enc": "userinfo_encryption_enc_values_supported", + "id_token_signed_response_alg": "id_token_signing_alg_values_supported", + "id_token_encrypted_response_alg": "id_token_encryption_alg_values_supported", + "id_token_encrypted_response_enc": "id_token_encryption_enc_values_supported", + "default_acr_values": "acr_values_supported", + "subject_type": "subject_types_supported", + "token_endpoint_auth_method": "token_endpoint_auth_methods_supported", + "token_endpoint_auth_signing_alg": "token_endpoint_auth_signing_alg_values_supported", + "response_types": "response_types_supported", + "grant_types": "grant_types_supported", + "scope": "scopes_supported", + "display": "display_values_supported", + "claims": "claims_supported", + "request": "request_parameter_supported", + "request_uri": "request_uri_parameter_supported", + 'claims_locales': 'claims_locales_supported', + 'ui_locales': 'ui_locales_supported', +} + +PREFERRED2REGISTER = dict([(v, k) for k, v in REGISTER2PREFERRED.items()]) + +REQUEST2REGISTER = { + 'client_id': "client_id", + "client_secret": "client_secret", + # 'acr_values': "default_acr_values" , + # 'max_age': "default_max_age", + 'redirect_uri': "redirect_uris", + 'response_type': "response_types", + 'request_uri': "request_uris", + 'grant_type': "grant_types" +} + + +# AUTHORIZATION_REQUEST = [ +# "acr_values", +# "claims", +# "claims_locales", +# "client_id", +# "display", +# "id_token_hint", +# "login_hint", +# "max_age", +# "nonce", +# "prompt", +# "redirect_uri", +# "registration", +# "request", +# "request_uri", +# "response_mode" +# "response_type", +# "scope", +# "state", +# "ui_locales", +# ] + + +def supported_to_preferred(supported: dict, preference: dict, info: Optional[dict] = None): + for key, val in supported.items(): + if info and key in info: + preference[key] = info[key] + continue + + if val is None: + continue + + if key not in preference: + preference[key] = val + + return preference + + +def preferred_to_register(prefers: dict, use: Optional[dict] = None): + if not use: + use = {} + + for key, spec in RegistrationResponse.c_param.items(): + _pref_key = REGISTER2PREFERRED.get(key, key) + + _preferred_values = prefers.get(_pref_key) + if not _preferred_values: + continue + + if isinstance(spec[0], list): + if _preferred_values: + use[key] = _preferred_values + else: + if _preferred_values: + if isinstance(_preferred_values, list): + use[key] = _preferred_values[0] + else: + use[key] = _preferred_values + + _rr_keys = list(RegistrationResponse.c_param.keys()) + for key, val in prefers.items(): + if PREFERRED2REGISTER.get(key): + continue + if key not in _rr_keys: + use[key] = val + + logger.debug(f"Entity uses: {use}") + return use From 4b4e2aa062edc0422d08b252a7e82296bc9533bf Mon Sep 17 00:00:00 2001 From: roland Date: Fri, 18 Nov 2022 08:50:03 +0100 Subject: [PATCH 30/76] Fixed tests up to client_28 --- .../client/work_condition/transform.py | 57 +++++++++++++++---- 1 file changed, 46 insertions(+), 11 deletions(-) diff --git a/src/idpyoidc/client/work_condition/transform.py b/src/idpyoidc/client/work_condition/transform.py index 6d0c8220..ad4d68a6 100644 --- a/src/idpyoidc/client/work_condition/transform.py +++ b/src/idpyoidc/client/work_condition/transform.py @@ -68,17 +68,51 @@ # ] -def supported_to_preferred(supported: dict, preference: dict, info: Optional[dict] = None): - for key, val in supported.items(): - if info and key in info: - preference[key] = info[key] - continue - - if val is None: - continue - - if key not in preference: - preference[key] = val +def supported_to_preferred(supported: dict, + preference: dict, + base_url: str, + info: Optional[dict] = None, + ): + if info: # The provider info + for key, val in supported.items(): + if key in preference: + _pref_val = preference.get(key) # defined in configuration + _info_val = info.get(key) + if _info_val: + # Only use provider setting if less or equal to what I support + if key.endswith('supported'): # list + preference[key] = [x for x in _pref_val if x in _info_val] + else: + pass + elif val is None: # No default + # if key not in ['jwks_uri', 'jwks']: + pass + else: + # there is a default + _info_val = info.get(key) + if _info_val: # The OP has an opinion + if key.endswith('supported'): # list + preference[key] = [x for x in val if x in _info_val] + else: + pass + else: + preference[key] = val + + # special case -> must have a request_uris value + if 'require_request_uri_registration' in info: + # only makes sense if I want to use request_uri + if preference.get('request_parameter') == 'request_uri': + if 'request_uri' not in preference: + preference['request_uris'] = [f'{base_url}/requests'] + else: # just ignore + logger.info('Asked for "request_uri" which it did not plan to use') + else: + # Add defaults + for key, val in supported.items(): + if val is None: + continue + if key not in preference: + preference[key] = val return preference @@ -104,6 +138,7 @@ def preferred_to_register(prefers: dict, use: Optional[dict] = None): else: use[key] = _preferred_values + # transfer those claims that are not part of the registration request _rr_keys = list(RegistrationResponse.c_param.keys()) for key, val in prefers.items(): if PREFERRED2REGISTER.get(key): From eb0c8963732da412b375046ef16e1cfd09140bc1 Mon Sep 17 00:00:00 2001 From: Roland Hedberg Date: Fri, 18 Nov 2022 14:54:01 +0100 Subject: [PATCH 31/76] working on tests. --- tests/pub_client.jwks | 2 +- tests/pub_iss.jwks | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/pub_client.jwks b/tests/pub_client.jwks index 84a27042..d5ce25ed 100644 --- a/tests/pub_client.jwks +++ b/tests/pub_client.jwks @@ -1 +1 @@ -{"keys": [{"kty": "EC", "use": "sig", "kid": "azZQQ2FEQjh3QnVZWVdrbHJkMEZSaWR6aVJ0LTBjeUFfeWRlbTRrRFZ5VQ", "crv": "P-256", "x": "2ADe18caWWGp6hpRbfa9HqQHDFNpid9xUmR56Wzm_wc", "y": "HnD_8QBanz4Y-UF8mKQFZXfqkGkXUSm34mLsdDKtSyk"}, {"kty": "RSA", "use": "sig", "kid": "SHEyYWcwNVk0LTdROTZzZ2FUWndIVXdack0xWUM5SEpwcS03dVUxWU4zRQ", "e": "AQAB", "n": "rRz52ddyP9Y2ezSlRsnkt-sjXfV_Ii7vOFX-cStLE3IUlVeSJGEe_kAASLr2r3BE2unjntaxj67NP8D95h_rzG1SpCklTEn-aTe3FOwNyTzUH_oiDVeRoEcf04Y43ciRGYRB5PhI6ii-2lYuig6hyUr776Qxiu6-0zw-M_ay2MgGSy5CEj55dDSvcUyxStUObxGpPWnEvybO1vnE7iJEWGNe0L5uPe5nLidOiR-JwjxSWEx1xZYtIjxaf2Ulu-qu4hwgwBUQdx4bNZyBfljKj55skWuHqPMG3xMjnedQC6Ms5bR3rIkbBpvmgI3kJK-4CZikM6ruyLo94-Lk19aYQw"}]} \ No newline at end of file +{"keys": [{"kty": "EC", "use": "sig", "kid": "azZQQ2FEQjh3QnVZWVdrbHJkMEZSaWR6aVJ0LTBjeUFfeWRlbTRrRFZ5VQ", "crv": "P-256", "x": "2ADe18caWWGp6hpRbfa9HqQHDFNpid9xUmR56Wzm_wc", "y": "HnD_8QBanz4Y-UF8mKQFZXfqkGkXUSm34mLsdDKtSyk"}, {"kty": "RSA", "use": "sig", "kid": "SHEyYWcwNVk0LTdROTZzZ2FUWndIVXdack0xWUM5SEpwcS03dVUxWU4zRQ", "n": "rRz52ddyP9Y2ezSlRsnkt-sjXfV_Ii7vOFX-cStLE3IUlVeSJGEe_kAASLr2r3BE2unjntaxj67NP8D95h_rzG1SpCklTEn-aTe3FOwNyTzUH_oiDVeRoEcf04Y43ciRGYRB5PhI6ii-2lYuig6hyUr776Qxiu6-0zw-M_ay2MgGSy5CEj55dDSvcUyxStUObxGpPWnEvybO1vnE7iJEWGNe0L5uPe5nLidOiR-JwjxSWEx1xZYtIjxaf2Ulu-qu4hwgwBUQdx4bNZyBfljKj55skWuHqPMG3xMjnedQC6Ms5bR3rIkbBpvmgI3kJK-4CZikM6ruyLo94-Lk19aYQw", "e": "AQAB"}]} \ No newline at end of file diff --git a/tests/pub_iss.jwks b/tests/pub_iss.jwks index 9b062907..77081f40 100644 --- a/tests/pub_iss.jwks +++ b/tests/pub_iss.jwks @@ -1 +1 @@ -{"keys": [{"kty": "EC", "use": "sig", "kid": "SmdKMlVGcG1zMnprdDdXZGpGWEczdHhlZVpGbkx1THpPdUY4d0w4bnZkSQ", "crv": "P-256", "x": "tRHJYm0fsOi0icpGEb33qiDVgt68ltMoYSWdLGhDGz4", "y": "fRpX0i6p5Jigf5I0qwW34PyStosMShwWAWS8x_w5o7E"}, {"kty": "RSA", "use": "sig", "kid": "R0FsaFdqREFaUFp1c0MwbUpsbHVSZ200blBJZWJVMTUtNGsyVlBmdHk5UQ", "e": "AQAB", "n": "2ilgsKVqF92KfhwmosSVeZOaDgb3RF1mbg-pqkmLO6YpOO06LF4V4angF-GhP-ysAm2E75aSIU4tnHVThFlcxTgKFqjYKJQXyVzTVK2r-L2IbvFPaDtvoU6WteybpMlIUVk2po3cFDGObCWYKCm7CUOLlwH0uOpui66P9VSCqdKVKbJRAQBvTSbP10KWPxulfqjWGJtHO5fY7-JVWwOBkG-eHSJIT_uaoPjyvKCZjknq04bLUV9qP78KRQpRyYijBN60w2v8F79baN9CN10TIEjjWKGz0uX0M_YYQzTUoSY5l5ka9RkL3wT4o2iQ1t5nHphX6aA-gqwgCQmi-nvjaw"}]} \ No newline at end of file +{"keys": [{"kty": "EC", "use": "sig", "kid": "SmdKMlVGcG1zMnprdDdXZGpGWEczdHhlZVpGbkx1THpPdUY4d0w4bnZkSQ", "crv": "P-256", "x": "tRHJYm0fsOi0icpGEb33qiDVgt68ltMoYSWdLGhDGz4", "y": "fRpX0i6p5Jigf5I0qwW34PyStosMShwWAWS8x_w5o7E"}, {"kty": "RSA", "use": "sig", "kid": "R0FsaFdqREFaUFp1c0MwbUpsbHVSZ200blBJZWJVMTUtNGsyVlBmdHk5UQ", "n": "2ilgsKVqF92KfhwmosSVeZOaDgb3RF1mbg-pqkmLO6YpOO06LF4V4angF-GhP-ysAm2E75aSIU4tnHVThFlcxTgKFqjYKJQXyVzTVK2r-L2IbvFPaDtvoU6WteybpMlIUVk2po3cFDGObCWYKCm7CUOLlwH0uOpui66P9VSCqdKVKbJRAQBvTSbP10KWPxulfqjWGJtHO5fY7-JVWwOBkG-eHSJIT_uaoPjyvKCZjknq04bLUV9qP78KRQpRyYijBN60w2v8F79baN9CN10TIEjjWKGz0uX0M_YYQzTUoSY5l5ka9RkL3wT4o2iQ1t5nHphX6aA-gqwgCQmi-nvjaw", "e": "AQAB"}]} \ No newline at end of file From 235b9b3e4f3e89e1649c22ec46f23a398e1e3118 Mon Sep 17 00:00:00 2001 From: Roland Hedberg Date: Tue, 22 Nov 2022 10:30:27 +0100 Subject: [PATCH 32/76] working on tests... --- .../client/work_condition/transform.py | 90 +++++++++---------- tests/pub_client.jwks | 2 +- 2 files changed, 44 insertions(+), 48 deletions(-) diff --git a/src/idpyoidc/client/work_condition/transform.py b/src/idpyoidc/client/work_condition/transform.py index ad4d68a6..c35a5252 100644 --- a/src/idpyoidc/client/work_condition/transform.py +++ b/src/idpyoidc/client/work_condition/transform.py @@ -23,12 +23,12 @@ "response_types": "response_types_supported", "grant_types": "grant_types_supported", "scope": "scopes_supported", - "display": "display_values_supported", - "claims": "claims_supported", - "request": "request_parameter_supported", - "request_uri": "request_uri_parameter_supported", - 'claims_locales': 'claims_locales_supported', - 'ui_locales': 'ui_locales_supported', + # "display": "display_values_supported", + # "claims": "claims_supported", + # "request": "request_parameter_supported", + # "request_uri": "request_uri_parameter_supported", + # 'claims_locales': 'claims_locales_supported', + # 'ui_locales': 'ui_locales_supported', } PREFERRED2REGISTER = dict([(v, k) for k, v in REGISTER2PREFERRED.items()]) @@ -41,33 +41,11 @@ 'redirect_uri': "redirect_uris", 'response_type': "response_types", 'request_uri': "request_uris", - 'grant_type': "grant_types" + 'grant_type': "grant_types", + "scope": 'scopes_supported', } -# AUTHORIZATION_REQUEST = [ -# "acr_values", -# "claims", -# "claims_locales", -# "client_id", -# "display", -# "id_token_hint", -# "login_hint", -# "max_age", -# "nonce", -# "prompt", -# "redirect_uri", -# "registration", -# "request", -# "request_uri", -# "response_mode" -# "response_type", -# "scope", -# "state", -# "ui_locales", -# ] - - def supported_to_preferred(supported: dict, preference: dict, base_url: str, @@ -84,7 +62,7 @@ def supported_to_preferred(supported: dict, preference[key] = [x for x in _pref_val if x in _info_val] else: pass - elif val is None: # No default + elif val is None: # No default, means the RP does not have a preference # if key not in ['jwks_uri', 'jwks']: pass else: @@ -117,26 +95,40 @@ def supported_to_preferred(supported: dict, return preference -def preferred_to_register(prefers: dict, use: Optional[dict] = None): - if not use: - use = {} +def array_to_singleton(claim_spec, values): + if isinstance(claim_spec[0], list): + return values + else: + if isinstance(values, list): + return values[0] + else: # singleton + return values + + +def preferred_to_registered(prefers: dict, registration_response: Optional[dict] = None): + """ + The claims with values that are returned from the OP is what goes unless (!!) + the values returned are not within the supported values. + + @param prefers: + @param registration_response: + @return: + """ + registered = {} + + if registration_response: + for key, val in registration_response.items(): + registered[key] = val # Should I just accept with the OP says ?? for key, spec in RegistrationResponse.c_param.items(): + if key in registered: + continue _pref_key = REGISTER2PREFERRED.get(key, key) _preferred_values = prefers.get(_pref_key) if not _preferred_values: continue - - if isinstance(spec[0], list): - if _preferred_values: - use[key] = _preferred_values - else: - if _preferred_values: - if isinstance(_preferred_values, list): - use[key] = _preferred_values[0] - else: - use[key] = _preferred_values + registered[key] = array_to_singleton(spec, _preferred_values) # transfer those claims that are not part of the registration request _rr_keys = list(RegistrationResponse.c_param.keys()) @@ -144,7 +136,11 @@ def preferred_to_register(prefers: dict, use: Optional[dict] = None): if PREFERRED2REGISTER.get(key): continue if key not in _rr_keys: - use[key] = val + registered[key] = val + + logger.debug(f"Entity registered: {registered}") + return registered + - logger.debug(f"Entity uses: {use}") - return use +def register_to_request(prefers, registration_response): + pass \ No newline at end of file diff --git a/tests/pub_client.jwks b/tests/pub_client.jwks index d5ce25ed..84a27042 100644 --- a/tests/pub_client.jwks +++ b/tests/pub_client.jwks @@ -1 +1 @@ -{"keys": [{"kty": "EC", "use": "sig", "kid": "azZQQ2FEQjh3QnVZWVdrbHJkMEZSaWR6aVJ0LTBjeUFfeWRlbTRrRFZ5VQ", "crv": "P-256", "x": "2ADe18caWWGp6hpRbfa9HqQHDFNpid9xUmR56Wzm_wc", "y": "HnD_8QBanz4Y-UF8mKQFZXfqkGkXUSm34mLsdDKtSyk"}, {"kty": "RSA", "use": "sig", "kid": "SHEyYWcwNVk0LTdROTZzZ2FUWndIVXdack0xWUM5SEpwcS03dVUxWU4zRQ", "n": "rRz52ddyP9Y2ezSlRsnkt-sjXfV_Ii7vOFX-cStLE3IUlVeSJGEe_kAASLr2r3BE2unjntaxj67NP8D95h_rzG1SpCklTEn-aTe3FOwNyTzUH_oiDVeRoEcf04Y43ciRGYRB5PhI6ii-2lYuig6hyUr776Qxiu6-0zw-M_ay2MgGSy5CEj55dDSvcUyxStUObxGpPWnEvybO1vnE7iJEWGNe0L5uPe5nLidOiR-JwjxSWEx1xZYtIjxaf2Ulu-qu4hwgwBUQdx4bNZyBfljKj55skWuHqPMG3xMjnedQC6Ms5bR3rIkbBpvmgI3kJK-4CZikM6ruyLo94-Lk19aYQw", "e": "AQAB"}]} \ No newline at end of file +{"keys": [{"kty": "EC", "use": "sig", "kid": "azZQQ2FEQjh3QnVZWVdrbHJkMEZSaWR6aVJ0LTBjeUFfeWRlbTRrRFZ5VQ", "crv": "P-256", "x": "2ADe18caWWGp6hpRbfa9HqQHDFNpid9xUmR56Wzm_wc", "y": "HnD_8QBanz4Y-UF8mKQFZXfqkGkXUSm34mLsdDKtSyk"}, {"kty": "RSA", "use": "sig", "kid": "SHEyYWcwNVk0LTdROTZzZ2FUWndIVXdack0xWUM5SEpwcS03dVUxWU4zRQ", "e": "AQAB", "n": "rRz52ddyP9Y2ezSlRsnkt-sjXfV_Ii7vOFX-cStLE3IUlVeSJGEe_kAASLr2r3BE2unjntaxj67NP8D95h_rzG1SpCklTEn-aTe3FOwNyTzUH_oiDVeRoEcf04Y43ciRGYRB5PhI6ii-2lYuig6hyUr776Qxiu6-0zw-M_ay2MgGSy5CEj55dDSvcUyxStUObxGpPWnEvybO1vnE7iJEWGNe0L5uPe5nLidOiR-JwjxSWEx1xZYtIjxaf2Ulu-qu4hwgwBUQdx4bNZyBfljKj55skWuHqPMG3xMjnedQC6Ms5bR3rIkbBpvmgI3kJK-4CZikM6ruyLo94-Lk19aYQw"}]} \ No newline at end of file From f906c85b36a369e96bc8c6f5d27c02764ceb0e49 Mon Sep 17 00:00:00 2001 From: roland Date: Tue, 22 Nov 2022 17:24:56 +0100 Subject: [PATCH 33/76] Fixed tests up to client_28. Back and forth... --- .../client/work_condition/transform.py | 25 +++++++++++++++---- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/src/idpyoidc/client/work_condition/transform.py b/src/idpyoidc/client/work_condition/transform.py index c35a5252..83f5e47e 100644 --- a/src/idpyoidc/client/work_condition/transform.py +++ b/src/idpyoidc/client/work_condition/transform.py @@ -1,6 +1,7 @@ import logging from typing import Optional +from idpyoidc.message.oidc import RegistrationRequest from idpyoidc.message.oidc import RegistrationResponse logger = logging.getLogger(__name__) @@ -95,9 +96,12 @@ def supported_to_preferred(supported: dict, return preference -def array_to_singleton(claim_spec, values): +def array_or_singleton(claim_spec, values): if isinstance(claim_spec[0], list): - return values + if isinstance(values, list): + return values + else: + return [values] else: if isinstance(values, list): return values[0] @@ -128,7 +132,7 @@ def preferred_to_registered(prefers: dict, registration_response: Optional[dict] _preferred_values = prefers.get(_pref_key) if not _preferred_values: continue - registered[key] = array_to_singleton(spec, _preferred_values) + registered[key] = array_or_singleton(spec, _preferred_values) # transfer those claims that are not part of the registration request _rr_keys = list(RegistrationResponse.c_param.keys()) @@ -142,5 +146,16 @@ def preferred_to_registered(prefers: dict, registration_response: Optional[dict] return registered -def register_to_request(prefers, registration_response): - pass \ No newline at end of file +def create_registration_request(prefers, supported): + _request = {} + for key, spec in RegistrationRequest.c_param.items(): + _pref_key = REGISTER2PREFERRED.get(key, key) + if _pref_key in prefers: + value = prefers[_pref_key] + elif _pref_key in supported: + value = supported[_pref_key] + else: + continue + + _request[key] = array_or_singleton(spec, value) + return _request \ No newline at end of file From f5c2eb6d978cf0cc8e0c07ccf8b228e23654fc5f Mon Sep 17 00:00:00 2001 From: Roland Hedberg Date: Sat, 26 Nov 2022 08:26:23 +0100 Subject: [PATCH 34/76] All tests green --- .../client/work_condition/transform.py | 42 +++++++++++++++---- tests/pub_client.jwks | 2 +- 2 files changed, 34 insertions(+), 10 deletions(-) diff --git a/src/idpyoidc/client/work_condition/transform.py b/src/idpyoidc/client/work_condition/transform.py index 83f5e47e..c2fcfe69 100644 --- a/src/idpyoidc/client/work_condition/transform.py +++ b/src/idpyoidc/client/work_condition/transform.py @@ -20,10 +20,11 @@ "default_acr_values": "acr_values_supported", "subject_type": "subject_types_supported", "token_endpoint_auth_method": "token_endpoint_auth_methods_supported", - "token_endpoint_auth_signing_alg": "token_endpoint_auth_signing_alg_values_supported", "response_types": "response_types_supported", "grant_types": "grant_types_supported", + # In OAuth2 but not in OIDC "scope": "scopes_supported", + "token_endpoint_auth_signing_alg": "token_endpoint_auth_signing_alg_values_supported", # "display": "display_values_supported", # "claims": "claims_supported", # "request": "request_parameter_supported", @@ -44,6 +45,7 @@ 'request_uri': "request_uris", 'grant_type': "grant_types", "scope": 'scopes_supported', + 'post_logout_redirect_uri': "post_logout_redirect_uris" } @@ -109,7 +111,18 @@ def array_or_singleton(claim_spec, values): return values -def preferred_to_registered(prefers: dict, registration_response: Optional[dict] = None): +def _is_subset(a, b): + if isinstance(a, list): + if isinstance(b, list): + return set(b).issubset(set(a)) + elif isinstance(b, list): + return a in b + else: + return a == b + + +def preferred_to_registered(prefers: dict, supported: dict, + registration_response: Optional[dict] = None): """ The claims with values that are returned from the OP is what goes unless (!!) the values returned are not within the supported values. @@ -122,25 +135,33 @@ def preferred_to_registered(prefers: dict, registration_response: Optional[dict] if registration_response: for key, val in registration_response.items(): - registered[key] = val # Should I just accept with the OP says ?? + if key in REGISTER2PREFERRED: + if _is_subset(val, supported.get(REGISTER2PREFERRED[key])): + registered[key] = val + else: + logger.warning(f'OP tells me to do something I do not support: {key} = {val}') + else: + registered[key] = val # Should I just accept with the OP says ?? for key, spec in RegistrationResponse.c_param.items(): if key in registered: continue _pref_key = REGISTER2PREFERRED.get(key, key) - _preferred_values = prefers.get(_pref_key) + _preferred_values = prefers.get(_pref_key, prefers.get(key)) if not _preferred_values: continue + registered[key] = array_or_singleton(spec, _preferred_values) # transfer those claims that are not part of the registration request _rr_keys = list(RegistrationResponse.c_param.keys()) for key, val in prefers.items(): - if PREFERRED2REGISTER.get(key): - continue - if key not in _rr_keys: - registered[key] = val + _reg_key = PREFERRED2REGISTER.get(key, key) + if _reg_key not in _rr_keys: + # If they are not part of the registration request I do not knoe if it is supposed to + # be a singleton or an array. So just add it as is. + registered[_reg_key] = val logger.debug(f"Entity registered: {registered}") return registered @@ -157,5 +178,8 @@ def create_registration_request(prefers, supported): else: continue + if not value: + continue + _request[key] = array_or_singleton(spec, value) - return _request \ No newline at end of file + return _request diff --git a/tests/pub_client.jwks b/tests/pub_client.jwks index 84a27042..d5ce25ed 100644 --- a/tests/pub_client.jwks +++ b/tests/pub_client.jwks @@ -1 +1 @@ -{"keys": [{"kty": "EC", "use": "sig", "kid": "azZQQ2FEQjh3QnVZWVdrbHJkMEZSaWR6aVJ0LTBjeUFfeWRlbTRrRFZ5VQ", "crv": "P-256", "x": "2ADe18caWWGp6hpRbfa9HqQHDFNpid9xUmR56Wzm_wc", "y": "HnD_8QBanz4Y-UF8mKQFZXfqkGkXUSm34mLsdDKtSyk"}, {"kty": "RSA", "use": "sig", "kid": "SHEyYWcwNVk0LTdROTZzZ2FUWndIVXdack0xWUM5SEpwcS03dVUxWU4zRQ", "e": "AQAB", "n": "rRz52ddyP9Y2ezSlRsnkt-sjXfV_Ii7vOFX-cStLE3IUlVeSJGEe_kAASLr2r3BE2unjntaxj67NP8D95h_rzG1SpCklTEn-aTe3FOwNyTzUH_oiDVeRoEcf04Y43ciRGYRB5PhI6ii-2lYuig6hyUr776Qxiu6-0zw-M_ay2MgGSy5CEj55dDSvcUyxStUObxGpPWnEvybO1vnE7iJEWGNe0L5uPe5nLidOiR-JwjxSWEx1xZYtIjxaf2Ulu-qu4hwgwBUQdx4bNZyBfljKj55skWuHqPMG3xMjnedQC6Ms5bR3rIkbBpvmgI3kJK-4CZikM6ruyLo94-Lk19aYQw"}]} \ No newline at end of file +{"keys": [{"kty": "EC", "use": "sig", "kid": "azZQQ2FEQjh3QnVZWVdrbHJkMEZSaWR6aVJ0LTBjeUFfeWRlbTRrRFZ5VQ", "crv": "P-256", "x": "2ADe18caWWGp6hpRbfa9HqQHDFNpid9xUmR56Wzm_wc", "y": "HnD_8QBanz4Y-UF8mKQFZXfqkGkXUSm34mLsdDKtSyk"}, {"kty": "RSA", "use": "sig", "kid": "SHEyYWcwNVk0LTdROTZzZ2FUWndIVXdack0xWUM5SEpwcS03dVUxWU4zRQ", "n": "rRz52ddyP9Y2ezSlRsnkt-sjXfV_Ii7vOFX-cStLE3IUlVeSJGEe_kAASLr2r3BE2unjntaxj67NP8D95h_rzG1SpCklTEn-aTe3FOwNyTzUH_oiDVeRoEcf04Y43ciRGYRB5PhI6ii-2lYuig6hyUr776Qxiu6-0zw-M_ay2MgGSy5CEj55dDSvcUyxStUObxGpPWnEvybO1vnE7iJEWGNe0L5uPe5nLidOiR-JwjxSWEx1xZYtIjxaf2Ulu-qu4hwgwBUQdx4bNZyBfljKj55skWuHqPMG3xMjnedQC6Ms5bR3rIkbBpvmgI3kJK-4CZikM6ruyLo94-Lk19aYQw", "e": "AQAB"}]} \ No newline at end of file From 8879ac96c70ffeaa0fa28e74827f075a92cf3d61 Mon Sep 17 00:00:00 2001 From: Roland Hedberg Date: Sun, 27 Nov 2022 09:28:53 +0100 Subject: [PATCH 35/76] Cleaned up code and removed keyjar from work_condition. --- tests/test_client_21_oidc_service.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_client_21_oidc_service.py b/tests/test_client_21_oidc_service.py index 53a4010c..cb8c5386 100644 --- a/tests/test_client_21_oidc_service.py +++ b/tests/test_client_21_oidc_service.py @@ -930,6 +930,7 @@ def test_config_with_required_request_uri(): 'token_endpoint_auth_signing_alg', 'userinfo_signed_response_alg'} + def test_config_logout_uri(): client_config = { "client_id": "client_id", From bf7d3eeefeae9ceb5f4da875d91db2a362fdcdf6 Mon Sep 17 00:00:00 2001 From: Roland Hedberg Date: Tue, 29 Nov 2022 20:18:57 +0100 Subject: [PATCH 36/76] Replaced StateInterface with Current a much simpler state manager. --- .../__init__.py => state_interface.py} | 0 src/idpyoidc/client/work_condition/oauth2.py | 0 src/idpyoidc/client/work_condition/oidc.py | 0 .../client/work_condition/transform.py | 185 ------------------ tests/test_client_21_oidc_service.py | 1 - 5 files changed, 186 deletions(-) rename src/idpyoidc/client/{work_condition/__init__.py => state_interface.py} (100%) delete mode 100644 src/idpyoidc/client/work_condition/oauth2.py delete mode 100644 src/idpyoidc/client/work_condition/oidc.py delete mode 100644 src/idpyoidc/client/work_condition/transform.py diff --git a/src/idpyoidc/client/work_condition/__init__.py b/src/idpyoidc/client/state_interface.py similarity index 100% rename from src/idpyoidc/client/work_condition/__init__.py rename to src/idpyoidc/client/state_interface.py diff --git a/src/idpyoidc/client/work_condition/oauth2.py b/src/idpyoidc/client/work_condition/oauth2.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/idpyoidc/client/work_condition/oidc.py b/src/idpyoidc/client/work_condition/oidc.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/idpyoidc/client/work_condition/transform.py b/src/idpyoidc/client/work_condition/transform.py deleted file mode 100644 index c2fcfe69..00000000 --- a/src/idpyoidc/client/work_condition/transform.py +++ /dev/null @@ -1,185 +0,0 @@ -import logging -from typing import Optional - -from idpyoidc.message.oidc import RegistrationRequest -from idpyoidc.message.oidc import RegistrationResponse - -logger = logging.getLogger(__name__) - -REGISTER2PREFERRED = { - # "require_signed_request_object": "request_object_algs_supported", - "request_object_signing_alg": "request_object_signing_alg_values_supported", - "request_object_encryption_alg": "request_object_encryption_alg_values_supported", - "request_object_encryption_enc": "request_object_encryption_enc_values_supported", - "userinfo_signed_response_alg": "userinfo_signing_alg_values_supported", - "userinfo_encrypted_response_alg": "userinfo_encryption_alg_values_supported", - "userinfo_encrypted_response_enc": "userinfo_encryption_enc_values_supported", - "id_token_signed_response_alg": "id_token_signing_alg_values_supported", - "id_token_encrypted_response_alg": "id_token_encryption_alg_values_supported", - "id_token_encrypted_response_enc": "id_token_encryption_enc_values_supported", - "default_acr_values": "acr_values_supported", - "subject_type": "subject_types_supported", - "token_endpoint_auth_method": "token_endpoint_auth_methods_supported", - "response_types": "response_types_supported", - "grant_types": "grant_types_supported", - # In OAuth2 but not in OIDC - "scope": "scopes_supported", - "token_endpoint_auth_signing_alg": "token_endpoint_auth_signing_alg_values_supported", - # "display": "display_values_supported", - # "claims": "claims_supported", - # "request": "request_parameter_supported", - # "request_uri": "request_uri_parameter_supported", - # 'claims_locales': 'claims_locales_supported', - # 'ui_locales': 'ui_locales_supported', -} - -PREFERRED2REGISTER = dict([(v, k) for k, v in REGISTER2PREFERRED.items()]) - -REQUEST2REGISTER = { - 'client_id': "client_id", - "client_secret": "client_secret", - # 'acr_values': "default_acr_values" , - # 'max_age': "default_max_age", - 'redirect_uri': "redirect_uris", - 'response_type': "response_types", - 'request_uri': "request_uris", - 'grant_type': "grant_types", - "scope": 'scopes_supported', - 'post_logout_redirect_uri': "post_logout_redirect_uris" -} - - -def supported_to_preferred(supported: dict, - preference: dict, - base_url: str, - info: Optional[dict] = None, - ): - if info: # The provider info - for key, val in supported.items(): - if key in preference: - _pref_val = preference.get(key) # defined in configuration - _info_val = info.get(key) - if _info_val: - # Only use provider setting if less or equal to what I support - if key.endswith('supported'): # list - preference[key] = [x for x in _pref_val if x in _info_val] - else: - pass - elif val is None: # No default, means the RP does not have a preference - # if key not in ['jwks_uri', 'jwks']: - pass - else: - # there is a default - _info_val = info.get(key) - if _info_val: # The OP has an opinion - if key.endswith('supported'): # list - preference[key] = [x for x in val if x in _info_val] - else: - pass - else: - preference[key] = val - - # special case -> must have a request_uris value - if 'require_request_uri_registration' in info: - # only makes sense if I want to use request_uri - if preference.get('request_parameter') == 'request_uri': - if 'request_uri' not in preference: - preference['request_uris'] = [f'{base_url}/requests'] - else: # just ignore - logger.info('Asked for "request_uri" which it did not plan to use') - else: - # Add defaults - for key, val in supported.items(): - if val is None: - continue - if key not in preference: - preference[key] = val - - return preference - - -def array_or_singleton(claim_spec, values): - if isinstance(claim_spec[0], list): - if isinstance(values, list): - return values - else: - return [values] - else: - if isinstance(values, list): - return values[0] - else: # singleton - return values - - -def _is_subset(a, b): - if isinstance(a, list): - if isinstance(b, list): - return set(b).issubset(set(a)) - elif isinstance(b, list): - return a in b - else: - return a == b - - -def preferred_to_registered(prefers: dict, supported: dict, - registration_response: Optional[dict] = None): - """ - The claims with values that are returned from the OP is what goes unless (!!) - the values returned are not within the supported values. - - @param prefers: - @param registration_response: - @return: - """ - registered = {} - - if registration_response: - for key, val in registration_response.items(): - if key in REGISTER2PREFERRED: - if _is_subset(val, supported.get(REGISTER2PREFERRED[key])): - registered[key] = val - else: - logger.warning(f'OP tells me to do something I do not support: {key} = {val}') - else: - registered[key] = val # Should I just accept with the OP says ?? - - for key, spec in RegistrationResponse.c_param.items(): - if key in registered: - continue - _pref_key = REGISTER2PREFERRED.get(key, key) - - _preferred_values = prefers.get(_pref_key, prefers.get(key)) - if not _preferred_values: - continue - - registered[key] = array_or_singleton(spec, _preferred_values) - - # transfer those claims that are not part of the registration request - _rr_keys = list(RegistrationResponse.c_param.keys()) - for key, val in prefers.items(): - _reg_key = PREFERRED2REGISTER.get(key, key) - if _reg_key not in _rr_keys: - # If they are not part of the registration request I do not knoe if it is supposed to - # be a singleton or an array. So just add it as is. - registered[_reg_key] = val - - logger.debug(f"Entity registered: {registered}") - return registered - - -def create_registration_request(prefers, supported): - _request = {} - for key, spec in RegistrationRequest.c_param.items(): - _pref_key = REGISTER2PREFERRED.get(key, key) - if _pref_key in prefers: - value = prefers[_pref_key] - elif _pref_key in supported: - value = supported[_pref_key] - else: - continue - - if not value: - continue - - _request[key] = array_or_singleton(spec, value) - return _request diff --git a/tests/test_client_21_oidc_service.py b/tests/test_client_21_oidc_service.py index cb8c5386..53a4010c 100644 --- a/tests/test_client_21_oidc_service.py +++ b/tests/test_client_21_oidc_service.py @@ -930,7 +930,6 @@ def test_config_with_required_request_uri(): 'token_endpoint_auth_signing_alg', 'userinfo_signed_response_alg'} - def test_config_logout_uri(): client_config = { "client_id": "client_id", From cfbeed2ed8d5648047e79de069b69c3bd9971ddb Mon Sep 17 00:00:00 2001 From: Roland Hedberg Date: Thu, 1 Dec 2022 10:30:42 +0100 Subject: [PATCH 37/76] Make server side also use WorkEnvironment --- src/idpyoidc/client/work_environment/oauth2.py | 2 +- tests/static/jwks.json | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/idpyoidc/client/work_environment/oauth2.py b/src/idpyoidc/client/work_environment/oauth2.py index 71fedde9..69b5a20c 100644 --- a/src/idpyoidc/client/work_environment/oauth2.py +++ b/src/idpyoidc/client/work_environment/oauth2.py @@ -1,6 +1,6 @@ from typing import Optional -from idpyoidc.client import work_environment +from idpyoidc import work_environment class WorkEnvironment(work_environment.WorkEnvironment): diff --git a/tests/static/jwks.json b/tests/static/jwks.json index 8322d976..161a407b 100644 --- a/tests/static/jwks.json +++ b/tests/static/jwks.json @@ -1 +1 @@ -{"keys": [{"kty": "RSA", "use": "sig", "kid": "YnNESFhyQjloMnYzV2VqRGR2a3VCblFLX2h4VGl3TDVlY3FUNkViUE90bw", "n": "2iMaDALTQolz4UaT--GhjriLMyNbrDGlIXxSmgRh17Cm3cuHiyPOIQv1pjZVg4ATU1aafxmFyTfrmtf56tPuJ8yqcNNZC8XadYPAw7PTW9g8GJgLtC8GURJ9GQZD6FYIE6YCou8fYo6yd4b99y2y_vsl06cm9xQnstfp6eyMkcgQyrmdmlbyeuXwvcxsxtGX61MTJtCp4VELmDctJiYP_bD7HNRPV7uqXDMNmWSY0TYL-tg0As4y8-w3wSwmtcfWhnQEraFT0-m4hBpEWHlouuFNXRQIrXbamKxeh6kJNO0wJN8fZ4Ovygf8sE4kEwBPfWO59wxDF7camTpDUqg29Q", "e": "AQAB"}, {"kty": "EC", "use": "sig", "kid": "aWhtalRSTDZmNmRTd1ZDNWZmY3ZGMTNqM1dnLVA2RjQyMi1CNGdOSUNKVQ", "crv": "P-256", "x": "Ww5XVT3CxYN88BpJDZGodRiar0qr8UvPFaRoqzyD1Io", "y": "w23EDFAvwe03NjL5NKtUXwxuVMFmEn3ecJOPbljiDkg"}]} \ No newline at end of file +{"keys": [{"kty": "RSA", "use": "sig", "kid": "YnNESFhyQjloMnYzV2VqRGR2a3VCblFLX2h4VGl3TDVlY3FUNkViUE90bw", "e": "AQAB", "n": "2iMaDALTQolz4UaT--GhjriLMyNbrDGlIXxSmgRh17Cm3cuHiyPOIQv1pjZVg4ATU1aafxmFyTfrmtf56tPuJ8yqcNNZC8XadYPAw7PTW9g8GJgLtC8GURJ9GQZD6FYIE6YCou8fYo6yd4b99y2y_vsl06cm9xQnstfp6eyMkcgQyrmdmlbyeuXwvcxsxtGX61MTJtCp4VELmDctJiYP_bD7HNRPV7uqXDMNmWSY0TYL-tg0As4y8-w3wSwmtcfWhnQEraFT0-m4hBpEWHlouuFNXRQIrXbamKxeh6kJNO0wJN8fZ4Ovygf8sE4kEwBPfWO59wxDF7camTpDUqg29Q"}, {"kty": "EC", "use": "sig", "kid": "aWhtalRSTDZmNmRTd1ZDNWZmY3ZGMTNqM1dnLVA2RjQyMi1CNGdOSUNKVQ", "crv": "P-256", "x": "Ww5XVT3CxYN88BpJDZGodRiar0qr8UvPFaRoqzyD1Io", "y": "w23EDFAvwe03NjL5NKtUXwxuVMFmEn3ecJOPbljiDkg"}]} \ No newline at end of file From a6383bad1623d5156aa165b2d87c9e9130e787f1 Mon Sep 17 00:00:00 2001 From: roland Date: Thu, 1 Dec 2022 19:10:43 +0100 Subject: [PATCH 38/76] Partly done with harmonizing work environment usage. --- tests/static/jwks.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/static/jwks.json b/tests/static/jwks.json index 161a407b..8322d976 100644 --- a/tests/static/jwks.json +++ b/tests/static/jwks.json @@ -1 +1 @@ -{"keys": [{"kty": "RSA", "use": "sig", "kid": "YnNESFhyQjloMnYzV2VqRGR2a3VCblFLX2h4VGl3TDVlY3FUNkViUE90bw", "e": "AQAB", "n": "2iMaDALTQolz4UaT--GhjriLMyNbrDGlIXxSmgRh17Cm3cuHiyPOIQv1pjZVg4ATU1aafxmFyTfrmtf56tPuJ8yqcNNZC8XadYPAw7PTW9g8GJgLtC8GURJ9GQZD6FYIE6YCou8fYo6yd4b99y2y_vsl06cm9xQnstfp6eyMkcgQyrmdmlbyeuXwvcxsxtGX61MTJtCp4VELmDctJiYP_bD7HNRPV7uqXDMNmWSY0TYL-tg0As4y8-w3wSwmtcfWhnQEraFT0-m4hBpEWHlouuFNXRQIrXbamKxeh6kJNO0wJN8fZ4Ovygf8sE4kEwBPfWO59wxDF7camTpDUqg29Q"}, {"kty": "EC", "use": "sig", "kid": "aWhtalRSTDZmNmRTd1ZDNWZmY3ZGMTNqM1dnLVA2RjQyMi1CNGdOSUNKVQ", "crv": "P-256", "x": "Ww5XVT3CxYN88BpJDZGodRiar0qr8UvPFaRoqzyD1Io", "y": "w23EDFAvwe03NjL5NKtUXwxuVMFmEn3ecJOPbljiDkg"}]} \ No newline at end of file +{"keys": [{"kty": "RSA", "use": "sig", "kid": "YnNESFhyQjloMnYzV2VqRGR2a3VCblFLX2h4VGl3TDVlY3FUNkViUE90bw", "n": "2iMaDALTQolz4UaT--GhjriLMyNbrDGlIXxSmgRh17Cm3cuHiyPOIQv1pjZVg4ATU1aafxmFyTfrmtf56tPuJ8yqcNNZC8XadYPAw7PTW9g8GJgLtC8GURJ9GQZD6FYIE6YCou8fYo6yd4b99y2y_vsl06cm9xQnstfp6eyMkcgQyrmdmlbyeuXwvcxsxtGX61MTJtCp4VELmDctJiYP_bD7HNRPV7uqXDMNmWSY0TYL-tg0As4y8-w3wSwmtcfWhnQEraFT0-m4hBpEWHlouuFNXRQIrXbamKxeh6kJNO0wJN8fZ4Ovygf8sE4kEwBPfWO59wxDF7camTpDUqg29Q", "e": "AQAB"}, {"kty": "EC", "use": "sig", "kid": "aWhtalRSTDZmNmRTd1ZDNWZmY3ZGMTNqM1dnLVA2RjQyMi1CNGdOSUNKVQ", "crv": "P-256", "x": "Ww5XVT3CxYN88BpJDZGodRiar0qr8UvPFaRoqzyD1Io", "y": "w23EDFAvwe03NjL5NKtUXwxuVMFmEn3ecJOPbljiDkg"}]} \ No newline at end of file From 27613ff2641b8ff3e0456f93561c6104acb683b0 Mon Sep 17 00:00:00 2001 From: Roland Hedberg Date: Sat, 3 Dec 2022 10:40:14 +0100 Subject: [PATCH 39/76] All tests green. --- src/idpyoidc/client/work_environment/oauth2.py | 2 +- src/idpyoidc/server/oauth2/token.py | 3 +++ tests/pub_client.jwks | 2 +- tests/pub_iss.jwks | 2 +- tests/static/jwks.json | 2 +- 5 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/idpyoidc/client/work_environment/oauth2.py b/src/idpyoidc/client/work_environment/oauth2.py index 69b5a20c..71fedde9 100644 --- a/src/idpyoidc/client/work_environment/oauth2.py +++ b/src/idpyoidc/client/work_environment/oauth2.py @@ -1,6 +1,6 @@ from typing import Optional -from idpyoidc import work_environment +from idpyoidc.client import work_environment class WorkEnvironment(work_environment.WorkEnvironment): diff --git a/src/idpyoidc/server/oauth2/token.py b/src/idpyoidc/server/oauth2/token.py index 431ae7ad..e0a77196 100755 --- a/src/idpyoidc/server/oauth2/token.py +++ b/src/idpyoidc/server/oauth2/token.py @@ -7,13 +7,16 @@ from idpyoidc.message import Message from idpyoidc.message.oauth2 import AccessTokenResponse from idpyoidc.message.oauth2 import ResponseMessage +from idpyoidc.message.oauth2 import TokenExchangeRequest from idpyoidc.message.oidc import TokenErrorResponse +from idpyoidc.server.constant import DEFAULT_REQUESTED_TOKEN_TYPE from idpyoidc.server.endpoint import Endpoint from idpyoidc.server.exception import ProcessError from idpyoidc.server.oauth2.token_helper import AccessTokenHelper from idpyoidc.server.oauth2.token_helper import RefreshTokenHelper from idpyoidc.server.oauth2.token_helper import TokenExchangeHelper from idpyoidc.server.session import MintingNotAllowed +from idpyoidc.server.session.token import TOKEN_TYPES_MAPPING from idpyoidc.util import importer logger = logging.getLogger(__name__) diff --git a/tests/pub_client.jwks b/tests/pub_client.jwks index d5ce25ed..84a27042 100644 --- a/tests/pub_client.jwks +++ b/tests/pub_client.jwks @@ -1 +1 @@ -{"keys": [{"kty": "EC", "use": "sig", "kid": "azZQQ2FEQjh3QnVZWVdrbHJkMEZSaWR6aVJ0LTBjeUFfeWRlbTRrRFZ5VQ", "crv": "P-256", "x": "2ADe18caWWGp6hpRbfa9HqQHDFNpid9xUmR56Wzm_wc", "y": "HnD_8QBanz4Y-UF8mKQFZXfqkGkXUSm34mLsdDKtSyk"}, {"kty": "RSA", "use": "sig", "kid": "SHEyYWcwNVk0LTdROTZzZ2FUWndIVXdack0xWUM5SEpwcS03dVUxWU4zRQ", "n": "rRz52ddyP9Y2ezSlRsnkt-sjXfV_Ii7vOFX-cStLE3IUlVeSJGEe_kAASLr2r3BE2unjntaxj67NP8D95h_rzG1SpCklTEn-aTe3FOwNyTzUH_oiDVeRoEcf04Y43ciRGYRB5PhI6ii-2lYuig6hyUr776Qxiu6-0zw-M_ay2MgGSy5CEj55dDSvcUyxStUObxGpPWnEvybO1vnE7iJEWGNe0L5uPe5nLidOiR-JwjxSWEx1xZYtIjxaf2Ulu-qu4hwgwBUQdx4bNZyBfljKj55skWuHqPMG3xMjnedQC6Ms5bR3rIkbBpvmgI3kJK-4CZikM6ruyLo94-Lk19aYQw", "e": "AQAB"}]} \ No newline at end of file +{"keys": [{"kty": "EC", "use": "sig", "kid": "azZQQ2FEQjh3QnVZWVdrbHJkMEZSaWR6aVJ0LTBjeUFfeWRlbTRrRFZ5VQ", "crv": "P-256", "x": "2ADe18caWWGp6hpRbfa9HqQHDFNpid9xUmR56Wzm_wc", "y": "HnD_8QBanz4Y-UF8mKQFZXfqkGkXUSm34mLsdDKtSyk"}, {"kty": "RSA", "use": "sig", "kid": "SHEyYWcwNVk0LTdROTZzZ2FUWndIVXdack0xWUM5SEpwcS03dVUxWU4zRQ", "e": "AQAB", "n": "rRz52ddyP9Y2ezSlRsnkt-sjXfV_Ii7vOFX-cStLE3IUlVeSJGEe_kAASLr2r3BE2unjntaxj67NP8D95h_rzG1SpCklTEn-aTe3FOwNyTzUH_oiDVeRoEcf04Y43ciRGYRB5PhI6ii-2lYuig6hyUr776Qxiu6-0zw-M_ay2MgGSy5CEj55dDSvcUyxStUObxGpPWnEvybO1vnE7iJEWGNe0L5uPe5nLidOiR-JwjxSWEx1xZYtIjxaf2Ulu-qu4hwgwBUQdx4bNZyBfljKj55skWuHqPMG3xMjnedQC6Ms5bR3rIkbBpvmgI3kJK-4CZikM6ruyLo94-Lk19aYQw"}]} \ No newline at end of file diff --git a/tests/pub_iss.jwks b/tests/pub_iss.jwks index 77081f40..9b062907 100644 --- a/tests/pub_iss.jwks +++ b/tests/pub_iss.jwks @@ -1 +1 @@ -{"keys": [{"kty": "EC", "use": "sig", "kid": "SmdKMlVGcG1zMnprdDdXZGpGWEczdHhlZVpGbkx1THpPdUY4d0w4bnZkSQ", "crv": "P-256", "x": "tRHJYm0fsOi0icpGEb33qiDVgt68ltMoYSWdLGhDGz4", "y": "fRpX0i6p5Jigf5I0qwW34PyStosMShwWAWS8x_w5o7E"}, {"kty": "RSA", "use": "sig", "kid": "R0FsaFdqREFaUFp1c0MwbUpsbHVSZ200blBJZWJVMTUtNGsyVlBmdHk5UQ", "n": "2ilgsKVqF92KfhwmosSVeZOaDgb3RF1mbg-pqkmLO6YpOO06LF4V4angF-GhP-ysAm2E75aSIU4tnHVThFlcxTgKFqjYKJQXyVzTVK2r-L2IbvFPaDtvoU6WteybpMlIUVk2po3cFDGObCWYKCm7CUOLlwH0uOpui66P9VSCqdKVKbJRAQBvTSbP10KWPxulfqjWGJtHO5fY7-JVWwOBkG-eHSJIT_uaoPjyvKCZjknq04bLUV9qP78KRQpRyYijBN60w2v8F79baN9CN10TIEjjWKGz0uX0M_YYQzTUoSY5l5ka9RkL3wT4o2iQ1t5nHphX6aA-gqwgCQmi-nvjaw", "e": "AQAB"}]} \ No newline at end of file +{"keys": [{"kty": "EC", "use": "sig", "kid": "SmdKMlVGcG1zMnprdDdXZGpGWEczdHhlZVpGbkx1THpPdUY4d0w4bnZkSQ", "crv": "P-256", "x": "tRHJYm0fsOi0icpGEb33qiDVgt68ltMoYSWdLGhDGz4", "y": "fRpX0i6p5Jigf5I0qwW34PyStosMShwWAWS8x_w5o7E"}, {"kty": "RSA", "use": "sig", "kid": "R0FsaFdqREFaUFp1c0MwbUpsbHVSZ200blBJZWJVMTUtNGsyVlBmdHk5UQ", "e": "AQAB", "n": "2ilgsKVqF92KfhwmosSVeZOaDgb3RF1mbg-pqkmLO6YpOO06LF4V4angF-GhP-ysAm2E75aSIU4tnHVThFlcxTgKFqjYKJQXyVzTVK2r-L2IbvFPaDtvoU6WteybpMlIUVk2po3cFDGObCWYKCm7CUOLlwH0uOpui66P9VSCqdKVKbJRAQBvTSbP10KWPxulfqjWGJtHO5fY7-JVWwOBkG-eHSJIT_uaoPjyvKCZjknq04bLUV9qP78KRQpRyYijBN60w2v8F79baN9CN10TIEjjWKGz0uX0M_YYQzTUoSY5l5ka9RkL3wT4o2iQ1t5nHphX6aA-gqwgCQmi-nvjaw"}]} \ No newline at end of file diff --git a/tests/static/jwks.json b/tests/static/jwks.json index 8322d976..161a407b 100644 --- a/tests/static/jwks.json +++ b/tests/static/jwks.json @@ -1 +1 @@ -{"keys": [{"kty": "RSA", "use": "sig", "kid": "YnNESFhyQjloMnYzV2VqRGR2a3VCblFLX2h4VGl3TDVlY3FUNkViUE90bw", "n": "2iMaDALTQolz4UaT--GhjriLMyNbrDGlIXxSmgRh17Cm3cuHiyPOIQv1pjZVg4ATU1aafxmFyTfrmtf56tPuJ8yqcNNZC8XadYPAw7PTW9g8GJgLtC8GURJ9GQZD6FYIE6YCou8fYo6yd4b99y2y_vsl06cm9xQnstfp6eyMkcgQyrmdmlbyeuXwvcxsxtGX61MTJtCp4VELmDctJiYP_bD7HNRPV7uqXDMNmWSY0TYL-tg0As4y8-w3wSwmtcfWhnQEraFT0-m4hBpEWHlouuFNXRQIrXbamKxeh6kJNO0wJN8fZ4Ovygf8sE4kEwBPfWO59wxDF7camTpDUqg29Q", "e": "AQAB"}, {"kty": "EC", "use": "sig", "kid": "aWhtalRSTDZmNmRTd1ZDNWZmY3ZGMTNqM1dnLVA2RjQyMi1CNGdOSUNKVQ", "crv": "P-256", "x": "Ww5XVT3CxYN88BpJDZGodRiar0qr8UvPFaRoqzyD1Io", "y": "w23EDFAvwe03NjL5NKtUXwxuVMFmEn3ecJOPbljiDkg"}]} \ No newline at end of file +{"keys": [{"kty": "RSA", "use": "sig", "kid": "YnNESFhyQjloMnYzV2VqRGR2a3VCblFLX2h4VGl3TDVlY3FUNkViUE90bw", "e": "AQAB", "n": "2iMaDALTQolz4UaT--GhjriLMyNbrDGlIXxSmgRh17Cm3cuHiyPOIQv1pjZVg4ATU1aafxmFyTfrmtf56tPuJ8yqcNNZC8XadYPAw7PTW9g8GJgLtC8GURJ9GQZD6FYIE6YCou8fYo6yd4b99y2y_vsl06cm9xQnstfp6eyMkcgQyrmdmlbyeuXwvcxsxtGX61MTJtCp4VELmDctJiYP_bD7HNRPV7uqXDMNmWSY0TYL-tg0As4y8-w3wSwmtcfWhnQEraFT0-m4hBpEWHlouuFNXRQIrXbamKxeh6kJNO0wJN8fZ4Ovygf8sE4kEwBPfWO59wxDF7camTpDUqg29Q"}, {"kty": "EC", "use": "sig", "kid": "aWhtalRSTDZmNmRTd1ZDNWZmY3ZGMTNqM1dnLVA2RjQyMi1CNGdOSUNKVQ", "crv": "P-256", "x": "Ww5XVT3CxYN88BpJDZGodRiar0qr8UvPFaRoqzyD1Io", "y": "w23EDFAvwe03NjL5NKtUXwxuVMFmEn3ecJOPbljiDkg"}]} \ No newline at end of file From c9043896d157b8fde94a8fde9596c17d8f8560eb Mon Sep 17 00:00:00 2001 From: roland Date: Sat, 8 Oct 2022 08:47:01 +0200 Subject: [PATCH 40/76] Fedservice support --- src/idpyoidc/__init__.py | 3 + src/idpyoidc/actor/__init__.py | 67 +++- src/idpyoidc/client/client_auth.py | 12 +- src/idpyoidc/client/entity.py | 30 +- src/idpyoidc/client/oauth2/__init__.py | 5 +- src/idpyoidc/client/oauth2/access_token.py | 8 +- src/idpyoidc/client/oauth2/add_on/dpop.py | 2 +- .../oauth2/add_on/identity_assurance.py | 2 +- src/idpyoidc/client/oauth2/add_on/pkce.py | 6 +- .../oauth2/add_on/pushed_authorization.py | 4 +- src/idpyoidc/client/oauth2/authorization.py | 13 +- .../client_credentials/cc_access_token.py | 8 +- .../cc_refresh_access_token.py | 10 +- .../client/oauth2/refresh_access_token.py | 8 +- src/idpyoidc/client/oauth2/server_metadata.py | 12 +- src/idpyoidc/client/oauth2/utils.py | 4 +- src/idpyoidc/client/oidc/__init__.py | 33 +- src/idpyoidc/client/oidc/access_token.py | 14 +- src/idpyoidc/client/oidc/authorization.py | 27 +- .../client/oidc/backchannel_authentication.py | 8 +- src/idpyoidc/client/oidc/check_id.py | 6 +- src/idpyoidc/client/oidc/check_session.py | 6 +- .../client/oidc/provider_info_discovery.py | 14 +- src/idpyoidc/client/oidc/read_registration.py | 4 +- .../client/oidc/refresh_access_token.py | 2 +- src/idpyoidc/client/oidc/registration.py | 4 +- src/idpyoidc/client/oidc/userinfo.py | 10 +- src/idpyoidc/client/oidc/webfinger.py | 8 +- src/idpyoidc/client/rp_handler.py | 43 ++- src/idpyoidc/client/service.py | 31 +- src/idpyoidc/context.py | 6 + src/idpyoidc/server/__init__.py | 20 +- src/idpyoidc/server/authz/__init__.py | 8 +- src/idpyoidc/server/client_authn.py | 2 +- src/idpyoidc/server/configure.py | 2 + src/idpyoidc/server/endpoint.py | 16 +- src/idpyoidc/server/endpoint_context.py | 8 +- src/idpyoidc/server/oauth2/add_on/dpop.py | 6 +- .../server/oauth2/add_on/extra_args.py | 2 +- src/idpyoidc/server/oauth2/authorization.py | 18 +- src/idpyoidc/server/oauth2/introspection.py | 4 +- .../server/oauth2/pushed_authorization.py | 2 +- src/idpyoidc/server/oauth2/token.py | 4 +- src/idpyoidc/server/oauth2/token_helper.py | 16 +- .../server/oidc/add_on/custom_scopes.py | 2 +- src/idpyoidc/server/oidc/add_on/pkce.py | 2 +- src/idpyoidc/server/oidc/authorization.py | 2 +- .../server/oidc/backchannel_authentication.py | 12 +- src/idpyoidc/server/oidc/discovery.py | 2 +- src/idpyoidc/server/oidc/provider_config.py | 2 +- src/idpyoidc/server/oidc/read_registration.py | 4 +- src/idpyoidc/server/oidc/registration.py | 30 +- src/idpyoidc/server/oidc/session.py | 24 +- src/idpyoidc/server/oidc/token_helper.py | 8 +- src/idpyoidc/server/oidc/userinfo.py | 6 +- src/idpyoidc/server/scopes.py | 4 +- src/idpyoidc/server/session/claims.py | 10 +- src/idpyoidc/server/token/handler.py | 2 +- src/idpyoidc/server/token/id_token.py | 8 +- src/idpyoidc/server/token/jwt_token.py | 6 +- src/idpyoidc/server/user_authn/user.py | 6 +- tests/test_server_17_client_authn.py | 6 +- tests/test_server_20d_client_authn.py | 4 +- ...server_24_oauth2_authorization_endpoint.py | 44 +-- ...er_24_oauth2_authorization_endpoint_jar.py | 4 +- ...t_server_24_oidc_authorization_endpoint.py | 36 +- tests/test_server_30_oidc_end_session.py | 50 +-- tests/test_server_31_oauth2_introspection.py | 20 +- tests/test_server_33_oauth2_pkce.py | 8 +- tests/test_server_36_oauth2_token_exchange.py | 2 +- tests/test_server_50_persistence.py | 4 +- tests/test_server_61_add_on.py | 2 +- tests/test_y_actor_01.py | 351 ++++++++++++++++++ 73 files changed, 814 insertions(+), 365 deletions(-) create mode 100644 tests/test_y_actor_01.py diff --git a/src/idpyoidc/__init__.py b/src/idpyoidc/__init__.py index 691c3c91..5b03c94b 100644 --- a/src/idpyoidc/__init__.py +++ b/src/idpyoidc/__init__.py @@ -1,6 +1,9 @@ __author__ = "Roland Hedberg" __version__ = "1.4.0" +import os +from typing import Dict + VERIFIED_CLAIM_PREFIX = "__verified" diff --git a/src/idpyoidc/actor/__init__.py b/src/idpyoidc/actor/__init__.py index 792d6005..cd62398a 100644 --- a/src/idpyoidc/actor/__init__.py +++ b/src/idpyoidc/actor/__init__.py @@ -1 +1,66 @@ -# +from typing import Optional +from uuid import uuid4 + +from cryptojwt.key_jar import KeyJar + +from idpyoidc.impexp import ImpExp + + +class CIBAClient(ImpExp): + parameter = {"context": {}} + + def __init__( + self, + keyjar: Optional[KeyJar] = None, + ): + ImpExp.__init__(self) + self.keyjar = keyjar + self.server = None + self.client = None + self.context = {} + + def create_authentication_request(self, scope, binding_message, login_hint): + _service = self.client.superior_get("service", "backchannel_authentication") + + client_notification_token = uuid4().hex + + request_args = { + "scope": scope, + "client_notification_token": client_notification_token, + "binding_message": binding_message, + "login_hint": login_hint, + } + request = _service.get_request_parameters( + request_args=request_args, authn_method="private_key_jwt" + ) + + self.context[client_notification_token] = { + "authentication_request": request, + "client_id": _service.superior_get("context").issuer, + } + return request + + def get_client_id_from_token(self, token): + _context = self.context[token] + return _context["client_id"] + + def do_client_notification(self, msg, http_info): + _notification_endpoint = self.server.server_get("endpoint", "client_notification") + _nreq = _notification_endpoint.parse_request( + msg, http_info, get_client_id_from_token=self.get_client_id_from_token + ) + _ninfo = _notification_endpoint.process_request(_nreq) + + +class CIBAServer(ImpExp): + parameter = {"context": {}} + + def __init__( + self, + keyjar: Optional[KeyJar] = None, + ): + ImpExp.__init__(self) + self.keyjar = keyjar + self.server = None + self.client = None + self.context = {} diff --git a/src/idpyoidc/client/client_auth.py b/src/idpyoidc/client/client_auth.py index a1d8f672..b04cf595 100755 --- a/src/idpyoidc/client/client_auth.py +++ b/src/idpyoidc/client/client_auth.py @@ -95,7 +95,7 @@ def _get_passwd(request, service, **kwargs): try: passwd = request["client_secret"] except KeyError: - passwd = service.client_get("service_context").get_usage('client_secret') + passwd = service.superior_get("context").get_usage('client_secret') return passwd @staticmethod @@ -103,7 +103,7 @@ def _get_user(service, **kwargs): try: user = kwargs["user"] except KeyError: - user = service.client_get("service_context").get_client_id() + user = service.superior_get("context").get_client_id() return user def _get_authentication_token(self, request, service, **kwargs): @@ -138,7 +138,7 @@ def _with_or_without_client_id(request, service): ): if "client_id" not in request: try: - request["client_id"] = service.client_get("service_context").get_client_id() + request["client_id"] = service.superior_get("context").get_client_id() except AttributeError: pass else: @@ -215,7 +215,7 @@ def modify_request(self, request, service, **kwargs): :param request: The request :param service: The service that is using this authentication method """ - _context = service.client_get("service_context") + _context = service.superior_get("context") if "client_secret" not in request: try: request["client_secret"] = kwargs["client_secret"] @@ -272,7 +272,7 @@ def find_token(request, token_type, service, **kwargs): except KeyError: # Get the latest acquired access token. _state = kwargs.get("state", kwargs.get("key")) - _arg = service.client_get("service_context").cstate.get_set(_state, claim=[token_type]) + _arg = service.superior_get("context").cstate.get_set(_state, claim=[token_type]) return _arg.get("access_token") @@ -482,7 +482,7 @@ def _get_audience_and_algorithm(self, context, **kwargs): return audience, algorithm def _construct_client_assertion(self, service, **kwargs): - _context = service.client_get("service_context") + _context = service.superior_get("context") audience, algorithm = self._get_audience_and_algorithm(_context, **kwargs) if "kid" in kwargs: diff --git a/src/idpyoidc/client/entity.py b/src/idpyoidc/client/entity.py index 90a95de3..9849d785 100644 --- a/src/idpyoidc/client/entity.py +++ b/src/idpyoidc/client/entity.py @@ -78,7 +78,9 @@ def __init__( config: Optional[Union[dict, Configuration]] = None, services: Optional[dict] = None, httpc_params: Optional[dict] = None, - client_type: Optional[str] = "oauth2" + client_type: Optional[str] = "oauth2", + context: Optional[OidcContext] = None, + superior_get: Optional[Callable] = None ): self.extra = {} if httpc_params: @@ -101,7 +103,7 @@ def __init__( else: _srvs = DEFAULT_OIDC_SERVICES - self._service = init_services(service_definitions=_srvs, client_get=self.client_get) + self._service = init_services(service_definitions=_srvs, superior_get=self.entity_get) self._service_context = ServiceContext( keyjar=keyjar, config=config, httpc_params=self.httpc_params, @@ -111,12 +113,12 @@ def __init__( self.keyjar = self._service_context.get_preference('keyjar') self.setup_client_authn_methods(config) + self.superior_get = superior_get # Deal with backward compatibility self.backward_compatibility(config) - - def client_get(self, what, *arg): + def entity_get(self, what, *arg): _func = getattr(self, "get_{}".format(what), None) if _func: return _func(*arg) @@ -125,7 +127,10 @@ def client_get(self, what, *arg): def get_services(self, *arg): return self._service - def get_service_context(self, *arg): + def get_service_context(self, *arg): # Want to get rid of this + return self._service_context + + def get_context(self, *arg): return self._service_context def get_service(self, service_name, *arg): @@ -151,10 +156,19 @@ def get_client_id(self): else: return self._service_context.work_environment.get_preference('client_id') + def get_keyjar(self): + if self.get_service_context().keyjar: + return self.get_service_context().keyjar + else: + return self.superior_get('application', 'keyjar') + def setup_client_authn_methods(self, config): - self._service_context.client_authn_method = client_auth_setup( - config.get("client_authn_methods") - ) + if config and "client_authn_methods" in config: + self._service_context.client_authn_method = client_auth_setup( + config.get("client_authn_methods") + ) + else: + self._service_context.client_authn_method = {} def backward_compatibility(self, config): _work_environment = self._service_context.work_environment diff --git a/src/idpyoidc/client/oauth2/__init__.py b/src/idpyoidc/client/oauth2/__init__.py index c464d4a2..6287abc0 100755 --- a/src/idpyoidc/client/oauth2/__init__.py +++ b/src/idpyoidc/client/oauth2/__init__.py @@ -41,6 +41,7 @@ def __init__( services=None, httpc_params=None, client_type: Optional[str] = "" + **kwargs ): """ @@ -140,7 +141,7 @@ def get_response( if resp.status_code < 300: if "keyjar" not in kwargs: - kwargs["keyjar"] = service.client_get("service_context").keyjar + kwargs["keyjar"] = service.superior_get("context").keyjar if not response_body_type: response_body_type = service.response_body_type @@ -293,7 +294,7 @@ def dynamic_provider_info_discovery(client: Client, behaviour_args: Optional[dic except KeyError: raise ConfigurationError("Can not do dynamic provider info discovery") else: - _context = client.client_get("service_context") + _context = client.superior_get("context") try: _context.set("issuer", _context.config["srv_discovery_url"]) except KeyError: diff --git a/src/idpyoidc/client/oauth2/access_token.py b/src/idpyoidc/client/oauth2/access_token.py index df0804d0..1ccb61e0 100644 --- a/src/idpyoidc/client/oauth2/access_token.py +++ b/src/idpyoidc/client/oauth2/access_token.py @@ -32,15 +32,15 @@ class AccessToken(Service): "token_endpoint_auth_signing_alg": get_signing_algs, } - def __init__(self, client_get, conf=None): - Service.__init__(self, client_get, conf=conf) + def __init__(self, superior_get, conf=None): + Service.__init__(self, superior_get, conf=conf) self.pre_construct.append(self.oauth_pre_construct) def update_service_context(self, resp, key: Optional[str] = '', **kwargs): if "expires_in" in resp: resp["__expires_at"] = time_sans_frac() + int(resp["expires_in"]) if key: - self.client_get("service_context").cstate.update(key, resp) + self.superior_get("context").cstate.update(key, resp) def oauth_pre_construct(self, request_args=None, post_args=None, **kwargs): """ @@ -52,7 +52,7 @@ def oauth_pre_construct(self, request_args=None, post_args=None, **kwargs): _state = get_state_parameter(request_args, kwargs) parameters = list(self.msg_type.c_param.keys()) - _context = self.client_get("service_context") + _context = self.superior_get("context") _args = _context.cstate.get_set(_state, claim=parameters) if "grant_type" not in _args: diff --git a/src/idpyoidc/client/oauth2/add_on/dpop.py b/src/idpyoidc/client/oauth2/add_on/dpop.py index cf381869..c52e6dc1 100644 --- a/src/idpyoidc/client/oauth2/add_on/dpop.py +++ b/src/idpyoidc/client/oauth2/add_on/dpop.py @@ -154,7 +154,7 @@ def add_support(services, signing_algorithms): # Access token request should use DPoP header _service = services["accesstoken"] - _context = _service.client_get("service_context") + _context = _service.superior_get("context") _context.add_on["dpop"] = { # "key": key_by_alg(signing_algorithm), "sign_algs": signing_algorithms diff --git a/src/idpyoidc/client/oauth2/add_on/identity_assurance.py b/src/idpyoidc/client/oauth2/add_on/identity_assurance.py index 9944ffb9..9815896c 100644 --- a/src/idpyoidc/client/oauth2/add_on/identity_assurance.py +++ b/src/idpyoidc/client/oauth2/add_on/identity_assurance.py @@ -73,7 +73,7 @@ def add_support( # Access token request should use DPoP header _service = services["userinfo"] - _context = _service.client_get("service_context") + _context = _service.superior_get("context") _context.add_on["identity_assurance"] = { "verified_claims_supported": True, "trust_frameworks_supported": trust_frameworks_supported, diff --git a/src/idpyoidc/client/oauth2/add_on/pkce.py b/src/idpyoidc/client/oauth2/add_on/pkce.py index d45f411a..5c250015 100644 --- a/src/idpyoidc/client/oauth2/add_on/pkce.py +++ b/src/idpyoidc/client/oauth2/add_on/pkce.py @@ -22,7 +22,7 @@ def add_code_challenge(request_args, service, **kwargs): :param kwargs: Extra set of keyword arguments :return: Updated set of request arguments """ - _context = service.client_get("service_context") + _context = service.superior_get("context") _kwargs = _context.add_on["pkce"] try: @@ -69,7 +69,7 @@ def add_code_verifier(request_args, service, **kwargs): _state = request_args.get("state") if _state is None: _state = kwargs.get("state") - _item = service.client_get("service_context").cstate.get_set(_state, claim=['code_verifier']) + _item = service.superior_get("context").cstate.get_set(_state, claim=['code_verifier']) request_args.update(_item) return request_args @@ -91,7 +91,7 @@ def add_support(service, code_challenge_length, code_challenge_method): """ if "authorization" in service and "accesstoken" in service: _service = service["authorization"] - _context = _service.client_get("service_context") + _context = _service.superior_get("context") _context.add_on["pkce"] = { "code_challenge_length": code_challenge_length, "code_challenge_method": code_challenge_method, diff --git a/src/idpyoidc/client/oauth2/add_on/pushed_authorization.py b/src/idpyoidc/client/oauth2/add_on/pushed_authorization.py index b33072ee..d40c7d52 100644 --- a/src/idpyoidc/client/oauth2/add_on/pushed_authorization.py +++ b/src/idpyoidc/client/oauth2/add_on/pushed_authorization.py @@ -16,7 +16,7 @@ def push_authorization(request_args, service, **kwargs): :param kwargs: Extra keyword arguments. """ - _context = service.client_get("service_context") + _context = service.superior_get("context") method_args = _context.add_on["pushed_authorization"] # construct the message body @@ -66,7 +66,7 @@ def add_support( http_client = requests _service = services["authorization"] - _service.client_get("service_context").add_on["pushed_authorization"] = { + _service.superior_get("context").add_on["pushed_authorization"] = { "body_format": body_format, "signing_algorithm": signing_algorithm, "http_client": http_client, diff --git a/src/idpyoidc/client/oauth2/authorization.py b/src/idpyoidc/client/oauth2/authorization.py index b557339e..e28b1fff 100644 --- a/src/idpyoidc/client/oauth2/authorization.py +++ b/src/idpyoidc/client/oauth2/authorization.py @@ -45,20 +45,20 @@ class Authorization(Service): } } - def __init__(self, client_get, conf=None): - Service.__init__(self, client_get, conf=conf) + def __init__(self, superior_get, conf=None): + Service.__init__(self, superior_get, conf=conf) self.pre_construct.extend([pre_construct_pick_redirect_uri, set_state_parameter]) self.post_construct.append(self.store_auth_request) def update_service_context(self, resp, key="", **kwargs): if "expires_in" in resp: resp["__expires_at"] = time_sans_frac() + int(resp["expires_in"]) - self.client_get("service_context").cstate.update(key, resp) + self.superior_get("context").cstate.update(key, resp) def store_auth_request(self, request_args=None, **kwargs): """Store the authorization request in the state DB.""" _key = get_state_parameter(request_args, kwargs) - self.client_get("service_context").cstate.update(_key, request_args) + self.superior_get("context").cstate.update(_key, request_args) return request_args def gather_request_args(self, **kwargs): @@ -66,8 +66,7 @@ def gather_request_args(self, **kwargs): if "redirect_uri" not in ar_args: try: - # _cb = self.client_get("service_context").get_usage("callback_uris") - ar_args["redirect_uri"] = self.client_get("service_context").get_usage( + ar_args["redirect_uri"] = self.superior_get("context").get_usage( "redirect_uris")[0] except (KeyError, AttributeError): raise MissingParameter("redirect_uri") @@ -91,7 +90,7 @@ def post_parse_response(self, response, **kwargs): pass else: if _key: - item = self.client_get("service_context").cstate.get_set( + item = self.superior_get("context").cstate.get_set( _key, message=oauth2.AuthorizationRequest) try: response["scope"] = item["scope"] diff --git a/src/idpyoidc/client/oauth2/client_credentials/cc_access_token.py b/src/idpyoidc/client/oauth2/client_credentials/cc_access_token.py index 1837e180..9f69f4b6 100644 --- a/src/idpyoidc/client/oauth2/client_credentials/cc_access_token.py +++ b/src/idpyoidc/client/oauth2/client_credentials/cc_access_token.py @@ -18,10 +18,10 @@ class CCAccessToken(Service): request_body_type = "urlencoded" response_body_type = "json" - def __init__(self, client_get, conf=None): - Service.__init__(self, client_get, conf=conf) + def __init__(self, superior_get, conf=None): + Service.__init__(self, superior_get, conf=conf) - def update_service_context(self, resp, key: Optional[str] = '', **kwargs): + def update_service_context(self, resp, key: Optional[str] = "cc", **kwargs): if "expires_in" in resp: resp["__expires_at"] = time_sans_frac() + int(resp["expires_in"]) - self.client_get("service_context").cstate.update(key, resp) + self.superior_get("context").cstate.update(key, resp) diff --git a/src/idpyoidc/client/oauth2/client_credentials/cc_refresh_access_token.py b/src/idpyoidc/client/oauth2/client_credentials/cc_refresh_access_token.py index 838fdb9f..111cc684 100644 --- a/src/idpyoidc/client/oauth2/client_credentials/cc_refresh_access_token.py +++ b/src/idpyoidc/client/oauth2/client_credentials/cc_refresh_access_token.py @@ -16,15 +16,15 @@ class CCRefreshAccessToken(Service): default_authn_method = "bearer_header" http_method = "POST" - def __init__(self, client_get, conf=None): - Service.__init__(self, client_get, conf=conf) + def __init__(self, superior_get, conf=None): + Service.__init__(self, superior_get, conf=conf) self.pre_construct.append(self.cc_pre_construct) self.post_construct.append(self.cc_post_construct) def cc_pre_construct(self, request_args=None, **kwargs): _state_id = kwargs.get("state", "cc") parameters = ["refresh_token"] - _current = self.client_get("service_context").cstate + _current = self.superior_get("context").cstate _args = _current.get_set(_state_id, claim=parameters) if request_args is None: @@ -44,7 +44,7 @@ def cc_post_construct(self, request_args, **kwargs): return request_args - def update_service_context(self, resp, key: Optional[str] = "", **kwargs): + def update_service_context(self, resp, key="cc", **kwargs): if "expires_in" in resp: resp["__expires_at"] = time_sans_frac() + int(resp["expires_in"]) - self.client_get("service_context").cstate.update(key, resp) + self.superior_get("context").cstate.update(key, resp) diff --git a/src/idpyoidc/client/oauth2/refresh_access_token.py b/src/idpyoidc/client/oauth2/refresh_access_token.py index 6ba8f986..f3345bfc 100644 --- a/src/idpyoidc/client/oauth2/refresh_access_token.py +++ b/src/idpyoidc/client/oauth2/refresh_access_token.py @@ -23,21 +23,21 @@ class RefreshAccessToken(Service): default_authn_method = "bearer_header" http_method = "POST" - def __init__(self, client_get, conf=None): - Service.__init__(self, client_get, conf=conf) + def __init__(self, superior_get, conf=None): + Service.__init__(self, superior_get, conf=conf) self.pre_construct.append(self.oauth_pre_construct) def update_service_context(self, resp, key: Optional[str] = "", **kwargs): if "expires_in" in resp: resp["__expires_at"] = time_sans_frac() + int(resp["expires_in"]) - self.client_get("service_context").cstate.update(key, resp) + self.superior_get("context").cstate.update(key, resp) def oauth_pre_construct(self, request_args=None, **kwargs): """Preconstructor of request arguments""" _state = get_state_parameter(request_args, kwargs) parameters = list(self.msg_type.c_param.keys()) - _current = self.client_get("service_context").cstate + _current = self.superior_get("context").cstate _args = _current.get_set(_state, claim=parameters) if request_args is None: diff --git a/src/idpyoidc/client/oauth2/server_metadata.py b/src/idpyoidc/client/oauth2/server_metadata.py index 8f2b8929..49da6959 100644 --- a/src/idpyoidc/client/oauth2/server_metadata.py +++ b/src/idpyoidc/client/oauth2/server_metadata.py @@ -24,8 +24,8 @@ class ServerMetadata(Service): _supports = {} - def __init__(self, client_get, conf=None): - Service.__init__(self, client_get, conf=conf) + def __init__(self, superior_get, conf=None): + Service.__init__(self, superior_get, conf=conf) def get_endpoint(self): """ @@ -34,7 +34,7 @@ def get_endpoint(self): :return: Service endpoint """ try: - _iss = self.client_get("service_context").issuer + _iss = self.superior_get("context").issuer except AttributeError: _iss = self.endpoint @@ -69,7 +69,7 @@ def _verify_issuer(self, resp, issuer): # In some cases we can live with the two URLs not being # the same. But this is an excepted that has to be explicit try: - self.client_get("service_context").allow["issuer_mismatch"] + self.superior_get("context").allow["issuer_mismatch"] except KeyError: if _issuer != _pcr_issuer: raise OidcServiceError( @@ -86,7 +86,7 @@ def _set_endpoints(self, resp): # a name ending in '_endpoint' so I can look specifically # for those if key.endswith("_endpoint"): - _srv = self.client_get("service_by_endpoint_name", key) + _srv = self.superior_get("service_by_endpoint_name", key) if _srv: _srv.endpoint = val @@ -99,7 +99,7 @@ def _update_service_context(self, resp): :param service_context: Information collected/used by services """ - _context = self.client_get("service_context") + _context = self.superior_get("context") # Verify that the issuer value received is the same as the # url that was used as service endpoint (without the .well-known part) if "issuer" in resp: diff --git a/src/idpyoidc/client/oauth2/utils.py b/src/idpyoidc/client/oauth2/utils.py index 3214d261..a87d70d0 100644 --- a/src/idpyoidc/client/oauth2/utils.py +++ b/src/idpyoidc/client/oauth2/utils.py @@ -80,8 +80,8 @@ def pre_construct_pick_redirect_uri( request_args: Optional[Union[Message, dict]] = None, service: Optional[Service] = None, **kwargs ): - request_args["redirect_uri"] = pick_redirect_uri(service.client_get("service_context"), - entity=service.client_get("entity"), + request_args["redirect_uri"] = pick_redirect_uri(service.superior_get("context"), + entity=service.superior_get("entity"), request_args=request_args) return request_args, {} diff --git a/src/idpyoidc/client/oidc/__init__.py b/src/idpyoidc/client/oidc/__init__.py index 042a71de..1e264081 100755 --- a/src/idpyoidc/client/oidc/__init__.py +++ b/src/idpyoidc/client/oidc/__init__.py @@ -40,6 +40,26 @@ # This should probably be part of the configuration MAX_AUTHENTICATION_AGE = 86400 +PREFERENCE2PROVIDER = { + # "require_signed_request_object": "request_object_algs_supported", + "request_object_signing_alg": "request_object_signing_alg_values_supported", + "request_object_encryption_alg": "request_object_encryption_alg_values_supported", + "request_object_encryption_enc": "request_object_encryption_enc_values_supported", + "userinfo_signed_response_alg": "userinfo_signing_alg_values_supported", + "userinfo_encrypted_response_alg": "userinfo_encryption_alg_values_supported", + "userinfo_encrypted_response_enc": "userinfo_encryption_enc_values_supported", + "id_token_signed_response_alg": "id_token_signing_alg_values_supported", + "id_token_encrypted_response_alg": "id_token_encryption_alg_values_supported", + "id_token_encrypted_response_enc": "id_token_encryption_enc_values_supported", + "default_acr_values": "acr_values_supported", + "subject_type": "subject_types_supported", + "token_endpoint_auth_method": "token_endpoint_auth_methods_supported", + "token_endpoint_auth_signing_alg": "token_endpoint_auth_signing_alg_values_supported", + "response_types": "response_types_supported", + "grant_types": "grant_types_supported", +} + +PROVIDER2PREFERENCE = dict([(v, k) for k, v in PREFERENCE2PROVIDER.items()]) PROVIDER_DEFAULT = { "token_endpoint_auth_method": "client_secret_basic", @@ -60,12 +80,10 @@ def __init__( httplib=None, services=None, httpc_params=None, + **kwargs ): - if isinstance(config, Configuration): - _srvs = services or config.conf.get("services", DEFAULT_OIDC_SERVICES) - else: - _srvs = services or config.get("services", DEFAULT_OIDC_SERVICES) + _srvs = services or DEFAULT_OIDC_SERVICES oauth2.Client.__init__( self, @@ -75,7 +93,8 @@ def __init__( httplib=httplib, services=_srvs, httpc_params=httpc_params, - client_type="oidc" + client_type="oidc", + **kwargs ) _context = self.get_service_context() @@ -100,7 +119,7 @@ def fetch_distributed_claims(self, userinfo, callback=None): if "access_token" in spec: cauth = BearerHeader() httpc_params = cauth.construct( - service=self.client_get("service", "userinfo"), + service=self.superior_get("service", "userinfo"), access_token=spec["access_token"], ) _resp = self.http.send(spec["endpoint"], "GET", **httpc_params) @@ -109,7 +128,7 @@ def fetch_distributed_claims(self, userinfo, callback=None): token = callback(spec["endpoint"]) cauth = BearerHeader() httpc_params = cauth.construct( - service=self.client_get("service", "userinfo"), access_token=token + service=self.superior_get("service", "userinfo"), access_token=token ) _resp = self.http.send(spec["endpoint"], "GET", **httpc_params) else: diff --git a/src/idpyoidc/client/oidc/access_token.py b/src/idpyoidc/client/oidc/access_token.py index 3c392615..b6980e90 100644 --- a/src/idpyoidc/client/oidc/access_token.py +++ b/src/idpyoidc/client/oidc/access_token.py @@ -29,6 +29,8 @@ class AccessToken(access_token.AccessToken): def __init__(self, client_get, conf: Optional[dict] = None): access_token.AccessToken.__init__(self, client_get, conf=conf) + def __init__(self, superior_get, conf: Optional[dict] = None): + access_token.AccessToken.__init__(self, superior_get, conf=conf) def gather_verify_arguments( self, response: Optional[Union[dict, Message]] = None, behaviour_args: Optional[dict] = None @@ -38,8 +40,8 @@ def gather_verify_arguments( :return: dictionary with arguments to the verify call """ - _context = self.client_get("service_context") - _entity = self.client_get("entity") + _context = self.superior_get("context") + _entity = self.superior_get("entity") kwargs = { "client_id": _entity.get_client_id(), @@ -70,7 +72,7 @@ def gather_verify_arguments( return kwargs def update_service_context(self, resp, key: Optional[str] ="", **kwargs): - _cstate = self.client_get("service_context").cstate + _cstate = self.superior_get("context").cstate try: _idt = resp[verified_claim_name("id_token")] except KeyError: @@ -90,5 +92,7 @@ def update_service_context(self, resp, key: Optional[str] ="", **kwargs): _cstate.update(key, resp) def get_authn_method(self): - _context = self.client_get("service_context") - return _context.get_usage("token_endpoint_auth_method", self.default_authn_method) + try: + return self.superior_get("service_context").behaviour["token_endpoint_auth_method"] + except KeyError: + return self.default_authn_method diff --git a/src/idpyoidc/client/oidc/authorization.py b/src/idpyoidc/client/oidc/authorization.py index e1484fcc..80a4a54d 100644 --- a/src/idpyoidc/client/oidc/authorization.py +++ b/src/idpyoidc/client/oidc/authorization.py @@ -55,8 +55,9 @@ class Authorization(authorization.Authorization): } } - def __init__(self, client_get, conf=None): - authorization.Authorization.__init__(self, client_get, conf=conf) + def __init__(self, superior_get, conf=None): + authorization.Authorization.__init__(self, superior_get, conf=conf) + self.default_request_args = {"scope": ["openid"]} self.pre_construct = [ self.set_state, pre_construct_pick_redirect_uri, @@ -67,7 +68,7 @@ def __init__(self, client_get, conf=None): self.default_request_args['scope'] = ['openid'] def set_state(self, request_args, **kwargs): - _context = self.client_get("service_context") + _context = self.superior_get("context") try: _state = kwargs["state"] except KeyError: @@ -81,14 +82,14 @@ def set_state(self, request_args, **kwargs): return request_args, {} def update_service_context(self, resp, key="", **kwargs): - _context = self.client_get("service_context") + _context = self.superior_get("context") if "expires_in" in resp: resp["__expires_at"] = time_sans_frac() + int(resp["expires_in"]) _context.cstate.update(key, resp) def get_request_from_response(self, response): - _context = self.client_get("service_context") + _context = self.superior_get("service_context") return _context.cstate.get_set(response["state"], message=oauth2.AuthorizationRequest) def post_parse_response(self, response, **kwargs): @@ -98,7 +99,7 @@ def post_parse_response(self, response, **kwargs): if _idt: # If there is a verified ID Token then we have to do nonce # verification. - _req_nonce = self.client_get("service_context").cstate.get_set( + _req_nonce = self.superior_get("context").cstate.get_set( response["state"], claim=['nonce']).get('nonce') if _req_nonce: _id_token_nonce = _idt.get("nonce") @@ -109,8 +110,7 @@ def post_parse_response(self, response, **kwargs): return response def oidc_pre_construct(self, request_args=None, post_args=None, **kwargs): - _context = self.client_get("service_context") - + _context = self.superior_get("context") if request_args is None: request_args = {} @@ -179,7 +179,7 @@ def get_request_object_signing_alg(self, **kwargs): break if not alg: - _context = self.client_get("service_context") + _context = self.superior_get("context") try: alg = _context.work_environment.get_usage("request_object_signing_alg") except KeyError: # Use default @@ -193,8 +193,7 @@ def store_request_on_file(self, req, **kwargs): :param kwargs: Extra keyword arguments :return: The URL the OP should use to access the file """ - _context = self.client_get("service_context") - + _context = self.superior_get("context") _webname = _context.get_usage("request_uris") if _webname is None: filename, _webname = construct_request_uri(**kwargs) @@ -215,7 +214,7 @@ def construct_request_parameter( alg = self.get_request_object_signing_alg(**kwargs) kwargs["request_object_signing_alg"] = alg - _context = self.client_get("service_context") + _context = self.superior_get("context") if "keys" not in kwargs and alg and alg != "none": kwargs["keys"] = _context.keyjar @@ -267,7 +266,7 @@ def oidc_post_construct(self, req, **kwargs): :param kwargs: Extra keyword arguments :return: A possibly modified request. """ - _context = self.client_get("service_context") + _context = self.superior_get("context") if "openid" in req["scope"]: _response_type = req["response_type"][0] if "id_token" in _response_type or "code" in _response_type: @@ -317,7 +316,7 @@ def gather_verify_arguments( :return: dictionary with arguments to the verify call """ - _context = self.client_get("service_context") + _context = self.superior_get("context") kwargs = { "iss": _context.issuer, "keyjar": _context.keyjar, diff --git a/src/idpyoidc/client/oidc/backchannel_authentication.py b/src/idpyoidc/client/oidc/backchannel_authentication.py index 86e09d50..0811e322 100644 --- a/src/idpyoidc/client/oidc/backchannel_authentication.py +++ b/src/idpyoidc/client/oidc/backchannel_authentication.py @@ -17,8 +17,8 @@ class BackChannelAuthentication(Service): service_name = "backchannel_authentication" response_body_type = "json" - def __init__(self, client_get, conf=None, **kwargs): - Service.__init__(self, client_get=client_get, conf=conf, **kwargs) + def __init__(self, superior_get, conf=None, **kwargs): + Service.__init__(self, superior_get=superior_get, conf=conf, **kwargs) self.default_request_args = {"scope": ["openid"]} self.pre_construct = [] self.post_construct = [] @@ -37,8 +37,8 @@ class ClientNotification(Service): response_body_type = "" http_method = "POST" - def __init__(self, client_get, conf=None, **kwargs): - Service.__init__(self, client_get=client_get, conf=conf, **kwargs) + def __init__(self, superior_get, conf=None, **kwargs): + Service.__init__(self, superior_get=superior_get, conf=conf, **kwargs) self.pre_construct = [] self.post_construct = [] diff --git a/src/idpyoidc/client/oidc/check_id.py b/src/idpyoidc/client/oidc/check_id.py index 6c3973cd..712972f5 100644 --- a/src/idpyoidc/client/oidc/check_id.py +++ b/src/idpyoidc/client/oidc/check_id.py @@ -19,12 +19,12 @@ class CheckID(Service): synchronous = True service_name = "check_id" - def __init__(self, client_get, conf=None): - Service.__init__(self, client_get, conf=conf) + def __init__(self, superior_get, conf=None): + Service.__init__(self, superior_get, conf=conf) self.pre_construct = [self.oidc_pre_construct] def oidc_pre_construct(self, request_args: Optional[dict]=None, **kwargs): - _args = self.client_get("service_context").cstate.get_set( + _args = self.superior_get("context").cstate.get_set( kwargs["state"], claim=["id_token"] ) diff --git a/src/idpyoidc/client/oidc/check_session.py b/src/idpyoidc/client/oidc/check_session.py index 525744fb..422142fc 100644 --- a/src/idpyoidc/client/oidc/check_session.py +++ b/src/idpyoidc/client/oidc/check_session.py @@ -18,12 +18,12 @@ class CheckSession(Service): synchronous = True service_name = "check_session" - def __init__(self, client_get, conf=None): - Service.__init__(self, client_get, conf=conf) + def __init__(self, superior_get, conf=None): + Service.__init__(self, superior_get, conf=conf) self.pre_construct = [self.oidc_pre_construct] def oidc_pre_construct(self, request_args=None, **kwargs): - _args = self.client_get("service_context").cstate.get_set(kwargs["state"], + _args = self.superior_get("context").cstate.get_set(kwargs["state"], claim=["id_token"]) if request_args: request_args.update(_args) diff --git a/src/idpyoidc/client/oidc/provider_info_discovery.py b/src/idpyoidc/client/oidc/provider_info_discovery.py index 50723c24..2343d235 100644 --- a/src/idpyoidc/client/oidc/provider_info_discovery.py +++ b/src/idpyoidc/client/oidc/provider_info_discovery.py @@ -25,7 +25,7 @@ def add_redirect_uris(request_args, service=None, **kwargs): :param kwargs: Possible extra keyword arguments :return: A possibly augmented set of request arguments. """ - _work_environment = service.client_get("service_context").work_environment + _work_environment = service.superior_get("context").work_environment if "redirect_uris" not in request_args: # Callbacks is a dictionary with callback type 'code', 'implicit', # 'form_post' as keys. @@ -49,12 +49,12 @@ class ProviderInfoDiscovery(server_metadata.ServerMetadata): _supports = {} - def __init__(self, client_get, conf=None): - server_metadata.ServerMetadata.__init__(self, client_get, conf=conf) + def __init__(self, superior_get, conf=None): + server_metadata.ServerMetadata.__init__(self, superior_get, conf=conf) - def update_service_context(self, resp, key, **kwargs): - _context = self.client_get("service_context") - self._update_service_context(resp) # set endpoints and import keys + def update_service_context(self, resp, **kwargs): + _context = self.superior_get("context") + self._update_service_context(resp) _context.map_supported_to_preferred(resp) if "pre_load_keys" in self.conf and self.conf["pre_load_keys"]: _jwks = _context.keyjar.export_jwks_as_json(issuer=resp["issuer"]) @@ -73,7 +73,7 @@ def match_preferences(self, pcr=None, issuer=None): :param pcr: Provider configuration response if available :param issuer: The issuer identifier """ - _context = self.client_get("service_context") + _context = self.superior_get("context") if not pcr: pcr = _context.provider_info diff --git a/src/idpyoidc/client/oidc/read_registration.py b/src/idpyoidc/client/oidc/read_registration.py index 252b9520..a105fed5 100644 --- a/src/idpyoidc/client/oidc/read_registration.py +++ b/src/idpyoidc/client/oidc/read_registration.py @@ -19,7 +19,7 @@ class RegistrationRead(Service): def get_endpoint(self): try: - return self.client_get("service_context").registration_response[ + return self.superior_get("context").registration_response[ "registration_client_uri" ] except KeyError: @@ -40,7 +40,7 @@ def get_authn_header(self, request, authn_method, **kwargs): if authn_method == "client_secret_basic": LOGGER.debug("Client authn method: %s", authn_method) headers["Authorization"] = "Bearer {}".format( - self.client_get("service_context").registration_response[ + self.superior_get("context").registration_response[ "registration_access_token" ] ) diff --git a/src/idpyoidc/client/oidc/refresh_access_token.py b/src/idpyoidc/client/oidc/refresh_access_token.py index 85274d93..0b209ff2 100644 --- a/src/idpyoidc/client/oidc/refresh_access_token.py +++ b/src/idpyoidc/client/oidc/refresh_access_token.py @@ -8,7 +8,7 @@ class RefreshAccessToken(refresh_access_token.RefreshAccessToken): error_msg = oidc.ResponseMessage def get_authn_method(self): - _work_environment = self.client_get("service_context").work_environment + _work_environment = self.superior_get("context").work_environment try: return _work_environment.get_usage("token_endpoint_auth_method") except KeyError: diff --git a/src/idpyoidc/client/oidc/registration.py b/src/idpyoidc/client/oidc/registration.py index 1ddeacd4..bcce75d5 100644 --- a/src/idpyoidc/client/oidc/registration.py +++ b/src/idpyoidc/client/oidc/registration.py @@ -29,7 +29,7 @@ def __init__(self, client_get, conf=None): self.post_construct = [self.oidc_post_construct] def add_client_preference(self, request_args=None, **kwargs): - _context = self.client_get("service_context") + _context = self.superior_get("context") _use = _context.map_preferred_to_registered() for prop, spec in self.msg_type.c_param.items(): if prop in request_args: @@ -64,7 +64,7 @@ def update_service_context(self, resp, key="", **kwargs): # if "token_endpoint_auth_method" not in resp: # resp["token_endpoint_auth_method"] = "client_secret_basic" - _context = self.client_get("service_context") + _context = self.superior_get("context") _context.map_preferred_to_registered(resp) _keyjar = _context.keyjar diff --git a/src/idpyoidc/client/oidc/userinfo.py b/src/idpyoidc/client/oidc/userinfo.py index 5c54b806..a6dbe231 100644 --- a/src/idpyoidc/client/oidc/userinfo.py +++ b/src/idpyoidc/client/oidc/userinfo.py @@ -49,8 +49,8 @@ class UserInfo(Service): "encrypt_userinfo_supported": None } - def __init__(self, client_get, conf=None): - Service.__init__(self, client_get, conf=conf) + def __init__(self, superior_get, conf=None): + Service.__init__(self, superior_get, conf=conf) self.pre_construct = [self.oidc_pre_construct, carry_state] def oidc_pre_construct(self, request_args=None, **kwargs): @@ -60,7 +60,7 @@ def oidc_pre_construct(self, request_args=None, **kwargs): if "access_token" in request_args: pass else: - request_args = self.client_get("service_context").cstate.get_set( + request_args = self.superior_get("context").cstate.get_set( kwargs["state"], claim=["access_token"] ) @@ -68,7 +68,7 @@ def oidc_pre_construct(self, request_args=None, **kwargs): return request_args, {} def post_parse_response(self, response, **kwargs): - _context = self.client_get("service_context") + _context = self.superior_get("context") _current = _context.cstate _args = _current.get_set(kwargs["state"], claim=[verified_claim_name("id_token")]) @@ -118,7 +118,7 @@ def gather_verify_arguments( :return: dictionary with arguments to the verify call """ - _context = self.client_get("service_context") + _context = self.superior_get("context") kwargs = { "client_id": _context.get_client_id(), "iss": _context.issuer, diff --git a/src/idpyoidc/client/oidc/webfinger.py b/src/idpyoidc/client/oidc/webfinger.py index ddfba9ee..fe4782b2 100644 --- a/src/idpyoidc/client/oidc/webfinger.py +++ b/src/idpyoidc/client/oidc/webfinger.py @@ -35,8 +35,8 @@ class WebFinger(Service): http_method = "GET" response_body_type = "json" - def __init__(self, client_get, conf=None, rel="", **kwargs): - Service.__init__(self, client_get, conf=conf, **kwargs) + def __init__(self, superior_get, conf=None, rel="", **kwargs): + Service.__init__(self, superior_get, conf=conf, **kwargs) self.rel = rel or OIC_ISSUER @@ -55,7 +55,7 @@ def update_service_context(self, resp, key="", **kwargs): if _href.startswith("http://") and not _http_allowed: raise ValueError("http link not allowed ({})".format(_href)) - self.client_get("service_context").issuer = link["href"] + self.superior_get("context").issuer = link["href"] break return resp @@ -150,7 +150,7 @@ def get_request_parameters(self, request_args=None, **kwargs): _resource = kwargs["resource"] except KeyError: try: - _resource = self.client_get("service_context").config["resource"] + _resource = self.superior_get("context").config["resource"] except KeyError: raise MissingRequiredAttribute("resource") diff --git a/src/idpyoidc/client/rp_handler.py b/src/idpyoidc/client/rp_handler.py index 0398261b..57d22906 100644 --- a/src/idpyoidc/client/rp_handler.py +++ b/src/idpyoidc/client/rp_handler.py @@ -126,7 +126,7 @@ def state2issuer(self, state): :return: An Issuer ID """ for _rp in self.issuer2rp.values(): - _iss = _rp.client_get("service_context").cstate.get_set( + _iss = _rp.superior_get("context").cstate.get_set( state, claim=['iss']).get('iss') if _iss: return _iss @@ -154,7 +154,7 @@ def get_session_information(self, key, client=None): if not client: client = self.get_client_from_session_key(key) - return client.client_get("service_context").cstate.get(key) + return client.superior_get("context").cstate.get(key) def init_client(self, issuer): """ @@ -197,7 +197,7 @@ def init_client(self, issuer): logger.error(message) raise - _context = client.client_get("service_context") + _context = client.superior_get("context") if _context.iss_hash: self.hash2issuer[_context.iss_hash] = issuer # If non persistent @@ -232,7 +232,7 @@ def do_provider_info( else: raise ValueError("Missing state/session key") - _context = client.client_get("service_context") + _context = client.superior_get("context") if not _context.get("provider_info"): dynamic_provider_info_discovery(client, behaviour_args=behaviour_args) return _context.get("provider_info")["issuer"] @@ -243,7 +243,7 @@ def do_provider_info( # a name ending in '_endpoint' so I can look specifically # for those if key.endswith("_endpoint"): - for _srv in client.client_get("services").values(): + for _srv in client.superior_get("services").values(): # Every service has an endpoint_name assigned # when initiated. This name *MUST* match the # endpoint names used in the provider info @@ -299,7 +299,7 @@ def do_client_registration( else: raise ValueError("Missing state/session key") - _context = client.client_get("service_context") + _context = client.superior_get("context") _iss = _context.get("issuer") self.hash2issuer[iss_id] = _iss @@ -421,8 +421,8 @@ def init_authorization( else: raise ValueError("Missing state/session key") - _context = client.client_get("service_context") - _entity = client.client_get("entity") + _context = client.superior_get("context") + _entity = client.superior_get("entity") _nonce = rndstr(24) _response_type = self._get_response_type(_context, req_args) request_args = { @@ -518,7 +518,7 @@ def get_client_authn_method(client, endpoint): :return: The client authentication method """ if endpoint == "token_endpoint": - am = client.client_get("service_context").get_usage("token_endpoint_auth_method") + am = client.superior_get("context").get_usage("token_endpoint_auth_method") if not am: return "" else: @@ -542,7 +542,7 @@ def get_tokens(self, state, client: Optional[Client] = None): if client is None: client = self.get_client_from_session_key(state) - _context = client.client_get("service_context") + _context = client.superior_get("context") _claims = _context.cstate.get_set(state, claim=['code', 'redirect_uri']) req_args = { @@ -628,8 +628,7 @@ def get_user_info(self, state, client=None, access_token="", **kwargs): client = self.get_client_from_session_key(state) if not access_token: - _arg = client.client_get("service_context").cstate.get_set(state, - claim=["access_token"]) + _arg = client.superior_get("context").cstate.get_set(state, claim=["access_token"]) access_token = _arg["access_token"] request_args = {"access_token": access_token} @@ -685,7 +684,7 @@ def finalize_auth( if is_error_message(authorization_response): return authorization_response - _context = client.client_get("service_context") + _context = client.superior_get("context") try: _iss = _context.cstate.get_set( authorization_response["state"], claim=['iss']).get('iss') @@ -726,7 +725,7 @@ def get_access_and_id_token( if client is None: client = self.get_client_from_session_key(state) - _context = client.client_get("service_context") + _context = client.superior_get("context") resp_attr = authorization_response or _context.cstate.get_set(state, message=AuthorizationResponse) @@ -812,7 +811,7 @@ def finalize(self, issuer, response, behaviour_args: Optional[dict] = None): _id_token = token.get("id_token") logger.debug(f"ID Token: {_id_token}") - if client.client_get("service", "userinfo") and token["access_token"]: + if client.superior_get("service", "userinfo") and token["access_token"]: inforesp = self.get_user_info( state=authorization_response["state"], client=client, @@ -829,7 +828,7 @@ def finalize(self, issuer, response, behaviour_args: Optional[dict] = None): logger.debug("UserInfo: %s", inforesp) - _context = client.client_get("service_context") + _context = client.superior_get("context") try: _sid_support = _context.get("provider_info")["backchannel_logout_session_required"] except KeyError: @@ -872,7 +871,7 @@ def has_active_authentication(self, state): client = self.get_client_from_session_key(state) # Look for an IdToken - _arg = client.client_get("service_context").cstate.get_set(state, + _arg = client.superior_get("context").cstate.get_set(state, claim=["__verified_id_token"]) if _arg: @@ -896,7 +895,7 @@ def get_valid_access_token(self, state): now = utc_time_sans_frac() client = self.get_client_from_session_key(state) - _context = client.client_get("service_context") + _context = client.superior_get("context") _args = _context.cstate.get_set(state, claim=["access_token", "__expires_at"]) if "access_token" in _args: access_token = _args["access_token"] @@ -938,7 +937,7 @@ def logout( client = self.get_client_from_session_key(state) try: - srv = client.client_get("service", "end_session") + srv = client.superior_get("service", "end_session") except KeyError: raise OidcServiceError("Does not know how to logout") @@ -970,7 +969,7 @@ def close( def clear_session(self, state): client = self.get_client_from_session_key(state) - client.client_get("service_context").cstate.remove_state(state) + client.superior_get("context").cstate.remove_state(state) def backchannel_logout(client, request="", request_args=None): @@ -986,7 +985,7 @@ def backchannel_logout(client, request="", request_args=None): else: raise MissingRequiredAttribute("logout_token") - _context = client.client_get("service_context") + _context = client.superior_get("context") kwargs = { "aud": client.get_client_id(), "iss": _context.get("issuer"), @@ -1030,7 +1029,7 @@ def load_registration_response(client, request_args=None): :param client: A :py:class:`idpyoidc.client.oidc.Client` instance """ - if not client.client_get("service_context").get_client_id(): + if not client.superior_get("context").get_client_id(): try: response = client.do_request("registration", request_args=request_args) except KeyError: diff --git a/src/idpyoidc/client/service.py b/src/idpyoidc/client/service.py index 8357038f..ed67491e 100644 --- a/src/idpyoidc/client/service.py +++ b/src/idpyoidc/client/service.py @@ -63,20 +63,17 @@ class Service(ImpExp): "response_cls": object, } - init_args = ["client_get"] + init_args = ["superior_get"] _supports = {} _callback_path = {} def __init__( - self, - client_get: Callable, - conf: Optional[Union[dict, Configuration]] = None, - **kwargs + self, superior_get: Callable, conf: Optional[Union[dict, Configuration]] = None, **kwargs ): ImpExp.__init__(self) - self.client_get = client_get + self.superior_get = superior_get self.default_request_args = {} if conf: @@ -118,7 +115,7 @@ def gather_request_args(self, **kwargs): """ ar_args = kwargs.copy() - _context = self.client_get("service_context") + _context = self.superior_get("context") _use = _context.collect_usage() if not _use: _use = _context.map_preferred_to_registered() @@ -269,7 +266,7 @@ def init_authentication_method(self, request, authn_method, http_args=None, **kw if authn_method: LOGGER.debug("Client authn method: %s", authn_method) - _context = self.client_get("service_context") + _context = self.superior_get("context") try: _func = _context.client_authn_method[authn_method] except KeyError: # not one of the common @@ -305,7 +302,7 @@ def get_endpoint(self): if self.endpoint: return self.endpoint - return self.client_get("service_context").provider_info[self.endpoint_name] + return self.superior_get("context").provider_info[self.endpoint_name] def get_authn_header( self, request: Union[dict, Message], authn_method: Optional[str] = "", **kwargs @@ -362,7 +359,7 @@ def get_headers( for meth in self.construct_extra_headers: _headers = meth( - self.client_get("service_context"), + self.superior_get("context"), headers=_headers, request=request, authn_method=authn_method, @@ -409,7 +406,7 @@ def get_request_parameters( _info = {"method": method, "request": request} _args = kwargs.copy() - _context = self.client_get("service_context") + _context = self.superior_get("context") if _context.issuer: _args["iss"] = _context.issuer @@ -485,7 +482,7 @@ def gather_verify_arguments( :return: dictionary with arguments to the verify call """ - _context = self.client_get("service_context") + _context = self.superior_get("context") kwargs = { "iss": _context.issuer, "keyjar": _context.keyjar, @@ -500,7 +497,7 @@ def gather_verify_arguments( return kwargs def _do_jwt(self, info): - _context = self.client_get("service_context") + _context = self.superior_get("context") args = {"allowed_sign_algs": _context.get_sign_alg(self.service_name)} enc_algs = _context.get_enc_alg_enc(self.service_name) args["allowed_enc_algs"] = enc_algs["alg"] @@ -510,7 +507,7 @@ def _do_jwt(self, info): return _jwt.unpack(info) def _do_response(self, info, sformat, **kwargs): - _context = self.client_get("service_context") + _context = self.superior_get("context") try: resp = self.response_cls().deserialize(info, sformat, iss=_context.issuer, **kwargs) @@ -662,12 +659,12 @@ def callback_uris(self): return list(self._callback_path.keys()) -def init_services(service_definitions, client_get): +def init_services(service_definitions, superior_get): """ Initiates a set of services :param service_definitions: A dictionary containing service definitions - :param client_get: A function that returns different things from the base entity. + :param superior_get: A function that returns different things from the base entity. :return: A dictionary, with service name as key and the service instance as value. """ @@ -678,7 +675,7 @@ def init_services(service_definitions, client_get): except KeyError: kwargs = {} - kwargs.update({"client_get": client_get}) + kwargs.update({"superior_get": superior_get}) if isinstance(service_configuration["class"], str): _cls = importer(service_configuration["class"]) diff --git a/src/idpyoidc/context.py b/src/idpyoidc/context.py index e0fa61b5..de558b2b 100644 --- a/src/idpyoidc/context.py +++ b/src/idpyoidc/context.py @@ -1,6 +1,12 @@ import copy +from typing import Optional +from typing import Union from urllib.parse import quote_plus +from cryptojwt import KeyJar +from cryptojwt.key_jar import init_key_jar +from idpyoidc.configure import Configuration + from idpyoidc.impexp import ImpExp diff --git a/src/idpyoidc/server/__init__.py b/src/idpyoidc/server/__init__.py index df93fd21..14887b00 100644 --- a/src/idpyoidc/server/__init__.py +++ b/src/idpyoidc/server/__init__.py @@ -1,12 +1,14 @@ # Server specific defaults and a basic Server class import logging from typing import Any +from typing import Callable from typing import Optional from typing import Union from cryptojwt import KeyJar from idpyoidc.impexp import ImpExp +from idpyoidc.message.oidc import RegistrationRequest from idpyoidc.server import authz from idpyoidc.server.client_authn import client_auth_setup from idpyoidc.server.configure import ASConfiguration @@ -33,7 +35,7 @@ def do_endpoints(conf, server_get): class Server(ImpExp): - parameter = {"endpoint": [Endpoint], "endpoint_context": EndpointContext} + parameter = {"endpoint": [Endpoint], "context": EndpointContext} def __init__( self, @@ -42,13 +44,10 @@ def __init__( cwd: Optional[str] = "", cookie_handler: Optional[Any] = None, httpc: Optional[Any] = None, + parent_get: Optional[Callable] = None ): ImpExp.__init__(self) self.conf = conf - - self.endpoint = do_endpoints(conf, self.server_get) - - # endpoint context MUST be done after do_endpoints !! self.endpoint_context = EndpointContext( conf=conf, server_get=self.server_get, @@ -57,13 +56,15 @@ def __init__( cookie_handler=cookie_handler, httpc=httpc, ) - self.endpoint_context.set_provider_info() + self.parent_get = parent_get self.endpoint_context.authz = self.setup_authz() - # _cap = get_provider_capabilities(conf, self.endpoint) - # self.endpoint_context.provider_info = self.endpoint_context.create_providerinfo(_cap) self.setup_authentication(self.endpoint_context) + self.endpoint = do_endpoints(conf, self.server_get) + _cap = get_provider_capabilities(conf, self.endpoint) + + self.endpoint_context.provider_info = self.endpoint_context.create_providerinfo(_cap) self.endpoint_context.do_add_on(endpoints=self.endpoint) self.endpoint_context.session_manager = create_session_manager( @@ -113,6 +114,9 @@ def get_endpoint(self, endpoint_name, *arg): def get_endpoint_context(self, *arg): return self.endpoint_context + def get_server(self, *args): + return self + def setup_authz(self): authz_spec = self.conf.get("authz") if authz_spec: diff --git a/src/idpyoidc/server/authz/__init__.py b/src/idpyoidc/server/authz/__init__.py index b2de74b3..2ae4bf9c 100755 --- a/src/idpyoidc/server/authz/__init__.py +++ b/src/idpyoidc/server/authz/__init__.py @@ -29,7 +29,7 @@ def usage_rules(self, client_id: Optional[str] = ""): return _usage_rules try: - _per_client = self.server_get("endpoint_context").cdb[client_id]["token_usage_rules"] + _per_client = self.server_get("context").cdb[client_id]["token_usage_rules"] except KeyError: pass else: @@ -61,7 +61,7 @@ def __call__( request: Union[dict, Message], resources: Optional[list] = None, ) -> Grant: - session_info = self.server_get("endpoint_context").session_manager.get_session_info( + session_info = self.server_get("context").session_manager.get_session_info( session_id=session_id, grant=True ) grant = session_info["grant"] @@ -86,7 +86,7 @@ def __call__( if not scopes: scopes = request.get("scope", []) grant.scope = scopes - grant.claims = self.server_get("endpoint_context").claims_interface.get_claims_all_usage( + grant.claims = self.server_get("context").claims_interface.get_claims_all_usage( session_id=session_id, scopes=scopes ) @@ -101,7 +101,7 @@ def __call__( resources: Optional[list] = None, ) -> Grant: args = self.grant_config.copy() - grant = self.server_get("endpoint_context").session_manager.get_grant(session_id=session_id) + grant = self.server_get("context").session_manager.get_grant(session_id=session_id) for arg, val in args: setattr(grant, arg, val) return grant diff --git a/src/idpyoidc/server/client_authn.py b/src/idpyoidc/server/client_authn.py index c36168a4..26d23731 100755 --- a/src/idpyoidc/server/client_authn.py +++ b/src/idpyoidc/server/client_authn.py @@ -74,7 +74,7 @@ def verify( :return: """ res = self._verify( - self.server_get("endpoint_context"), + self.server_get("context"), request=request, authorization_token=authorization_token, endpoint=endpoint, diff --git a/src/idpyoidc/server/configure.py b/src/idpyoidc/server/configure.py index 61538603..0feff177 100755 --- a/src/idpyoidc/server/configure.py +++ b/src/idpyoidc/server/configure.py @@ -2,6 +2,7 @@ import copy import logging import os +from typing import Callable from typing import Dict from typing import List from typing import Optional @@ -171,6 +172,7 @@ def __init__( port: Optional[int] = 0, file_attributes: Optional[List[str]] = None, dir_attributes: Optional[List[str]] = None, + superior_get: Optional[Callable] = None ): conf = copy.deepcopy(conf) diff --git a/src/idpyoidc/server/endpoint.py b/src/idpyoidc/server/endpoint.py index b84db51d..f41bc91e 100755 --- a/src/idpyoidc/server/endpoint.py +++ b/src/idpyoidc/server/endpoint.py @@ -160,7 +160,7 @@ def parse_request( LOGGER.debug("- {} -".format(self.endpoint_name)) LOGGER.info("Request: %s" % sanitize(request)) - _context = self.server_get("endpoint_context") + _context = self.server_get("context") if http_info is None: http_info = {} @@ -174,7 +174,7 @@ def parse_request( req = _cls_inst.deserialize( request, "jwt", - keyjar=_context.keyjar, + keyjar=self.server_get("keyjar"), verify=_context.httpc_params["verify"], **kwargs ) @@ -196,7 +196,7 @@ def parse_request( else: _client_id = req.get("client_id") - keyjar = _context.keyjar + keyjar = self.server_get("keyjar") # verify that the request message is correct try: @@ -237,7 +237,7 @@ def client_authentication(self, request: Message, http_info: Optional[dict] = No kwargs["get_client_id_from_token"] = getattr(self, "get_client_id_from_token", None) authn_info = verify_client( - endpoint_context=self.server_get("endpoint_context"), + endpoint_context=self.server_get("context"), request=request, http_info=http_info, **kwargs @@ -254,7 +254,7 @@ def client_authentication(self, request: Message, http_info: Optional[dict] = No def do_post_parse_request( self, request: Message, client_id: Optional[str] = "", **kwargs ) -> Message: - _context = self.server_get("endpoint_context") + _context = self.server_get("context") for meth in self.post_parse_request: if isinstance(request, self.error_cls): break @@ -264,7 +264,7 @@ def do_post_parse_request( def do_pre_construct( self, response_args: dict, request: Optional[Union[Message, dict]] = None, **kwargs ) -> dict: - _context = self.server_get("endpoint_context") + _context = self.server_get("context") for meth in self.pre_construct: response_args = meth(response_args, request, endpoint_context=_context, **kwargs) @@ -276,7 +276,7 @@ def do_post_construct( request: Optional[Union[Message, dict]] = None, **kwargs ) -> dict: - _context = self.server_get("endpoint_context") + _context = self.server_get("context") for meth in self.post_construct: response_args = meth(response_args, request, endpoint_context=_context, **kwargs) @@ -435,7 +435,7 @@ def do_response( def allowed_target_uris(self): res = [] - _context = self.server_get("endpoint_context") + _context = self.server_get("context") for t in self.allowed_targets: if t == "": res.append(_context.issuer) diff --git a/src/idpyoidc/server/endpoint_context.py b/src/idpyoidc/server/endpoint_context.py index 4ccf2f0c..2c450f32 100755 --- a/src/idpyoidc/server/endpoint_context.py +++ b/src/idpyoidc/server/endpoint_context.py @@ -123,8 +123,10 @@ def __init__( cookie_handler: Optional[Any] = None, httpc: Optional[Any] = None, server_type: Optional[str] = '' + entity_id: Optional[str] = "" ): - OidcContext.__init__(self, conf, entity_id=conf.get("issuer", "")) + _id = entity_id or conf.get("issuer", "") + OidcContext.__init__(self, conf, entity_id=_id) self.conf = conf self.server_get = server_get @@ -245,7 +247,9 @@ def __init__( self.set_scopes_handler() self.dev_auth_db = None - self.claims_interface = init_service(conf["claims_interface"], self.server_get) + _interface = conf.get("claims_interface") + if _interface: + self.claims_interface = init_service(_interface, self.server_get) if isinstance(conf, OPConfiguration): self.keyjar = self.work_environment.load_conf(conf.conf, supports=self.supports(), diff --git a/src/idpyoidc/server/oauth2/add_on/dpop.py b/src/idpyoidc/server/oauth2/add_on/dpop.py index 5e1aef16..84ef7d84 100644 --- a/src/idpyoidc/server/oauth2/add_on/dpop.py +++ b/src/idpyoidc/server/oauth2/add_on/dpop.py @@ -142,11 +142,11 @@ def add_support(endpoint, **kwargs): if not _algs_supported: _algs_supported = ["RS256"] - _token_endp.server_get("endpoint_context").provider_info[ + _token_endp.server_get("context").provider_info[ "dpop_signing_alg_values_supported" ] = _algs_supported - _endpoint_context = _token_endp.server_get("endpoint_context") + _endpoint_context = _token_endp.server_get("context") _endpoint_context.dpop_enabled = True @@ -163,7 +163,7 @@ def is_usable(self, request=None, authorization_info=None, http_headers=None): def verify(self, authorization_info, **kwargs): client_info = basic_authn(authorization_info) - _context = self.server_get("endpoint_context") + _context = self.server_get("context") if _context.cdb[client_info["id"]]["client_secret"] == client_info["secret"]: return {"client_id": client_info["id"]} else: diff --git a/src/idpyoidc/server/oauth2/add_on/extra_args.py b/src/idpyoidc/server/oauth2/add_on/extra_args.py index 11132df5..dba819b4 100644 --- a/src/idpyoidc/server/oauth2/add_on/extra_args.py +++ b/src/idpyoidc/server/oauth2/add_on/extra_args.py @@ -47,5 +47,5 @@ def add_support(endpoint, **kwargs): _endp.pre_construct.append(pre_construct) if _added is False: - _endp.server_get("endpoint_context").add_on["extra_args"] = kwargs + _endp.server_get("context").add_on["extra_args"] = kwargs _added = True diff --git a/src/idpyoidc/server/oauth2/authorization.py b/src/idpyoidc/server/oauth2/authorization.py index d7c6ddb6..3ae7991d 100755 --- a/src/idpyoidc/server/oauth2/authorization.py +++ b/src/idpyoidc/server/oauth2/authorization.py @@ -389,7 +389,7 @@ def mint_token(self, token_class, grant, session_id, based_on=None, **kwargs): usage_rules = grant.usage_rules.get(token_class, {}) token = grant.mint_token( session_id=session_id, - endpoint_context=self.server_get("endpoint_context"), + endpoint_context=self.server_get("context"), token_class=token_class, based_on=based_on, usage_rules=usage_rules, @@ -402,7 +402,7 @@ def mint_token(self, token_class, grant, session_id, based_on=None, **kwargs): if _exp_in: token.expires_at = utc_time_sans_frac() + _exp_in - _mngr = self.server_get("endpoint_context").session_manager + _mngr = self.server_get("context").session_manager _mngr.set(_mngr.unpack_session_key(session_id), grant) return token @@ -555,7 +555,7 @@ def _enforce_resource_indicators_policy(self, request, config): return self.error_cls(error="server_error", error_description="Internal server error") def pick_authn_method(self, request, redirect_uri, acr=None, **kwargs): - _context = self.server_get("endpoint_context") + _context = self.server_get("context") auth_id = kwargs.get("auth_method_id") if auth_id: return _context.authn_broker[auth_id] @@ -579,7 +579,7 @@ def pick_authn_method(self, request, redirect_uri, acr=None, **kwargs): } def create_session(self, request, user_id, acr, time_stamp, authn_method): - _context = self.server_get("endpoint_context") + _context = self.server_get("context") _mngr = _context.session_manager authn_event = create_authn_event( user_id, @@ -657,7 +657,7 @@ def setup_auth( authn_class_ref = res["acr"] client_id = request.get("client_id") - _context = self.server_get("endpoint_context") + _context = self.server_get("context") try: _auth_info = kwargs.get("authn", "") if "upm_answer" in request and request["upm_answer"] == "true": @@ -837,7 +837,7 @@ def create_authn_response(self, request: Union[dict, Message], sid: str) -> dict if "response_type" in request and request["response_type"] == ["none"]: fragment_enc = False else: - _context = self.server_get("endpoint_context") + _context = self.server_get("context") _mngr = _context.session_manager _sinfo = _mngr.get_session_info(sid, grant=True) @@ -944,7 +944,7 @@ def post_authentication(self, request: Union[dict, Message], session_id: str, ** """ response_info = {} - _context = self.server_get("endpoint_context") + _context = self.server_get("context") _mngr = _context.session_manager # Do the authorization @@ -1013,7 +1013,7 @@ def authz_part2(self, request, session_id, **kwargs): except Exception as err: return self.error_by_response_mode({}, request, "server_error", err) - _context = self.server_get("endpoint_context") + _context = self.server_get("context") logger.debug(f"resp_info: {resp_info}") @@ -1089,7 +1089,7 @@ def process_request( return request _cid = request["client_id"] - _context = self.server_get("endpoint_context") + _context = self.server_get("context") cinfo = _context.cdb[_cid] # logger.debug("client {}: {}".format(_cid, cinfo)) diff --git a/src/idpyoidc/server/oauth2/introspection.py b/src/idpyoidc/server/oauth2/introspection.py index 11b29cca..75b043d9 100644 --- a/src/idpyoidc/server/oauth2/introspection.py +++ b/src/idpyoidc/server/oauth2/introspection.py @@ -52,7 +52,7 @@ def _introspect(self, token, client_id, grant): if not aud: aud = grant.resources - _context = self.server_get("endpoint_context") + _context = self.server_get("context") ret = { "active": True, "scope": " ".join(scope), @@ -98,7 +98,7 @@ def process_request(self, request=None, release: Optional[list] = None, **kwargs request_token = _introspect_request["token"] _resp = self.response_cls(active=False) - _context = self.server_get("endpoint_context") + _context = self.server_get("context") try: _session_info = _context.session_manager.get_session_info_by_token( diff --git a/src/idpyoidc/server/oauth2/pushed_authorization.py b/src/idpyoidc/server/oauth2/pushed_authorization.py index c8aa10d5..5a6dd7fe 100644 --- a/src/idpyoidc/server/oauth2/pushed_authorization.py +++ b/src/idpyoidc/server/oauth2/pushed_authorization.py @@ -29,7 +29,7 @@ def process_request(self, request=None, **kwargs): # create URN _urn = "urn:uuid:{}".format(uuid.uuid4()) - self.server_get("endpoint_context").par_db[_urn] = request + self.server_get("context").par_db[_urn] = request return { "http_response": {"request_uri": _urn, "expires_in": self.ttl}, diff --git a/src/idpyoidc/server/oauth2/token.py b/src/idpyoidc/server/oauth2/token.py index e0a77196..652ec463 100755 --- a/src/idpyoidc/server/oauth2/token.py +++ b/src/idpyoidc/server/oauth2/token.py @@ -86,7 +86,7 @@ def _post_parse_request( ): grant_type = request["grant_type"] _helper = self.helper.get(grant_type) - client = kwargs["endpoint_context"].cdb[client_id] + client = kwargs["context"].cdb[client_id] grant_types_supported = client.get("grant_types_supported", self.grant_types_supported) if grant_type not in grant_types_supported: return self.error_cls( @@ -132,7 +132,7 @@ def process_request(self, request: Optional[Union[Message, dict]] = None, **kwar return response_args _access_token = response_args["access_token"] - _context = self.server_get("endpoint_context") + _context = self.server_get("context") if isinstance(_helper, TokenExchangeHelper): _handler_key = _helper.get_handler_key(request, _context) diff --git a/src/idpyoidc/server/oauth2/token_helper.py b/src/idpyoidc/server/oauth2/token_helper.py index 8b17850f..576a06d7 100755 --- a/src/idpyoidc/server/oauth2/token_helper.py +++ b/src/idpyoidc/server/oauth2/token_helper.py @@ -62,7 +62,7 @@ def _mint_token( token_args: Optional[dict] = None, token_type: Optional[str] = "", ) -> SessionToken: - _context = self.endpoint.server_get("endpoint_context") + _context = self.endpoint.server_get("context") _mngr = _context.session_manager usage_rules = grant.usage_rules.get(token_class) if usage_rules: @@ -159,7 +159,7 @@ def process_request(self, req: Union[Message, dict], **kwargs): :param kwargs: :return: """ - _context = self.endpoint.server_get("endpoint_context") + _context = self.endpoint.server_get("context") _mngr = _context.session_manager logger.debug("Access Token") @@ -310,7 +310,7 @@ def post_parse_request( :returns: """ - _mngr = self.endpoint.server_get("endpoint_context").session_manager + _mngr = self.endpoint.server_get("context").session_manager try: _session_info = _mngr.get_session_info_by_token( request["code"], grant=True, handler_key="authorization_code" @@ -339,7 +339,7 @@ def post_parse_request( class RefreshTokenHelper(TokenEndpointHelper): def process_request(self, req: Union[Message, dict], **kwargs): - _context = self.endpoint.server_get("endpoint_context") + _context = self.endpoint.server_get("context") _mngr = _context.session_manager logger.debug("Refresh Token") @@ -433,7 +433,7 @@ def post_parse_request( """ request = RefreshAccessTokenRequest(**request.to_dict()) - _context = self.endpoint.server_get("endpoint_context") + _context = self.endpoint.server_get("context") try: keyjar = _context.keyjar except AttributeError: @@ -498,7 +498,7 @@ def __init__(self, endpoint, config=None): def post_parse_request(self, request, client_id="", **kwargs): request = TokenExchangeRequest(**request.to_dict()) - _context = self.endpoint.server_get("endpoint_context") + _context = self.endpoint.server_get("context") if "token_exchange" in _context.cdb[request["client_id"]]: config = _context.cdb[request["client_id"]]["token_exchange"] else: @@ -567,7 +567,7 @@ def post_parse_request(self, request, client_id="", **kwargs): return resp def _enforce_policy(self, request, token, config): - _context = self.endpoint.server_get("endpoint_context") + _context = self.endpoint.server_get("context") subject_token_types_supported = config.get( "subject_token_types_supported", self.token_types_mapping.keys() ) @@ -638,7 +638,7 @@ def token_exchange_response(self, token, issued_token_type): return TokenExchangeResponse(**response_args) def process_request(self, request, **kwargs): - _context = self.endpoint.server_get("endpoint_context") + _context = self.endpoint.server_get("context") _mngr = _context.session_manager try: _handler_key = self.token_types_mapping[request["subject_token_type"]] diff --git a/src/idpyoidc/server/oidc/add_on/custom_scopes.py b/src/idpyoidc/server/oidc/add_on/custom_scopes.py index c5daa350..299f619c 100644 --- a/src/idpyoidc/server/oidc/add_on/custom_scopes.py +++ b/src/idpyoidc/server/oidc/add_on/custom_scopes.py @@ -18,7 +18,7 @@ def add_custom_scopes(endpoint, **kwargs): _scopes2claims = SCOPE2CLAIMS.copy() _scopes2claims.update(kwargs) - _context = _endpoint.server_get("endpoint_context") + _context = _endpoint.server_get("context") _context.scopes_handler.set_scopes_mapping(_scopes2claims) pi = _context.provider_info diff --git a/src/idpyoidc/server/oidc/add_on/pkce.py b/src/idpyoidc/server/oidc/add_on/pkce.py index 298b0ac7..958fd1cd 100644 --- a/src/idpyoidc/server/oidc/add_on/pkce.py +++ b/src/idpyoidc/server/oidc/add_on/pkce.py @@ -147,4 +147,4 @@ def add_pkce_support(endpoint: Dict[str, Endpoint], **kwargs): raise ValueError("Unsupported method: {}".format(method)) kwargs["code_challenge_methods"][method] = CC_METHOD[method] - authn_endpoint.server_get("endpoint_context").args["pkce"] = kwargs + authn_endpoint.server_get("context").args["pkce"] = kwargs diff --git a/src/idpyoidc/server/oidc/authorization.py b/src/idpyoidc/server/oidc/authorization.py index 6a7d5eef..bf09e56d 100755 --- a/src/idpyoidc/server/oidc/authorization.py +++ b/src/idpyoidc/server/oidc/authorization.py @@ -102,7 +102,7 @@ def do_request_user(self, request_info, **kwargs): else: _login_hint = request_info.get("login_hint") if _login_hint: - _context = self.server_get("endpoint_context") + _context = self.server_get("context") if _context.login_hint_lookup: kwargs["req_user"] = _context.login_hint_lookup(_login_hint) return kwargs diff --git a/src/idpyoidc/server/oidc/backchannel_authentication.py b/src/idpyoidc/server/oidc/backchannel_authentication.py index aaf44ce1..aeb069a8 100644 --- a/src/idpyoidc/server/oidc/backchannel_authentication.py +++ b/src/idpyoidc/server/oidc/backchannel_authentication.py @@ -60,11 +60,11 @@ def do_request_user(self, request): elif request.get("login_hint"): _login_hint = request.get("login_hint") if _login_hint: - _context = self.server_get("endpoint_context") + _context = self.server_get("context") if _context.login_hint_lookup: _request_user = _context.login_hint_lookup(_login_hint) elif request.get("login_hint_token"): - _context = self.server_get("endpoint_context") + _context = self.server_get("context") _request_user = execute( self.parse_login_hint_token, keyjar=_context.keyjar, @@ -79,7 +79,7 @@ def allowed_target_uris(self): The OP MUST accept its Issuer Identifier, Token Endpoint URL, or Backchannel Authentication Endpoint URL as values that identify it as an intended audience. """ - _context = self.server_get("endpoint_context") + _context = self.server_get("context") res = [_context.issuer] res.append(self.full_path) res.append(self.server_get("endpoint", "token").full_path) @@ -101,7 +101,7 @@ def process_request( return _error_msg if request_user: # Got a request for a legitimate user, create a session - _context = self.server_get("endpoint_context") + _context = self.server_get("context") _sid = _context.session_manager.create_session( None, request, request_user, client_id=request["client_id"] ) @@ -139,7 +139,7 @@ def _get_session_info(self, request, session_manager): def post_parse_request( self, request: Union[Message, dict], client_id: Optional[str] = "", **kwargs ) -> Union[Message, dict]: - _context = self.endpoint.server_get("endpoint_context") + _context = self.endpoint.server_get("context") _mngr = _context.session_manager _session_id = _mngr.auth_req_id_map[request["auth_req_id"]] _info = _mngr.get_session_info(_session_id) @@ -180,7 +180,7 @@ def process_request(self, req: Union[Message, dict], **kwargs): :param kwargs: :return: """ - _context = self.endpoint.server_get("endpoint_context") + _context = self.endpoint.server_get("context") _mngr = _context.session_manager logger.debug("OIDC Access Token") diff --git a/src/idpyoidc/server/oidc/discovery.py b/src/idpyoidc/server/oidc/discovery.py index 39ff0ecf..a14643d6 100755 --- a/src/idpyoidc/server/oidc/discovery.py +++ b/src/idpyoidc/server/oidc/discovery.py @@ -37,5 +37,5 @@ def do_response(self, response_args=None, request=None, **kwargs): def process_request(self, request=None, **kwargs): return { "subject": request["resource"], - "hrefs": [self.server_get("endpoint_context").issuer], + "hrefs": [self.server_get("context").issuer], } diff --git a/src/idpyoidc/server/oidc/provider_config.py b/src/idpyoidc/server/oidc/provider_config.py index 0c0de15a..abdfd20b 100755 --- a/src/idpyoidc/server/oidc/provider_config.py +++ b/src/idpyoidc/server/oidc/provider_config.py @@ -33,4 +33,4 @@ def add_endpoints(self, info, client_id, endpoint_context, **kwargs): return info def process_request(self, request=None, **kwargs): - return {"response_args": self.server_get("endpoint_context").provider_info} + return {"response_args": self.server_get("context").provider_info} diff --git a/src/idpyoidc/server/oidc/read_registration.py b/src/idpyoidc/server/oidc/read_registration.py index fe58a736..492532ef 100644 --- a/src/idpyoidc/server/oidc/read_registration.py +++ b/src/idpyoidc/server/oidc/read_registration.py @@ -18,13 +18,13 @@ def get_client_id_from_token(self, endpoint_context, token, request=None): if "client_id" in request: if ( request["client_id"] - == self.server_get("endpoint_context").registration_access_token[token] + == self.server_get("context").registration_access_token[token] ): return request["client_id"] return "" def process_request(self, request=None, **kwargs): - _cli_info = self.server_get("endpoint_context").cdb[request["client_id"]] + _cli_info = self.server_get("context").cdb[request["client_id"]] args = {k: v for k, v in _cli_info.items() if k in RegistrationResponse.c_param} comb_uri(args) return {"response_args": RegistrationResponse(**args)} diff --git a/src/idpyoidc/server/oidc/registration.py b/src/idpyoidc/server/oidc/registration.py index 283b5cc5..395ecbe9 100755 --- a/src/idpyoidc/server/oidc/registration.py +++ b/src/idpyoidc/server/oidc/registration.py @@ -140,7 +140,7 @@ def __init__(self, *args, **kwargs): def match_client_request(self, request: dict) -> list: err = [] - _provider_info = self.server_get("endpoint_context").provider_info + _provider_info = self.server_get("context").provider_info for key, val in request.items(): if key not in REGISTER2PREFERRED: continue @@ -158,7 +158,7 @@ def match_client_request(self, request: dict) -> list: def do_client_registration(self, request, client_id, ignore=None): if ignore is None: ignore = [] - _context = self.server_get("endpoint_context") + _context = self.server_get("context") _cinfo = _context.cdb[client_id].copy() logger.debug("_cinfo: %s" % sanitize(_cinfo)) @@ -221,24 +221,6 @@ def do_client_registration(self, request, client_id, ignore=None): error_description="%s pointed to illegal URL" % item, ) - # Do I have the necessary keys - # for item in ["id_token_signed_response_alg", "userinfo_signed_response_alg"]: - # if item in request: - # if request[item] in _context.provider_info[PREFERENCE2SUPPORTED[item]]: - # ktyp = alg2keytype(request[item]) - # # do I have this ktyp and for EC type keys the curve - # if ktyp not in ["none", "oct"]: - # _k = [] - # for iss in ["", _context.issuer]: - # _k.extend( - # _context.keyjar.get_signing_key( - # ktyp, alg=request[item], issuer_id=iss - # ) - # ) - # if not _k: - # logger.warning('Lacking support for "{}"'.format(request[item])) - # del _cinfo[item] - t = {"jwks_uri": "", "jwks": None} for item in ["jwks_uri", "jwks"]: @@ -316,8 +298,8 @@ def _verify_sector_identifier(self, request): """ si_url = request["sector_identifier_uri"] try: - res = self.server_get("endpoint_context").httpc.get( - si_url, **self.server_get("endpoint_context").httpc_params + res = self.server_get("context").httpc.get( + si_url, **self.server_get("context").httpc_params ) logger.debug("sector_identifier_uri => %s", sanitize(res.text)) except Exception as err: @@ -388,7 +370,7 @@ def client_registration_setup(self, request, new_id=True, set_secret=True): error_description=f"Don't support proposed {faulty_claims}" ) - _context = self.server_get("endpoint_context") + _context = self.server_get("context") if new_id: if self.kwargs.get("client_id_generator"): cid_generator = importer(self.kwargs["client_id_generator"]["class"]) @@ -461,7 +443,7 @@ def process_request(self, request=None, new_id=True, set_secret=True, **kwargs): if "error" in reg_resp: return reg_resp else: - _context = self.server_get("endpoint_context") + _context = self.server_get("context") _cookie = _context.new_cookie( name=_context.cookie_handler.name["register"], client_id=reg_resp["client_id"], diff --git a/src/idpyoidc/server/oidc/session.py b/src/idpyoidc/server/oidc/session.py index 841419a8..5743e12c 100644 --- a/src/idpyoidc/server/oidc/session.py +++ b/src/idpyoidc/server/oidc/session.py @@ -92,18 +92,18 @@ class Session(Endpoint): def __init__(self, server_get, **kwargs): _csi = kwargs.get("check_session_iframe") if _csi and not _csi.startswith("http"): - kwargs["check_session_iframe"] = add_path(server_get("endpoint_context").issuer, _csi) + kwargs["check_session_iframe"] = add_path(server_get("context").issuer, _csi) Endpoint.__init__(self, server_get, **kwargs) self.iv = as_bytes(rndstr(24)) def _encrypt_sid(self, sid): - encrypter = AES_GCMEncrypter(key=as_bytes(self.server_get("endpoint_context").symkey)) + encrypter = AES_GCMEncrypter(key=as_bytes(self.server_get("context").symkey)) enc_msg = encrypter.encrypt(as_bytes(sid), iv=self.iv) return as_unicode(b64e(enc_msg)) def _decrypt_sid(self, enc_msg): _msg = b64d(as_bytes(enc_msg)) - encrypter = AES_GCMEncrypter(key=as_bytes(self.server_get("endpoint_context").symkey)) + encrypter = AES_GCMEncrypter(key=as_bytes(self.server_get("context").symkey)) ctx, tag = split_ctx_and_tag(_msg) return as_unicode(encrypter.decrypt(as_bytes(ctx), iv=self.iv, tag=as_bytes(tag))) @@ -115,7 +115,7 @@ def do_back_channel_logout(self, cinfo, sid): :return: Tuple with logout URI and signed logout token """ - _context = self.server_get("endpoint_context") + _context = self.server_get("context") try: back_channel_logout_uri = cinfo["backchannel_logout_uri"] @@ -143,12 +143,12 @@ def do_back_channel_logout(self, cinfo, sid): def clean_sessions(self, usids): # Revoke all sessions - _context = self.server_get("endpoint_context") + _context = self.server_get("context") for sid in usids: _context.session_manager.revoke_client_session(sid) def logout_all_clients(self, sid): - _context = self.server_get("endpoint_context") + _context = self.server_get("context") _mngr = _context.session_manager _session_info = _mngr.get_session_info(sid) @@ -217,14 +217,14 @@ def unpack_signed_jwt(self, sjwt, sig_alg=""): else: alg = self.kwargs["signing_alg"] - sign_keys = self.server_get("endpoint_context").keyjar.get_signing_key(alg2keytype(alg)) + sign_keys = self.server_get("context").keyjar.get_signing_key(alg2keytype(alg)) _info = _jwt.verify_compact(keys=sign_keys, sigalg=alg) return _info else: raise ValueError("Not a signed JWT") def logout_from_client(self, sid): - _context = self.server_get("endpoint_context") + _context = self.server_get("context") _cdb = _context.cdb _session_information = _context.session_manager.get_session_info(sid, grant=True) _client_id = _session_information["client_id"] @@ -257,7 +257,7 @@ def process_request( :param kwargs: :return: """ - _context = self.server_get("endpoint_context") + _context = self.server_get("context") _mngr = _context.session_manager if "post_logout_redirect_uri" in request: @@ -371,7 +371,7 @@ def parse_request(self, request, http_info=None, **kwargs): request["access_token"] = auth_info["token"] if isinstance(request, dict): - _context = self.server_get("endpoint_context") + _context = self.server_get("context") request = self.request_cls(**request) if not request.verify(keyjar=_context.keyjar, sigalg=""): raise InvalidRequest("Request didn't verify") @@ -398,7 +398,7 @@ def do_verified_logout(self, sid, alla=False, **kwargs): bcl = _res.get("blu") if bcl: - _context = self.server_get("endpoint_context") + _context = self.server_get("context") # take care of Back channel logout first for _cid, spec in bcl.items(): _url, sjwt = spec @@ -421,7 +421,7 @@ def do_verified_logout(self, sid, alla=False, **kwargs): return _res["flu"].values() if _res.get("flu") else [] def kill_cookies(self): - _context = self.server_get("endpoint_context") + _context = self.server_get("context") _handler = _context.cookie_handler session_mngmnt = _handler.make_cookie_content( value="", name=_handler.name["session_management"], max_age=-1 diff --git a/src/idpyoidc/server/oidc/token_helper.py b/src/idpyoidc/server/oidc/token_helper.py index d319b9e2..0a003186 100755 --- a/src/idpyoidc/server/oidc/token_helper.py +++ b/src/idpyoidc/server/oidc/token_helper.py @@ -43,7 +43,7 @@ def process_request(self, req: Union[Message, dict], **kwargs): :param kwargs: :return: """ - _context = self.endpoint.server_get("endpoint_context") + _context = self.endpoint.server_get("context") _mngr = _context.session_manager logger.debug("OIDC Access Token") @@ -175,7 +175,7 @@ def post_parse_request( :returns: """ - _mngr = self.endpoint.server_get("endpoint_context").session_manager + _mngr = self.endpoint.server_get("context").session_manager try: _session_info = _mngr.get_session_info_by_token( request["code"], grant=True, handler_key="authorization_code" @@ -209,7 +209,7 @@ def post_parse_request( class RefreshTokenHelper(TokenEndpointHelper): def process_request(self, req: Union[Message, dict], **kwargs): - _context = self.endpoint.server_get("endpoint_context") + _context = self.endpoint.server_get("context") _mngr = _context.session_manager if req["grant_type"] != "refresh_token": @@ -326,7 +326,7 @@ def post_parse_request( """ request = RefreshAccessTokenRequest(**request.to_dict()) - _context = self.endpoint.server_get("endpoint_context") + _context = self.endpoint.server_get("context") try: keyjar = _context.keyjar except AttributeError: diff --git a/src/idpyoidc/server/oidc/userinfo.py b/src/idpyoidc/server/oidc/userinfo.py index c965e3e4..9bdb7ce3 100755 --- a/src/idpyoidc/server/oidc/userinfo.py +++ b/src/idpyoidc/server/oidc/userinfo.py @@ -63,7 +63,7 @@ def do_response( if "error" in kwargs and kwargs["error"]: return Endpoint.do_response(self, response_args, request, **kwargs) - _context = self.server_get("endpoint_context") + _context = self.server_get("context") if not client_id: raise MissingValue("client_id") @@ -112,7 +112,7 @@ def do_response( return {"response": resp, "http_headers": http_headers} def process_request(self, request=None, **kwargs): - _mngr = self.server_get("endpoint_context").session_manager + _mngr = self.server_get("context").session_manager try: _session_info = _mngr.get_session_info_by_token( request["access_token"], grant=True, handler_key="access_token" @@ -147,7 +147,7 @@ def process_request(self, request=None, **kwargs): # pass if allowed: - _cntxt = self.server_get("endpoint_context") + _cntxt = self.server_get("context") _claims_restriction = _cntxt.claims_interface.get_claims( _session_info["branch_id"], scopes=token.scope, claims_release_point="userinfo" ) diff --git a/src/idpyoidc/server/scopes.py b/src/idpyoidc/server/scopes.py index 0c239c71..9aab827a 100644 --- a/src/idpyoidc/server/scopes.py +++ b/src/idpyoidc/server/scopes.py @@ -65,7 +65,7 @@ def get_allowed_scopes(self, client_id=None): """ allowed_scopes = self.allowed_scopes if client_id: - client = self.server_get("endpoint_context").cdb.get(client_id) + client = self.server_get("context").cdb.get(client_id) if client is not None: allowed_scopes = client.get("allowed_scopes", allowed_scopes) return allowed_scopes @@ -79,7 +79,7 @@ def get_scopes_mapping(self, client_id=None): """ scopes_to_claims = self._scopes_to_claims if client_id: - client = self.server_get("endpoint_context").cdb.get(client_id) + client = self.server_get("context").cdb.get(client_id) if client is not None: scopes_to_claims = client.get("scopes_to_claims", scopes_to_claims) return scopes_to_claims diff --git a/src/idpyoidc/server/session/claims.py b/src/idpyoidc/server/session/claims.py index 35b43ee4..c0a9d263 100755 --- a/src/idpyoidc/server/session/claims.py +++ b/src/idpyoidc/server/session/claims.py @@ -65,7 +65,7 @@ def _client_claims( claims_release_point: str, secondary_identifier: Optional[str] = "", ): - _context = self.server_get("endpoint_context") + _context = self.server_get("context") add_claims_by_scope = _context.cdb[client_id].get("add_claims", {}).get("by_scope", {}) if add_claims_by_scope: _claims_by_scope = add_claims_by_scope.get(claims_release_point) @@ -93,7 +93,7 @@ def get_claims_from_request( client_id: str = None, secondary_identifier: str = "", ) -> dict: - _context = self.server_get("endpoint_context") + _context = self.server_get("context") # which endpoint module configuration to get the base claims from module = self._get_module(claims_release_point, _context) @@ -159,7 +159,7 @@ def get_claims( "userinfo"/"id_token"/"introspection"/"access_token" :return: Claims specification as a dictionary. """ - _context = self.server_get("endpoint_context") + _context = self.server_get("context") session_info = _context.session_manager.get_session_info(session_id, grant=True) client_id = session_info["client_id"] grant = session_info["grant"] @@ -189,7 +189,7 @@ def get_claims_all_usage_from_request( return _claims def get_claims_all_usage(self, session_id: str, scopes: str) -> dict: - grant = self.server_get("endpoint_context").session_manager.get_grant(session_id) + grant = self.server_get("context").session_manager.get_grant(session_id) if grant.authorization_request: auth_req = grant.authorization_request else: @@ -203,7 +203,7 @@ def get_user_claims(self, user_id: str, claims_restriction: dict) -> dict: :param claims_restriction: Specifies the upper limit of which claims can be returned :return: """ - meth = self.server_get("endpoint_context").userinfo + meth = self.server_get("context").userinfo if not meth: raise ImproperlyConfigured("userinfo MUST be defined in the configuration") if claims_restriction: diff --git a/src/idpyoidc/server/token/handler.py b/src/idpyoidc/server/token/handler.py index cd05692d..ea52844a 100755 --- a/src/idpyoidc/server/token/handler.py +++ b/src/idpyoidc/server/token/handler.py @@ -169,7 +169,7 @@ def factory( key_defs = [] read_only = False - cwd = server_get("endpoint_context").cwd + cwd = server_get("context").cwd if kwargs.get("jwks_def"): defs = kwargs["jwks_def"] if not jwks_file: diff --git a/src/idpyoidc/server/token/id_token.py b/src/idpyoidc/server/token/id_token.py index 7c58f677..bc38850e 100755 --- a/src/idpyoidc/server/token/id_token.py +++ b/src/idpyoidc/server/token/id_token.py @@ -150,7 +150,7 @@ def payload( :return: IDToken instance """ - _context = self.server_get("endpoint_context") + _context = self.server_get("context") _mngr = _context.session_manager session_information = _mngr.get_session_info(session_id, grant=True) grant = session_information["grant"] @@ -236,7 +236,7 @@ def sign_encrypt( :return: IDToken as a signed and/or encrypted JWT """ - _context = self.server_get("endpoint_context") + _context = self.server_get("context") client_info = _context.cdb[client_id] alg_dict = get_sign_and_encrypt_algorithms( @@ -269,7 +269,7 @@ def __call__( usage_rules: Optional[dict] = None, **kwargs, ) -> str: - _context = self.server_get("endpoint_context") + _context = self.server_get("context") user_id, client_id, grant_id = _context.session_manager.decrypt_session_id(session_id) @@ -307,7 +307,7 @@ def info(self, token): :return: tuple of token type and session id """ - _context = self.server_get("endpoint_context") + _context = self.server_get("context") _jwt = factory(token) if not _jwt: diff --git a/src/idpyoidc/server/token/jwt_token.py b/src/idpyoidc/server/token/jwt_token.py index e552115b..010cc703 100644 --- a/src/idpyoidc/server/token/jwt_token.py +++ b/src/idpyoidc/server/token/jwt_token.py @@ -37,7 +37,7 @@ def __init__( self.lifetime = lifetime self.kwargs = kwargs - _context = server_get("endpoint_context") + _context = server_get("context") # self.key_jar = keyjar or _context.keyjar self.issuer = issuer or _context.issuer self.cdb = _context.cdb @@ -85,7 +85,7 @@ def __call__( payload = self.load_custom_claims(payload) # payload.update(kwargs) - _context = self.server_get("endpoint_context") + _context = self.server_get("context") if usage_rules and "expires_in" in usage_rules: lifetime = usage_rules.get("expires_in") else: @@ -112,7 +112,7 @@ def __call__( return signer.pack(payload) def get_payload(self, token): - _context = self.server_get("endpoint_context") + _context = self.server_get("context") verifier = JWT(key_jar=_context.keyjar, allowed_sign_algs=[self.alg]) try: _payload = verifier.unpack(token) diff --git a/src/idpyoidc/server/user_authn/user.py b/src/idpyoidc/server/user_authn/user.py index 9db578ac..0706ee98 100755 --- a/src/idpyoidc/server/user_authn/user.py +++ b/src/idpyoidc/server/user_authn/user.py @@ -90,7 +90,7 @@ def verify(self, *args, **kwargs): raise NotImplementedError def unpack_token(self, token): - return verify_signed_jwt(token=token, keyjar=self.server_get("endpoint_context").keyjar) + return verify_signed_jwt(token=token, keyjar=self.server_get("context").keyjar) def done(self, areq): """ @@ -106,7 +106,7 @@ def done(self, areq): return False def cookie_info(self, cookie: List[dict], client_id: str) -> dict: - _context = self.server_get("endpoint_context") + _context = self.server_get("context") logger.debug("Value cookies: {}".format(cookie)) if cookie is None: @@ -192,7 +192,7 @@ def __call__(self, **kwargs): ) if not self.server_get: raise Exception(f"{self.__class__.__name__} doesn't have a working server_get") - _context = self.server_get("endpoint_context") + _context = self.server_get("context") # Stores information need afterwards in a signed JWT that then # appears as a hidden input in the form jws = create_signed_jwt(_context.issuer, _context.keyjar, **kwargs) diff --git a/tests/test_server_17_client_authn.py b/tests/test_server_17_client_authn.py index d42a2325..e22a46c7 100644 --- a/tests/test_server_17_client_authn.py +++ b/tests/test_server_17_client_authn.py @@ -193,7 +193,7 @@ def create_method(self): def test_client_secret_jwt(self): client_keyjar = KeyJar() client_keyjar.import_jwks(KEYJAR.export_jwks(private=True), CONF["issuer"]) - # The only own key the client has a this point + # The only own key the client has at this point client_keyjar.add_symmetric("", client_secret, ["sig"]) _jwt = JWT(client_keyjar, iss=client_id, sign_alg="HS256") @@ -475,7 +475,7 @@ def create_method(self): self.server = Server(conf=CONF, keyjar=KEYJAR) self.server.endpoint_context.cdb[client_id] = {"client_secret": client_secret} self.server.endpoint = do_endpoints(CONF, self.server.server_get) - self.endpoint_context = self.server.server_get("endpoint_context") + self.endpoint_context = self.server.server_get("context") def test_verify_per_client(self): self.server.endpoint_context.cdb[client_id]["client_authn_method"] = ["public"] @@ -612,7 +612,7 @@ def create_method(self): self.server = Server(conf=CONF, keyjar=KEYJAR) self.server.endpoint_context.cdb[client_id] = {"client_secret": client_secret} self.server.endpoint = do_endpoints(CONF, self.server.server_get) - self.endpoint_context = self.server.server_get("endpoint_context") + self.endpoint_context = self.server.server_get("context") def test_verify_client_jws_authn_method(self): client_keyjar = KeyJar() diff --git a/tests/test_server_20d_client_authn.py b/tests/test_server_20d_client_authn.py index 55ab886c..badd4842 100755 --- a/tests/test_server_20d_client_authn.py +++ b/tests/test_server_20d_client_authn.py @@ -428,7 +428,7 @@ class TestVerify: def create_method(self): self.server = Server(conf=CONF, keyjar=KEYJAR) self.server.endpoint_context.cdb[client_id] = {"client_secret": client_secret} - self.endpoint_context = self.server.server_get("endpoint_context") + self.endpoint_context = self.server.server_get("context") def test_verify_per_client(self): self.server.endpoint_context.cdb[client_id]["client_authn_method"] = ["public"] @@ -565,7 +565,7 @@ class TestVerify2: def create_method(self): self.server = Server(conf=CONF, keyjar=KEYJAR) self.server.endpoint_context.cdb[client_id] = {"client_secret": client_secret} - self.endpoint_context = self.server.server_get("endpoint_context") + self.endpoint_context = self.server.server_get("context") def test_verify_client_jws_authn_method(self): client_keyjar = KeyJar() diff --git a/tests/test_server_24_oauth2_authorization_endpoint.py b/tests/test_server_24_oauth2_authorization_endpoint.py index 3f8e2d56..ecd0fdf9 100755 --- a/tests/test_server_24_oauth2_authorization_endpoint.py +++ b/tests/test_server_24_oauth2_authorization_endpoint.py @@ -272,7 +272,7 @@ def create_endpoint(self): self.rp_keyjar = KeyJar() self.rp_keyjar.add_symmetric("client_1", "hemligtkodord1234567890") - self.endpoint.server_get("endpoint_context").keyjar.add_symmetric( + self.endpoint.server_get("context").keyjar.add_symmetric( "client_1", "hemligtkodord1234567890" ) @@ -334,24 +334,24 @@ def test_do_response_code_token(self): def test_verify_uri_unknown_client(self): request = {"redirect_uri": "https://rp.example.com/cb"} with pytest.raises(UnknownClient): - verify_uri(self.endpoint.server_get("endpoint_context"), request, "redirect_uri") + verify_uri(self.endpoint.server_get("context"), request, "redirect_uri") def test_verify_uri_fragment(self): - _context = self.endpoint.server_get("endpoint_context") + _context = self.endpoint.server_get("context") _context.cdb["client_id"] = {"redirect_uri": ["https://rp.example.com/auth_cb"]} request = {"redirect_uri": "https://rp.example.com/cb#foobar"} with pytest.raises(URIError): verify_uri(_context, request, "redirect_uri", "client_id") def test_verify_uri_noregistered(self): - _context = self.endpoint.server_get("endpoint_context") + _context = self.endpoint.server_get("context") request = {"redirect_uri": "https://rp.example.com/cb"} with pytest.raises(KeyError): verify_uri(_context, request, "redirect_uri", "client_id") def test_verify_uri_unregistered(self): - _context = self.endpoint.server_get("endpoint_context") + _context = self.endpoint.server_get("context") _context.cdb["client_id"] = {"redirect_uris": [("https://rp.example.com/auth_cb", {})]} request = {"redirect_uri": "https://rp.example.com/cb"} @@ -360,7 +360,7 @@ def test_verify_uri_unregistered(self): verify_uri(_context, request, "redirect_uri", "client_id") def test_verify_uri_qp_match(self): - _context = self.endpoint.server_get("endpoint_context") + _context = self.endpoint.server_get("context") _context.cdb["client_id"] = { "redirect_uris": [("https://rp.example.com/cb", {"foo": ["bar"]})] } @@ -370,7 +370,7 @@ def test_verify_uri_qp_match(self): verify_uri(_context, request, "redirect_uri", "client_id") def test_verify_uri_qp_mismatch(self): - _context = self.endpoint.server_get("endpoint_context") + _context = self.endpoint.server_get("context") _context.cdb["client_id"] = { "redirect_uris": [("https://rp.example.com/cb", {"foo": ["bar"]})] } @@ -392,7 +392,7 @@ def test_verify_uri_qp_mismatch(self): verify_uri(_context, request, "redirect_uri", "client_id") def test_verify_uri_qp_missing(self): - _context = self.endpoint.server_get("endpoint_context") + _context = self.endpoint.server_get("context") _context.cdb["client_id"] = { "redirect_uris": [("https://rp.example.com/cb", {"foo": ["bar"], "state": ["low"]})] } @@ -402,7 +402,7 @@ def test_verify_uri_qp_missing(self): verify_uri(_context, request, "redirect_uri", "client_id") def test_verify_uri_qp_missing_val(self): - _context = self.endpoint.server_get("endpoint_context") + _context = self.endpoint.server_get("context") _context.cdb["client_id"] = { "redirect_uris": [("https://rp.example.com/cb", {"foo": ["bar", "low"]})] } @@ -412,7 +412,7 @@ def test_verify_uri_qp_missing_val(self): verify_uri(_context, request, "redirect_uri", "client_id") def test_verify_uri_no_registered_qp(self): - _context = self.endpoint.server_get("endpoint_context") + _context = self.endpoint.server_get("context") _context.cdb["client_id"] = {"redirect_uris": [("https://rp.example.com/cb", {})]} request = {"redirect_uri": "https://rp.example.com/cb?foo=bob"} @@ -420,7 +420,7 @@ def test_verify_uri_no_registered_qp(self): verify_uri(_context, request, "redirect_uri", "client_id") def test_verify_uri_wrong_uri_type(self): - _context = self.endpoint.server_get("endpoint_context") + _context = self.endpoint.server_get("context") _context.cdb["client_id"] = {"redirect_uris": [("https://rp.example.com/cb", {})]} request = {"redirect_uri": "https://rp.example.com/cb?foo=bob"} @@ -428,7 +428,7 @@ def test_verify_uri_wrong_uri_type(self): verify_uri(_context, request, "post_logout_redirect_uri", "client_id") def test_verify_uri_none_registered(self): - _context = self.endpoint.server_get("endpoint_context") + _context = self.endpoint.server_get("context") _context.cdb["client_id"] = { "post_logout_redirect_uri": [("https://rp.example.com/plrc", {})] } @@ -438,7 +438,7 @@ def test_verify_uri_none_registered(self): verify_uri(_context, request, "redirect_uri", "client_id") def test_get_uri(self): - _context = self.endpoint.server_get("endpoint_context") + _context = self.endpoint.server_get("context") _context.cdb["client_id"] = {"redirect_uris": [("https://rp.example.com/cb", {})]} request = { @@ -449,7 +449,7 @@ def test_get_uri(self): assert get_uri(_context, request, "redirect_uri") == "https://rp.example.com/cb" def test_get_uri_no_redirect_uri(self): - _context = self.endpoint.server_get("endpoint_context") + _context = self.endpoint.server_get("context") _context.cdb["client_id"] = {"redirect_uris": [("https://rp.example.com/cb", {})]} request = {"client_id": "client_id"} @@ -457,7 +457,7 @@ def test_get_uri_no_redirect_uri(self): assert get_uri(_context, request, "redirect_uri") == "https://rp.example.com/cb" def test_get_uri_no_registered(self): - _context = self.endpoint.server_get("endpoint_context") + _context = self.endpoint.server_get("context") _context.cdb["client_id"] = {"redirect_uris": [("https://rp.example.com/cb", {})]} request = {"client_id": "client_id"} @@ -466,7 +466,7 @@ def test_get_uri_no_registered(self): get_uri(_context, request, "post_logout_redirect_uri") def test_get_uri_more_then_one_registered(self): - _context = self.endpoint.server_get("endpoint_context") + _context = self.endpoint.server_get("context") _context.cdb["client_id"] = { "redirect_uris": [ ("https://rp.example.com/cb", {}), @@ -489,7 +489,7 @@ def test_create_authn_response(self): scope="openid", ) - self.endpoint.server_get("endpoint_context").cdb["client_id"] = { + self.endpoint.server_get("context").cdb["client_id"] = { "client_id": "client_id", "redirect_uris": [("https://rp.example.com/cb", {})], "id_token_signed_response_alg": "ES256", @@ -517,7 +517,7 @@ def test_setup_auth(self): "id_token_signed_response_alg": "RS256", } - kaka = self.endpoint.server_get("endpoint_context").cookie_handler.make_cookie_content( + kaka = self.endpoint.server_get("context").cookie_handler.make_cookie_content( "value", "sso" ) @@ -545,7 +545,7 @@ def test_setup_auth_error(self): "id_token_signed_response_alg": "RS256", } - item = self.endpoint.server_get("endpoint_context").authn_broker.db["anon"] + item = self.endpoint.server_get("context").authn_broker.db["anon"] item["method"].fail = NoSuchAuthentication res = self.endpoint.setup_auth(request, redirect_uri, cinfo, None) @@ -575,7 +575,7 @@ def test_setup_auth_invalid_scope(self): "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"] } - _context = self.endpoint.server_get("endpoint_context") + _context = self.endpoint.server_get("context") _context.cdb["client_id"] = cinfo kaka = _context.cookie_handler.make_cookie_content("value", "sso") @@ -608,7 +608,7 @@ def test_setup_auth_user(self): session_id = self._create_session(request) - item = self.endpoint.server_get("endpoint_context").authn_broker.db["anon"] + item = self.endpoint.server_get("context").authn_broker.db["anon"] item["method"].user = b64e(as_bytes(json.dumps({"uid": "krall", "sid": session_id}))) res = self.endpoint.setup_auth(request, redirect_uri, cinfo, None) @@ -633,7 +633,7 @@ def test_setup_auth_session_revoked(self): session_id = self._create_session(request) - _context = self.endpoint.server_get("endpoint_context") + _context = self.endpoint.server_get("context") _mngr = _context.session_manager _csi = _mngr[session_id] _csi.revoked = True diff --git a/tests/test_server_24_oauth2_authorization_endpoint_jar.py b/tests/test_server_24_oauth2_authorization_endpoint_jar.py index 5d1f6e30..885e8a0e 100755 --- a/tests/test_server_24_oauth2_authorization_endpoint_jar.py +++ b/tests/test_server_24_oauth2_authorization_endpoint_jar.py @@ -205,7 +205,7 @@ def test_parse_request_parameter(self): _jwt = JWT(key_jar=self.rp_keyjar, iss="client_1", sign_alg="HS256") _jws = _jwt.pack( AUTH_REQ_DICT, - aud=self.endpoint.server_get("endpoint_context").provider_info["issuer"], + aud=self.endpoint.server_get("context").provider_info["issuer"], ) # ----------------- _req = self.endpoint.parse_request( @@ -223,7 +223,7 @@ def test_parse_request_uri(self): _jwt = JWT(key_jar=self.rp_keyjar, iss="client_1", sign_alg="HS256") _jws = _jwt.pack( AUTH_REQ_DICT, - aud=self.endpoint.server_get("endpoint_context").provider_info["issuer"], + aud=self.endpoint.server_get("context").provider_info["issuer"], ) request_uri = "https://client.example.com/req" diff --git a/tests/test_server_24_oidc_authorization_endpoint.py b/tests/test_server_24_oidc_authorization_endpoint.py index 6a2d7912..3e16ec65 100755 --- a/tests/test_server_24_oidc_authorization_endpoint.py +++ b/tests/test_server_24_oidc_authorization_endpoint.py @@ -434,7 +434,7 @@ def test_id_token_claims(self): _resp = self.endpoint.process_request(_pr_resp) idt = verify_id_token( _resp["response_args"], - keyjar=self.endpoint.server_get("endpoint_context").keyjar, + keyjar=self.endpoint.server_get("context").keyjar, ) assert idt # from config @@ -445,7 +445,7 @@ def test_id_token_claims(self): def test_re_authenticate(self): request = {"prompt": "login"} - authn = UserAuthnMethod(self.endpoint.server_get("endpoint_context")) + authn = UserAuthnMethod(self.endpoint.server_get("context")) assert re_authenticate(request, authn) def test_id_token_acr(self): @@ -459,7 +459,7 @@ def test_id_token_acr(self): _resp = self.endpoint.process_request(_pr_resp) res = verify_id_token( _resp["response_args"], - keyjar=self.endpoint.server_get("endpoint_context").keyjar, + keyjar=self.endpoint.server_get("context").keyjar, ) assert res res = _resp["response_args"][verified_claim_name("id_token")] @@ -468,24 +468,24 @@ def test_id_token_acr(self): def test_verify_uri_unknown_client(self): request = {"redirect_uri": "https://rp.example.com/cb"} with pytest.raises(UnknownClient): - verify_uri(self.endpoint.server_get("endpoint_context"), request, "redirect_uri") + verify_uri(self.endpoint.server_get("context"), request, "redirect_uri") def test_verify_uri_fragment(self): - _ec = self.endpoint.server_get("endpoint_context") + _ec = self.endpoint.server_get("context") _ec.cdb["client_id"] = {"redirect_uri": ["https://rp.example.com/auth_cb"]} request = {"redirect_uri": "https://rp.example.com/cb#foobar"} with pytest.raises(URIError): verify_uri(_ec, request, "redirect_uri", "client_id") def test_verify_uri_noregistered(self): - _ec = self.endpoint.server_get("endpoint_context") + _ec = self.endpoint.server_get("context") request = {"redirect_uri": "https://rp.example.com/cb"} with pytest.raises(KeyError): verify_uri(_ec, request, "redirect_uri", "client_id") def test_verify_uri_unregistered(self): - _ec = self.endpoint.server_get("endpoint_context") + _ec = self.endpoint.server_get("context") _ec.cdb["client_id"] = {"redirect_uris": [("https://rp.example.com/auth_cb", {})]} request = {"redirect_uri": "https://rp.example.com/cb"} @@ -494,7 +494,7 @@ def test_verify_uri_unregistered(self): verify_uri(_ec, request, "redirect_uri", "client_id") def test_verify_uri_qp_match(self): - _ec = self.endpoint.server_get("endpoint_context") + _ec = self.endpoint.server_get("context") _ec.cdb["client_id"] = {"redirect_uris": [("https://rp.example.com/cb", {"foo": ["bar"]})]} request = {"redirect_uri": "https://rp.example.com/cb?foo=bar"} @@ -502,7 +502,7 @@ def test_verify_uri_qp_match(self): verify_uri(_ec, request, "redirect_uri", "client_id") def test_verify_uri_qp_mismatch(self): - _ec = self.endpoint.server_get("endpoint_context") + _ec = self.endpoint.server_get("context") _ec.cdb["client_id"] = {"redirect_uris": [("https://rp.example.com/cb", {"foo": ["bar"]})]} request = {"redirect_uri": "https://rp.example.com/cb?foo=bob"} @@ -522,7 +522,7 @@ def test_verify_uri_qp_mismatch(self): verify_uri(_ec, request, "redirect_uri", "client_id") def test_verify_uri_qp_missing(self): - _ec = self.endpoint.server_get("endpoint_context") + _ec = self.endpoint.server_get("context") _ec.cdb["client_id"] = { "redirect_uris": [("https://rp.example.com/cb", {"foo": ["bar"], "state": ["low"]})] } @@ -532,7 +532,7 @@ def test_verify_uri_qp_missing(self): verify_uri(_ec, request, "redirect_uri", "client_id") def test_verify_uri_qp_missing_val(self): - _ec = self.endpoint.server_get("endpoint_context") + _ec = self.endpoint.server_get("context") _ec.cdb["client_id"] = { "redirect_uris": [("https://rp.example.com/cb", {"foo": ["bar", "low"]})] } @@ -542,7 +542,7 @@ def test_verify_uri_qp_missing_val(self): verify_uri(_ec, request, "redirect_uri", "client_id") def test_verify_uri_no_registered_qp(self): - _ec = self.endpoint.server_get("endpoint_context") + _ec = self.endpoint.server_get("context") _ec.cdb["client_id"] = {"redirect_uris": [("https://rp.example.com/cb", {})]} request = {"redirect_uri": "https://rp.example.com/cb?foo=bob"} @@ -550,7 +550,7 @@ def test_verify_uri_no_registered_qp(self): verify_uri(_ec, request, "redirect_uri", "client_id") def test_get_uri(self): - _ec = self.endpoint.server_get("endpoint_context") + _ec = self.endpoint.server_get("context") _ec.cdb["client_id"] = {"redirect_uris": [("https://rp.example.com/cb", {})]} request = { @@ -561,7 +561,7 @@ def test_get_uri(self): assert get_uri(_ec, request, "redirect_uri") == "https://rp.example.com/cb" def test_get_uri_no_redirect_uri(self): - _ec = self.endpoint.server_get("endpoint_context") + _ec = self.endpoint.server_get("context") _ec.cdb["client_id"] = {"redirect_uris": [("https://rp.example.com/cb", {})]} request = {"client_id": "client_id"} @@ -569,7 +569,7 @@ def test_get_uri_no_redirect_uri(self): assert get_uri(_ec, request, "redirect_uri") == "https://rp.example.com/cb" def test_get_uri_no_registered(self): - _ec = self.endpoint.server_get("endpoint_context") + _ec = self.endpoint.server_get("context") _ec.cdb["client_id"] = {"redirect_uris": [("https://rp.example.com/cb", {})]} request = {"client_id": "client_id"} @@ -578,7 +578,7 @@ def test_get_uri_no_registered(self): get_uri(_ec, request, "post_logout_redirect_uri") def test_get_uri_more_then_one_registered(self): - _ec = self.endpoint.server_get("endpoint_context") + _ec = self.endpoint.server_get("context") _ec.cdb["client_id"] = { "redirect_uris": [ ("https://rp.example.com/cb", {}), @@ -601,7 +601,7 @@ def test_create_authn_response_id_token(self): scope=["openid", "profile"], ) - _ec = self.endpoint.server_get("endpoint_context") + _ec = self.endpoint.server_get("context") _ec.cdb["client_id"] = { "client_id": "client_id", "redirect_uris": [("https://rp.example.com/cb", {})], @@ -629,7 +629,7 @@ def test_create_authn_response_id_token_request_claims(self): scope=["openid"], ) - _ec = self.endpoint.server_get("endpoint_context") + _ec = self.endpoint.server_get("context") _ec.cdb["client_id"] = { "client_id": "client_id", "redirect_uris": [("https://rp.example.com/cb", {})], diff --git a/tests/test_server_30_oidc_end_session.py b/tests/test_server_30_oidc_end_session.py index 7b12ab32..eea70c34 100644 --- a/tests/test_server_30_oidc_end_session.py +++ b/tests/test_server_30_oidc_end_session.py @@ -236,7 +236,7 @@ def test_end_session_endpoint(self): _ = self.session_endpoint.process_request("", http_info=http_info) def _create_cookie(self, session_id): - ec = self.session_endpoint.server_get("endpoint_context") + ec = self.session_endpoint.server_get("context") return ec.new_cookie( name=ec.cookie_handler.name["session"], sid=session_id, @@ -276,7 +276,7 @@ def _auth_with_id_token(self, state): _pr_resp = self.authn_endpoint.parse_request(req.to_dict()) _resp = self.authn_endpoint.process_request(_pr_resp) - _info = self.session_endpoint.server_get("endpoint_context").cookie_handler.parse_cookie( + _info = self.session_endpoint.server_get("context").cookie_handler.parse_cookie( "oidc_op", _resp["cookie"] ) # value is a JSON document @@ -336,7 +336,7 @@ def test_end_session_endpoint_with_cookie_id_token_and_unknown_sid(self): http_info = {"cookie": [cookie]} msg = Message(id_token=id_token) - verify_id_token(msg, keyjar=self.session_endpoint.server_get("endpoint_context").keyjar) + verify_id_token(msg, keyjar=self.session_endpoint.server_get("context").keyjar) msg2 = Message(id_token_hint=id_token) msg2[verified_claim_name("id_token_hint")] = msg[verified_claim_name("id_token")] @@ -376,7 +376,7 @@ def test_end_session_endpoint_with_post_logout_redirect_uri(self): http_info = {"cookie": [cookie]} post_logout_redirect_uri = join_query( - *self.session_endpoint.server_get("endpoint_context").cdb["client_1"][ + *self.session_endpoint.server_get("context").cdb["client_1"][ "post_logout_redirect_uri" ] ) @@ -403,7 +403,7 @@ def test_end_session_endpoint_with_wrong_post_logout_redirect_uri(self): post_logout_redirect_uri = "https://demo.example.com/log_out" msg = Message(id_token=id_token) - verify_id_token(msg, keyjar=self.session_endpoint.server_get("endpoint_context").keyjar) + verify_id_token(msg, keyjar=self.session_endpoint.server_get("context").keyjar) with pytest.raises(RedirectURIError): self.session_endpoint.process_request( @@ -420,14 +420,14 @@ def test_back_channel_logout_no_backchannel_logout_uri(self): info = self._code_auth("1234567") res = self.session_endpoint.do_back_channel_logout( - self.session_endpoint.server_get("endpoint_context").cdb["client_1"], info["session_id"] + self.session_endpoint.server_get("context").cdb["client_1"], info["session_id"] ) assert res is None def test_back_channel_logout(self): info = self._code_auth("1234567") - _cdb = copy.copy(self.session_endpoint.server_get("endpoint_context").cdb["client_1"]) + _cdb = copy.copy(self.session_endpoint.server_get("context").cdb["client_1"]) _cdb["backchannel_logout_uri"] = "https://example.com/bc_logout" _cdb["client_id"] = "client_1" res = self.session_endpoint.do_back_channel_logout(_cdb, info["session_id"]) @@ -442,7 +442,7 @@ def test_back_channel_logout(self): def test_front_channel_logout(self): self._code_auth("1234567") - _cdb = copy.copy(self.session_endpoint.server_get("endpoint_context").cdb["client_1"]) + _cdb = copy.copy(self.session_endpoint.server_get("context").cdb["client_1"]) _cdb["frontchannel_logout_uri"] = "https://example.com/fc_logout" _cdb["client_id"] = "client_1" res = do_front_channel_logout_iframe(_cdb, ISS, "_sid_") @@ -451,7 +451,7 @@ def test_front_channel_logout(self): def test_front_channel_logout_session_required(self): self._code_auth("1234567") - _cdb = copy.copy(self.session_endpoint.server_get("endpoint_context").cdb["client_1"]) + _cdb = copy.copy(self.session_endpoint.server_get("context").cdb["client_1"]) _cdb["frontchannel_logout_uri"] = "https://example.com/fc_logout" _cdb["frontchannel_logout_session_required"] = True _cdb["client_id"] = "client_1" @@ -467,7 +467,7 @@ def test_front_channel_logout_session_required(self): def test_front_channel_logout_with_query(self): self._code_auth("1234567") - _cdb = copy.copy(self.session_endpoint.server_get("endpoint_context").cdb["client_1"]) + _cdb = copy.copy(self.session_endpoint.server_get("context").cdb["client_1"]) _cdb["frontchannel_logout_uri"] = "https://example.com/fc_logout?entity_id=foo" _cdb["frontchannel_logout_session_required"] = True _cdb["client_id"] = "client_1" @@ -489,10 +489,10 @@ def test_logout_from_client_bc(self): _code, client_session_info=True, handler_key="authorization_code" ) - self.session_endpoint.server_get("endpoint_context").cdb["client_1"][ + self.session_endpoint.server_get("context").cdb["client_1"][ "backchannel_logout_uri" ] = "https://example.com/bc_logout" - self.session_endpoint.server_get("endpoint_context").cdb["client_1"][ + self.session_endpoint.server_get("context").cdb["client_1"][ "client_id" ] = "client_1" @@ -517,12 +517,12 @@ def test_logout_from_client_fc(self): _code, client_session_info=True, handler_key="authorization_code" ) - # del self.session_endpoint.server_get("endpoint_context").cdb['client_1'][ + # del self.session_endpoint.server_get("context").cdb['client_1'][ # 'backchannel_logout_uri'] - self.session_endpoint.server_get("endpoint_context").cdb["client_1"][ + self.session_endpoint.server_get("context").cdb["client_1"][ "frontchannel_logout_uri" ] = "https://example.com/fc_logout" - self.session_endpoint.server_get("endpoint_context").cdb["client_1"][ + self.session_endpoint.server_get("context").cdb["client_1"][ "client_id" ] = "client_1" @@ -554,13 +554,13 @@ def test_logout_from_client(self): ) # client0 - self.session_endpoint.server_get("endpoint_context").cdb["client_1"][ + self.session_endpoint.server_get("context").cdb["client_1"][ "backchannel_logout_uri"] = "https://example.com/bc_logout" - self.session_endpoint.server_get("endpoint_context").cdb["client_1"][ + self.session_endpoint.server_get("context").cdb["client_1"][ "client_id"] = "client_1" - self.session_endpoint.server_get("endpoint_context").cdb["client_2"][ + self.session_endpoint.server_get("context").cdb["client_2"][ "frontchannel_logout_uri"] = "https://example.com/fc_logout" - self.session_endpoint.server_get("endpoint_context").cdb["client_2"][ + self.session_endpoint.server_get("context").cdb["client_2"][ "client_id"] = "client_2" res = self.session_endpoint.logout_all_clients(_session_info["branch_id"]) @@ -599,7 +599,7 @@ def test_do_verified_logout(self): _session_info = self.session_manager.get_session_info_by_token( _code, handler_key="authorization_code" ) - _cdb = self.session_endpoint.server_get("endpoint_context").cdb + _cdb = self.session_endpoint.server_get("context").cdb _cdb["client_1"]["backchannel_logout_uri"] = "https://example.com/bc_logout" _cdb["client_1"]["client_id"] = "client_1" @@ -628,21 +628,21 @@ def test_logout_from_client_no_session(self): self._code_auth2("abcdefg") # client0 - self.session_endpoint.server_get("endpoint_context").cdb["client_1"][ + self.session_endpoint.server_get("context").cdb["client_1"][ "backchannel_logout_uri" ] = "https://example.com/bc_logout" - self.session_endpoint.server_get("endpoint_context").cdb["client_1"][ + self.session_endpoint.server_get("context").cdb["client_1"][ "client_id" ] = "client_1" - self.session_endpoint.server_get("endpoint_context").cdb["client_2"][ + self.session_endpoint.server_get("context").cdb["client_2"][ "frontchannel_logout_uri" ] = "https://example.com/fc_logout" - self.session_endpoint.server_get("endpoint_context").cdb["client_2"][ + self.session_endpoint.server_get("context").cdb["client_2"][ "client_id" ] = "client_2" _uid, _cid, _gid = self.session_manager.decrypt_session_id(_session_info["branch_id"]) - self.session_endpoint.server_get("endpoint_context").session_manager.delete([_uid, _cid]) + self.session_endpoint.server_get("context").session_manager.delete([_uid, _cid]) with pytest.raises(InvalidBranchID): self.session_endpoint.logout_all_clients(_session_info["branch_id"]) diff --git a/tests/test_server_31_oauth2_introspection.py b/tests/test_server_31_oauth2_introspection.py index 04748917..773ef1d0 100644 --- a/tests/test_server_31_oauth2_introspection.py +++ b/tests/test_server_31_oauth2_introspection.py @@ -231,7 +231,7 @@ def _mint_token(self, token_class, grant, session_id, based_on=None, **kwargs): # Constructing an authorization code is now done return grant.mint_token( session_id=session_id, - endpoint_context=self.token_endpoint.server_get("endpoint_context"), + endpoint_context=self.token_endpoint.server_get("context"), token_class=token_class, token_handler=self.session_manager.token_handler.handler[token_class], expires_at=utc_time_sans_frac() + 300, # 5 minutes from now @@ -242,7 +242,7 @@ def _mint_token(self, token_class, grant, session_id, based_on=None, **kwargs): def _get_access_token(self, areq): session_id = self._create_session(areq) # Consent handling - grant = self.token_endpoint.server_get("endpoint_context").authz(session_id, areq) + grant = self.token_endpoint.server_get("context").authz(session_id, areq) self.session_manager[session_id] = grant # grant = self.session_manager[session_id] code = self._mint_token("authorization_code", grant, session_id) @@ -256,7 +256,7 @@ def test_parse_no_authn(self): def test_parse_with_client_auth_in_req(self): access_token = self._get_access_token(AUTH_REQ) - _context = self.introspection_endpoint.server_get("endpoint_context") + _context = self.introspection_endpoint.server_get("context") _req = self.introspection_endpoint.parse_request( { "token": access_token.value, @@ -273,7 +273,7 @@ def test_parse_with_wrong_client_authn(self): _basic_token = "{}:{}".format( "client_1", - self.introspection_endpoint.server_get("endpoint_context").cdb["client_1"][ + self.introspection_endpoint.server_get("context").cdb["client_1"][ "client_secret" ], ) @@ -293,7 +293,7 @@ def test_process_request(self): { "token": access_token.value, "client_id": "client_1", - "client_secret": self.introspection_endpoint.server_get("endpoint_context").cdb[ + "client_secret": self.introspection_endpoint.server_get("context").cdb[ "client_1" ]["client_secret"], } @@ -317,7 +317,7 @@ def test_do_response(self): { "token": access_token.value, "client_id": "client_1", - "client_secret": self.introspection_endpoint.server_get("endpoint_context").cdb[ + "client_secret": self.introspection_endpoint.server_get("context").cdb[ "client_1" ]["client_secret"], } @@ -348,7 +348,7 @@ def test_do_response(self): def test_do_response_no_token(self): # access_token = self._get_access_token(AUTH_REQ) - _context = self.introspection_endpoint.server_get("endpoint_context") + _context = self.introspection_endpoint.server_get("context") _req = self.introspection_endpoint.parse_request( { "client_id": "client_1", @@ -360,7 +360,7 @@ def test_do_response_no_token(self): def test_access_token(self): access_token = self._get_access_token(AUTH_REQ) - _context = self.introspection_endpoint.server_get("endpoint_context") + _context = self.introspection_endpoint.server_get("context") _req = self.introspection_endpoint.parse_request( { "token": access_token.value, @@ -378,12 +378,12 @@ def test_code(self): session_id = self._create_session(AUTH_REQ) # Apply consent - grant = self.token_endpoint.server_get("endpoint_context").authz(session_id, AUTH_REQ) + grant = self.token_endpoint.server_get("context").authz(session_id, AUTH_REQ) self.session_manager[session_id] = grant code = self._mint_token("authorization_code", grant, session_id) - _context = self.introspection_endpoint.server_get("endpoint_context") + _context = self.introspection_endpoint.server_get("context") _req = self.introspection_endpoint.parse_request( { diff --git a/tests/test_server_33_oauth2_pkce.py b/tests/test_server_33_oauth2_pkce.py index fbfc961f..7b9c1ea0 100644 --- a/tests/test_server_33_oauth2_pkce.py +++ b/tests/test_server_33_oauth2_pkce.py @@ -327,7 +327,7 @@ def test_essential_per_client(self, conf): authn_endpoint = server.server_get("endpoint", "authorization") token_endpoint = server.server_get("endpoint", "token") _authn_req = AUTH_REQ.copy() - endpoint_context = server.server_get("endpoint_context") + endpoint_context = server.server_get("context") endpoint_context.cdb[AUTH_REQ["client_id"]]["pkce_essential"] = True _pr_resp = authn_endpoint.parse_request(_authn_req.to_dict()) @@ -342,7 +342,7 @@ def test_not_essential_per_client(self, conf): authn_endpoint = server.server_get("endpoint", "authorization") token_endpoint = server.server_get("endpoint", "token") _authn_req = AUTH_REQ.copy() - endpoint_context = server.server_get("endpoint_context") + endpoint_context = server.server_get("context") endpoint_context.cdb[AUTH_REQ["client_id"]]["pkce_essential"] = False _pr_resp = authn_endpoint.parse_request(_authn_req.to_dict()) @@ -440,7 +440,7 @@ def test_missing_authz_endpoint(): server = Server(configuration) add_pkce_support(server.server_get("endpoints")) - assert "pkce" not in server.server_get("endpoint_context").args + assert "pkce" not in server.server_get("context").args def test_missing_token_endpoint(): @@ -465,4 +465,4 @@ def test_missing_token_endpoint(): server = Server(configuration) add_pkce_support(server.server_get("endpoints")) - assert "pkce" not in server.server_get("endpoint_context").args + assert "pkce" not in server.server_get("context").args diff --git a/tests/test_server_36_oauth2_token_exchange.py b/tests/test_server_36_oauth2_token_exchange.py index 7c1e70c8..a18342cf 100644 --- a/tests/test_server_36_oauth2_token_exchange.py +++ b/tests/test_server_36_oauth2_token_exchange.py @@ -229,7 +229,7 @@ def _mint_code(self, grant, client_id): # Constructing an authorization code is now done _code = grant.mint_token( session_id=session_id, - endpoint_context=self.endpoint.server_get("endpoint_context"), + endpoint_context=self.endpoint.server_get("context"), token_class="authorization_code", token_handler=self.session_manager.token_handler["authorization_code"], usage_rules=usage_rules, diff --git a/tests/test_server_50_persistence.py b/tests/test_server_50_persistence.py index 52570e68..46c521e7 100644 --- a/tests/test_server_50_persistence.py +++ b/tests/test_server_50_persistence.py @@ -254,7 +254,7 @@ def _mint_code(self, grant, session_id, index=1): # Constructing an authorization code is now done _code = grant.mint_token( session_id, - endpoint_context=self.endpoint[index].server_get("endpoint_context"), + endpoint_context=self.endpoint[index].server_get("context"), token_class="authorization_code", token_handler=self.session_manager[index].token_handler["authorization_code"], ) @@ -272,7 +272,7 @@ def _mint_access_token(self, grant, session_id, token_ref=None, index=1): _token = grant.mint_token( session_id=session_id, - endpoint_context=self.endpoint[index].server_get("endpoint_context"), + endpoint_context=self.endpoint[index].server_get("context"), token_class="access_token", token_handler=self.session_manager[index].token_handler["access_token"], based_on=token_ref, # Means the token (tok) was used to mint this token diff --git a/tests/test_server_61_add_on.py b/tests/test_server_61_add_on.py index 366630eb..b83a5ea3 100644 --- a/tests/test_server_61_add_on.py +++ b/tests/test_server_61_add_on.py @@ -148,7 +148,7 @@ def create_endpoint(self): self.endpoint = server.server_get("endpoint", "authorization") def test_process_request(self): - _context = self.endpoint.server_get("endpoint_context") + _context = self.endpoint.server_get("context") assert _context.add_on["extra_args"] == {"authorization": {"iss": "issuer"}} _pr_resp = self.endpoint.parse_request(AUTH_REQ) diff --git a/tests/test_y_actor_01.py b/tests/test_y_actor_01.py new file mode 100644 index 00000000..3ec9cbc2 --- /dev/null +++ b/tests/test_y_actor_01.py @@ -0,0 +1,351 @@ + +import copy +import os + +import pytest +from cryptojwt.jwt import JWT +from cryptojwt.key_jar import KeyJar +from cryptojwt.key_jar import init_key_jar + +from idpyoidc.actor import CIBAClient +from idpyoidc.actor import CIBAServer +from idpyoidc.client.entity import Entity +from idpyoidc.message.oidc.backchannel_authentication import AuthenticationRequest +from idpyoidc.server import OPConfiguration +from idpyoidc.server import Server +from idpyoidc.server.authn_event import create_authn_event +from idpyoidc.server.client_authn import verify_client +from idpyoidc.server.oidc.backchannel_authentication import BackChannelAuthentication +from idpyoidc.server.oidc.backchannel_authentication import ClientNotification +from idpyoidc.server.oidc.token import Token +from idpyoidc.server.user_authn.authn_context import MOBILETWOFACTORCONTRACT +from idpyoidc.util import rndstr +from tests import CRYPT_CONFIG +from tests import SESSION_PARAMS + +BASEDIR = os.path.abspath(os.path.dirname(__file__)) +ISSUER_1 = "https://example.com/actor1" +ISSUER_2 = "https://example.com/actor2" + +KEYSPEC = [ + {"type": "RSA", "use": ["sig"]}, + {"type": "EC", "crv": "P-256", "use": ["sig"]}, +] + +RESPONSE_TYPES_SUPPORTED = [ + ["code"], + ["token"], + ["id_token"], + ["code", "token"], + ["code", "id_token"], + ["id_token", "token"], + ["code", "token", "id_token"], + ["none"], +] + +CAPABILITIES = { + "response_types_supported": [" ".join(x) for x in RESPONSE_TYPES_SUPPORTED], + "token_endpoint_auth_methods_supported": [ + "client_secret_post", + "client_secret_basic", + "client_secret_jwt", + "private_key_jwt", + ], + "response_modes_supported": ["query", "fragment", "form_post"], + "subject_types_supported": ["public", "pairwise", "ephemeral"], + "grant_types_supported": [ + "authorization_code", + "implicit", + "urn:ietf:params:oauth:grant-type:jwt-bearer", + ], + "claim_types_supported": ["normal", "aggregated", "distributed"], + "claims_parameter_supported": True, + "request_parameter_supported": True, + "request_uri_parameter_supported": True, +} + +SERVER_CONFIG = { + "httpc_params": {"verify": False, "timeout": 1}, + "capabilities": CAPABILITIES, + "keys": {"uri_path": "jwks.json", "key_defs": KEYSPEC}, + "token_handler_args": { + "jwks_file": "private/token_jwks.json", + "code": {"lifetime": 600, "kwargs": {"crypt_conf": CRYPT_CONFIG}}, + "token": { + "class": "idpyoidc.server.token.jwt_token.JWTToken", + "kwargs": { + "lifetime": 3600, + "base_claims": {"eduperson_scoped_affiliation": None}, + "add_claims_by_scope": True, + }, + }, + "refresh": { + "class": "idpyoidc.server.token.jwt_token.JWTToken", + "kwargs": {"lifetime": 3600}, + }, + "id_token": { + "class": "idpyoidc.server.token.id_token.IDToken", + "kwargs": { + "base_claims": { + "email": {"essential": True}, + "email_verified": {"essential": True}, + } + }, + }, + }, + "endpoint": { + "token": {"path": "token", "class": Token, "kwargs": {}}, + }, + "client_authn": verify_client, + "session_params": SESSION_PARAMS, +} + + +def _create_client(issuer, client_id, service): + client_config = { + "issuer": issuer, + "client_id": client_id, + "client_secret": rndstr(24), + "redirect_uris": [f"https://example.com/{client_id}/authz_cb"], + "behaviour": {"response_types": ["code"]}, + "client_authn_methods": { + "client_notification_authn": "idpyoidc.client.oidc.backchannel_authentication.ClientNotificationAuthn" + }, + } + _services = { + "discovery": { + "class": "idpyoidc.client.oidc.provider_info_discovery.ProviderInfoDiscovery" + }, + "registration": {"class": "idpyoidc.client.oidc.registration.Registration"}, + } + _services.update(service) + + _cli_1_key = init_key_jar(key_defs=KEYSPEC) + + return Entity(config=client_config, services=_services, keyjar=_cli_1_key) + + +def _create_server(issuer, endpoint, port, extra_conf=None): + _config = copy.deepcopy(SERVER_CONFIG) + _config["issuer"] = issuer + _config["endpoint"].update(endpoint) + if extra_conf: + _config.update(extra_conf) + + return Server(OPConfiguration(conf=_config, base_path=BASEDIR, domain="127.0.0.1", port=port)) + + +# Locally defined +def parse_login_hint_token(keyjar: KeyJar, login_hint_token: str, context=None) -> str: + _jwt = JWT(keyjar) + _info = _jwt.unpack(login_hint_token) + # here comes the special knowledge + _sub_id = _info.get("sub_id") + _sub = "" + if _sub_id: + if _sub_id["format"] == "phone": + _sub = "tel:" + _sub_id["phone"] + elif _sub_id["format"] == "mail": + _sub = "mail:" + _sub_id["mail"] + + if _sub and context and context.login_hint_lookup: + try: + _sub = context.login_hint_lookup(_sub) + except KeyError: + _sub = "" + + return _sub + + +class TestPushActor: + @pytest.fixture(autouse=True) + def create_actor(self): + # ============== ACTOR 1 ============== + # Actor 1 can use Authentication Service and provides a Client Notification Endpoint + actor_1 = CIBAClient() + actor_1.client = _create_client( + ISSUER_2, + "actor1", + { + "authentication": { + "class": "idpyoidc.client.oidc.backchannel_authentication.BackChannelAuthentication" + } + }, + ) + + endpoint = { + "client_notify": { + "path": "notify", + "class": ClientNotification, + "kwargs": {"client_authn_method": ["client_notification_authn"]}, + } + } + extra = { + "client_authn_methods": { + "client_notification_authn": "idpyoidc.server.oidc.backchannel_authentication.ClientNotificationAuthn" + } + } + + actor_1.server = _create_server(ISSUER_1, endpoint, 6000, extra_conf=extra) + + self.actor_1 = actor_1 + + # ============== ACTOR 2 ============== + # Provides Authentication endpoint and can use the Client notification service + actor_2 = CIBAServer() + actor_2.client = _create_client( + ISSUER_1, + "actor2", + { + "notification": { + "class": "idpyoidc.client.oidc.backchannel_authentication.ClientNotification" + } + }, + ) + endpoint = { + "backchannel_authentication": { + "path": "authentication", + "class": BackChannelAuthentication, + "kwargs": { + "client_authn_method": [ + "client_secret_basic", + "client_secret_post", + "client_secret_jwt", + "private_key_jwt", + ], + "parse_login_hint_token": {"func": parse_login_hint_token}, + }, + } + } + extra = { + "login_hint_lookup": {"class": "idpyoidc.server.login_hint.LoginHintLookup"}, + "userinfo": { + "class": "idpyoidc.server.user_info.UserInfo", + "kwargs": {"db_file": "users.json"}, + }, + } + actor_2.server = _create_server(ISSUER_2, endpoint, 7000, extra) + + # register clients with servers. + _server_context = actor_1.server.server_get("context") + _client_context = actor_2.client.client_get("service_context") + _server_context.cdb = { + _client_context.client_id: { + "client_secret": _client_context.client_secret, + }, + actor_2.server.server_get("context").issuer: { + "client_secret": _client_context.client_secret + }, + } + _server_context = actor_2.server.server_get("context") + _client_context = actor_1.client.client_get("service_context") + _server_context.cdb = { + _client_context.client_id: {"client_secret": _client_context.client_secret}, + actor_1.server.server_get("context").issuer: { + "client_secret": _client_context.client_secret + }, + } + + # Transfer provider metadata 1->2 and 2->1 + _client_context = actor_2.client.client_get("service_context") + _server_context = actor_1.server.server_get("context") + _client_context.provider_info = _server_context.provider_info + + _client_context = actor_1.client.client_get("service_context") + _server_context = actor_2.server.server_get("context") + _client_context.provider_info = _server_context.provider_info + + _server_context.parse_login_hint_token = parse_login_hint_token + + # keys + _client_keyjar = actor_2.client.client_get("service_context").keyjar + _server_keyjar = actor_1.server.server_get("context").keyjar + _server_keyjar.import_jwks(_client_keyjar.export_jwks(), "actor2") + _client_keyjar.import_jwks(_server_keyjar.export_jwks(), ISSUER_1) + + _client_keyjar = actor_1.client.client_get("service_context").keyjar + _server_keyjar = actor_2.server.server_get("context").keyjar + _server_keyjar.import_jwks(_client_keyjar.export_jwks(), "actor1") + _client_keyjar.import_jwks(_server_keyjar.export_jwks(), ISSUER_2) + + self.actor_1 = actor_1 + self.actor_2 = actor_2 + + def _create_session( + self, server, user_id, auth_req, sub_type="public", sector_identifier="", authn_info="" + ): + if sector_identifier: + authz_req = auth_req.copy() + authz_req["sector_identifier_uri"] = sector_identifier + else: + authz_req = auth_req + client_id = authz_req["client_id"] + ae = create_authn_event(user_id, authn_info=authn_info) + _session_manager = server.endpoint_context.session_manager + return _session_manager.create_session( + ae, authz_req, user_id, client_id=client_id, sub_type=sub_type + ) + + def test_init(self): + assert self.actor_1.client + assert self.actor_2.client + assert self.actor_1.server + assert self.actor_2.server + + def test_query(self): + _req = self.actor_1.create_authentication_request( + scope="openid email example-scope", + binding_message="W4SCT", + login_hint="mail:diana@example.org", + ) + assert _req + assert _req["method"] == "GET" + assert isinstance(_req["request"], AuthenticationRequest) + assert _req["request"]["login_hint"] == "mail:diana@example.org" + + # On the CIBA server side + _endpoint = self.actor_2.server.server_get("endpoint", "backchannel_authentication") + _request = _endpoint.parse_request(_req["request"].to_urlencoded()) + assert _request + # If ping mode + assert "client_notification_token" in _request + req_user = _endpoint.do_request_user(_request) + assert req_user == "diana" + # Construct response to the authentication request + _info = _endpoint.process_request(_request) + assert _info + + # User interaction with the authentication device returns some authentication info + + session_id_2 = self._create_session( + self.actor_2.server, req_user, _request, authn_info=MOBILETWOFACTORCONTRACT + ) + + # Create fake token response + token_request = { + "grant_type": "urn:openid:params:grant-type:ciba", + "auth_req_id": _info["response_args"]["auth_req_id"], + "client_id": "actor1", + } + _token_endpoint = self.actor_2.server.server_get("endpoint", "token") + _treq = _token_endpoint.parse_request(token_request) + # Construct response to the authentication request + _tinfo = _token_endpoint.process_request(_treq) + assert _tinfo + + # Send the response to the client notification endpoint + + _tinfo["response_args"]["client_notification_token"] = _request["client_notification_token"] + _notification_service = self.actor_2.client.client_get("service", "client_notification") + _not_req = _notification_service.get_request_parameters( + request_args=_tinfo["response_args"], authn_method="client_notification_authn" + ) + + assert _not_req + + # The receiver of the notification + + _ninfo = self.actor_1.do_client_notification( + _not_req["body"], http_info={"headers": _not_req["headers"]} + ) + assert _ninfo is None \ No newline at end of file From 9749b22c1ff6f4ce82538b3a3a2ee994443d20fe Mon Sep 17 00:00:00 2001 From: roland Date: Sat, 15 Oct 2022 08:35:54 +0200 Subject: [PATCH 41/76] Fedservice support --- src/idpyoidc/client/entity.py | 7 +------ src/idpyoidc/client/oauth2/__init__.py | 9 ++++++--- src/idpyoidc/client/oidc/__init__.py | 25 +++++++++++++++---------- src/idpyoidc/message/__init__.py | 3 +++ 4 files changed, 25 insertions(+), 19 deletions(-) diff --git a/src/idpyoidc/client/entity.py b/src/idpyoidc/client/entity.py index 9849d785..7a1b2a48 100644 --- a/src/idpyoidc/client/entity.py +++ b/src/idpyoidc/client/entity.py @@ -96,12 +96,7 @@ def __init__( _srvs = None if not _srvs: - if services: - _srvs = services - elif client_type == "oauth2": - _srvs = DEFAULT_OAUTH2_SERVICES - else: - _srvs = DEFAULT_OIDC_SERVICES + _srvs = DEFAULT_OAUTH2_SERVICES self._service = init_services(service_definitions=_srvs, superior_get=self.entity_get) diff --git a/src/idpyoidc/client/oauth2/__init__.py b/src/idpyoidc/client/oauth2/__init__.py index 6287abc0..170822db 100755 --- a/src/idpyoidc/client/oauth2/__init__.py +++ b/src/idpyoidc/client/oauth2/__init__.py @@ -1,5 +1,6 @@ import logging from json import JSONDecodeError +from typing import Callable from typing import Optional from idpyoidc.client.entity import Entity @@ -37,9 +38,10 @@ def __init__( keyjar=None, verify_ssl=True, config=None, - httplib=None, + httpc=None, services=None, httpc_params=None, + superior_get: Optional[Callable] = None, client_type: Optional[str] = "" **kwargs ): @@ -51,7 +53,7 @@ def __init__( :param config: Configuration information passed on to the :py:class:`idpyoidc.client.service_context.ServiceContext` initialization - :param httplib: A HTTP client to use + :param httpc: A HTTP client to use :param services: A list of service definitions :param httpc_params: HTTP request arguments :return: Client instance @@ -66,10 +68,11 @@ def __init__( config=config, services=services, httpc_params=httpc_params, + superior_get=superior_get client_type=client_type ) - self.http = httplib or HTTPLib(httpc_params) + self.http = httpc or HTTPLib(httpc_params) if isinstance(config, Configuration): _add_ons = config.conf.get("add_ons") diff --git a/src/idpyoidc/client/oidc/__init__.py b/src/idpyoidc/client/oidc/__init__.py index 1e264081..759ede90 100755 --- a/src/idpyoidc/client/oidc/__init__.py +++ b/src/idpyoidc/client/oidc/__init__.py @@ -1,5 +1,7 @@ import json import logging +from typing import Callable +from typing import Optional from idpyoidc.client import oauth2 from idpyoidc.client.client_auth import BearerHeader @@ -72,17 +74,19 @@ class FetchException(Exception): class RP(oauth2.Client): + def __init__( - self, - keyjar=None, - verify_ssl=True, - config=None, - httplib=None, - services=None, - httpc_params=None, - **kwargs + self, + keyjar=None, + verify_ssl=True, + config=None, + httpc=None, + services=None, + httpc_params=None, + superior_get: Optional[Callable] = None, + **kwargs ): - + self.superior_get = superior_get _srvs = services or DEFAULT_OIDC_SERVICES oauth2.Client.__init__( @@ -90,10 +94,11 @@ def __init__( keyjar=keyjar, verify_ssl=verify_ssl, config=config, - httplib=httplib, + httpc=httpc, services=_srvs, httpc_params=httpc_params, client_type="oidc", + superior_get=superior_get, **kwargs ) diff --git a/src/idpyoidc/message/__init__.py b/src/idpyoidc/message/__init__.py index 234cb014..010155fb 100644 --- a/src/idpyoidc/message/__init__.py +++ b/src/idpyoidc/message/__init__.py @@ -673,6 +673,9 @@ def request(self, location, fragment_enc=False): """ _l = as_unicode(location) _qp = as_unicode(self.to_urlencoded()) + if not _qp: + return _l + if fragment_enc: return "%s#%s" % (_l, _qp) else: From c1f0f2f5f470ae3cb274c2b26577848e3da101bb Mon Sep 17 00:00:00 2001 From: Roland Hedberg Date: Sun, 4 Dec 2022 12:24:30 +0100 Subject: [PATCH 42/76] Rebased onto improved --- example/flask_op/views.py | 32 +- example/flask_rp/views.py | 10 +- pyproject.toml | 2 +- src/idpyoidc/__init__.py | 2 +- src/idpyoidc/actor/__init__.py | 6 +- src/idpyoidc/client/client_auth.py | 49 +- src/idpyoidc/client/entity.py | 99 +-- src/idpyoidc/client/oauth2/__init__.py | 108 +-- src/idpyoidc/client/oauth2/access_token.py | 8 +- src/idpyoidc/client/oauth2/add_on/dpop.py | 2 +- .../oauth2/add_on/identity_assurance.py | 2 +- src/idpyoidc/client/oauth2/add_on/pkce.py | 6 +- .../oauth2/add_on/pushed_authorization.py | 7 +- src/idpyoidc/client/oauth2/authorization.py | 12 +- .../client_credentials/cc_access_token.py | 6 +- .../cc_refresh_access_token.py | 8 +- .../client/oauth2/refresh_access_token.py | 8 +- src/idpyoidc/client/oauth2/server_metadata.py | 16 +- src/idpyoidc/client/oauth2/utils.py | 4 +- src/idpyoidc/client/oidc/__init__.py | 43 +- src/idpyoidc/client/oidc/access_token.py | 23 +- src/idpyoidc/client/oidc/authorization.py | 28 +- .../client/oidc/backchannel_authentication.py | 8 +- src/idpyoidc/client/oidc/check_id.py | 6 +- src/idpyoidc/client/oidc/check_session.py | 6 +- src/idpyoidc/client/oidc/end_session.py | 20 +- .../client/oidc/provider_info_discovery.py | 13 +- src/idpyoidc/client/oidc/read_registration.py | 4 +- .../client/oidc/refresh_access_token.py | 2 +- src/idpyoidc/client/oidc/registration.py | 8 +- src/idpyoidc/client/oidc/userinfo.py | 18 +- src/idpyoidc/client/oidc/utils.py | 5 +- src/idpyoidc/client/oidc/webfinger.py | 8 +- src/idpyoidc/client/rp_handler.py | 57 +- src/idpyoidc/client/service.py | 37 +- src/idpyoidc/client/service_context.py | 5 +- src/idpyoidc/node.py | 154 +++++ src/idpyoidc/server/__init__.py | 44 +- src/idpyoidc/server/authz/__init__.py | 16 +- src/idpyoidc/server/client_authn.py | 82 ++- src/idpyoidc/server/client_configure.py | 4 +- src/idpyoidc/server/configure.py | 2 +- src/idpyoidc/server/endpoint.py | 33 +- src/idpyoidc/server/endpoint_context.py | 19 +- src/idpyoidc/server/login_hint.py | 8 +- src/idpyoidc/server/oauth2/add_on/dpop.py | 22 +- .../server/oauth2/add_on/extra_args.py | 10 +- src/idpyoidc/server/oauth2/authorization.py | 89 +-- src/idpyoidc/server/oauth2/introspection.py | 8 +- .../server/oauth2/pushed_authorization.py | 6 +- src/idpyoidc/server/oauth2/token.py | 6 +- src/idpyoidc/server/oauth2/token_helper.py | 36 +- .../server/oidc/add_on/custom_scopes.py | 2 +- src/idpyoidc/server/oidc/add_on/pkce.py | 16 +- src/idpyoidc/server/oidc/authorization.py | 6 +- .../server/oidc/backchannel_authentication.py | 26 +- src/idpyoidc/server/oidc/discovery.py | 2 +- src/idpyoidc/server/oidc/provider_config.py | 10 +- src/idpyoidc/server/oidc/read_registration.py | 6 +- src/idpyoidc/server/oidc/registration.py | 41 +- src/idpyoidc/server/oidc/session.py | 37 +- src/idpyoidc/server/oidc/token_helper.py | 34 +- src/idpyoidc/server/oidc/userinfo.py | 16 +- src/idpyoidc/server/scopes.py | 8 +- src/idpyoidc/server/session/claims.py | 34 +- src/idpyoidc/server/session/grant.py | 28 +- src/idpyoidc/server/session/manager.py | 4 +- src/idpyoidc/server/token/handler.py | 12 +- src/idpyoidc/server/token/id_token.py | 44 +- src/idpyoidc/server/token/jwt_token.py | 39 +- .../server/user_authn/authn_context.py | 24 +- src/idpyoidc/server/user_authn/user.py | 33 +- src/idpyoidc/server/util.py | 18 +- tests/request123456.jwt | 2 +- tests/test_12_context.py | 19 + tests/test_client_02_entity.py | 5 - tests/test_client_04_service.py | 14 +- tests/test_client_06_client_authn.py | 71 +- tests/test_client_12_client_auth.py | 71 +- tests/test_client_13_service_context.py | 0 .../test_client_14_service_context_impexp.py | 14 +- tests/test_client_18_service.py | 3 +- tests/test_client_19_webfinger.py | 20 +- tests/test_client_20_oauth2.py | 28 +- tests/test_client_21_oidc_service.py | 82 +-- tests/test_client_22_oidc.py | 16 +- tests/test_client_23_pkce.py | 23 +- tests/test_client_25_cc_oauth2_service.py | 30 +- tests/test_client_26_read_registration.py | 4 +- tests/test_client_27_conversation.py | 20 +- tests/test_client_28_rp_handler_oidc.py | 68 +- tests/test_client_29_pushed_auth.py | 6 +- tests/test_client_30_rph_defaults.py | 17 +- tests/test_client_31_oauth2_persistent.py | 24 +- tests/test_client_32_oidc_persistent.py | 46 +- tests/test_client_40_dpop.py | 10 +- tests/test_client_41_rp_handler_persistent.py | 52 +- tests/test_client_50_ciba.py | 2 +- tests/test_client_51_identity_assurance.py | 10 +- tests/test_server_00a_client_configure.py | 4 +- tests/test_server_01_claims.py | 5 +- tests/test_server_03_authz_handling.py | 6 +- tests/test_server_06_grant.py | 68 +- tests/test_server_08_id_token.py | 6 +- tests/test_server_09_authn_context.py | 4 +- tests/test_server_10_session_manager.py | 2 +- tests/test_server_12_session_life.py | 16 +- tests/test_server_13_user_authn.py | 6 +- tests/test_server_16_endpoint.py | 6 +- tests/test_server_17_client_authn.py | 73 ++- tests/test_server_20b_claims.py | 16 +- tests/test_server_20d_client_authn.py | 58 +- tests/test_server_20e_jwt_token.py | 8 +- tests/test_server_20f_userinfo.py | 4 +- .../test_server_21_oidc_discovery_endpoint.py | 2 +- ...server_22_oidc_provider_config_endpoint.py | 4 +- ...st_server_23_oidc_registration_endpoint.py | 2 +- ...server_24_oauth2_authorization_endpoint.py | 46 +- ...er_24_oauth2_authorization_endpoint_jar.py | 6 +- tests/test_server_24_oauth2_token_endpoint.py | 10 +- ...t_server_24_oidc_authorization_endpoint.py | 74 +-- .../test_server_26_oidc_userinfo_endpoint.py | 62 +- tests/test_server_30_oidc_end_session.py | 70 +- tests/test_server_31_oauth2_introspection.py | 42 +- .../test_server_32_oidc_read_registration.py | 4 +- tests/test_server_33_oauth2_pkce.py | 30 +- tests/test_server_34_oidc_sso.py | 20 +- tests/test_server_35_oidc_token_endpoint.py | 16 +- tests/test_server_36_oauth2_token_exchange.py | 6 +- ...t_server_40_oauth2_pushed_authorization.py | 4 +- tests/test_server_50_persistence.py | 34 +- tests/test_server_60_dpop.py | 4 +- tests/test_server_61_add_on.py | 4 +- tests/test_y_actor_01.py | 351 ---------- tests/x_test_ciba_01_backchannel_auth.py | 617 ++++++++++++++++++ 135 files changed, 2259 insertions(+), 1745 deletions(-) create mode 100644 src/idpyoidc/node.py create mode 100644 tests/test_12_context.py create mode 100644 tests/test_client_13_service_context.py create mode 100644 tests/x_test_ciba_01_backchannel_auth.py diff --git a/example/flask_op/views.py b/example/flask_op/views.py index c45dce04..5da08b0a 100644 --- a/example/flask_op/views.py +++ b/example/flask_op/views.py @@ -119,7 +119,7 @@ def verify(authn_method): auth_args = authn_method.unpack_token(kwargs['token']) authz_request = AuthorizationRequest().from_urlencoded(auth_args['query']) - endpoint = current_app.server.server_get("endpoint", 'authorization') + endpoint = current_app.server.upstream_get("endpoint", 'authorization') _session_id = endpoint.create_session(authz_request, username, auth_args['authn_class_ref'], auth_args['iat'], authn_method) @@ -133,7 +133,7 @@ def verify(authn_method): @oidc_op_views.route('/verify/user', methods=['GET', 'POST']) def verify_user(): - authn_method = current_app.server.server_get( + authn_method = current_app.server.upstream_get( "endpoint_context").authn_broker.get_method_by_id('user') try: return verify(authn_method) @@ -143,7 +143,7 @@ def verify_user(): @oidc_op_views.route('/verify/user_pass_jinja', methods=['GET', 'POST']) def verify_user_pass_jinja(): - authn_method = current_app.server.server_get( + authn_method = current_app.server.upstream_get( "endpoint_context").authn_broker.get_method_by_id('user') try: return verify(authn_method) @@ -154,9 +154,9 @@ def verify_user_pass_jinja(): @oidc_op_views.route('/.well-known/') def well_known(service): if service == 'openid-configuration': - _endpoint = current_app.server.server_get("endpoint", 'provider_config') + _endpoint = current_app.server.upstream_get("endpoint", 'provider_config') elif service == 'webfinger': - _endpoint = current_app.server.server_get("endpoint", 'discovery') + _endpoint = current_app.server.upstream_get("endpoint", 'discovery') else: return make_response('Not supported', 400) @@ -166,45 +166,45 @@ def well_known(service): @oidc_op_views.route('/registration', methods=['GET', 'POST']) def registration(): return service_endpoint( - current_app.server.server_get("endpoint", 'registration')) + current_app.server.upstream_get("endpoint", 'registration')) @oidc_op_views.route('/registration_api', methods=['GET', 'DELETE']) def registration_api(): if request.method == "DELETE": return service_endpoint( - current_app.server.server_get("endpoint", 'registration_delete')) + current_app.server.upstream_get("endpoint", 'registration_delete')) else: return service_endpoint( - current_app.server.server_get("endpoint", 'registration_read')) + current_app.server.upstream_get("endpoint", 'registration_read')) @oidc_op_views.route('/authorization') def authorization(): return service_endpoint( - current_app.server.server_get("endpoint", 'authorization')) + current_app.server.upstream_get("endpoint", 'authorization')) @oidc_op_views.route('/token', methods=['GET', 'POST']) def token(): return service_endpoint( - current_app.server.server_get("endpoint", 'token')) + current_app.server.upstream_get("endpoint", 'token')) @oidc_op_views.route('/introspection', methods=['POST']) def introspection_endpoint(): return service_endpoint( - current_app.server.server_get("endpoint", 'introspection')) + current_app.server.upstream_get("endpoint", 'introspection')) @oidc_op_views.route('/userinfo', methods=['GET', 'POST']) def userinfo(): return service_endpoint( - current_app.server.server_get("endpoint", 'userinfo')) + current_app.server.upstream_get("endpoint", 'userinfo')) @oidc_op_views.route('/session', methods=['GET']) def session_endpoint(): return service_endpoint( - current_app.server.server_get("endpoint", 'session')) + current_app.server.upstream_get("endpoint", 'session')) IGNORE = ["cookie", "user-agent"] @@ -298,7 +298,7 @@ def check_session_iframe(): req_args = dict([(k, v) for k, v in request.form.items()]) if req_args: - _context = current_app.server.server_get("endpoint_context") + _context = current_app.server.upstream_get("endpoint_context") # will contain client_id and origin if req_args['origin'] != _context.issuer: return 'error' @@ -314,7 +314,7 @@ def check_session_iframe(): @oidc_op_views.route('/verify_logout', methods=['GET', 'POST']) def verify_logout(): - part = urlparse(current_app.server.server_get("endpoint_context").issuer) + part = urlparse(current_app.server.upstream_get("endpoint_context").issuer) page = render_template('logout.html', op=part.hostname, do_logout='rp_logout', sjwt=request.args['sjwt']) return page @@ -322,7 +322,7 @@ def verify_logout(): @oidc_op_views.route('/rp_logout', methods=['GET', 'POST']) def rp_logout(): - _endp = current_app.server.server_get("endpoint", 'session') + _endp = current_app.server.upstream_get("endpoint", 'session') _info = _endp.unpack_signed_jwt(request.form['sjwt']) try: request.form['logout'] diff --git a/example/flask_rp/views.py b/example/flask_rp/views.py index c5ede9d5..e3f7b64f 100644 --- a/example/flask_rp/views.py +++ b/example/flask_rp/views.py @@ -100,7 +100,7 @@ def finalize(op_identifier, request_args): logger.error(rp.response[0].decode()) return rp.response[0], rp.status_code - _context = rp.client_get("service_context") + _context = rp.client_get("context") session['client_id'] = _context.get('client_id') session['state'] = request_args.get('state') @@ -123,7 +123,7 @@ def finalize(op_identifier, request_args): raise excp if 'userinfo' in res: - _context = rp.client_get("service_context") + _context = rp.client_get("context") endpoints = {} for k, v in _context.provider_info.items(): if k.endswith('_endpoint'): @@ -197,7 +197,7 @@ def session_iframe(): # session management logger.debug('session_iframe request_args: {}'.format(request.args)) _rp = get_rp(session['op_identifier']) - _context = _rp.client_get("service_context") + _context = _rp.client_get("context") session_change_url = "{}/session_change".format(_context.base_url) _issuer = current_app.rph.hash2issuer[session['op_identifier']] @@ -237,7 +237,7 @@ def session_change(): def session_logout(op_identifier): _rp = get_rp(op_identifier) logger.debug('post_logout') - return "Post logout from {}".format(_rp.client_get("service_context").issuer) + return "Post logout from {}".format(_rp.client_get("context").issuer) # RP initiated logout @@ -267,7 +267,7 @@ def frontchannel_logout(op_identifier): _rp = get_rp(op_identifier) sid = request.args['sid'] _iss = request.args['iss'] - if _iss != _rp.client_get("service_context").get('issuer'): + if _iss != _rp.client_get("context").get('issuer'): return 'Bad request', 400 _state = _rp.session_interface.get_state_by_sid(sid) _rp.session_interface.remove_state(_state) diff --git a/pyproject.toml b/pyproject.toml index c0a9cd4d..32305ba9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta" [metadata] name = "idpyoidc" -version = "1.4.0" +version = "2.0.0" author = "Roland Hedberg" author_email = "roland@catalogix.se" description = "Everything OAuth2 and OIDC" diff --git a/src/idpyoidc/__init__.py b/src/idpyoidc/__init__.py index 5b03c94b..1ca2d4a7 100644 --- a/src/idpyoidc/__init__.py +++ b/src/idpyoidc/__init__.py @@ -1,5 +1,5 @@ __author__ = "Roland Hedberg" -__version__ = "1.4.0" +__version__ = "2.0.0" import os from typing import Dict diff --git a/src/idpyoidc/actor/__init__.py b/src/idpyoidc/actor/__init__.py index cd62398a..6132b08c 100644 --- a/src/idpyoidc/actor/__init__.py +++ b/src/idpyoidc/actor/__init__.py @@ -20,7 +20,7 @@ def __init__( self.context = {} def create_authentication_request(self, scope, binding_message, login_hint): - _service = self.client.superior_get("service", "backchannel_authentication") + _service = self.client.upstream_get("service", "backchannel_authentication") client_notification_token = uuid4().hex @@ -36,7 +36,7 @@ def create_authentication_request(self, scope, binding_message, login_hint): self.context[client_notification_token] = { "authentication_request": request, - "client_id": _service.superior_get("context").issuer, + "client_id": _service.upstream_get("context").issuer, } return request @@ -45,7 +45,7 @@ def get_client_id_from_token(self, token): return _context["client_id"] def do_client_notification(self, msg, http_info): - _notification_endpoint = self.server.server_get("endpoint", "client_notification") + _notification_endpoint = self.server.upstream_get("endpoint", "client_notification") _nreq = _notification_endpoint.parse_request( msg, http_info, get_client_id_from_token=self.get_client_id_from_token ) diff --git a/src/idpyoidc/client/client_auth.py b/src/idpyoidc/client/client_auth.py index b04cf595..760812f3 100755 --- a/src/idpyoidc/client/client_auth.py +++ b/src/idpyoidc/client/client_auth.py @@ -95,7 +95,7 @@ def _get_passwd(request, service, **kwargs): try: passwd = request["client_secret"] except KeyError: - passwd = service.superior_get("context").get_usage('client_secret') + passwd = service.upstream_get("context").get_usage('client_secret') return passwd @staticmethod @@ -103,7 +103,7 @@ def _get_user(service, **kwargs): try: user = kwargs["user"] except KeyError: - user = service.superior_get("context").get_client_id() + user = service.upstream_get("context").get_client_id() return user def _get_authentication_token(self, request, service, **kwargs): @@ -138,7 +138,7 @@ def _with_or_without_client_id(request, service): ): if "client_id" not in request: try: - request["client_id"] = service.superior_get("context").get_client_id() + request["client_id"] = service.upstream_get("context").get_client_id() except AttributeError: pass else: @@ -215,7 +215,7 @@ def modify_request(self, request, service, **kwargs): :param request: The request :param service: The service that is using this authentication method """ - _context = service.superior_get("context") + _context = service.upstream_get("context") if "client_secret" not in request: try: request["client_secret"] = kwargs["client_secret"] @@ -272,7 +272,7 @@ def find_token(request, token_type, service, **kwargs): except KeyError: # Get the latest acquired access token. _state = kwargs.get("state", kwargs.get("key")) - _arg = service.superior_get("context").cstate.get_set(_state, claim=[token_type]) + _arg = service.upstream_get("context").cstate.get_set(_state, claim=[token_type]) return _arg.get("access_token") @@ -285,7 +285,7 @@ def construct(self, request=None, service=None, http_args=None, **kwargs): the Authorization header is "Bearer ". :param request: Request class instance - :param service: Service + :param service: The service this authentication method applies to. :param http_args: HTTP header arguments :param kwargs: extra keyword arguments :return: @@ -399,7 +399,7 @@ def choose_algorithm(context, **kwargs): return algorithm @staticmethod - def get_signing_key_from_keyjar(algorithm, service_context): + def get_signing_key_from_keyjar(algorithm, keyjar): """ Pick signing key based on signing algorithm to be used @@ -408,10 +408,10 @@ def get_signing_key_from_keyjar(algorithm, service_context): instance :return: A key """ - return service_context.keyjar.get_signing_key(alg2keytype(algorithm), alg=algorithm) + return keyjar.get_signing_key(alg2keytype(algorithm), alg=algorithm) @staticmethod - def _get_key_by_kid(kid, algorithm, service_context): + def _get_key_by_kid(kid, algorithm, keyjar): """ Pick a key that matches a given key ID and signing algorithm. @@ -422,7 +422,7 @@ def _get_key_by_kid(kid, algorithm, service_context): :return: A matching key """ # signing so using my keys - for _key in service_context.keyjar.get_issuer_keys(""): + for _key in keyjar.get_issuer_keys(""): if kid == _key.kid: ktype = alg2keytype(algorithm) if _key.kty != ktype: @@ -432,20 +432,20 @@ def _get_key_by_kid(kid, algorithm, service_context): raise MissingKey("No key with kid:%s" % kid) - def _get_signing_key(self, algorithm, context, kid=None): + def _get_signing_key(self, algorithm, keyjar, key_types, kid=None): ktype = alg2keytype(algorithm) try: if kid: - signing_key = [self._get_key_by_kid(kid, algorithm, context)] - elif ktype in context.kid["sig"]: + signing_key = [self._get_key_by_kid(kid, algorithm, keyjar)] + elif ktype in key_types: try: signing_key = [ - self._get_key_by_kid(context.kid["sig"][ktype], algorithm, context) + self._get_key_by_kid(key_types[ktype], algorithm, keyjar) ] except KeyError: - signing_key = self.get_signing_key_from_keyjar(algorithm, context) + signing_key = self.get_signing_key_from_keyjar(algorithm, keyjar) else: - signing_key = self.get_signing_key_from_keyjar(algorithm, context) + signing_key = self.get_signing_key_from_keyjar(algorithm, keyjar) except (MissingKey,) as err: LOGGER.error("%s", sanitize(err)) raise @@ -482,13 +482,16 @@ def _get_audience_and_algorithm(self, context, **kwargs): return audience, algorithm def _construct_client_assertion(self, service, **kwargs): - _context = service.superior_get("context") + _context = service.upstream_get("context") + _entity = service.upstream_get("entity") + _keyjar = service.upstream_get('attribute', 'keyjar') audience, algorithm = self._get_audience_and_algorithm(_context, **kwargs) if "kid" in kwargs: - signing_key = self._get_signing_key(algorithm, _context, kid=kwargs["kid"]) + signing_key = self._get_signing_key(algorithm, _keyjar, _context.kid["sig"], + kid=kwargs["kid"]) else: - signing_key = self._get_signing_key(algorithm, _context) + signing_key = self._get_signing_key(algorithm, _keyjar, _context.kid["sig"]) if not signing_key: raise UnsupportedAlgorithm(algorithm) @@ -564,8 +567,8 @@ class ClientSecretJWT(JWSAuthnMethod): def choose_algorithm(self, context="client_secret_jwt", **kwargs): return JWSAuthnMethod.choose_algorithm(context, **kwargs) - def get_signing_key_from_keyjar(self, algorithm, service_context): - return service_context.keyjar.get_signing_key(alg2keytype(algorithm), alg=algorithm) + def get_signing_key_from_keyjar(self, algorithm, keyjar): + return keyjar.get_signing_key(alg2keytype(algorithm), alg=algorithm) class PrivateKeyJWT(JWSAuthnMethod): @@ -576,8 +579,8 @@ class PrivateKeyJWT(JWSAuthnMethod): def choose_algorithm(self, context="private_key_jwt", **kwargs): return JWSAuthnMethod.choose_algorithm(context, **kwargs) - def get_signing_key_from_keyjar(self, algorithm, service_context=None): - return service_context.keyjar.get_signing_key(alg2keytype(algorithm), "", alg=algorithm) + def get_signing_key_from_keyjar(self, algorithm, keyjar): + return keyjar.get_signing_key(alg2keytype(algorithm), "", alg=algorithm) # Map from client authentication identifiers to corresponding class diff --git a/src/idpyoidc/client/entity.py b/src/idpyoidc/client/entity.py index 7a1b2a48..7f1739aa 100644 --- a/src/idpyoidc/client/entity.py +++ b/src/idpyoidc/client/entity.py @@ -5,6 +5,7 @@ from cryptojwt import KeyJar from cryptojwt.key_jar import init_key_jar +from idpyoidc.client.client_auth import CLIENT_AUTHN_METHOD from idpyoidc.client.client_auth import client_auth_setup from idpyoidc.client.configure import Configuration from idpyoidc.client.configure import get_configuration @@ -12,6 +13,8 @@ from idpyoidc.client.defaults import DEFAULT_OIDC_SERVICES from idpyoidc.client.service import init_services from idpyoidc.client.service_context import ServiceContext +from idpyoidc.context import OidcContext +from idpyoidc.node import Unit logger = logging.getLogger(__name__) @@ -70,54 +73,47 @@ def redirect_uris_from_callback_uris(callback_uris): return res -class Entity(object): - +class Entity(Unit): def __init__( self, keyjar: Optional[KeyJar] = None, config: Optional[Union[dict, Configuration]] = None, services: Optional[dict] = None, + jwks_uri: Optional[str] = "", + httpc: Optional[Callable] = None, httpc_params: Optional[dict] = None, client_type: Optional[str] = "oauth2", context: Optional[OidcContext] = None, - superior_get: Optional[Callable] = None + upstream_get: Optional[Callable] = None, + key_conf: Optional[dict] = None, + entity_id: Optional[str] = '' ): - self.extra = {} - if httpc_params: - self.httpc_params = httpc_params - else: - self.httpc_params = {"verify": True} + Unit.__init__(self, upstream_get=upstream_get, keyjar=keyjar, httpc=httpc, + httpc_params=httpc_params, config=config, key_conf=key_conf, + entity_id=entity_id) - config = get_configuration(config) + if context: + self._service_context = context + else: + self._service_context = ServiceContext(config=config, jwks_uri=jwks_uri, + upstream_get=self.unit_get) - if config: - _srvs = config.conf.get("services") + if services: + _srvs = services + elif config: + _srvs = config.get("services") else: _srvs = None if not _srvs: _srvs = DEFAULT_OAUTH2_SERVICES - self._service = init_services(service_definitions=_srvs, superior_get=self.entity_get) - - self._service_context = ServiceContext( - keyjar=keyjar, config=config, httpc_params=self.httpc_params, - client_type=client_type, client_get=self.client_get - ) + self._service = init_services(service_definitions=_srvs, upstream_get=self.unit_get) self.keyjar = self._service_context.get_preference('keyjar') self.setup_client_authn_methods(config) - self.superior_get = superior_get - - # Deal with backward compatibility - self.backward_compatibility(config) - - def entity_get(self, what, *arg): - _func = getattr(self, "get_{}".format(what), None) - if _func: - return _func(*arg) - return None + self.upstream_get = upstream_get def get_services(self, *arg): return self._service @@ -151,52 +147,15 @@ def get_client_id(self): else: return self._service_context.work_environment.get_preference('client_id') - def get_keyjar(self): - if self.get_service_context().keyjar: - return self.get_service_context().keyjar - else: - return self.superior_get('application', 'keyjar') - def setup_client_authn_methods(self, config): if config and "client_authn_methods" in config: self._service_context.client_authn_method = client_auth_setup( config.get("client_authn_methods") ) else: - self._service_context.client_authn_method = {} - - def backward_compatibility(self, config): - _work_environment = self._service_context.work_environment - _uris = config.get("redirect_uris") - if _uris: - _work_environment.set_preference("redirect_uris", _uris) - - _dir = config.conf.get("requests_dir") - if _dir: - _work_environment.set_preference('requests_dir', _dir) - - _pref = config.get("client_preferences", {}) - for key, val in _pref.items(): - _work_environment.set_preference(key, val) - - auth_request_args = config.conf.get("request_args", {}) - if auth_request_args: - authz_serv = self.get_service('authorization') - authz_serv.default_request_args.update(auth_request_args) - - def config_args(self): - res = {} - for id, service in self._service.items(): - res[id] = { - "preference": service.supports(), - } - res[""] = { - "preference": self._service_context.work_environment.supports, - } - return res - - def prefers(self): - return self._service_context.work_environment.prefers() - - def use(self): - return self._service_context.work_environment.get_use() + _default_methods = set( + [s.default_authn_method for s in self._service.db.values() if + s.default_authn_method]) + _methods = {m: CLIENT_AUTHN_METHOD[m] for m in _default_methods if + m in CLIENT_AUTHN_METHOD} + self._service_context.client_authn_method = client_auth_setup(_methods) diff --git a/src/idpyoidc/client/oauth2/__init__.py b/src/idpyoidc/client/oauth2/__init__.py index 170822db..9c6239c0 100755 --- a/src/idpyoidc/client/oauth2/__init__.py +++ b/src/idpyoidc/client/oauth2/__init__.py @@ -1,19 +1,23 @@ -import logging from json import JSONDecodeError +import logging from typing import Callable from typing import Optional +from typing import Union + +from cryptojwt.key_jar import KeyJar +from requests import request from idpyoidc.client.entity import Entity from idpyoidc.client.exception import ConfigurationError from idpyoidc.client.exception import OidcServiceError from idpyoidc.client.exception import ParseError -from idpyoidc.client.http import HTTPLib from idpyoidc.client.service import REQUEST_INFO from idpyoidc.client.service import SUCCESSFUL from idpyoidc.client.service import Service from idpyoidc.client.util import do_add_ons from idpyoidc.client.util import get_deserialization_method from idpyoidc.configure import Configuration +from idpyoidc.context import OidcContext from idpyoidc.exception import FormatError from idpyoidc.message import Message from idpyoidc.message.oauth2 import is_error_message @@ -34,16 +38,20 @@ class ExpiredToken(Exception): class Client(Entity): def __init__( - self, - keyjar=None, - verify_ssl=True, - config=None, - httpc=None, - services=None, - httpc_params=None, - superior_get: Optional[Callable] = None, - client_type: Optional[str] = "" - **kwargs + self, + keyjar: Optional[KeyJar] = None, + config: Optional[Union[dict, Configuration]] = None, + services: Optional[dict] = None, + httpc: Optional[Callable] = None, + httpc_params: Optional[dict] = None, + context: Optional[OidcContext] = None, + upstream_get: Optional[Callable] = None, + key_conf: Optional[dict] = None, + entity_id: Optional[str] = '', + verify_ssl: Optional[bool] = True, + jwks_uri: Optional[str] = "", + client_type: Optional[str] = "", + **kwargs ): """ @@ -54,25 +62,38 @@ def __init__( :py:class:`idpyoidc.client.service_context.ServiceContext` initialization :param httpc: A HTTP client to use - :param services: A list of service definitions :param httpc_params: HTTP request arguments + :param services: A list of service definitions + :param jwks_uri: A jwks_uri :return: Client instance """ if not client_type: client_type = "oauth2" + if verify_ssl in False: + # just ignore verify_ssl until it goes away + if httpc_params: + httpc_params['verify'] = False + else: + httpc_params = {'verify': False} + Entity.__init__( self, keyjar=keyjar, config=config, services=services, + jwks_uri=jwks_uri, + httpc=httpc, httpc_params=httpc_params, - superior_get=superior_get - client_type=client_type + client_type=client_type, + context=context, + upstream_get=upstream_get, + key_conf=key_conf, + entity_id=entity_id ) - self.http = httpc or HTTPLib(httpc_params) + self.httpc = httpc or request if isinstance(config, Configuration): _add_ons = config.conf.get("add_ons") @@ -82,16 +103,13 @@ def __init__( if _add_ons: do_add_ons(_add_ons, self._service) - # just ignore verify_ssl until it goes away - self.verify_ssl = self.httpc_params.get("verify", True) - def do_request( - self, - request_type: str, - response_body_type: Optional[str] = "", - request_args: Optional[dict] = None, - behaviour_args: Optional[dict] = None, - **kwargs + self, + request_type: str, + response_body_type: Optional[str] = "", + request_args: Optional[dict] = None, + behaviour_args: Optional[dict] = None, + **kwargs ): _srv = self._service[request_type] @@ -114,14 +132,14 @@ def set_client_id(self, client_id): self._service_context.set("client_id", client_id) def get_response( - self, - service: Service, - url: str, - method: Optional[str] = "GET", - body: Optional[dict] = None, - response_body_type: Optional[str] = "", - headers: Optional[dict] = None, - **kwargs + self, + service: Service, + url: str, + method: Optional[str] = "GET", + body: Optional[dict] = None, + response_body_type: Optional[str] = "", + headers: Optional[dict] = None, + **kwargs ): """ @@ -134,7 +152,7 @@ def get_response( :return: """ try: - resp = self.http(url, method, data=body, headers=headers) + resp = self.httpc(url, method, data=body, headers=headers) except Exception as err: logger.error("Exception on request: {}".format(err)) raise @@ -144,7 +162,7 @@ def get_response( if resp.status_code < 300: if "keyjar" not in kwargs: - kwargs["keyjar"] = service.superior_get("context").keyjar + kwargs["keyjar"] = self.get_attribute('keyjar') if not response_body_type: response_body_type = service.response_body_type @@ -157,14 +175,14 @@ def get_response( return self.parse_request_response(service, resp, response_body_type, **kwargs) def service_request( - self, - service: Service, - url: str, - method: Optional[str] = "GET", - body: Optional[dict] = None, - response_body_type: Optional[str] = "", - headers: Optional[dict] = None, - **kwargs + self, + service: Service, + url: str, + method: Optional[str] = "GET", + body: Optional[dict] = None, + response_body_type: Optional[str] = "", + headers: Optional[dict] = None, + **kwargs ) -> Message: """ The method that sends the request and handles the response returned. @@ -203,7 +221,7 @@ def service_request( def parse_request_response(self, service, reqresp, response_body_type="", state="", **kwargs): """ - Deal with a self.http response. The response are expected to + Deal with a self.httpc response. The response are expected to follow a special pattern, having the attributes: - headers (list of tuples with headers attributes and their values) @@ -297,7 +315,7 @@ def dynamic_provider_info_discovery(client: Client, behaviour_args: Optional[dic except KeyError: raise ConfigurationError("Can not do dynamic provider info discovery") else: - _context = client.superior_get("context") + _context = client.get_context() try: _context.set("issuer", _context.config["srv_discovery_url"]) except KeyError: diff --git a/src/idpyoidc/client/oauth2/access_token.py b/src/idpyoidc/client/oauth2/access_token.py index 1ccb61e0..51b87ce4 100644 --- a/src/idpyoidc/client/oauth2/access_token.py +++ b/src/idpyoidc/client/oauth2/access_token.py @@ -32,15 +32,15 @@ class AccessToken(Service): "token_endpoint_auth_signing_alg": get_signing_algs, } - def __init__(self, superior_get, conf=None): - Service.__init__(self, superior_get, conf=conf) + def __init__(self, upstream_get, conf=None): + Service.__init__(self, upstream_get, conf=conf) self.pre_construct.append(self.oauth_pre_construct) def update_service_context(self, resp, key: Optional[str] = '', **kwargs): if "expires_in" in resp: resp["__expires_at"] = time_sans_frac() + int(resp["expires_in"]) if key: - self.superior_get("context").cstate.update(key, resp) + self.upstream_get("context").cstate.update(key, resp) def oauth_pre_construct(self, request_args=None, post_args=None, **kwargs): """ @@ -52,7 +52,7 @@ def oauth_pre_construct(self, request_args=None, post_args=None, **kwargs): _state = get_state_parameter(request_args, kwargs) parameters = list(self.msg_type.c_param.keys()) - _context = self.superior_get("context") + _context = self.upstream_get("context") _args = _context.cstate.get_set(_state, claim=parameters) if "grant_type" not in _args: diff --git a/src/idpyoidc/client/oauth2/add_on/dpop.py b/src/idpyoidc/client/oauth2/add_on/dpop.py index c52e6dc1..a83fdaa9 100644 --- a/src/idpyoidc/client/oauth2/add_on/dpop.py +++ b/src/idpyoidc/client/oauth2/add_on/dpop.py @@ -154,7 +154,7 @@ def add_support(services, signing_algorithms): # Access token request should use DPoP header _service = services["accesstoken"] - _context = _service.superior_get("context") + _context = _service.upstream_get("context") _context.add_on["dpop"] = { # "key": key_by_alg(signing_algorithm), "sign_algs": signing_algorithms diff --git a/src/idpyoidc/client/oauth2/add_on/identity_assurance.py b/src/idpyoidc/client/oauth2/add_on/identity_assurance.py index 9815896c..ea1253cd 100644 --- a/src/idpyoidc/client/oauth2/add_on/identity_assurance.py +++ b/src/idpyoidc/client/oauth2/add_on/identity_assurance.py @@ -73,7 +73,7 @@ def add_support( # Access token request should use DPoP header _service = services["userinfo"] - _context = _service.superior_get("context") + _context = _service.upstream_get("context") _context.add_on["identity_assurance"] = { "verified_claims_supported": True, "trust_frameworks_supported": trust_frameworks_supported, diff --git a/src/idpyoidc/client/oauth2/add_on/pkce.py b/src/idpyoidc/client/oauth2/add_on/pkce.py index 5c250015..f9491975 100644 --- a/src/idpyoidc/client/oauth2/add_on/pkce.py +++ b/src/idpyoidc/client/oauth2/add_on/pkce.py @@ -22,7 +22,7 @@ def add_code_challenge(request_args, service, **kwargs): :param kwargs: Extra set of keyword arguments :return: Updated set of request arguments """ - _context = service.superior_get("context") + _context = service.upstream_get("context") _kwargs = _context.add_on["pkce"] try: @@ -69,7 +69,7 @@ def add_code_verifier(request_args, service, **kwargs): _state = request_args.get("state") if _state is None: _state = kwargs.get("state") - _item = service.superior_get("context").cstate.get_set(_state, claim=['code_verifier']) + _item = service.upstream_get("context").cstate.get_set(_state, claim=['code_verifier']) request_args.update(_item) return request_args @@ -91,7 +91,7 @@ def add_support(service, code_challenge_length, code_challenge_method): """ if "authorization" in service and "accesstoken" in service: _service = service["authorization"] - _context = _service.superior_get("context") + _context = _service.upstream_get("context") _context.add_on["pkce"] = { "code_challenge_length": code_challenge_length, "code_challenge_method": code_challenge_method, diff --git a/src/idpyoidc/client/oauth2/add_on/pushed_authorization.py b/src/idpyoidc/client/oauth2/add_on/pushed_authorization.py index d40c7d52..c8790c3d 100644 --- a/src/idpyoidc/client/oauth2/add_on/pushed_authorization.py +++ b/src/idpyoidc/client/oauth2/add_on/pushed_authorization.py @@ -16,14 +16,15 @@ def push_authorization(request_args, service, **kwargs): :param kwargs: Extra keyword arguments. """ - _context = service.superior_get("context") + _context = service.upstream_get("context") method_args = _context.add_on["pushed_authorization"] # construct the message body if method_args["body_format"] == "urlencoded": _body = request_args.to_urlencoded() else: - _jwt = JWT(key_jar=_context.keyjar, iss=_context.base_url) + _jwt = JWT(key_jar=service.upstream_get('attribute','keyjar'), + iss=_context.base_url) _jws = _jwt.pack(request_args.to_dict()) _msg = Message(request=_jws) @@ -66,7 +67,7 @@ def add_support( http_client = requests _service = services["authorization"] - _service.superior_get("context").add_on["pushed_authorization"] = { + _service.upstream_get("context").add_on["pushed_authorization"] = { "body_format": body_format, "signing_algorithm": signing_algorithm, "http_client": http_client, diff --git a/src/idpyoidc/client/oauth2/authorization.py b/src/idpyoidc/client/oauth2/authorization.py index e28b1fff..e3f1b0ac 100644 --- a/src/idpyoidc/client/oauth2/authorization.py +++ b/src/idpyoidc/client/oauth2/authorization.py @@ -45,20 +45,20 @@ class Authorization(Service): } } - def __init__(self, superior_get, conf=None): - Service.__init__(self, superior_get, conf=conf) + def __init__(self, upstream_get, conf=None): + Service.__init__(self, upstream_get, conf=conf) self.pre_construct.extend([pre_construct_pick_redirect_uri, set_state_parameter]) self.post_construct.append(self.store_auth_request) def update_service_context(self, resp, key="", **kwargs): if "expires_in" in resp: resp["__expires_at"] = time_sans_frac() + int(resp["expires_in"]) - self.superior_get("context").cstate.update(key, resp) + self.upstream_get("context").cstate.update(key, resp) def store_auth_request(self, request_args=None, **kwargs): """Store the authorization request in the state DB.""" _key = get_state_parameter(request_args, kwargs) - self.superior_get("context").cstate.update(_key, request_args) + self.upstream_get("context").cstate.update(_key, request_args) return request_args def gather_request_args(self, **kwargs): @@ -66,7 +66,7 @@ def gather_request_args(self, **kwargs): if "redirect_uri" not in ar_args: try: - ar_args["redirect_uri"] = self.superior_get("context").get_usage( + ar_args["redirect_uri"] = self.upstream_get("context").get_usage( "redirect_uris")[0] except (KeyError, AttributeError): raise MissingParameter("redirect_uri") @@ -90,7 +90,7 @@ def post_parse_response(self, response, **kwargs): pass else: if _key: - item = self.superior_get("context").cstate.get_set( + item = self.upstream_get("context").cstate.get_set( _key, message=oauth2.AuthorizationRequest) try: response["scope"] = item["scope"] diff --git a/src/idpyoidc/client/oauth2/client_credentials/cc_access_token.py b/src/idpyoidc/client/oauth2/client_credentials/cc_access_token.py index 9f69f4b6..af65573a 100644 --- a/src/idpyoidc/client/oauth2/client_credentials/cc_access_token.py +++ b/src/idpyoidc/client/oauth2/client_credentials/cc_access_token.py @@ -18,10 +18,10 @@ class CCAccessToken(Service): request_body_type = "urlencoded" response_body_type = "json" - def __init__(self, superior_get, conf=None): - Service.__init__(self, superior_get, conf=conf) + def __init__(self, upstream_get, conf=None): + Service.__init__(self, upstream_get, conf=conf) def update_service_context(self, resp, key: Optional[str] = "cc", **kwargs): if "expires_in" in resp: resp["__expires_at"] = time_sans_frac() + int(resp["expires_in"]) - self.superior_get("context").cstate.update(key, resp) + self.upstream_get("context").cstate.update(key, resp) diff --git a/src/idpyoidc/client/oauth2/client_credentials/cc_refresh_access_token.py b/src/idpyoidc/client/oauth2/client_credentials/cc_refresh_access_token.py index 111cc684..69ac5ff5 100644 --- a/src/idpyoidc/client/oauth2/client_credentials/cc_refresh_access_token.py +++ b/src/idpyoidc/client/oauth2/client_credentials/cc_refresh_access_token.py @@ -16,15 +16,15 @@ class CCRefreshAccessToken(Service): default_authn_method = "bearer_header" http_method = "POST" - def __init__(self, superior_get, conf=None): - Service.__init__(self, superior_get, conf=conf) + def __init__(self, upstream_get, conf=None): + Service.__init__(self, upstream_get, conf=conf) self.pre_construct.append(self.cc_pre_construct) self.post_construct.append(self.cc_post_construct) def cc_pre_construct(self, request_args=None, **kwargs): _state_id = kwargs.get("state", "cc") parameters = ["refresh_token"] - _current = self.superior_get("context").cstate + _current = self.upstream_get("context").cstate _args = _current.get_set(_state_id, claim=parameters) if request_args is None: @@ -47,4 +47,4 @@ def cc_post_construct(self, request_args, **kwargs): def update_service_context(self, resp, key="cc", **kwargs): if "expires_in" in resp: resp["__expires_at"] = time_sans_frac() + int(resp["expires_in"]) - self.superior_get("context").cstate.update(key, resp) + self.upstream_get("context").cstate.update(key, resp) diff --git a/src/idpyoidc/client/oauth2/refresh_access_token.py b/src/idpyoidc/client/oauth2/refresh_access_token.py index f3345bfc..6dbc6d5a 100644 --- a/src/idpyoidc/client/oauth2/refresh_access_token.py +++ b/src/idpyoidc/client/oauth2/refresh_access_token.py @@ -23,21 +23,21 @@ class RefreshAccessToken(Service): default_authn_method = "bearer_header" http_method = "POST" - def __init__(self, superior_get, conf=None): - Service.__init__(self, superior_get, conf=conf) + def __init__(self, upstream_get, conf=None): + Service.__init__(self, upstream_get, conf=conf) self.pre_construct.append(self.oauth_pre_construct) def update_service_context(self, resp, key: Optional[str] = "", **kwargs): if "expires_in" in resp: resp["__expires_at"] = time_sans_frac() + int(resp["expires_in"]) - self.superior_get("context").cstate.update(key, resp) + self.upstream_get("context").cstate.update(key, resp) def oauth_pre_construct(self, request_args=None, **kwargs): """Preconstructor of request arguments""" _state = get_state_parameter(request_args, kwargs) parameters = list(self.msg_type.c_param.keys()) - _current = self.superior_get("context").cstate + _current = self.upstream_get("context").cstate _args = _current.get_set(_state, claim=parameters) if request_args is None: diff --git a/src/idpyoidc/client/oauth2/server_metadata.py b/src/idpyoidc/client/oauth2/server_metadata.py index 49da6959..bb4ba306 100644 --- a/src/idpyoidc/client/oauth2/server_metadata.py +++ b/src/idpyoidc/client/oauth2/server_metadata.py @@ -24,8 +24,8 @@ class ServerMetadata(Service): _supports = {} - def __init__(self, superior_get, conf=None): - Service.__init__(self, superior_get, conf=conf) + def __init__(self, upstream_get, conf=None): + Service.__init__(self, upstream_get, conf=conf) def get_endpoint(self): """ @@ -34,7 +34,7 @@ def get_endpoint(self): :return: Service endpoint """ try: - _iss = self.superior_get("context").issuer + _iss = self.upstream_get("context").issuer except AttributeError: _iss = self.endpoint @@ -69,7 +69,7 @@ def _verify_issuer(self, resp, issuer): # In some cases we can live with the two URLs not being # the same. But this is an excepted that has to be explicit try: - self.superior_get("context").allow["issuer_mismatch"] + self.upstream_get("context").allow["issuer_mismatch"] except KeyError: if _issuer != _pcr_issuer: raise OidcServiceError( @@ -86,7 +86,7 @@ def _set_endpoints(self, resp): # a name ending in '_endpoint' so I can look specifically # for those if key.endswith("_endpoint"): - _srv = self.superior_get("service_by_endpoint_name", key) + _srv = self.upstream_get("service_by_endpoint_name", key) if _srv: _srv.endpoint = val @@ -99,7 +99,7 @@ def _update_service_context(self, resp): :param service_context: Information collected/used by services """ - _context = self.superior_get("context") + _context = self.upstream_get("context") # Verify that the issuer value received is the same as the # url that was used as service endpoint (without the .well-known part) if "issuer" in resp: @@ -115,7 +115,7 @@ def _update_service_context(self, resp): # If I already have a Key Jar then I'll add then provider keys to # that. Otherwise a new Key Jar is minted try: - _keyjar = _context.keyjar + _keyjar = self.upstream_get('attribute', 'keyjar') except KeyError: _keyjar = KeyJar() @@ -126,7 +126,5 @@ def _update_service_context(self, resp): elif "jwks" in resp: _keyjar.load_keys(_pcr_issuer, jwks=resp["jwks"]) - _context.keyjar = _keyjar - def update_service_context(self, resp, **kwargs): return self._update_service_context(resp) diff --git a/src/idpyoidc/client/oauth2/utils.py b/src/idpyoidc/client/oauth2/utils.py index a87d70d0..e16ce052 100644 --- a/src/idpyoidc/client/oauth2/utils.py +++ b/src/idpyoidc/client/oauth2/utils.py @@ -80,8 +80,8 @@ def pre_construct_pick_redirect_uri( request_args: Optional[Union[Message, dict]] = None, service: Optional[Service] = None, **kwargs ): - request_args["redirect_uri"] = pick_redirect_uri(service.superior_get("context"), - entity=service.superior_get("entity"), + request_args["redirect_uri"] = pick_redirect_uri(service.upstream_get("context"), + entity=service.upstream_get("entity"), request_args=request_args) return request_args, {} diff --git a/src/idpyoidc/client/oidc/__init__.py b/src/idpyoidc/client/oidc/__init__.py index 759ede90..e70df309 100755 --- a/src/idpyoidc/client/oidc/__init__.py +++ b/src/idpyoidc/client/oidc/__init__.py @@ -2,6 +2,9 @@ import logging from typing import Callable from typing import Optional +from typing import Union + +from cryptojwt.key_jar import KeyJar from idpyoidc.client import oauth2 from idpyoidc.client.client_auth import BearerHeader @@ -77,28 +80,34 @@ class RP(oauth2.Client): def __init__( self, - keyjar=None, - verify_ssl=True, - config=None, - httpc=None, - services=None, - httpc_params=None, - superior_get: Optional[Callable] = None, + keyjar: Optional[KeyJar] = None, + config: Optional[Union[dict, Configuration]] = None, + services: Optional[dict] = None, + httpc: Optional[Callable] = None, + httpc_params: Optional[dict] = None, + context: Optional[OidcContext] = None, + upstream_get: Optional[Callable] = None, + key_conf: Optional[dict] = None, + entity_id: Optional[str] = '', + verify_ssl: Optional[bool] = True, + jwks_uri: Optional[str] = "", **kwargs ): - self.superior_get = superior_get + self.upstream_get = upstream_get _srvs = services or DEFAULT_OIDC_SERVICES oauth2.Client.__init__( self, keyjar=keyjar, - verify_ssl=verify_ssl, config=config, - httpc=httpc, services=_srvs, + httpc=httpc, httpc_params=httpc_params, - client_type="oidc", - superior_get=superior_get, + upstream_get=upstream_get, + key_conf=key_conf, + entity_id=entity_id, + verify_ssl=verify_ssl, + jwks_uri=jwks_uri, **kwargs ) @@ -124,20 +133,20 @@ def fetch_distributed_claims(self, userinfo, callback=None): if "access_token" in spec: cauth = BearerHeader() httpc_params = cauth.construct( - service=self.superior_get("service", "userinfo"), + service=self.get_service("userinfo"), access_token=spec["access_token"], ) - _resp = self.http.send(spec["endpoint"], "GET", **httpc_params) + _resp = self.httpc.send(spec["endpoint"], "GET", **httpc_params) else: if callback: token = callback(spec["endpoint"]) cauth = BearerHeader() httpc_params = cauth.construct( - service=self.superior_get("service", "userinfo"), access_token=token + service=self.get_service("userinfo"), access_token=token ) - _resp = self.http.send(spec["endpoint"], "GET", **httpc_params) + _resp = self.httpc.send(spec["endpoint"], "GET", **httpc_params) else: - _resp = self.http.send(spec["endpoint"], "GET") + _resp = self.httpc.send(spec["endpoint"], "GET") if _resp.status_code == 200: _uinfo = json.loads(_resp.text) diff --git a/src/idpyoidc/client/oidc/access_token.py b/src/idpyoidc/client/oidc/access_token.py index b6980e90..87324567 100644 --- a/src/idpyoidc/client/oidc/access_token.py +++ b/src/idpyoidc/client/oidc/access_token.py @@ -27,26 +27,25 @@ class AccessToken(access_token.AccessToken): "token_endpoint_auth_signing_alg_values_supported": get_signing_algs } - def __init__(self, client_get, conf: Optional[dict] = None): - access_token.AccessToken.__init__(self, client_get, conf=conf) - def __init__(self, superior_get, conf: Optional[dict] = None): - access_token.AccessToken.__init__(self, superior_get, conf=conf) + def __init__(self, upstream_get, conf: Optional[dict] = None): + access_token.AccessToken.__init__(self, upstream_get, conf=conf) def gather_verify_arguments( - self, response: Optional[Union[dict, Message]] = None, behaviour_args: Optional[dict] = None + self, response: Optional[Union[dict, Message]] = None, + behaviour_args: Optional[dict] = None ): """ Need to add some information before running verify() :return: dictionary with arguments to the verify call """ - _context = self.superior_get("context") - _entity = self.superior_get("entity") + _context = self.upstream_get("context") + _entity = self.upstream_get("entity") kwargs = { "client_id": _entity.get_client_id(), "iss": _context.issuer, - "keyjar": _context.keyjar, + "keyjar": self.upstream_get('attribute', 'keyjar'), "verify": True, "skew": _context.clock_skew, } @@ -72,7 +71,7 @@ def gather_verify_arguments( return kwargs def update_service_context(self, resp, key: Optional[str] ="", **kwargs): - _cstate = self.superior_get("context").cstate + _cstate = self.upstream_get("context").cstate try: _idt = resp[verified_claim_name("id_token")] except KeyError: @@ -92,7 +91,5 @@ def update_service_context(self, resp, key: Optional[str] ="", **kwargs): _cstate.update(key, resp) def get_authn_method(self): - try: - return self.superior_get("service_context").behaviour["token_endpoint_auth_method"] - except KeyError: - return self.default_authn_method + return self.upstream_get("context").get_preference("token_endpoint_auth_method", + self.default_authn_method) diff --git a/src/idpyoidc/client/oidc/authorization.py b/src/idpyoidc/client/oidc/authorization.py index 80a4a54d..16554ace 100644 --- a/src/idpyoidc/client/oidc/authorization.py +++ b/src/idpyoidc/client/oidc/authorization.py @@ -55,9 +55,11 @@ class Authorization(authorization.Authorization): } } - def __init__(self, superior_get, conf=None): - authorization.Authorization.__init__(self, superior_get, conf=conf) + def __init__(self, upstream_get, conf=None, request_args: Optional[dict] = None): + authorization.Authorization.__init__(self, upstream_get, conf=conf) self.default_request_args = {"scope": ["openid"]} + if request_args: + self.default_request_args.update(request_args) self.pre_construct = [ self.set_state, pre_construct_pick_redirect_uri, @@ -68,7 +70,7 @@ def __init__(self, superior_get, conf=None): self.default_request_args['scope'] = ['openid'] def set_state(self, request_args, **kwargs): - _context = self.superior_get("context") + _context = self.upstream_get("context") try: _state = kwargs["state"] except KeyError: @@ -82,14 +84,14 @@ def set_state(self, request_args, **kwargs): return request_args, {} def update_service_context(self, resp, key="", **kwargs): - _context = self.superior_get("context") + _context = self.upstream_get("context") if "expires_in" in resp: resp["__expires_at"] = time_sans_frac() + int(resp["expires_in"]) _context.cstate.update(key, resp) def get_request_from_response(self, response): - _context = self.superior_get("service_context") + _context = self.upstream_get("service_context") return _context.cstate.get_set(response["state"], message=oauth2.AuthorizationRequest) def post_parse_response(self, response, **kwargs): @@ -110,7 +112,7 @@ def post_parse_response(self, response, **kwargs): return response def oidc_pre_construct(self, request_args=None, post_args=None, **kwargs): - _context = self.superior_get("context") + _context = self.upstream_get("context") if request_args is None: request_args = {} @@ -179,7 +181,7 @@ def get_request_object_signing_alg(self, **kwargs): break if not alg: - _context = self.superior_get("context") + _context = self.upstream_get("context") try: alg = _context.work_environment.get_usage("request_object_signing_alg") except KeyError: # Use default @@ -193,7 +195,7 @@ def store_request_on_file(self, req, **kwargs): :param kwargs: Extra keyword arguments :return: The URL the OP should use to access the file """ - _context = self.superior_get("context") + _context = self.upstream_get("context") _webname = _context.get_usage("request_uris") if _webname is None: filename, _webname = construct_request_uri(**kwargs) @@ -214,9 +216,9 @@ def construct_request_parameter( alg = self.get_request_object_signing_alg(**kwargs) kwargs["request_object_signing_alg"] = alg - _context = self.superior_get("context") + _context = self.upstream_get("context") if "keys" not in kwargs and alg and alg != "none": - kwargs["keys"] = _context.keyjar + kwargs["keys"] = self.upstream_get('attribute', 'keyjar') if alg == "none": kwargs["keys"] = [] @@ -266,7 +268,7 @@ def oidc_post_construct(self, req, **kwargs): :param kwargs: Extra keyword arguments :return: A possibly modified request. """ - _context = self.superior_get("context") + _context = self.upstream_get("context") if "openid" in req["scope"]: _response_type = req["response_type"][0] if "id_token" in _response_type or "code" in _response_type: @@ -316,10 +318,10 @@ def gather_verify_arguments( :return: dictionary with arguments to the verify call """ - _context = self.superior_get("context") + _context = self.upstream_get("context") kwargs = { "iss": _context.issuer, - "keyjar": _context.keyjar, + "keyjar": self.upstream_get('attribute', 'keyjar'), "verify": True, "skew": _context.clock_skew, } diff --git a/src/idpyoidc/client/oidc/backchannel_authentication.py b/src/idpyoidc/client/oidc/backchannel_authentication.py index 0811e322..f2824e25 100644 --- a/src/idpyoidc/client/oidc/backchannel_authentication.py +++ b/src/idpyoidc/client/oidc/backchannel_authentication.py @@ -17,8 +17,8 @@ class BackChannelAuthentication(Service): service_name = "backchannel_authentication" response_body_type = "json" - def __init__(self, superior_get, conf=None, **kwargs): - Service.__init__(self, superior_get=superior_get, conf=conf, **kwargs) + def __init__(self, upstream_get, conf=None, **kwargs): + Service.__init__(self, upstream_get=upstream_get, conf=conf, **kwargs) self.default_request_args = {"scope": ["openid"]} self.pre_construct = [] self.post_construct = [] @@ -37,8 +37,8 @@ class ClientNotification(Service): response_body_type = "" http_method = "POST" - def __init__(self, superior_get, conf=None, **kwargs): - Service.__init__(self, superior_get=superior_get, conf=conf, **kwargs) + def __init__(self, upstream_get, conf=None, **kwargs): + Service.__init__(self, upstream_get=upstream_get, conf=conf, **kwargs) self.pre_construct = [] self.post_construct = [] diff --git a/src/idpyoidc/client/oidc/check_id.py b/src/idpyoidc/client/oidc/check_id.py index 712972f5..38e5897f 100644 --- a/src/idpyoidc/client/oidc/check_id.py +++ b/src/idpyoidc/client/oidc/check_id.py @@ -19,12 +19,12 @@ class CheckID(Service): synchronous = True service_name = "check_id" - def __init__(self, superior_get, conf=None): - Service.__init__(self, superior_get, conf=conf) + def __init__(self, upstream_get, conf=None): + Service.__init__(self, upstream_get, conf=conf) self.pre_construct = [self.oidc_pre_construct] def oidc_pre_construct(self, request_args: Optional[dict]=None, **kwargs): - _args = self.superior_get("context").cstate.get_set( + _args = self.upstream_get("context").cstate.get_set( kwargs["state"], claim=["id_token"] ) diff --git a/src/idpyoidc/client/oidc/check_session.py b/src/idpyoidc/client/oidc/check_session.py index 422142fc..373f5242 100644 --- a/src/idpyoidc/client/oidc/check_session.py +++ b/src/idpyoidc/client/oidc/check_session.py @@ -18,12 +18,12 @@ class CheckSession(Service): synchronous = True service_name = "check_session" - def __init__(self, superior_get, conf=None): - Service.__init__(self, superior_get, conf=conf) + def __init__(self, upstream_get, conf=None): + Service.__init__(self, upstream_get, conf=conf) self.pre_construct = [self.oidc_pre_construct] def oidc_pre_construct(self, request_args=None, **kwargs): - _args = self.superior_get("context").cstate.get_set(kwargs["state"], + _args = self.upstream_get("context").cstate.get_set(kwargs["state"], claim=["id_token"]) if request_args: request_args.update(_args) diff --git a/src/idpyoidc/client/oidc/end_session.py b/src/idpyoidc/client/oidc/end_session.py index 47f967a7..5820f89b 100644 --- a/src/idpyoidc/client/oidc/end_session.py +++ b/src/idpyoidc/client/oidc/end_session.py @@ -36,8 +36,8 @@ class EndSession(Service): "post_logout_redirect_uris": "session_logout" } - def __init__(self, client_get, conf=None): - Service.__init__(self, client_get, conf=conf) + def __init__(self, upstream_get, conf=None): + Service.__init__(self, upstream_get, conf=conf) self.pre_construct = [ self.get_id_token_hint, self.add_post_logout_redirect_uri, @@ -53,9 +53,8 @@ def get_id_token_hint(self, request_args=None, **kwargs): :return: """ - _args = self.client_get('service_context').cstate.get_set(kwargs["state"], + _args = self.upstream_get("context").cstate.get_set(kwargs["state"], claim=['id_token']) - try: request_args["id_token_hint"] = _args["id_token"] except KeyError: @@ -65,12 +64,11 @@ def get_id_token_hint(self, request_args=None, **kwargs): def add_post_logout_redirect_uri(self, request_args=None, **kwargs): if "post_logout_redirect_uri" not in request_args: - _uri = self.metadata.get("post_logout_redirect_uris", '') - if _uri: - if isinstance(_uri, str): - request_args["post_logout_redirect_uri"] = _uri - else: # assume list - request_args["post_logout_redirect_uri"] = _uri[0] + _uri = self.upstream_get("context").get_usage("post_logout_redirect_uris") + if isinstance(_uri, str): + request_args["post_logout_redirect_uri"] = _uri + else: # assume list + request_args["post_logout_redirect_uri"] = _uri[0] return request_args, {} @@ -79,6 +77,6 @@ def add_state(self, request_args=None, **kwargs): request_args["state"] = rndstr(32) # As a side effect bind logout state to session state - self.client_get("service_context").cstate.bind_key(request_args["state"], kwargs["state"]) + self.upstream_get("context").cstate.bind_key(request_args["state"], kwargs["state"]) return request_args, {} diff --git a/src/idpyoidc/client/oidc/provider_info_discovery.py b/src/idpyoidc/client/oidc/provider_info_discovery.py index 2343d235..0acae2c7 100644 --- a/src/idpyoidc/client/oidc/provider_info_discovery.py +++ b/src/idpyoidc/client/oidc/provider_info_discovery.py @@ -25,7 +25,7 @@ def add_redirect_uris(request_args, service=None, **kwargs): :param kwargs: Possible extra keyword arguments :return: A possibly augmented set of request arguments. """ - _work_environment = service.superior_get("context").work_environment + _work_environment = service.upstream_get("context").work_environment if "redirect_uris" not in request_args: # Callbacks is a dictionary with callback type 'code', 'implicit', # 'form_post' as keys. @@ -49,15 +49,16 @@ class ProviderInfoDiscovery(server_metadata.ServerMetadata): _supports = {} - def __init__(self, superior_get, conf=None): - server_metadata.ServerMetadata.__init__(self, superior_get, conf=conf) + def __init__(self, upstream_get, conf=None): + server_metadata.ServerMetadata.__init__(self, upstream_get, conf=conf) def update_service_context(self, resp, **kwargs): - _context = self.superior_get("context") + _context = self.upstream_get("context") self._update_service_context(resp) _context.map_supported_to_preferred(resp) if "pre_load_keys" in self.conf and self.conf["pre_load_keys"]: - _jwks = _context.keyjar.export_jwks_as_json(issuer=resp["issuer"]) + _jwks = self.upstream_get('attribute', 'keyjar').export_jwks_as_json( + issuer=resp["issuer"]) logger.info("Preloaded keys for {}: {}".format(resp["issuer"], _jwks)) def match_preferences(self, pcr=None, issuer=None): @@ -73,7 +74,7 @@ def match_preferences(self, pcr=None, issuer=None): :param pcr: Provider configuration response if available :param issuer: The issuer identifier """ - _context = self.superior_get("context") + _context = self.upstream_get("context") if not pcr: pcr = _context.provider_info diff --git a/src/idpyoidc/client/oidc/read_registration.py b/src/idpyoidc/client/oidc/read_registration.py index a105fed5..cf0a02a9 100644 --- a/src/idpyoidc/client/oidc/read_registration.py +++ b/src/idpyoidc/client/oidc/read_registration.py @@ -19,7 +19,7 @@ class RegistrationRead(Service): def get_endpoint(self): try: - return self.superior_get("context").registration_response[ + return self.upstream_get("context").registration_response[ "registration_client_uri" ] except KeyError: @@ -40,7 +40,7 @@ def get_authn_header(self, request, authn_method, **kwargs): if authn_method == "client_secret_basic": LOGGER.debug("Client authn method: %s", authn_method) headers["Authorization"] = "Bearer {}".format( - self.superior_get("context").registration_response[ + self.upstream_get("context").registration_response[ "registration_access_token" ] ) diff --git a/src/idpyoidc/client/oidc/refresh_access_token.py b/src/idpyoidc/client/oidc/refresh_access_token.py index 0b209ff2..8ee78d98 100644 --- a/src/idpyoidc/client/oidc/refresh_access_token.py +++ b/src/idpyoidc/client/oidc/refresh_access_token.py @@ -8,7 +8,7 @@ class RefreshAccessToken(refresh_access_token.RefreshAccessToken): error_msg = oidc.ResponseMessage def get_authn_method(self): - _work_environment = self.superior_get("context").work_environment + _work_environment = self.upstream_get("context").work_environment try: return _work_environment.get_usage("token_endpoint_auth_method") except KeyError: diff --git a/src/idpyoidc/client/oidc/registration.py b/src/idpyoidc/client/oidc/registration.py index bcce75d5..1e29935c 100644 --- a/src/idpyoidc/client/oidc/registration.py +++ b/src/idpyoidc/client/oidc/registration.py @@ -23,13 +23,13 @@ class Registration(Service): callback_path = {} - def __init__(self, client_get, conf=None): - Service.__init__(self, client_get, conf=conf) + def __init__(self, upstream_get, conf=None): + Service.__init__(self, upstream_get, conf=conf) self.pre_construct = [self.add_client_preference] self.post_construct = [self.oidc_post_construct] def add_client_preference(self, request_args=None, **kwargs): - _context = self.superior_get("context") + _context = self.upstream_get("context") _use = _context.map_preferred_to_registered() for prop, spec in self.msg_type.c_param.items(): if prop in request_args: @@ -64,7 +64,7 @@ def update_service_context(self, resp, key="", **kwargs): # if "token_endpoint_auth_method" not in resp: # resp["token_endpoint_auth_method"] = "client_secret_basic" - _context = self.superior_get("context") + _context = self.upstream_get("context") _context.map_preferred_to_registered(resp) _keyjar = _context.keyjar diff --git a/src/idpyoidc/client/oidc/userinfo.py b/src/idpyoidc/client/oidc/userinfo.py index a6dbe231..2dd54b1c 100644 --- a/src/idpyoidc/client/oidc/userinfo.py +++ b/src/idpyoidc/client/oidc/userinfo.py @@ -49,8 +49,8 @@ class UserInfo(Service): "encrypt_userinfo_supported": None } - def __init__(self, superior_get, conf=None): - Service.__init__(self, superior_get, conf=conf) + def __init__(self, upstream_get, conf=None): + Service.__init__(self, upstream_get, conf=conf) self.pre_construct = [self.oidc_pre_construct, carry_state] def oidc_pre_construct(self, request_args=None, **kwargs): @@ -60,7 +60,7 @@ def oidc_pre_construct(self, request_args=None, **kwargs): if "access_token" in request_args: pass else: - request_args = self.superior_get("context").cstate.get_set( + request_args = self.upstream_get("context").cstate.get_set( kwargs["state"], claim=["access_token"] ) @@ -68,7 +68,7 @@ def oidc_pre_construct(self, request_args=None, **kwargs): return request_args, {} def post_parse_response(self, response, **kwargs): - _context = self.superior_get("context") + _context = self.upstream_get("context") _current = _context.cstate _args = _current.get_set(kwargs["state"], claim=[verified_claim_name("id_token")]) @@ -89,7 +89,8 @@ def post_parse_response(self, response, **kwargs): if "JWT" in spec: try: aggregated_claims = Message().from_jwt( - spec["JWT"].encode("utf-8"), keyjar=_context.keyjar + spec["JWT"].encode("utf-8"), + keyjar=self.upstream_get('attribute', 'keyjar') ) except MissingSigningKey as err: logger.warning( @@ -111,18 +112,19 @@ def post_parse_response(self, response, **kwargs): return response def gather_verify_arguments( - self, response: Optional[Union[dict, Message]] = None, behaviour_args: Optional[dict] = None + self, response: Optional[Union[dict, Message]] = None, + behaviour_args: Optional[dict] = None ): """ Need to add some information before running verify() :return: dictionary with arguments to the verify call """ - _context = self.superior_get("context") + _context = self.upstream_get("context") kwargs = { "client_id": _context.get_client_id(), "iss": _context.issuer, - "keyjar": _context.keyjar, + "keyjar": self.upstream_get('attribute', 'keyjar'), "verify": True, "skew": _context.clock_skew, } diff --git a/src/idpyoidc/client/oidc/utils.py b/src/idpyoidc/client/oidc/utils.py index 097f6f9f..5240eeb4 100644 --- a/src/idpyoidc/client/oidc/utils.py +++ b/src/idpyoidc/client/oidc/utils.py @@ -49,11 +49,12 @@ def request_object_encryption(msg, service_context, **kwargs): if "target" not in kwargs: raise MissingRequiredAttribute("No target specified") + _keyjar = service_context.upstream_get('attribute', 'keyjar') if _kid: - _keys = service_context.keyjar.get_encrypt_key(_kty, issuer_id=kwargs["target"], kid=_kid) + _keys = _keyjar.get_encrypt_key(_kty, issuer_id=kwargs["target"], kid=_kid) _jwe["kid"] = _kid else: - _keys = service_context.keyjar.get_encrypt_key(_kty, issuer_id=kwargs["target"]) + _keys = _keyjar.get_encrypt_key(_kty, issuer_id=kwargs["target"]) return _jwe.encrypt(_keys) diff --git a/src/idpyoidc/client/oidc/webfinger.py b/src/idpyoidc/client/oidc/webfinger.py index fe4782b2..84235c30 100644 --- a/src/idpyoidc/client/oidc/webfinger.py +++ b/src/idpyoidc/client/oidc/webfinger.py @@ -35,8 +35,8 @@ class WebFinger(Service): http_method = "GET" response_body_type = "json" - def __init__(self, superior_get, conf=None, rel="", **kwargs): - Service.__init__(self, superior_get, conf=conf, **kwargs) + def __init__(self, upstream_get, conf=None, rel="", **kwargs): + Service.__init__(self, upstream_get, conf=conf, **kwargs) self.rel = rel or OIC_ISSUER @@ -55,7 +55,7 @@ def update_service_context(self, resp, key="", **kwargs): if _href.startswith("http://") and not _http_allowed: raise ValueError("http link not allowed ({})".format(_href)) - self.superior_get("context").issuer = link["href"] + self.upstream_get("context").issuer = link["href"] break return resp @@ -150,7 +150,7 @@ def get_request_parameters(self, request_args=None, **kwargs): _resource = kwargs["resource"] except KeyError: try: - _resource = self.superior_get("context").config["resource"] + _resource = self.upstream_get("context").config["resource"] except KeyError: raise MissingRequiredAttribute("resource") diff --git a/src/idpyoidc/client/rp_handler.py b/src/idpyoidc/client/rp_handler.py index 57d22906..b444e107 100644 --- a/src/idpyoidc/client/rp_handler.py +++ b/src/idpyoidc/client/rp_handler.py @@ -47,7 +47,7 @@ def __init__( verify_ssl=True, client_cls=None, state_db=None, - http_lib=None, + httpc=None, httpc_params=None, config=None, **kwargs, @@ -106,7 +106,7 @@ def __init__( # keep track on which RP instance that serves which OP self.issuer2rp = {} self.hash2issuer = {} - self.httplib = http_lib + self.httpc = httpc if not httpc_params: self.httpc_params = {"verify": verify_ssl} @@ -126,7 +126,7 @@ def state2issuer(self, state): :return: An Issuer ID """ for _rp in self.issuer2rp.values(): - _iss = _rp.superior_get("context").cstate.get_set( + _iss = _rp.upstream_get("context").cstate.get_set( state, claim=['iss']).get('iss') if _iss: return _iss @@ -154,7 +154,7 @@ def get_session_information(self, key, client=None): if not client: client = self.get_client_from_session_key(key) - return client.superior_get("context").cstate.get(key) + return client.upstream_get("context").cstate.get(key) def init_client(self, issuer): """ @@ -188,7 +188,7 @@ def init_client(self, issuer): client = self.client_cls( services=_services, config=_cnf, - httplib=self.httplib, + httpc=self.httpc, httpc_params=self.httpc_params, ) except Exception as err: @@ -197,12 +197,13 @@ def init_client(self, issuer): logger.error(message) raise - _context = client.superior_get("context") + _context = client.upstream_get("context") if _context.iss_hash: self.hash2issuer[_context.iss_hash] = issuer # If non persistent - _context.keyjar.load(self.keyjar.dump()) - # If persistent nothing has to be copied + _keyjar = client.get_attribute('keyjar') + _keyjar.load(self.keyjar.dump()) + # If persistent nothings has to be copied _context.base_url = self.base_url _context.jwks_uri = self.jwks_uri @@ -232,7 +233,7 @@ def do_provider_info( else: raise ValueError("Missing state/session key") - _context = client.superior_get("context") + _context = client.get_context() if not _context.get("provider_info"): dynamic_provider_info_discovery(client, behaviour_args=behaviour_args) return _context.get("provider_info")["issuer"] @@ -243,7 +244,7 @@ def do_provider_info( # a name ending in '_endpoint' so I can look specifically # for those if key.endswith("_endpoint"): - for _srv in client.superior_get("services").values(): + for _srv in client.get_services().values(): # Every service has an endpoint_name assigned # when initiated. This name *MUST* match the # endpoint names used in the provider info @@ -251,7 +252,7 @@ def do_provider_info( _srv.endpoint = val if "keys" in _pi: - _kj = _context.keyjar + _kj = client.get_attribute('keyjar') for typ, _spec in _pi["keys"].items(): if typ == "url": for _iss, _url in _spec.items(): @@ -299,7 +300,7 @@ def do_client_registration( else: raise ValueError("Missing state/session key") - _context = client.superior_get("context") + _context = client.get_context() _iss = _context.get("issuer") self.hash2issuer[iss_id] = _iss @@ -421,8 +422,8 @@ def init_authorization( else: raise ValueError("Missing state/session key") - _context = client.superior_get("context") - _entity = client.superior_get("entity") + _context = client.upstream_get("context") + _entity = client.upstream_get("entity") _nonce = rndstr(24) _response_type = self._get_response_type(_context, req_args) request_args = { @@ -518,7 +519,7 @@ def get_client_authn_method(client, endpoint): :return: The client authentication method """ if endpoint == "token_endpoint": - am = client.superior_get("context").get_usage("token_endpoint_auth_method") + am = client.upstream_get("context").get_usage("token_endpoint_auth_method") if not am: return "" else: @@ -542,7 +543,7 @@ def get_tokens(self, state, client: Optional[Client] = None): if client is None: client = self.get_client_from_session_key(state) - _context = client.superior_get("context") + _context = client.upstream_get("context") _claims = _context.cstate.get_set(state, claim=['code', 'redirect_uri']) req_args = { @@ -628,7 +629,7 @@ def get_user_info(self, state, client=None, access_token="", **kwargs): client = self.get_client_from_session_key(state) if not access_token: - _arg = client.superior_get("context").cstate.get_set(state, claim=["access_token"]) + _arg = client.upstream_get("context").cstate.get_set(state, claim=["access_token"]) access_token = _arg["access_token"] request_args = {"access_token": access_token} @@ -684,7 +685,7 @@ def finalize_auth( if is_error_message(authorization_response): return authorization_response - _context = client.superior_get("context") + _context = client.get_context() try: _iss = _context.cstate.get_set( authorization_response["state"], claim=['iss']).get('iss') @@ -725,7 +726,7 @@ def get_access_and_id_token( if client is None: client = self.get_client_from_session_key(state) - _context = client.superior_get("context") + _context = client.get_context() resp_attr = authorization_response or _context.cstate.get_set(state, message=AuthorizationResponse) @@ -811,7 +812,7 @@ def finalize(self, issuer, response, behaviour_args: Optional[dict] = None): _id_token = token.get("id_token") logger.debug(f"ID Token: {_id_token}") - if client.superior_get("service", "userinfo") and token["access_token"]: + if client.get_service("userinfo") and token["access_token"]: inforesp = self.get_user_info( state=authorization_response["state"], client=client, @@ -828,7 +829,7 @@ def finalize(self, issuer, response, behaviour_args: Optional[dict] = None): logger.debug("UserInfo: %s", inforesp) - _context = client.superior_get("context") + _context = client.get_context() try: _sid_support = _context.get("provider_info")["backchannel_logout_session_required"] except KeyError: @@ -871,7 +872,7 @@ def has_active_authentication(self, state): client = self.get_client_from_session_key(state) # Look for an IdToken - _arg = client.superior_get("context").cstate.get_set(state, + _arg = client.upstream_get("context").cstate.get_set(state, claim=["__verified_id_token"]) if _arg: @@ -895,7 +896,7 @@ def get_valid_access_token(self, state): now = utc_time_sans_frac() client = self.get_client_from_session_key(state) - _context = client.superior_get("context") + _context = client.upstream_get("context") _args = _context.cstate.get_set(state, claim=["access_token", "__expires_at"]) if "access_token" in _args: access_token = _args["access_token"] @@ -937,7 +938,7 @@ def logout( client = self.get_client_from_session_key(state) try: - srv = client.superior_get("service", "end_session") + srv = client.get_service("end_session") except KeyError: raise OidcServiceError("Does not know how to logout") @@ -969,7 +970,7 @@ def close( def clear_session(self, state): client = self.get_client_from_session_key(state) - client.superior_get("context").cstate.remove_state(state) + client.upstream_get("context").cstate.remove_state(state) def backchannel_logout(client, request="", request_args=None): @@ -985,11 +986,11 @@ def backchannel_logout(client, request="", request_args=None): else: raise MissingRequiredAttribute("logout_token") - _context = client.superior_get("context") + _context = client.get_context() kwargs = { "aud": client.get_client_id(), "iss": _context.get("issuer"), - "keyjar": _context.keyjar, + "keyjar": client.get_attribute('keyjar'), "allowed_sign_alg": _context.get("registration_response").get( "id_token_signed_response_alg", "RS256" ), @@ -1029,7 +1030,7 @@ def load_registration_response(client, request_args=None): :param client: A :py:class:`idpyoidc.client.oidc.Client` instance """ - if not client.superior_get("context").get_client_id(): + if not client.upstream_get("context").get_client_id(): try: response = client.do_request("registration", request_args=request_args) except KeyError: diff --git a/src/idpyoidc/client/service.py b/src/idpyoidc/client/service.py index ed67491e..8e3053b2 100644 --- a/src/idpyoidc/client/service.py +++ b/src/idpyoidc/client/service.py @@ -63,17 +63,20 @@ class Service(ImpExp): "response_cls": object, } - init_args = ["superior_get"] + init_args = ["upstream_get"] _supports = {} _callback_path = {} def __init__( - self, superior_get: Callable, conf: Optional[Union[dict, Configuration]] = None, **kwargs + self, + upstream_get: Callable, + conf: Optional[Union[dict, Configuration]] = None, + **kwargs ): ImpExp.__init__(self) - self.superior_get = superior_get + self.upstream_get = upstream_get self.default_request_args = {} if conf: @@ -115,7 +118,7 @@ def gather_request_args(self, **kwargs): """ ar_args = kwargs.copy() - _context = self.superior_get("context") + _context = self.upstream_get("context") _use = _context.collect_usage() if not _use: _use = _context.map_preferred_to_registered() @@ -210,7 +213,7 @@ def update_service_context(self, resp: Message, key: Optional[str] = '', **kwarg """ pass - def construct(self, request_args=None, **kwargs): + def construct(self, request_args: Optional[dict] = None, **kwargs): """ Instantiate the request as a message class instance with attribute values gathered in a pre_construct method or in the @@ -266,7 +269,7 @@ def init_authentication_method(self, request, authn_method, http_args=None, **kw if authn_method: LOGGER.debug("Client authn method: %s", authn_method) - _context = self.superior_get("context") + _context = self.upstream_get("context") try: _func = _context.client_authn_method[authn_method] except KeyError: # not one of the common @@ -302,7 +305,7 @@ def get_endpoint(self): if self.endpoint: return self.endpoint - return self.superior_get("context").provider_info[self.endpoint_name] + return self.upstream_get("context").provider_info[self.endpoint_name] def get_authn_header( self, request: Union[dict, Message], authn_method: Optional[str] = "", **kwargs @@ -359,7 +362,7 @@ def get_headers( for meth in self.construct_extra_headers: _headers = meth( - self.superior_get("context"), + self.upstream_get("context"), headers=_headers, request=request, authn_method=authn_method, @@ -406,7 +409,7 @@ def get_request_parameters( _info = {"method": method, "request": request} _args = kwargs.copy() - _context = self.superior_get("context") + _context = self.upstream_get("context") if _context.issuer: _args["iss"] = _context.issuer @@ -482,10 +485,10 @@ def gather_verify_arguments( :return: dictionary with arguments to the verify call """ - _context = self.superior_get("context") + _context = self.upstream_get("context") kwargs = { "iss": _context.issuer, - "keyjar": _context.keyjar, + "keyjar": self.upstream_get('attribute', 'keyjar'), "verify": True, "client_id": _context.get_client_id(), } @@ -497,17 +500,17 @@ def gather_verify_arguments( return kwargs def _do_jwt(self, info): - _context = self.superior_get("context") + _context = self.upstream_get("context") args = {"allowed_sign_algs": _context.get_sign_alg(self.service_name)} enc_algs = _context.get_enc_alg_enc(self.service_name) args["allowed_enc_algs"] = enc_algs["alg"] args["allowed_enc_encs"] = enc_algs["enc"] - _jwt = JWT(key_jar=_context.keyjar, **args) + _jwt = JWT(key_jar=self.upstream_get('attribute','keyjar'), **args) _jwt.iss = _context.get_client_id() return _jwt.unpack(info) def _do_response(self, info, sformat, **kwargs): - _context = self.superior_get("context") + _context = self.upstream_get("context") try: resp = self.response_cls().deserialize(info, sformat, iss=_context.issuer, **kwargs) @@ -659,12 +662,12 @@ def callback_uris(self): return list(self._callback_path.keys()) -def init_services(service_definitions, superior_get): +def init_services(service_definitions, upstream_get): """ Initiates a set of services :param service_definitions: A dictionary containing service definitions - :param superior_get: A function that returns different things from the base entity. + :param upstream_get: A function that returns different things from the base entity. :return: A dictionary, with service name as key and the service instance as value. """ @@ -675,7 +678,7 @@ def init_services(service_definitions, superior_get): except KeyError: kwargs = {} - kwargs.update({"superior_get": superior_get}) + kwargs.update({"upstream_get": upstream_get}) if isinstance(service_configuration["class"], str): _cls = importer(service_configuration["class"]) diff --git a/src/idpyoidc/client/service_context.py b/src/idpyoidc/client/service_context.py index ae8c9573..dc721d19 100644 --- a/src/idpyoidc/client/service_context.py +++ b/src/idpyoidc/client/service_context.py @@ -113,17 +113,16 @@ class ServiceContext(ImpExp): } def __init__(self, - client_get: Optional[Callable] = None, base_url: Optional[str] = "", - keyjar: Optional[KeyJar] = None, config: Optional[Union[dict, Configuration]] = None, cstate: Optional[Current] = None, + upstream_get: Optional[Callable] = None, client_type: Optional[str] = 'oauth2', **kwargs): ImpExp.__init__(self) config = get_configuration(config) self.config = config - self.client_get = client_get + self.upstream_get = upstream_get if not client_type or client_type == "oidc": self.work_environment = OIDC_Specs() diff --git a/src/idpyoidc/node.py b/src/idpyoidc/node.py new file mode 100644 index 00000000..f5622247 --- /dev/null +++ b/src/idpyoidc/node.py @@ -0,0 +1,154 @@ +from typing import Callable +from typing import Optional +from typing import Union + +from cryptojwt import KeyJar +from cryptojwt.key_jar import init_key_jar +from idpyoidc.configure import Configuration +from idpyoidc.impexp import ImpExp +from idpyoidc.util import instantiate + + +class Unit(ImpExp): + name = '' + + def __init__(self, + upstream_get: Callable = None, + keyjar: Optional[KeyJar] = None, + httpc: Optional[object] = None, + httpc_params: Optional[dict] = None, + config: Optional[Union[Configuration, dict]] = None, + entity_id: Optional[str] = "", + key_conf: Optional[dict] = None + ): + ImpExp.__init__(self) + self.upstream_get = upstream_get + self.httpc = httpc + + if config is None: + config = {} + + self.entity_id = entity_id or config.get('entity_id', "") + + if keyjar or key_conf or config.get('key_conf') or config.get('jwks'): + self.keyjar = self._keyjar(keyjar, conf=config, entity_id=self.entity_id, + key_conf=key_conf) + else: + self.keyjar = None + + self.httpc_params = httpc_params or config.get("httpc_params", {}) + + if self.keyjar: + self.keyjar.httpc = self.httpc + self.keyjar.httpc_params = self.httpc_params + + def unit_get(self, what, *arg): + _func = getattr(self, "get_{}".format(what), None) + if _func: + return _func(*arg) + return None + + def get_attribute(self, attr, *args): + try: + val = getattr(self, attr) + except AttributeError: + if self.upstream_get: + return self.upstream_get("attribute", attr) + else: + return None + else: + if val is None and self.upstream_get: + return self.upstream_get("attribute", attr) + else: + return val + + def get_unit(self, *args): + return self + + def _keyjar(self, + keyjar: Optional[KeyJar] = None, + conf: Optional[Union[dict, Configuration]] = None, + entity_id: Optional[str] = "", + key_conf: Optional[dict] = None): + if keyjar is None: + if key_conf: + keys_args = {k: v for k, v in key_conf.items() if k != "uri_path"} + _keyjar = init_key_jar(**keys_args) + elif conf: + if "keys" in conf: + keys_args = {k: v for k, v in conf["keys"].items() if k != "uri_path"} + _keyjar = init_key_jar(**keys_args) + elif "key_conf" in conf: + keys_args = {k: v for k, v in conf["key_conf"].items() if k != "uri_path"} + _keyjar = init_key_jar(**keys_args) + else: + _keyjar = KeyJar() + if "jwks" in conf: + _keyjar.import_jwks(conf["jwks"], "") + else: + _keyjar = None + + if _keyjar and "" in _keyjar and entity_id: + # make sure I have the keys under my own name too (if I know it) + _keyjar.import_jwks_as_json(_keyjar.export_jwks_as_json(True, ""), entity_id) + + return _keyjar + else: + return keyjar + + +def find_topmost_unit(unit): + while hasattr(unit, 'upstream_get'): + unit = unit.upstream_get('unit') + + return unit + + +class ClientUnit(Unit): + name = '' + + def __init__(self, + upstream_get: Callable = None, + httpc: Optional[object] = None, + httpc_params: Optional[dict] = None, + keyjar: Optional[KeyJar] = None, + context: Optional[ImpExp] = None, + config: Optional[Union[Configuration, dict]] = None, + jwks_uri: Optional[str] = "", + entity_id: Optional[str] = "", + key_conf: Optional[dict] = None + ): + Unit.__init__(self, upstream_get=upstream_get, keyjar=keyjar, httpc=httpc, + httpc_params=httpc_params, config=config, entity_id=entity_id, + key_conf=key_conf) + + self._service_context = context or None + + +class Collection(Unit): + + def __init__(self, + upstream_get: Callable = None, + keyjar: Optional[KeyJar] = None, + httpc: Optional[object] = None, + httpc_params: Optional[dict] = None, + config: Optional[Union[Configuration, dict]] = None, + entity_id: Optional[str] = "", + key_conf: Optional[dict] = None, + functions: Optional[dict] = None, + metadata: Optional[dict] = None + ): + + Unit.__init__(self, upstream_get, keyjar, httpc, httpc_params, config, entity_id, key_conf) + + _args = { + 'upstream_get': self.unit_get + } + + self.metadata = metadata or {} + + if functions: + for key, val in functions.items(): + _kwargs = val["kwargs"].copy() + _kwargs.update(_args) + setattr(self, key, instantiate(val["class"], **_kwargs)) diff --git a/src/idpyoidc/server/__init__.py b/src/idpyoidc/server/__init__.py index 14887b00..f84df605 100644 --- a/src/idpyoidc/server/__init__.py +++ b/src/idpyoidc/server/__init__.py @@ -9,6 +9,7 @@ from idpyoidc.impexp import ImpExp from idpyoidc.message.oidc import RegistrationRequest +from idpyoidc.node import Unit from idpyoidc.server import authz from idpyoidc.server.client_authn import client_auth_setup from idpyoidc.server.configure import ASConfiguration @@ -26,15 +27,15 @@ logger = logging.getLogger(__name__) -def do_endpoints(conf, server_get): +def do_endpoints(conf, upstream_get): _endpoints = conf.get("endpoint") if _endpoints: - return build_endpoints(_endpoints, server_get=server_get, issuer=conf["issuer"]) + return build_endpoints(_endpoints, upstream_get=upstream_get, issuer=conf["issuer"]) else: return {} -class Server(ImpExp): +class Server(Unit): parameter = {"endpoint": [Endpoint], "context": EndpointContext} def __init__( @@ -43,20 +44,24 @@ def __init__( keyjar: Optional[KeyJar] = None, cwd: Optional[str] = "", cookie_handler: Optional[Any] = None, - httpc: Optional[Any] = None, - parent_get: Optional[Callable] = None + httpc: Optional[Callable] = None, + upstream_get: Optional[Callable] = None, + httpc_params: Optional[dict] = None, + entity_id: Optional[str] = "", + key_conf: Optional[dict] = None ): - ImpExp.__init__(self) + Unit.__init__(self, config=conf, keyjar=keyjar, httpc=httpc, upstream_get=upstream_get, + httpc_params=httpc_params, entity_id=entity_id, key_conf=key_conf) + + self.upstream_get = upstream_get self.conf = conf self.endpoint_context = EndpointContext( conf=conf, - server_get=self.server_get, - keyjar=keyjar, + upstream_get=self.server_get, # points to me cwd=cwd, - cookie_handler=cookie_handler, - httpc=httpc, + cookie_handler=cookie_handler ) - self.parent_get = parent_get + self.endpoint_context.authz = self.setup_authz() self.setup_authentication(self.endpoint_context) @@ -80,7 +85,7 @@ def __init__( self.setup_client_authn_methods() for endpoint_name, _ in self.endpoint.items(): - self.endpoint[endpoint_name].server_get = self.server_get + self.endpoint[endpoint_name].upstream_get = self.server_get _token_endp = self.endpoint.get("token") if _token_endp: @@ -102,6 +107,15 @@ def server_get(self, what, *arg): return _func(*arg) return None + def get_attribute(self, attribute, *args): + try: + getattr(self, attribute) + except AttributeError: + if self.upstream_get: + return self.upstream_get(attribute) + else: + return None + def get_endpoints(self, *arg): return self.endpoint @@ -111,12 +125,18 @@ def get_endpoint(self, endpoint_name, *arg): except KeyError: return None + def get_context(self, *arg): + return self.endpoint_context + def get_endpoint_context(self, *arg): return self.endpoint_context def get_server(self, *args): return self + def get_entity(self, *args): + return self + def setup_authz(self): authz_spec = self.conf.get("authz") if authz_spec: diff --git a/src/idpyoidc/server/authz/__init__.py b/src/idpyoidc/server/authz/__init__.py index 2ae4bf9c..b90e5ce3 100755 --- a/src/idpyoidc/server/authz/__init__.py +++ b/src/idpyoidc/server/authz/__init__.py @@ -14,8 +14,8 @@ class AuthzHandling(object): """Class that allow an entity to manage authorization""" - def __init__(self, server_get, grant_config=None, **kwargs): - self.server_get = server_get + def __init__(self, upstream_get, grant_config=None, **kwargs): + self.upstream_get = upstream_get self.grant_config = grant_config or {} self.kwargs = kwargs @@ -29,7 +29,7 @@ def usage_rules(self, client_id: Optional[str] = ""): return _usage_rules try: - _per_client = self.server_get("context").cdb[client_id]["token_usage_rules"] + _per_client = self.upstream_get("context").cdb[client_id]["token_usage_rules"] except KeyError: pass else: @@ -61,7 +61,7 @@ def __call__( request: Union[dict, Message], resources: Optional[list] = None, ) -> Grant: - session_info = self.server_get("context").session_manager.get_session_info( + session_info = self.upstream_get("context").session_manager.get_session_info( session_id=session_id, grant=True ) grant = session_info["grant"] @@ -86,7 +86,7 @@ def __call__( if not scopes: scopes = request.get("scope", []) grant.scope = scopes - grant.claims = self.server_get("context").claims_interface.get_claims_all_usage( + grant.claims = self.upstream_get("context").claims_interface.get_claims_all_usage( session_id=session_id, scopes=scopes ) @@ -101,13 +101,13 @@ def __call__( resources: Optional[list] = None, ) -> Grant: args = self.grant_config.copy() - grant = self.server_get("context").session_manager.get_grant(session_id=session_id) + grant = self.upstream_get("context").session_manager.get_grant(session_id=session_id) for arg, val in args: setattr(grant, arg, val) return grant -def factory(msgtype, server_get, **kwargs): +def factory(msgtype, upstream_get, **kwargs): """ Factory method that can be used to easily instantiate a class instance @@ -120,6 +120,6 @@ def factory(msgtype, server_get, **kwargs): if inspect.isclass(obj) and issubclass(obj, AuthzHandling): try: if obj.__name__ == msgtype: - return obj(server_get, **kwargs) + return obj(upstream_get, **kwargs) except AttributeError: pass diff --git a/src/idpyoidc/server/client_authn.py b/src/idpyoidc/server/client_authn.py index 26d23731..a7b4c1d9 100755 --- a/src/idpyoidc/server/client_authn.py +++ b/src/idpyoidc/server/client_authn.py @@ -14,6 +14,7 @@ from cryptojwt.exception import MissingKey from cryptojwt.jwt import JWT from cryptojwt.jwt import utc_time_sans_frac +from cryptojwt.key_jar import KeyJar from cryptojwt.utils import as_bytes from cryptojwt.utils import as_unicode @@ -39,15 +40,14 @@ class ClientAuthnMethod(object): tag = None - def __init__(self, server_get): + def __init__(self, upstream_get): """ - :param server_get: A method that can be used to get general server information. + :param upstream_get: A method that can be used to get general server information. """ - self.server_get = server_get + self.upstream_get = upstream_get def _verify( self, - endpoint_context: "EndpointContext", request: Optional[Union[dict, Message]] = None, authorization_token: Optional[str] = None, endpoint=None, # Optional[Endpoint] @@ -74,7 +74,6 @@ def verify( :return: """ res = self._verify( - self.server_get("context"), request=request, authorization_token=authorization_token, endpoint=endpoint, @@ -125,7 +124,6 @@ def is_usable(self, request=None, authorization_token=None): def _verify( self, - endpoint_context: "EndpointContext", request: Optional[Union[dict, Message]] = None, authorization_token: Optional[str] = None, endpoint=None, # Optional[Endpoint] @@ -147,7 +145,6 @@ def is_usable(self, request=None, authorization_token=None): def _verify( self, - endpoint_context: "EndpointContext", request: Optional[Union[dict, Message]] = None, authorization_token: Optional[str] = None, endpoint=None, # Optional[Endpoint] @@ -172,15 +169,14 @@ def is_usable(self, request=None, authorization_token=None): def _verify( self, - endpoint_context: "EndpointContext", request: Optional[Union[dict, Message]] = None, authorization_token: Optional[str] = None, endpoint=None, # Optional[Endpoint] **kwargs, ): client_info = basic_authn(authorization_token) - - if endpoint_context.cdb[client_info["id"]]["client_secret"] == client_info["secret"]: + _context = self.upstream_get('context') + if _context.cdb[client_info["id"]]["client_secret"] == client_info["secret"]: return {"client_id": client_info["id"]} else: raise ClientAuthenticationError() @@ -205,13 +201,13 @@ def is_usable(self, request=None, authorization_token=None): def _verify( self, - endpoint_context: "EndpointContext", request: Optional[Union[dict, Message]] = None, authorization_token: Optional[str] = None, endpoint=None, # Optional[Endpoint] **kwargs, ): - if endpoint_context.cdb[request["client_id"]]["client_secret"] == request["client_secret"]: + _context = self.upstream_get('context') + if _context.cdb[request["client_id"]]["client_secret"] == request["client_secret"]: return {"client_id": request["client_id"]} else: raise ClientAuthenticationError("secrets doesn't match") @@ -229,7 +225,6 @@ def is_usable(self, request=None, authorization_token=None): def _verify( self, - endpoint_context: "EndpointContext", request: Optional[Union[dict, Message]] = None, authorization_token: Optional[str] = None, endpoint=None, # Optional[Endpoint] @@ -237,8 +232,9 @@ def _verify( **kwargs, ): token = authorization_token.split(" ", 1)[1] + _context = self.upstream_get('context') try: - client_id = get_client_id_from_token(endpoint_context, token, request) + client_id = get_client_id_from_token(_context, token, request) except ToOld: raise BearerTokenAuthenticationError("Expired token") except KeyError: @@ -259,13 +255,11 @@ def is_usable(self, request=None, authorization_token=None): return False def _verify( - self, - endpoint_context: "EndpointContext", - request: Optional[Union[dict, Message]] = None, - authorization_token: Optional[str] = None, - endpoint=None, # Optional[Endpoint] - get_client_id_from_token=None, - **kwargs, + self, + request: Optional[Union[dict, Message]] = None, + authorization_token: Optional[str] = None, + endpoint=None, # Optional[Endpoint] + **kwargs, ): _token = request.get("access_token") if _token is None: @@ -289,14 +283,15 @@ def is_usable(self, request=None, authorization_token=None): def _verify( self, - endpoint_context: "EndpointContext", request: Optional[Union[dict, Message]] = None, authorization_token: Optional[str] = None, endpoint=None, # Optional[Endpoint] key_type: Optional[str] = None, **kwargs, ): - _jwt = JWT(endpoint_context.keyjar, msg_cls=JsonWebToken) + _context = self.upstream_get('context') + _keyjar = self.upstream_get('attribute','keyjar') + _jwt = JWT(_keyjar, msg_cls=JsonWebToken) try: ca_jwt = _jwt.unpack(request["client_assertion"]) except (Invalid, MissingKey, BadSignature) as err: @@ -307,10 +302,10 @@ def _verify( if _sign_alg and _sign_alg.startswith("HS"): if key_type == "private_key": raise AttributeError("Wrong key type") - keys = endpoint_context.keyjar.get( + keys = _keyjar.get( "sig", "oct", ca_jwt["iss"], ca_jwt.jws_header.get("kid") ) - _secret = endpoint_context.cdb[ca_jwt["iss"]].get("client_secret") + _secret = _context.cdb[ca_jwt["iss"]].get("client_secret") if _secret and keys[0].key != as_bytes(_secret): raise AttributeError("Oct key used for signing not client_secret") else: @@ -321,7 +316,7 @@ def _verify( logger.debug("authntoken: {}".format(authtoken)) if endpoint is None or not endpoint: - if endpoint_context.issuer in ca_jwt["aud"]: + if _context.issuer in ca_jwt["aud"]: pass else: raise InvalidToken("Not for me!") @@ -335,10 +330,10 @@ def _verify( _jti = ca_jwt.get("jti") if _jti: _key = "{}:{}".format(ca_jwt["iss"], _jti) - if _key in endpoint_context.jti_db: + if _key in _context.jti_db: raise InvalidToken("Have seen this token once before") else: - endpoint_context.jti_db[_key] = utc_time_sans_frac() + _context.jti_db[_key] = utc_time_sans_frac() request[verified_claim_name("client_assertion")] = ca_jwt client_id = kwargs.get("client_id") or ca_jwt["iss"] @@ -358,14 +353,13 @@ class ClientSecretJWT(JWSAuthnMethod): def _verify( self, - endpoint_context: "EndpointContext", request: Optional[Union[dict, Message]] = None, authorization_token: Optional[str] = None, endpoint=None, # Optional[Endpoint] **kwargs, ): res = super()._verify( - endpoint_context, request=request, key_type="client_secret", endpoint=endpoint, **kwargs + request=request, key_type="client_secret", endpoint=endpoint, **kwargs ) # Verify that a HS alg was used return res @@ -380,14 +374,12 @@ class PrivateKeyJWT(JWSAuthnMethod): def _verify( self, - endpoint_context: "EndpointContext", request: Optional[Union[dict, Message]] = None, authorization_token: Optional[str] = None, endpoint=None, # Optional[Endpoint] **kwargs, ): res = super()._verify( - endpoint_context, request=request, authorization_token=authorization_token, endpoint=endpoint, @@ -407,13 +399,13 @@ def is_usable(self, request=None, authorization_token=None): def _verify( self, - endpoint_context: "EndpointContext", request: Optional[Union[dict, Message]] = None, authorization_token: Optional[str] = None, endpoint=None, # Optional[Endpoint] **kwargs, ): - _jwt = JWT(endpoint_context.keyjar, msg_cls=JsonWebToken) + _context = self.upstream_get('context') + _jwt = JWT(self.upstream_get('attribute', 'keyjar'), msg_cls=JsonWebToken) try: _jwt = _jwt.unpack(request["request"]) except (Invalid, MissingKey, BadSignature) as err: @@ -424,10 +416,10 @@ def _verify( _jti = _jwt.get("jti") if _jti: _key = "{}:{}".format(_jwt["iss"], _jti) - if _key in endpoint_context.jti_db: + if _key in _context.jti_db: raise InvalidToken("Have seen this token once before") else: - endpoint_context.jti_db[_key] = utc_time_sans_frac() + _context.jti_db[_key] = utc_time_sans_frac() request[verified_claim_name("client_assertion")] = _jwt client_id = kwargs.get("client_id") or _jwt["iss"] @@ -458,7 +450,6 @@ def valid_client_info(cinfo): def verify_client( - endpoint_context: "EndpointContext", request: Union[dict, Message], http_info: Optional[dict] = None, get_client_id_from_token: Optional[Callable] = None, @@ -470,7 +461,7 @@ def verify_client( :param also_known_as: :param endpoint: Endpoint instance - :param endpoint_context: EndpointContext instance + :param context: EndpointContext instance :param request: The request :param http_info: Client authentication information :param get_client_id_from_token: Function that based on a token returns a client id. @@ -486,7 +477,7 @@ def verify_client( authorization_token = None auth_info = {} - methods = endpoint_context.client_authn_method + methods = context.client_authn_method client_id = None allowed_methods = getattr(endpoint, "client_authn_method") if not allowed_methods: @@ -499,6 +490,7 @@ def verify_client( try: logger.info(f"Verifying client authentication using {_method.tag}") auth_info = _method.verify( + keyjar=keyjar, request=request, authorization_token=authorization_token, endpoint=endpoint, @@ -521,10 +513,10 @@ def verify_client( client_id = also_known_as[client_id] auth_info["client_id"] = client_id - if client_id not in endpoint_context.cdb: + if client_id not in context.cdb: raise UnknownClient("Unknown Client ID") - _cinfo = endpoint_context.cdb[client_id] + _cinfo = context.cdb[client_id] if not valid_client_info(_cinfo): logger.warning("Client registration has timed out or " "client secret is expired.") @@ -548,14 +540,14 @@ def verify_client( _request_type = request.__class__.__name__ _used_authn_method = _cinfo.get("auth_method") if _used_authn_method: - endpoint_context.cdb[client_id]["auth_method"][_request_type] = auth_info["method"] + context.cdb[client_id]["auth_method"][_request_type] = auth_info["method"] else: - endpoint_context.cdb[client_id]["auth_method"] = {_request_type: auth_info["method"]} + context.cdb[client_id]["auth_method"] = {_request_type: auth_info["method"]} return auth_info -def client_auth_setup(server_get, auth_set=None): +def client_auth_setup(upstream_get, auth_set=None): if auth_set is None: auth_set = CLIENT_AUTHN_METHOD else: @@ -565,7 +557,7 @@ def client_auth_setup(server_get, auth_set=None): for name, cls in auth_set.items(): if isinstance(cls, str): cls = importer(cls) - res[name] = cls(server_get) + res[name] = cls(upstream_get) return res diff --git a/src/idpyoidc/server/client_configure.py b/src/idpyoidc/server/client_configure.py index b2eb1acd..043bfa45 100644 --- a/src/idpyoidc/server/client_configure.py +++ b/src/idpyoidc/server/client_configure.py @@ -63,12 +63,12 @@ def verify(self, **kwargs): def verify_oidc_client_information( - conf: dict, server_get: Optional[Callable] = None, **kwargs + conf: dict, upstream_get: Optional[Callable] = None, **kwargs ) -> dict: res = {} for key, item in conf.items(): _rr = ClientConfiguration(**item) - _rr.verify(server_get=server_get, **kwargs) + _rr.verify(upstream_get=upstream_get, **kwargs) if _rr.extra(): logger.info(f"Extras: {_rr.extra()}") res[key] = _rr diff --git a/src/idpyoidc/server/configure.py b/src/idpyoidc/server/configure.py index 0feff177..0daae06a 100755 --- a/src/idpyoidc/server/configure.py +++ b/src/idpyoidc/server/configure.py @@ -172,7 +172,7 @@ def __init__( port: Optional[int] = 0, file_attributes: Optional[List[str]] = None, dir_attributes: Optional[List[str]] = None, - superior_get: Optional[Callable] = None + upstream_get: Optional[Callable] = None ): conf = copy.deepcopy(conf) diff --git a/src/idpyoidc/server/endpoint.py b/src/idpyoidc/server/endpoint.py index f41bc91e..a35da3da 100755 --- a/src/idpyoidc/server/endpoint.py +++ b/src/idpyoidc/server/endpoint.py @@ -93,8 +93,8 @@ class Endpoint(object): _supports = {} - def __init__(self, server_get: Callable, **kwargs): - self.server_get = server_get + def __init__(self, upstream_get: Callable, **kwargs): + self.upstream_get = upstream_get self.pre_construct = [] self.post_construct = [] self.post_parse_request = [] @@ -160,7 +160,8 @@ def parse_request( LOGGER.debug("- {} -".format(self.endpoint_name)) LOGGER.info("Request: %s" % sanitize(request)) - _context = self.server_get("context") + _context = self.upstream_get("context") + _keyjar = self.upstream_get('attribute', 'keyjar') if http_info is None: http_info = {} @@ -174,7 +175,7 @@ def parse_request( req = _cls_inst.deserialize( request, "jwt", - keyjar=self.server_get("keyjar"), + keyjar=_keyjar, verify=_context.httpc_params["verify"], **kwargs ) @@ -196,14 +197,12 @@ def parse_request( else: _client_id = req.get("client_id") - keyjar = self.server_get("keyjar") - # verify that the request message is correct try: if verify_args is None: - req.verify(keyjar=keyjar, opponent_id=_client_id) + req.verify(keyjar=_keyjar, opponent_id=_client_id) else: - req.verify(keyjar=keyjar, opponent_id=_client_id, **verify_args) + req.verify(keyjar=_keyjar, opponent_id=_client_id, **verify_args) except (MissingRequiredAttribute, ValueError, MissingRequiredValue, ParameterError) as err: _error = "invalid_request" if isinstance(err, ValueError) and self.request_cls == RegistrationRequest: @@ -237,7 +236,7 @@ def client_authentication(self, request: Message, http_info: Optional[dict] = No kwargs["get_client_id_from_token"] = getattr(self, "get_client_id_from_token", None) authn_info = verify_client( - endpoint_context=self.server_get("context"), + context=self.upstream_get("context"), request=request, http_info=http_info, **kwargs @@ -254,19 +253,19 @@ def client_authentication(self, request: Message, http_info: Optional[dict] = No def do_post_parse_request( self, request: Message, client_id: Optional[str] = "", **kwargs ) -> Message: - _context = self.server_get("context") + _context = self.upstream_get("context") for meth in self.post_parse_request: if isinstance(request, self.error_cls): break - request = meth(request, client_id, endpoint_context=_context, **kwargs) + request = meth(request, client_id, context=_context, **kwargs) return request def do_pre_construct( self, response_args: dict, request: Optional[Union[Message, dict]] = None, **kwargs ) -> dict: - _context = self.server_get("context") + _context = self.upstream_get("context") for meth in self.pre_construct: - response_args = meth(response_args, request, endpoint_context=_context, **kwargs) + response_args = meth(response_args, request, context=_context, **kwargs) return response_args @@ -276,9 +275,9 @@ def do_post_construct( request: Optional[Union[Message, dict]] = None, **kwargs ) -> dict: - _context = self.server_get("context") + _context = self.upstream_get("context") for meth in self.post_construct: - response_args = meth(response_args, request, endpoint_context=_context, **kwargs) + response_args = meth(response_args, request, context=_context, **kwargs) return response_args @@ -435,12 +434,12 @@ def do_response( def allowed_target_uris(self): res = [] - _context = self.server_get("context") + _context = self.upstream_get("context") for t in self.allowed_targets: if t == "": res.append(_context.issuer) else: - res.append(self.server_get("endpoint", t).full_path) + res.append(self.upstream_get("endpoint", t).full_path) return set(res) def supports(self): diff --git a/src/idpyoidc/server/endpoint_context.py b/src/idpyoidc/server/endpoint_context.py index 2c450f32..34512093 100755 --- a/src/idpyoidc/server/endpoint_context.py +++ b/src/idpyoidc/server/endpoint_context.py @@ -51,11 +51,11 @@ def init_user_info(conf, cwd: str): return conf["class"](**kwargs) -def init_service(conf, server_get=None): +def init_service(conf, upstream_get=None): kwargs = conf.get("kwargs", {}) - if server_get: - kwargs["server_get"] = server_get + if upstream_get: + kwargs["upstream_get"] = upstream_get if isinstance(conf["class"], str): try: @@ -117,18 +117,17 @@ class EndpointContext(OidcContext): def __init__( self, conf: Union[dict, OPConfiguration], - server_get: Callable, - keyjar: Optional[KeyJar] = None, + upstream_get: Callable, cwd: Optional[str] = "", cookie_handler: Optional[Any] = None, httpc: Optional[Any] = None, - server_type: Optional[str] = '' + server_type: Optional[str] = '', entity_id: Optional[str] = "" ): _id = entity_id or conf.get("issuer", "") OidcContext.__init__(self, conf, entity_id=_id) self.conf = conf - self.server_get = server_get + self.upstream_get = upstream_get if not server_type or server_type == "oidc": self.work_environment = OIDC_Env() @@ -249,7 +248,7 @@ def __init__( self.dev_auth_db = None _interface = conf.get("claims_interface") if _interface: - self.claims_interface = init_service(_interface, self.server_get) + self.claims_interface = init_service(_interface, self.upstream_get) if isinstance(conf, OPConfiguration): self.keyjar = self.work_environment.load_conf(conf.conf, supports=self.supports(), @@ -269,10 +268,10 @@ def set_scopes_handler(self): if _spec: _kwargs = _spec.get("kwargs", {}) _cls = importer(_spec["class"]) - self.scopes_handler = _cls(self.server_get, **_kwargs) + self.scopes_handler = _cls(self.upstream_get, **_kwargs) else: self.scopes_handler = Scopes( - self.server_get, + self.upstream_get, allowed_scopes=self.conf.get("allowed_scopes"), scopes_to_claims=self.conf.get("scopes_to_claims"), ) diff --git a/src/idpyoidc/server/login_hint.py b/src/idpyoidc/server/login_hint.py index 904a64a2..a6cea493 100644 --- a/src/idpyoidc/server/login_hint.py +++ b/src/idpyoidc/server/login_hint.py @@ -2,10 +2,10 @@ class LoginHintLookup(object): - def __init__(self, userinfo=None, server_get=None): + def __init__(self, userinfo=None, upstream_get=None): self.userinfo = userinfo self.default_country_code = "46" - self.server_get = server_get + self.upstream_get = upstream_get def __call__(self, arg): if arg.startswith("tel:"): @@ -25,9 +25,9 @@ class LoginHint2Acrs(object): OIDC Login hint support """ - def __init__(self, scheme_map, server_get=None): + def __init__(self, scheme_map, upstream_get=None): self.scheme_map = scheme_map - self.server_get = server_get + self.upstream_get = upstream_get def __call__(self, hint): p = urlparse(hint) diff --git a/src/idpyoidc/server/oauth2/add_on/dpop.py b/src/idpyoidc/server/oauth2/add_on/dpop.py index 84ef7d84..4bb431c8 100644 --- a/src/idpyoidc/server/oauth2/add_on/dpop.py +++ b/src/idpyoidc/server/oauth2/add_on/dpop.py @@ -84,14 +84,14 @@ def verify_header(self, dpop_header) -> Optional["DPoPProof"]: return None -def post_parse_request(request, client_id, endpoint_context, **kwargs): +def post_parse_request(request, client_id, context, **kwargs): """ Expect http_info attribute in kwargs. http_info should be a dictionary containing HTTP information. :param request: :param client_id: - :param endpoint_context: + :param context: :param kwargs: :return: """ @@ -119,14 +119,14 @@ def post_parse_request(request, client_id, endpoint_context, **kwargs): return request -def token_args(endpoint_context, client_id, token_args: Optional[dict] = None): - dpop_jkt = endpoint_context.cdb[client_id]["dpop_jkt"] +def token_args(context, client_id, token_args: Optional[dict] = None): + dpop_jkt = context.cdb[client_id]["dpop_jkt"] _jkt = list(dpop_jkt.keys())[0] - if "dpop_jkt" in endpoint_context.cdb[client_id]: + if "dpop_jkt" in context.cdb[client_id]: if token_args is None: token_args = {"cnf": {"jkt": _jkt}} else: - token_args.update({"cnf": {"jkt": endpoint_context.cdb[client_id]["dpop_jkt"]}}) + token_args.update({"cnf": {"jkt": context.cdb[client_id]["dpop_jkt"]}}) return token_args @@ -137,17 +137,17 @@ def add_support(endpoint, **kwargs): _token_endp.post_parse_request.append(post_parse_request) # Endpoint Context stuff - # _endp.endpoint_context.token_args_methods.append(token_args) + # _endp.context.token_args_methods.append(token_args) _algs_supported = kwargs.get("dpop_signing_alg_values_supported") if not _algs_supported: _algs_supported = ["RS256"] - _token_endp.server_get("context").provider_info[ + _token_endp.upstream_get("context").provider_info[ "dpop_signing_alg_values_supported" ] = _algs_supported - _endpoint_context = _token_endp.server_get("context") - _endpoint_context.dpop_enabled = True + _context = _token_endp.upstream_get("context") + _context.dpop_enabled = True # DPoP-bound access token in the "Authorization" header and the DPoP proof in the "DPoP" header @@ -163,7 +163,7 @@ def is_usable(self, request=None, authorization_info=None, http_headers=None): def verify(self, authorization_info, **kwargs): client_info = basic_authn(authorization_info) - _context = self.server_get("context") + _context = self.upstream_get("context") if _context.cdb[client_info["id"]]["client_secret"] == client_info["secret"]: return {"client_id": client_info["id"]} else: diff --git a/src/idpyoidc/server/oauth2/add_on/extra_args.py b/src/idpyoidc/server/oauth2/add_on/extra_args.py index dba819b4..335362ae 100644 --- a/src/idpyoidc/server/oauth2/add_on/extra_args.py +++ b/src/idpyoidc/server/oauth2/add_on/extra_args.py @@ -5,18 +5,18 @@ from idpyoidc.message.oidc import OpenIDSchema -def pre_construct(response_args, request, endpoint_context, **kwargs): +def pre_construct(response_args, request, context, **kwargs): """ Add extra arguments to the request. :param response_args: :param request: - :param endpoint_context: + :param context: :param kwargs: :return: """ - _extra = endpoint_context.add_on.get("extra_args") + _extra = context.add_on.get("extra_args") if _extra: if isinstance(response_args, AuthorizationResponse): _args = _extra.get("authorization", {}) @@ -32,7 +32,7 @@ def pre_construct(response_args, request, endpoint_context, **kwargs): _args = {} for arg, _param in _args.items(): - _val = getattr(endpoint_context, _param) + _val = getattr(context, _param) if _val: response_args[arg] = _val @@ -47,5 +47,5 @@ def add_support(endpoint, **kwargs): _endp.pre_construct.append(pre_construct) if _added is False: - _endp.server_get("context").add_on["extra_args"] = kwargs + _endp.upstream_get("context").add_on["extra_args"] = kwargs _added = True diff --git a/src/idpyoidc/server/oauth2/authorization.py b/src/idpyoidc/server/oauth2/authorization.py index 3ae7991d..60d26c0f 100755 --- a/src/idpyoidc/server/oauth2/authorization.py +++ b/src/idpyoidc/server/oauth2/authorization.py @@ -93,7 +93,7 @@ def max_age(request): def verify_uri( - endpoint_context: EndpointContext, + context: EndpointContext, request: Union[dict, Message], uri_type: str, client_id: Optional[str] = None, @@ -103,7 +103,7 @@ def verify_uri( MUST NOT contain a fragment MAY contain query component - :param endpoint_context: An EndpointContext instance + :param context: An EndpointContext instance :param request: The authorization request :param uri_type: redirect_uri or post_logout_redirect_uri :return: An error response if the redirect URI is faulty otherwise @@ -128,7 +128,7 @@ def verify_uri( (_base, _query) = split_uri(_redirect_uri) # Get the clients registered redirect uris - client_info = endpoint_context.cdb.get(_cid) + client_info = context.cdb.get(_cid) if client_info is None: raise KeyError("No such client") @@ -194,10 +194,10 @@ def join_query(base, query): return base -def get_uri(endpoint_context, request, uri_type): +def get_uri(context, request, uri_type): """verify that the redirect URI is reasonable. - :param endpoint_context: An EndpointContext instance + :param context: An EndpointContext instance :param request: The Authorization request :param uri_type: 'redirect_uri' or 'post_logout_redirect_uri' :return: redirect_uri @@ -205,13 +205,13 @@ def get_uri(endpoint_context, request, uri_type): uri = "" if uri_type in request: - verify_uri(endpoint_context, request, uri_type) + verify_uri(context, request, uri_type) uri = request[uri_type] else: uris = f"{uri_type}s" client_id = str(request["client_id"]) - if client_id in endpoint_context.cdb: - _specs = endpoint_context.cdb[client_id].get(uris) + if client_id in context.cdb: + _specs = context.cdb[client_id].get(uris) if not _specs: raise ParameterError(f"Missing '{uri_type}' and none registered") @@ -267,12 +267,12 @@ def authn_args_gather( return authn_args -def check_unknown_scopes_policy(request_info, client_id, endpoint_context): - if not endpoint_context.conf["capabilities"].get("deny_unknown_scopes"): +def check_unknown_scopes_policy(request_info, client_id, context): + if not context.conf["capabilities"].get("deny_unknown_scopes"): return scope = request_info["scope"] - filtered_scopes = set(endpoint_context.scopes_handler.filter_scopes(scope, client_id=client_id)) + filtered_scopes = set(context.scopes_handler.filter_scopes(scope, client_id=client_id)) scopes = set(scope) # this prevents that authz would be released for unavailable scopes if scopes != filtered_scopes: @@ -354,15 +354,13 @@ class Authorization(Endpoint): "client_authn_method": ["request_param", "public"], } - def __init__(self, server_get, **kwargs): - Endpoint.__init__(self, server_get, **kwargs) - - self.resource_indicators_config = kwargs.get("resource_indicators", None) + def __init__(self, upstream_get, **kwargs): + Endpoint.__init__(self, upstream_get, **kwargs) self.post_parse_request.append(self._do_request_uri) self.post_parse_request.append(self._post_parse_request) self.allowed_request_algorithms = AllowedAlgorithms(ALG_PARAMS) - def filter_request(self, endpoint_context, req): + def filter_request(self, context, req): return req def extra_response_args(self, aresp): @@ -389,7 +387,7 @@ def mint_token(self, token_class, grant, session_id, based_on=None, **kwargs): usage_rules = grant.usage_rules.get(token_class, {}) token = grant.mint_token( session_id=session_id, - endpoint_context=self.server_get("context"), + context=self.upstream_get("context"), token_class=token_class, based_on=based_on, usage_rules=usage_rules, @@ -402,32 +400,32 @@ def mint_token(self, token_class, grant, session_id, based_on=None, **kwargs): if _exp_in: token.expires_at = utc_time_sans_frac() + _exp_in - _mngr = self.server_get("context").session_manager + _mngr = self.upstream_get("context").session_manager _mngr.set(_mngr.unpack_session_key(session_id), grant) return token - def _do_request_uri(self, request, client_id, endpoint_context, **kwargs): + def _do_request_uri(self, request, client_id, context, **kwargs): _request_uri = request.get("request_uri") if _request_uri: # Do I do pushed authorization requests ? - _endp = self.server_get("endpoint", "pushed_authorization") + _endp = self.upstream_get("endpoint", "pushed_authorization") if _endp: # Is it a UUID urn if _request_uri.startswith("urn:uuid:"): - _req = endpoint_context.par_db.get(_request_uri) + _req = context.par_db.get(_request_uri) if _req: # One time usage - del endpoint_context.par_db[_request_uri] + del context.par_db[_request_uri] return _req else: raise ValueError("Got a request_uri I can not resolve") # Do I support request_uri ? - if endpoint_context.provider_info.get("request_uri_parameter_supported", True) is False: + if context.provider_info.get("request_uri_parameter_supported", True) is False: raise ServiceError("Someone is using request_uri which I'm not supporting") - _registered = endpoint_context.cdb[client_id].get("request_uris") + _registered = context.cdb[client_id].get("request_uris") # Not registered should be handled else where if _registered: # Before matching remove a possible fragment @@ -437,26 +435,29 @@ def _do_request_uri(self, request, client_id, endpoint_context, **kwargs): raise ValueError("A request_uri outside the registered") # Fetch the request - _resp = endpoint_context.httpc.get(_request_uri, **endpoint_context.httpc_params) + _resp = context.httpc.get(_request_uri, **context.httpc_params) if _resp.status_code == 200: - args = {"keyjar": endpoint_context.keyjar, "issuer": client_id} + args = { + "keyjar": self.upstream_get('attribute', 'keyjar'), + "issuer": client_id + } _ver_request = self.request_cls().from_jwt(_resp.text, **args) self.allowed_request_algorithms( client_id, - endpoint_context, + context, _ver_request.jws_header.get("alg", "RS256"), "sign", ) if _ver_request.jwe_header is not None: self.allowed_request_algorithms( client_id, - endpoint_context, + context, _ver_request.jws_header.get("alg"), "enc_alg", ) self.allowed_request_algorithms( client_id, - endpoint_context, + context, _ver_request.jws_header.get("enc"), "enc_enc", ) @@ -470,11 +471,11 @@ def _do_request_uri(self, request, client_id, endpoint_context, **kwargs): return request - def _post_parse_request(self, request, client_id, endpoint_context, **kwargs): + def _post_parse_request(self, request, client_id, context, **kwargs): """ Verify the authorization request. - :param endpoint_context: + :param context: :param request: :param client_id: :param kwargs: @@ -486,9 +487,9 @@ def _post_parse_request(self, request, client_id, endpoint_context, **kwargs): request, error="invalid_request", error_description="Can not parse AuthzRequest" ) - request = self.filter_request(endpoint_context, request) + request = self.filter_request(context, request) - _cinfo = endpoint_context.cdb.get(client_id) + _cinfo = context.cdb.get(client_id) if not _cinfo: logger.error("Client ID ({}) not in client database".format(request["client_id"])) return self.authentication_error_response( @@ -505,7 +506,7 @@ def _post_parse_request(self, request, client_id, endpoint_context, **kwargs): # Get a verified redirect URI try: - redirect_uri = get_uri(endpoint_context, request, "redirect_uri") + redirect_uri = get_uri(context, request, "redirect_uri") except (RedirectURIError, ParameterError) as err: return self.authentication_error_response( request, @@ -555,7 +556,7 @@ def _enforce_resource_indicators_policy(self, request, config): return self.error_cls(error="server_error", error_description="Internal server error") def pick_authn_method(self, request, redirect_uri, acr=None, **kwargs): - _context = self.server_get("context") + _context = self.upstream_get("context") auth_id = kwargs.get("auth_method_id") if auth_id: return _context.authn_broker[auth_id] @@ -579,7 +580,7 @@ def pick_authn_method(self, request, redirect_uri, acr=None, **kwargs): } def create_session(self, request, user_id, acr, time_stamp, authn_method): - _context = self.server_get("context") + _context = self.upstream_get("context") _mngr = _context.session_manager authn_event = create_authn_event( user_id, @@ -657,7 +658,7 @@ def setup_auth( authn_class_ref = res["acr"] client_id = request.get("client_id") - _context = self.server_get("context") + _context = self.upstream_get("context") try: _auth_info = kwargs.get("authn", "") if "upm_answer" in request and request["upm_answer"] == "true": @@ -837,7 +838,7 @@ def create_authn_response(self, request: Union[dict, Message], sid: str) -> dict if "response_type" in request and request["response_type"] == ["none"]: fragment_enc = False else: - _context = self.server_get("context") + _context = self.upstream_get("context") _mngr = _context.session_manager _sinfo = _mngr.get_session_info(sid, grant=True) @@ -944,7 +945,7 @@ def post_authentication(self, request: Union[dict, Message], session_id: str, ** """ response_info = {} - _context = self.server_get("context") + _context = self.upstream_get("context") _mngr = _context.session_manager # Do the authorization @@ -1013,7 +1014,7 @@ def authz_part2(self, request, session_id, **kwargs): except Exception as err: return self.error_by_response_mode({}, request, "server_error", err) - _context = self.server_get("context") + _context = self.upstream_get("context") logger.debug(f"resp_info: {resp_info}") @@ -1089,7 +1090,7 @@ def process_request( return request _cid = request["client_id"] - _context = self.server_get("context") + _context = self.upstream_get("context") cinfo = _context.cdb[_cid] # logger.debug("client {}: {}".format(_cid, cinfo)) @@ -1142,9 +1143,9 @@ class AllowedAlgorithms: def __init__(self, algorithm_parameters): self.algorithm_parameters = algorithm_parameters - def __call__(self, client_id, endpoint_context, alg, alg_type): - _cinfo = endpoint_context.cdb[client_id] - _pinfo = endpoint_context.provider_info + def __call__(self, client_id, context, alg, alg_type): + _cinfo = context.cdb[client_id] + _pinfo = context.provider_info _reg, _sup = self.algorithm_parameters[alg_type] _allowed = _cinfo.get(_reg) diff --git a/src/idpyoidc/server/oauth2/introspection.py b/src/idpyoidc/server/oauth2/introspection.py index 75b043d9..2ecf3b00 100644 --- a/src/idpyoidc/server/oauth2/introspection.py +++ b/src/idpyoidc/server/oauth2/introspection.py @@ -29,8 +29,8 @@ class Introspection(Endpoint): ] } - def __init__(self, server_get, **kwargs): - Endpoint.__init__(self, server_get, **kwargs) + def __init__(self, upstream_get, **kwargs): + Endpoint.__init__(self, upstream_get, **kwargs) self.offset = kwargs.get("offset", 0) def _introspect(self, token, client_id, grant): @@ -52,7 +52,7 @@ def _introspect(self, token, client_id, grant): if not aud: aud = grant.resources - _context = self.server_get("context") + _context = self.upstream_get("context") ret = { "active": True, "scope": " ".join(scope), @@ -98,7 +98,7 @@ def process_request(self, request=None, release: Optional[list] = None, **kwargs request_token = _introspect_request["token"] _resp = self.response_cls(active=False) - _context = self.server_get("context") + _context = self.upstream_get("context") try: _session_info = _context.session_manager.get_session_info_by_token( diff --git a/src/idpyoidc/server/oauth2/pushed_authorization.py b/src/idpyoidc/server/oauth2/pushed_authorization.py index 5a6dd7fe..40d319d8 100644 --- a/src/idpyoidc/server/oauth2/pushed_authorization.py +++ b/src/idpyoidc/server/oauth2/pushed_authorization.py @@ -14,8 +14,8 @@ class PushedAuthorization(Authorization): response_format = "json" name = "pushed_authorization" - def __init__(self, server_get, **kwargs): - Authorization.__init__(self, server_get, **kwargs) + def __init__(self, upstream_get, **kwargs): + Authorization.__init__(self, upstream_get, **kwargs) # self.pre_construct.append(self._pre_construct) self.post_parse_request.append(self._post_parse_request) self.ttl = kwargs.get("ttl", 3600) @@ -29,7 +29,7 @@ def process_request(self, request=None, **kwargs): # create URN _urn = "urn:uuid:{}".format(uuid.uuid4()) - self.server_get("context").par_db[_urn] = request + self.upstream_get("context").par_db[_urn] = request return { "http_response": {"request_uri": _urn, "expires_in": self.ttl}, diff --git a/src/idpyoidc/server/oauth2/token.py b/src/idpyoidc/server/oauth2/token.py index 652ec463..2ba4ecf5 100755 --- a/src/idpyoidc/server/oauth2/token.py +++ b/src/idpyoidc/server/oauth2/token.py @@ -38,8 +38,8 @@ class Token(Endpoint): "refresh_token": RefreshTokenHelper, } - def __init__(self, server_get, new_refresh_token=False, **kwargs): - Endpoint.__init__(self, server_get, **kwargs) + def __init__(self, upstream_get, new_refresh_token=False, **kwargs): + Endpoint.__init__(self, upstream_get, **kwargs) self.post_parse_request.append(self._post_parse_request) self.allow_refresh = False self.new_refresh_token = new_refresh_token @@ -132,7 +132,7 @@ def process_request(self, request: Optional[Union[Message, dict]] = None, **kwar return response_args _access_token = response_args["access_token"] - _context = self.server_get("context") + _context = self.upstream_get("context") if isinstance(_helper, TokenExchangeHelper): _handler_key = _helper.get_handler_key(request, _context) diff --git a/src/idpyoidc/server/oauth2/token_helper.py b/src/idpyoidc/server/oauth2/token_helper.py index 576a06d7..f55296df 100755 --- a/src/idpyoidc/server/oauth2/token_helper.py +++ b/src/idpyoidc/server/oauth2/token_helper.py @@ -62,7 +62,7 @@ def _mint_token( token_args: Optional[dict] = None, token_type: Optional[str] = "", ) -> SessionToken: - _context = self.endpoint.server_get("context") + _context = self.endpoint.upstream_get("context") _mngr = _context.session_manager usage_rules = grant.usage_rules.get(token_class) if usage_rules: @@ -81,7 +81,7 @@ def _mint_token( token = grant.mint_token( session_id, - endpoint_context=_context, + context=_context, token_class=token_class, token_handler=_mngr.token_handler[token_class], based_on=based_on, @@ -159,7 +159,7 @@ def process_request(self, req: Union[Message, dict], **kwargs): :param kwargs: :return: """ - _context = self.endpoint.server_get("context") + _context = self.endpoint.upstream_get("context") _mngr = _context.session_manager logger.debug("Access Token") @@ -310,7 +310,7 @@ def post_parse_request( :returns: """ - _mngr = self.endpoint.server_get("context").session_manager + _mngr = self.endpoint.upstream_get("context").session_manager try: _session_info = _mngr.get_session_info_by_token( request["code"], grant=True, handler_key="authorization_code" @@ -339,7 +339,7 @@ def post_parse_request( class RefreshTokenHelper(TokenEndpointHelper): def process_request(self, req: Union[Message, dict], **kwargs): - _context = self.endpoint.server_get("context") + _context = self.endpoint.upstream_get("context") _mngr = _context.session_manager logger.debug("Refresh Token") @@ -433,13 +433,11 @@ def post_parse_request( """ request = RefreshAccessTokenRequest(**request.to_dict()) - _context = self.endpoint.server_get("context") - try: - keyjar = _context.keyjar - except AttributeError: - keyjar = "" + _context = self.endpoint.upstream_get("context") - request.verify(keyjar=keyjar, opponent_id=client_id) + request.verify( + keyjar=self.endpoint.upstream_get('sttribute', 'keyjar'), + opponent_id=client_id) _mngr = _context.session_manager try: @@ -498,19 +496,17 @@ def __init__(self, endpoint, config=None): def post_parse_request(self, request, client_id="", **kwargs): request = TokenExchangeRequest(**request.to_dict()) - _context = self.endpoint.server_get("context") + _context = self.endpoint.upstream_get("context") if "token_exchange" in _context.cdb[request["client_id"]]: config = _context.cdb[request["client_id"]]["token_exchange"] else: config = self.config try: - keyjar = _context.keyjar - except AttributeError: - keyjar = "" - - try: - request.verify(keyjar=keyjar, opponent_id=client_id) + request.verify( + keyjar=self.endpoint.upstream_get('attribute', 'keyjar'), + opponent_id=client_id + ) except ( MissingRequiredAttribute, ValueError, @@ -567,7 +563,7 @@ def post_parse_request(self, request, client_id="", **kwargs): return resp def _enforce_policy(self, request, token, config): - _context = self.endpoint.server_get("context") + _context = self.endpoint.upstream_get("context") subject_token_types_supported = config.get( "subject_token_types_supported", self.token_types_mapping.keys() ) @@ -638,7 +634,7 @@ def token_exchange_response(self, token, issued_token_type): return TokenExchangeResponse(**response_args) def process_request(self, request, **kwargs): - _context = self.endpoint.server_get("context") + _context = self.endpoint.upstream_get("context") _mngr = _context.session_manager try: _handler_key = self.token_types_mapping[request["subject_token_type"]] diff --git a/src/idpyoidc/server/oidc/add_on/custom_scopes.py b/src/idpyoidc/server/oidc/add_on/custom_scopes.py index 299f619c..8fe6d59a 100644 --- a/src/idpyoidc/server/oidc/add_on/custom_scopes.py +++ b/src/idpyoidc/server/oidc/add_on/custom_scopes.py @@ -18,7 +18,7 @@ def add_custom_scopes(endpoint, **kwargs): _scopes2claims = SCOPE2CLAIMS.copy() _scopes2claims.update(kwargs) - _context = _endpoint.server_get("context") + _context = _endpoint.upstream_get("context") _context.scopes_handler.set_scopes_mapping(_scopes2claims) pi = _context.provider_info diff --git a/src/idpyoidc/server/oidc/add_on/pkce.py b/src/idpyoidc/server/oidc/add_on/pkce.py index 958fd1cd..01952ecb 100644 --- a/src/idpyoidc/server/oidc/add_on/pkce.py +++ b/src/idpyoidc/server/oidc/add_on/pkce.py @@ -30,20 +30,20 @@ def wrapper(code_verifier): } -def post_authn_parse(request, client_id, endpoint_context, **kwargs): +def post_authn_parse(request, client_id, context, **kwargs): """ :param request: :param client_id: - :param endpoint_context: + :param context: :param kwargs: :return: """ - client = endpoint_context.cdb[client_id] + client = context.cdb[client_id] if "pkce_essential" in client: essential = client["pkce_essential"] else: - essential = endpoint_context.args["pkce"].get("essential", False) + essential = context.args["pkce"].get("essential", False) if essential and "code_challenge" not in request: return AuthorizationErrorResponse( error="invalid_request", @@ -55,7 +55,7 @@ def post_authn_parse(request, client_id, endpoint_context, **kwargs): if "code_challenge" in request and ( request["code_challenge_method"] - not in endpoint_context.args["pkce"]["code_challenge_methods"] + not in context.args["pkce"]["code_challenge_methods"] ): return AuthorizationErrorResponse( error="invalid_request", @@ -84,7 +84,7 @@ def verify_code_challenge(code_verifier, code_challenge, code_challenge_method=" return True -def post_token_parse(request, client_id, endpoint_context, **kwargs): +def post_token_parse(request, client_id, context, **kwargs): """ To be used as a post_parse_request function. @@ -98,7 +98,7 @@ def post_token_parse(request, client_id, endpoint_context, **kwargs): return request try: - _session_info = endpoint_context.session_manager.get_session_info_by_token( + _session_info = context.session_manager.get_session_info_by_token( request["code"], grant=True, handler_key="authorization_code" ) except KeyError: @@ -147,4 +147,4 @@ def add_pkce_support(endpoint: Dict[str, Endpoint], **kwargs): raise ValueError("Unsupported method: {}".format(method)) kwargs["code_challenge_methods"][method] = CC_METHOD[method] - authn_endpoint.server_get("context").args["pkce"] = kwargs + authn_endpoint.upstream_get("context").args["pkce"] = kwargs diff --git a/src/idpyoidc/server/oidc/authorization.py b/src/idpyoidc/server/oidc/authorization.py index bf09e56d..ab6d074b 100755 --- a/src/idpyoidc/server/oidc/authorization.py +++ b/src/idpyoidc/server/oidc/authorization.py @@ -90,8 +90,8 @@ class Authorization(authorization.Authorization): "subject_types_supported": ["public", "pairwise", "ephemeral"], } - def __init__(self, server_get: Callable, **kwargs): - authorization.Authorization.__init__(self, server_get, **kwargs) + def __init__(self, upstream_get: Callable, **kwargs): + authorization.Authorization.__init__(self, upstream_get, **kwargs) # self.pre_construct.append(self._pre_construct) self.post_parse_request.append(self._do_request_uri) self.post_parse_request.append(self._post_parse_request) @@ -102,7 +102,7 @@ def do_request_user(self, request_info, **kwargs): else: _login_hint = request_info.get("login_hint") if _login_hint: - _context = self.server_get("context") + _context = self.upstream_get("context") if _context.login_hint_lookup: kwargs["req_user"] = _context.login_hint_lookup(_login_hint) return kwargs diff --git a/src/idpyoidc/server/oidc/backchannel_authentication.py b/src/idpyoidc/server/oidc/backchannel_authentication.py index aeb069a8..6607cebe 100644 --- a/src/idpyoidc/server/oidc/backchannel_authentication.py +++ b/src/idpyoidc/server/oidc/backchannel_authentication.py @@ -43,8 +43,8 @@ class BackChannelAuthentication(Endpoint): "backchannel_user_code_parameter_supported": True, } - def __init__(self, server_get: Callable, **kwargs): - Endpoint.__init__(self, server_get, **kwargs) + def __init__(self, upstream_get: Callable, **kwargs): + Endpoint.__init__(self, upstream_get, **kwargs) # self.pre_construct.append(self._pre_construct) # self.post_parse_request.append(self._do_request_uri) # self.post_parse_request.append(self._post_parse_request) @@ -60,14 +60,14 @@ def do_request_user(self, request): elif request.get("login_hint"): _login_hint = request.get("login_hint") if _login_hint: - _context = self.server_get("context") + _context = self.upstream_get("context") if _context.login_hint_lookup: _request_user = _context.login_hint_lookup(_login_hint) elif request.get("login_hint_token"): - _context = self.server_get("context") + _context = self.upstream_get("context") _request_user = execute( self.parse_login_hint_token, - keyjar=_context.keyjar, + keyjar=self.upstream_get('attribute', 'keyjar'), login_hint_token=request.get("login_hint_token"), context=_context, ) @@ -79,10 +79,10 @@ def allowed_target_uris(self): The OP MUST accept its Issuer Identifier, Token Endpoint URL, or Backchannel Authentication Endpoint URL as values that identify it as an intended audience. """ - _context = self.server_get("context") + _context = self.upstream_get("context") res = [_context.issuer] res.append(self.full_path) - res.append(self.server_get("endpoint", "token").full_path) + res.append(self.upstream_get("endpoint", "token").full_path) return set(res) def process_request( @@ -101,7 +101,7 @@ def process_request( return _error_msg if request_user: # Got a request for a legitimate user, create a session - _context = self.server_get("context") + _context = self.upstream_get("context") _sid = _context.session_manager.create_session( None, request, request_user, client_id=request["client_id"] ) @@ -139,7 +139,7 @@ def _get_session_info(self, request, session_manager): def post_parse_request( self, request: Union[Message, dict], client_id: Optional[str] = "", **kwargs ) -> Union[Message, dict]: - _context = self.endpoint.server_get("context") + _context = self.endpoint.upstream_get("context") _mngr = _context.session_manager _session_id = _mngr.auth_req_id_map[request["auth_req_id"]] _info = _mngr.get_session_info(_session_id) @@ -180,7 +180,7 @@ def process_request(self, req: Union[Message, dict], **kwargs): :param kwargs: :return: """ - _context = self.endpoint.server_get("context") + _context = self.endpoint.upstream_get("context") _mngr = _context.session_manager logger.debug("OIDC Access Token") @@ -299,8 +299,8 @@ class ClientNotification(Endpoint): "backchannel_client_notification_endpoint": None, } - def __init__(self, server_get: Callable, **kwargs): - Endpoint.__init__(self, server_get, **kwargs) + def __init__(self, upstream_get: Callable, **kwargs): + Endpoint.__init__(self, upstream_get, **kwargs) def process_request( self, @@ -323,7 +323,7 @@ def is_usable(self, request=None, authorization_token=None): def _verify( self, - endpoint_context: EndpointContext, + context: EndpointContext, request: Optional[Union[dict, Message]] = None, authorization_token: Optional[str] = None, endpoint=None, # Optional[Endpoint] diff --git a/src/idpyoidc/server/oidc/discovery.py b/src/idpyoidc/server/oidc/discovery.py index a14643d6..70767c75 100755 --- a/src/idpyoidc/server/oidc/discovery.py +++ b/src/idpyoidc/server/oidc/discovery.py @@ -37,5 +37,5 @@ def do_response(self, response_args=None, request=None, **kwargs): def process_request(self, request=None, **kwargs): return { "subject": request["resource"], - "hrefs": [self.server_get("context").issuer], + "hrefs": [self.upstream_get("context").issuer], } diff --git a/src/idpyoidc/server/oidc/provider_config.py b/src/idpyoidc/server/oidc/provider_config.py index abdfd20b..361d5195 100755 --- a/src/idpyoidc/server/oidc/provider_config.py +++ b/src/idpyoidc/server/oidc/provider_config.py @@ -14,11 +14,11 @@ class ProviderConfiguration(Endpoint): name = "provider_config" # _supports = {"require_request_uri_registration": None} - def __init__(self, server_get, **kwargs): - Endpoint.__init__(self, server_get=server_get, **kwargs) + def __init__(self, upstream_get, **kwargs): + Endpoint.__init__(self, upstream_get=upstream_get, **kwargs) self.pre_construct.append(self.add_endpoints) - def add_endpoints(self, info, client_id, endpoint_context, **kwargs): + def add_endpoints(self, request, client_id, context, **kwargs): for endpoint in [ "authorization", "provider_config", @@ -26,11 +26,11 @@ def add_endpoints(self, info, client_id, endpoint_context, **kwargs): "userinfo", "session", ]: - endp_instance = self.server_get("endpoint", endpoint) + endp_instance = self.upstream_get("endpoint", endpoint) if endp_instance: info[endp_instance.endpoint_name] = endp_instance.full_path return info def process_request(self, request=None, **kwargs): - return {"response_args": self.server_get("context").provider_info} + return {"response_args": self.upstream_get("context").provider_info} diff --git a/src/idpyoidc/server/oidc/read_registration.py b/src/idpyoidc/server/oidc/read_registration.py index 492532ef..e824415c 100644 --- a/src/idpyoidc/server/oidc/read_registration.py +++ b/src/idpyoidc/server/oidc/read_registration.py @@ -14,17 +14,17 @@ class RegistrationRead(Endpoint): response_format = "json" name = "registration_read" - def get_client_id_from_token(self, endpoint_context, token, request=None): + def get_client_id_from_token(self, context, token, request=None): if "client_id" in request: if ( request["client_id"] - == self.server_get("context").registration_access_token[token] + == self.upstream_get("context").registration_access_token[token] ): return request["client_id"] return "" def process_request(self, request=None, **kwargs): - _cli_info = self.server_get("context").cdb[request["client_id"]] + _cli_info = self.upstream_get("context").cdb[request["client_id"]] args = {k: v for k, v in _cli_info.items() if k in RegistrationResponse.c_param} comb_uri(args) return {"response_args": RegistrationResponse(**args)} diff --git a/src/idpyoidc/server/oidc/registration.py b/src/idpyoidc/server/oidc/registration.py index 395ecbe9..e213f904 100755 --- a/src/idpyoidc/server/oidc/registration.py +++ b/src/idpyoidc/server/oidc/registration.py @@ -7,8 +7,10 @@ from urllib.parse import urlencode from urllib.parse import urlparse +from cryptojwt.jws.utils import alg2keytype from cryptojwt.utils import as_bytes +from idpyoidc.client.oidc import PREFERENCE2PROVIDER # from idpyoidc.defaults import PREFERENCE2SUPPORTED from idpyoidc.client.work_environment.transform import REGISTER2PREFERRED @@ -140,7 +142,7 @@ def __init__(self, *args, **kwargs): def match_client_request(self, request: dict) -> list: err = [] - _provider_info = self.server_get("context").provider_info + _provider_info = self.upstream_get("context").provider_info for key, val in request.items(): if key not in REGISTER2PREFERRED: continue @@ -158,7 +160,7 @@ def match_client_request(self, request: dict) -> list: def do_client_registration(self, request, client_id, ignore=None): if ignore is None: ignore = [] - _context = self.server_get("context") + _context = self.upstream_get("context") _cinfo = _context.cdb[client_id].copy() logger.debug("_cinfo: %s" % sanitize(_cinfo)) @@ -221,6 +223,23 @@ def do_client_registration(self, request, client_id, ignore=None): error_description="%s pointed to illegal URL" % item, ) + _keyjar = self.upstream_get('attribute', 'keyjar') + # Do I have the necessary keys + for item in ["id_token_signed_response_alg", "userinfo_signed_response_alg"]: + if item in request: + if request[item] in _context.provider_info[PREFERENCE2PROVIDER[item]]: + ktyp = alg2keytype(request[item]) + # do I have this ktyp and for EC type keys the curve + if ktyp not in ["none", "oct"]: + _k = [] + for iss in ["", _context.issuer]: + _k.extend( + _keyjar.get_signing_key(ktyp, alg=request[item], issuer_id=iss) + ) + if not _k: + logger.warning('Lacking support for "{}"'.format(request[item])) + del _cinfo[item] + t = {"jwks_uri": "", "jwks": None} for item in ["jwks_uri", "jwks"]: @@ -229,10 +248,10 @@ def do_client_registration(self, request, client_id, ignore=None): # if it can't load keys because the URL is false it will # just silently fail. Waiting for better times. - _context.keyjar.load_keys(client_id, jwks_uri=t["jwks_uri"], jwks=t["jwks"]) + _keyjar.load_keys(client_id, jwks_uri=t["jwks_uri"], jwks=t["jwks"]) n_keys = 0 - for kb in _context.keyjar.get(client_id, []): + for kb in _keyjar.get(client_id, []): n_keys += len(kb.keys()) msg = "found {} keys for client_id={}" logger.debug(msg.format(n_keys, client_id)) @@ -298,8 +317,8 @@ def _verify_sector_identifier(self, request): """ si_url = request["sector_identifier_uri"] try: - res = self.server_get("context").httpc.get( - si_url, **self.server_get("context").httpc_params + res = self.upstream_get("context").httpc.get( + si_url, **self.upstream_get("context").httpc_params ) logger.debug("sector_identifier_uri => %s", sanitize(res.text)) except Exception as err: @@ -324,7 +343,7 @@ def add_registration_api(self, cinfo, client_id, context): _rat = rndstr(32) cinfo["registration_access_token"] = _rat - endpoint = self.server_get("endpoints") + endpoint = self.upstream_get("endpoints") cinfo["registration_client_uri"] = "{}?client_id={}".format( endpoint["registration_read"].full_path, client_id ) @@ -370,7 +389,7 @@ def client_registration_setup(self, request, new_id=True, set_secret=True): error_description=f"Don't support proposed {faulty_claims}" ) - _context = self.server_get("context") + _context = self.upstream_get("context") if new_id: if self.kwargs.get("client_id_generator"): cid_generator = importer(self.kwargs["client_id_generator"]["class"]) @@ -388,7 +407,7 @@ def client_registration_setup(self, request, new_id=True, set_secret=True): _cinfo = {"client_id": client_id, "client_salt": rndstr(8)} - if self.server_get("endpoint", "registration_read"): + if self.upstream_get("endpoint", "registration_read"): self.add_registration_api(_cinfo, client_id, _context) if new_id: @@ -416,7 +435,7 @@ def client_registration_setup(self, request, new_id=True, set_secret=True): # Add the client_secret as a symmetric key to the key jar if client_secret: - _context.keyjar.add_symmetric(client_id, str(client_secret)) + self.upstream_get('attribute', 'keyjar').add_symmetric(client_id, str(client_secret)) logger.debug("Stored updated client info in CDB under cid={}".format(client_id)) logger.debug("ClientInfo: {}".format(_cinfo)) @@ -443,7 +462,7 @@ def process_request(self, request=None, new_id=True, set_secret=True, **kwargs): if "error" in reg_resp: return reg_resp else: - _context = self.server_get("context") + _context = self.upstream_get("context") _cookie = _context.new_cookie( name=_context.cookie_handler.name["register"], client_id=reg_resp["client_id"], diff --git a/src/idpyoidc/server/oidc/session.py b/src/idpyoidc/server/oidc/session.py index 5743e12c..97182772 100644 --- a/src/idpyoidc/server/oidc/session.py +++ b/src/idpyoidc/server/oidc/session.py @@ -89,21 +89,21 @@ class Session(Endpoint): "check_session_iframe": None, } - def __init__(self, server_get, **kwargs): + def __init__(self, upstream_get, **kwargs): _csi = kwargs.get("check_session_iframe") if _csi and not _csi.startswith("http"): - kwargs["check_session_iframe"] = add_path(server_get("context").issuer, _csi) - Endpoint.__init__(self, server_get, **kwargs) + kwargs["check_session_iframe"] = add_path(upstream_get("context").issuer, _csi) + Endpoint.__init__(self, upstream_get, **kwargs) self.iv = as_bytes(rndstr(24)) def _encrypt_sid(self, sid): - encrypter = AES_GCMEncrypter(key=as_bytes(self.server_get("context").symkey)) + encrypter = AES_GCMEncrypter(key=as_bytes(self.upstream_get("context").symkey)) enc_msg = encrypter.encrypt(as_bytes(sid), iv=self.iv) return as_unicode(b64e(enc_msg)) def _decrypt_sid(self, enc_msg): _msg = b64d(as_bytes(enc_msg)) - encrypter = AES_GCMEncrypter(key=as_bytes(self.server_get("context").symkey)) + encrypter = AES_GCMEncrypter(key=as_bytes(self.upstream_get("context").symkey)) ctx, tag = split_ctx_and_tag(_msg) return as_unicode(encrypter.decrypt(as_bytes(ctx), iv=self.iv, tag=as_bytes(tag))) @@ -115,7 +115,7 @@ def do_back_channel_logout(self, cinfo, sid): :return: Tuple with logout URI and signed logout token """ - _context = self.server_get("context") + _context = self.upstream_get("context") try: back_channel_logout_uri = cinfo["backchannel_logout_uri"] @@ -135,7 +135,10 @@ def do_back_channel_logout(self, cinfo, sid): except KeyError: alg = _context.provider_info["id_token_signing_alg_values_supported"][0] - _jws = JWT(_context.keyjar, iss=_context.issuer, lifetime=86400, sign_alg=alg) + _jws = JWT(self.upstream_get('attribute', 'keyjar'), + iss=_context.issuer, + lifetime=86400, + sign_alg=alg) _jws.with_jti = True _logout_token = _jws.pack(payload=payload, recv=cinfo["client_id"]) @@ -143,12 +146,12 @@ def do_back_channel_logout(self, cinfo, sid): def clean_sessions(self, usids): # Revoke all sessions - _context = self.server_get("context") + _context = self.upstream_get("context") for sid in usids: _context.session_manager.revoke_client_session(sid) def logout_all_clients(self, sid): - _context = self.server_get("context") + _context = self.upstream_get("context") _mngr = _context.session_manager _session_info = _mngr.get_session_info(sid) @@ -217,14 +220,14 @@ def unpack_signed_jwt(self, sjwt, sig_alg=""): else: alg = self.kwargs["signing_alg"] - sign_keys = self.server_get("context").keyjar.get_signing_key(alg2keytype(alg)) + sign_keys = self.upstream_get('attribute', 'keyjar').get_signing_key(alg2keytype(alg)) _info = _jwt.verify_compact(keys=sign_keys, sigalg=alg) return _info else: raise ValueError("Not a signed JWT") def logout_from_client(self, sid): - _context = self.server_get("context") + _context = self.upstream_get("context") _cdb = _context.cdb _session_information = _context.session_manager.get_session_info(sid, grant=True) _client_id = _session_information["client_id"] @@ -257,7 +260,7 @@ def process_request( :param kwargs: :return: """ - _context = self.server_get("context") + _context = self.upstream_get("context") _mngr = _context.session_manager if "post_logout_redirect_uri" in request: @@ -338,7 +341,7 @@ def process_request( logger.debug("JWS payload: {}".format(payload)) # From me to me _jws = JWT( - _context.keyjar, + self.upstream_get('attribute', 'keyjar'), iss=_context.issuer, lifetime=86400, sign_alg=self.kwargs["signing_alg"], @@ -371,9 +374,9 @@ def parse_request(self, request, http_info=None, **kwargs): request["access_token"] = auth_info["token"] if isinstance(request, dict): - _context = self.server_get("context") + _context = self.upstream_get("context") request = self.request_cls(**request) - if not request.verify(keyjar=_context.keyjar, sigalg=""): + if not request.verify(keyjar=self.upstream_get('attribute', 'keyjar'), sigalg=""): raise InvalidRequest("Request didn't verify") # id_token_signing_alg_values_supported try: @@ -398,7 +401,7 @@ def do_verified_logout(self, sid, alla=False, **kwargs): bcl = _res.get("blu") if bcl: - _context = self.server_get("context") + _context = self.upstream_get("context") # take care of Back channel logout first for _cid, spec in bcl.items(): _url, sjwt = spec @@ -421,7 +424,7 @@ def do_verified_logout(self, sid, alla=False, **kwargs): return _res["flu"].values() if _res.get("flu") else [] def kill_cookies(self): - _context = self.server_get("context") + _context = self.upstream_get("context") _handler = _context.cookie_handler session_mngmnt = _handler.make_cookie_content( value="", name=_handler.name["session_management"], max_age=-1 diff --git a/src/idpyoidc/server/oidc/token_helper.py b/src/idpyoidc/server/oidc/token_helper.py index 0a003186..5b8e73f6 100755 --- a/src/idpyoidc/server/oidc/token_helper.py +++ b/src/idpyoidc/server/oidc/token_helper.py @@ -43,7 +43,7 @@ def process_request(self, req: Union[Message, dict], **kwargs): :param kwargs: :return: """ - _context = self.endpoint.server_get("context") + _context = self.endpoint.upstream_get("context") _mngr = _context.session_manager logger.debug("OIDC Access Token") @@ -120,9 +120,9 @@ def process_request(self, req: Union[Message, dict], **kwargs): _response["expires_in"] = token.expires_at - utc_time_sans_frac() if ( - issue_refresh - and "refresh_token" in _supports_minting - and "refresh_token" in grant_types_supported + issue_refresh + and "refresh_token" in _supports_minting + and "refresh_token" in grant_types_supported ): try: refresh_token = self._mint_token( @@ -165,7 +165,7 @@ def process_request(self, req: Union[Message, dict], **kwargs): return _response def post_parse_request( - self, request: Union[Message, dict], client_id: Optional[str] = "", **kwargs + self, request: Union[Message, dict], client_id: Optional[str] = "", **kwargs ) -> Union[Message, dict]: """ This is where clients come to get their access tokens @@ -175,7 +175,7 @@ def post_parse_request( :returns: """ - _mngr = self.endpoint.server_get("context").session_manager + _mngr = self.endpoint.upstream_get("context").session_manager try: _session_info = _mngr.get_session_info_by_token( request["code"], grant=True, handler_key="authorization_code" @@ -209,7 +209,7 @@ def post_parse_request( class RefreshTokenHelper(TokenEndpointHelper): def process_request(self, req: Union[Message, dict], **kwargs): - _context = self.endpoint.server_get("context") + _context = self.endpoint.upstream_get("context") _mngr = _context.session_manager if req["grant_type"] != "refresh_token": @@ -301,9 +301,9 @@ def process_request(self, req: Union[Message, dict], **kwargs): token.register_usage() if ( - "client_id" in req - and req["client_id"] in _context.cdb - and "revoke_refresh_on_issue" in _context.cdb[req["client_id"]] + "client_id" in req + and req["client_id"] in _context.cdb + and "revoke_refresh_on_issue" in _context.cdb[req["client_id"]] ): revoke_refresh = _context.cdb[req["client_id"]].get("revoke_refresh_on_issue") else: @@ -315,7 +315,10 @@ def process_request(self, req: Union[Message, dict], **kwargs): return _resp def post_parse_request( - self, request: Union[Message, dict], client_id: Optional[str] = "", **kwargs + self, + request: Union[Message, dict], + client_id: Optional[str] = "", + **kwargs ): """ This is where clients come to refresh their access tokens @@ -326,13 +329,10 @@ def post_parse_request( """ request = RefreshAccessTokenRequest(**request.to_dict()) - _context = self.endpoint.server_get("context") - try: - keyjar = _context.keyjar - except AttributeError: - keyjar = "" + _context = self.endpoint.upstream_get("context") - request.verify(keyjar=keyjar, opponent_id=client_id) + request.verify(keyjar=self.endpoint.upstream_get('attribute', 'keyjar'), + opponent_id=client_id) _mngr = _context.session_manager try: diff --git a/src/idpyoidc/server/oidc/userinfo.py b/src/idpyoidc/server/oidc/userinfo.py index 9bdb7ce3..78a593bb 100755 --- a/src/idpyoidc/server/oidc/userinfo.py +++ b/src/idpyoidc/server/oidc/userinfo.py @@ -36,18 +36,18 @@ class UserInfo(Endpoint): "userinfo_encryption_enc_values_supported": work_environment.get_encryption_encs, } - def __init__(self, server_get: Callable, add_claims_by_scope: Optional[bool] = True, **kwargs): + def __init__(self, upstream_get: Callable, add_claims_by_scope: Optional[bool] = True, **kwargs): Endpoint.__init__( self, - server_get, + upstream_get, add_claims_by_scope=add_claims_by_scope, **kwargs, ) # Add the issuer ID as an allowed JWT target self.allowed_targets.append("") - def get_client_id_from_token(self, endpoint_context, token, request=None): - _info = endpoint_context.session_manager.get_session_info_by_token( + def get_client_id_from_token(self, context, token, request=None): + _info = context.session_manager.get_session_info_by_token( token, handler_key="access_token" ) return _info["client_id"] @@ -63,7 +63,7 @@ def do_response( if "error" in kwargs and kwargs["error"]: return Endpoint.do_response(self, response_args, request, **kwargs) - _context = self.server_get("context") + _context = self.upstream_get("context") if not client_id: raise MissingValue("client_id") @@ -88,7 +88,7 @@ def do_response( if encrypt or sign: _jwt = JWT( - _context.keyjar, + self.upstream_get('attribute', 'keyjar'), iss=_context.issuer, sign=sign, sign_alg=sign_alg, @@ -112,7 +112,7 @@ def do_response( return {"response": resp, "http_headers": http_headers} def process_request(self, request=None, **kwargs): - _mngr = self.server_get("context").session_manager + _mngr = self.upstream_get("context").session_manager try: _session_info = _mngr.get_session_info_by_token( request["access_token"], grant=True, handler_key="access_token" @@ -147,7 +147,7 @@ def process_request(self, request=None, **kwargs): # pass if allowed: - _cntxt = self.server_get("context") + _cntxt = self.upstream_get("context") _claims_restriction = _cntxt.claims_interface.get_claims( _session_info["branch_id"], scopes=token.scope, claims_release_point="userinfo" ) diff --git a/src/idpyoidc/server/scopes.py b/src/idpyoidc/server/scopes.py index 9aab827a..2e2bed27 100644 --- a/src/idpyoidc/server/scopes.py +++ b/src/idpyoidc/server/scopes.py @@ -49,8 +49,8 @@ def convert_scopes2claims(scopes, allowed_claims=None, scope2claim_map=None): class Scopes: - def __init__(self, server_get, allowed_scopes=None, scopes_to_claims=None): - self.server_get = server_get + def __init__(self, upstream_get, allowed_scopes=None, scopes_to_claims=None): + self.upstream_get = upstream_get if not scopes_to_claims: scopes_to_claims = dict(SCOPE2CLAIMS) self._scopes_to_claims = scopes_to_claims @@ -65,7 +65,7 @@ def get_allowed_scopes(self, client_id=None): """ allowed_scopes = self.allowed_scopes if client_id: - client = self.server_get("context").cdb.get(client_id) + client = self.upstream_get("context").cdb.get(client_id) if client is not None: allowed_scopes = client.get("allowed_scopes", allowed_scopes) return allowed_scopes @@ -79,7 +79,7 @@ def get_scopes_mapping(self, client_id=None): """ scopes_to_claims = self._scopes_to_claims if client_id: - client = self.server_get("context").cdb.get(client_id) + client = self.upstream_get("context").cdb.get(client_id) if client is not None: scopes_to_claims = client.get("scopes_to_claims", scopes_to_claims) return scopes_to_claims diff --git a/src/idpyoidc/server/session/claims.py b/src/idpyoidc/server/session/claims.py index c0a9d263..179ce4ca 100755 --- a/src/idpyoidc/server/session/claims.py +++ b/src/idpyoidc/server/session/claims.py @@ -14,8 +14,8 @@ STANDARD_CLAIMS = [c for c in OpenIDSchema.c_param.keys() if c not in IGNORE] -def available_claims(endpoint_context): - _supported = endpoint_context.provider_info.get("claims_supported") +def available_claims(context): + _supported = context.provider_info.get("claims_supported") if _supported: return _supported else: @@ -26,8 +26,8 @@ class ClaimsInterface: init_args = {"add_claims_by_scope": False, "enable_claims_per_client": False} claims_release_points = ["userinfo", "introspection", "id_token", "access_token"] - def __init__(self, server_get): - self.server_get = server_get + def __init__(self, upstream_get): + self.upstream_get = upstream_get def authorization_request_claims( self, @@ -39,20 +39,20 @@ def authorization_request_claims( return {} - def _get_module(self, usage, endpoint_context): + def _get_module(self, usage, context): module = None if usage == "userinfo": - module = self.server_get("endpoint", "userinfo") + module = self.upstream_get("endpoint", "userinfo") elif usage == "id_token": try: - module = endpoint_context.session_manager.token_handler["id_token"] + module = context.session_manager.token_handler["id_token"] except KeyError: raise ServiceError("No support for ID Tokens") elif usage == "introspection": - module = self.server_get("endpoint", "introspection") + module = self.upstream_get("endpoint", "introspection") elif usage == "access_token": try: - module = endpoint_context.session_manager.token_handler["access_token"] + module = context.session_manager.token_handler["access_token"] except KeyError: raise ServiceError("No support for Access Tokens") @@ -65,7 +65,7 @@ def _client_claims( claims_release_point: str, secondary_identifier: Optional[str] = "", ): - _context = self.server_get("context") + _context = self.upstream_get("context") add_claims_by_scope = _context.cdb[client_id].get("add_claims", {}).get("by_scope", {}) if add_claims_by_scope: _claims_by_scope = add_claims_by_scope.get(claims_release_point) @@ -93,7 +93,7 @@ def get_claims_from_request( client_id: str = None, secondary_identifier: str = "", ) -> dict: - _context = self.server_get("context") + _context = self.upstream_get("context") # which endpoint module configuration to get the base claims from module = self._get_module(claims_release_point, _context) @@ -159,7 +159,7 @@ def get_claims( "userinfo"/"id_token"/"introspection"/"access_token" :return: Claims specification as a dictionary. """ - _context = self.server_get("context") + _context = self.upstream_get("context") session_info = _context.session_manager.get_session_info(session_id, grant=True) client_id = session_info["client_id"] grant = session_info["grant"] @@ -189,7 +189,7 @@ def get_claims_all_usage_from_request( return _claims def get_claims_all_usage(self, session_id: str, scopes: str) -> dict: - grant = self.server_get("context").session_manager.get_grant(session_id) + grant = self.upstream_get("context").session_manager.get_grant(session_id) if grant.authorization_request: auth_req = grant.authorization_request else: @@ -203,7 +203,7 @@ def get_user_claims(self, user_id: str, claims_restriction: dict) -> dict: :param claims_restriction: Specifies the upper limit of which claims can be returned :return: """ - meth = self.server_get("context").userinfo + meth = self.upstream_get("context").userinfo if not meth: raise ImproperlyConfigured("userinfo MUST be defined in the configuration") if claims_restriction: @@ -273,13 +273,13 @@ def by_schema(cls, **kwa): class OAuth2ClaimsInterface(ClaimsInterface): claims_release_points = ["introspection", "access_token"] - def _get_module(self, usage, endpoint_context): + def _get_module(self, usage, context): module = None if usage == "introspection": - module = self.server_get("endpoint", "introspection") + module = self.upstream_get("endpoint", "introspection") elif usage == "access_token": try: - module = endpoint_context.session_manager.token_handler["access_token"] + module = context.session_manager.token_handler["access_token"] except KeyError: raise ServiceError("No support for Access Tokens") diff --git a/src/idpyoidc/server/session/grant.py b/src/idpyoidc/server/session/grant.py index 761991fe..a191be45 100644 --- a/src/idpyoidc/server/session/grant.py +++ b/src/idpyoidc/server/session/grant.py @@ -181,7 +181,7 @@ def add_acr_value(self, claims_release_point): def payload_arguments( self, session_id: str, - endpoint_context, + context, item: SessionToken, claims_release_point: str, scope: Optional[dict] = None, @@ -191,7 +191,7 @@ def payload_arguments( """ :param session_id: Session ID - :param endpoint_context: EndPoint Context + :param context: EndPoint Context :param item: A SessionToken instance :param claims_release_point: One of "userinfo", "introspection", "id_token", "access_token" :param extra_payload: @@ -231,16 +231,16 @@ def payload_arguments( if item.claims: _claims_restriction = item.claims else: - _claims_restriction = endpoint_context.claims_interface.get_claims( + _claims_restriction = context.claims_interface.get_claims( session_id, scopes=payload["scope"], claims_release_point=claims_release_point, secondary_identifier=secondary_identifier, ) - if endpoint_context.session_manager.node_type[0] == "user": - user_id, _, _ = endpoint_context.session_manager.decrypt_branch_id(session_id) - user_info = endpoint_context.claims_interface.get_user_claims(user_id, + if context.session_manager.node_type[0] == "user": + user_id, _, _ = context.session_manager.decrypt_branch_id(session_id) + user_info = context.claims_interface.get_user_claims(user_id, _claims_restriction) payload.update(user_info) @@ -255,7 +255,7 @@ def payload_arguments( def mint_token( self, session_id: str, - endpoint_context: object, + context: object, token_class: str, token_handler: TokenHandler = None, based_on: Optional[SessionToken] = None, @@ -270,7 +270,7 @@ def mint_token( """ :param session_id: - :param endpoint_context: + :param context: :param token_type: :param token_handler: :param based_on: @@ -343,9 +343,9 @@ def mint_token( **class_args, ) if token_handler is None: - token_handler = endpoint_context.session_manager.token_handler.handler[token_class] + token_handler = context.session_manager.token_handler.handler[token_class] - if token_class in endpoint_context.claims_interface.claims_release_points: + if token_class in context.claims_interface.claims_release_points: claims_release_point = token_class else: claims_release_point = "" @@ -361,7 +361,7 @@ def mint_token( token_payload = self.payload_arguments( session_id, - endpoint_context, + context, item=item, claims_release_point=claims_release_point, scope=scope, @@ -455,19 +455,19 @@ def last_issued_token_of_type(self, token_class): } -def get_usage_rules(token_type, endpoint_context, grant, client_id): +def get_usage_rules(token_type, context, grant, client_id): """ The order of importance: Grant, Client, EndPointContext, Default :param token_type: The type of token - :param endpoint_context: An EndpointContext instance + :param context: An EndpointContext instance :param grant: A Grant instance :param client_id: The client identifier :return: Usage specification """ - _usage = endpoint_context.authz.usage_rules_for(client_id, token_type) + _usage = context.authz.usage_rules_for(client_id, token_type) if not _usage: _usage = DEFAULT_USAGE[token_type] diff --git a/src/idpyoidc/server/session/manager.py b/src/idpyoidc/server/session/manager.py index 42bf5a5d..2e2338ca 100644 --- a/src/idpyoidc/server/session/manager.py +++ b/src/idpyoidc/server/session/manager.py @@ -544,6 +544,6 @@ def unpack_session_key(self, key): return self.unpack_branch_key(key) -def create_session_manager(server_get, token_handler_args, sub_func=None, conf=None): - _token_handler = handler.factory(server_get, **token_handler_args) +def create_session_manager(upstream_get, token_handler_args, sub_func=None, conf=None): + _token_handler = handler.factory(upstream_get, **token_handler_args) return SessionManager(_token_handler, sub_func=sub_func, conf=conf) diff --git a/src/idpyoidc/server/token/handler.py b/src/idpyoidc/server/token/handler.py index ea52844a..4cebeacc 100755 --- a/src/idpyoidc/server/token/handler.py +++ b/src/idpyoidc/server/token/handler.py @@ -83,7 +83,7 @@ def keys(self): return self.handler.keys() -def init_token_handler(server_get, spec, token_class): +def init_token_handler(upstream_get, spec, token_class): _kwargs = spec.get("kwargs", {}) _lt = spec.get("lifetime") @@ -109,7 +109,7 @@ def init_token_handler(server_get, spec, token_class): ) _kwargs = spec - return cls(token_class=token_class, server_get=server_get, **_kwargs) + return cls(token_class=token_class, upstream_get=upstream_get, **_kwargs) def _add_passwd(keyjar, conf, kid): @@ -142,7 +142,7 @@ def default_token(spec): def factory( - server_get, + upstream_get, code: Optional[dict] = None, token: Optional[dict] = None, refresh: Optional[dict] = None, @@ -169,7 +169,7 @@ def factory( key_defs = [] read_only = False - cwd = server_get("context").cwd + cwd = upstream_get("context").cwd if kwargs.get("jwks_def"): defs = kwargs["jwks_def"] if not jwks_file: @@ -195,9 +195,9 @@ def factory( if default_token(cnf): if kj: _add_passwd(kj, cnf, cls) - args[attr] = init_token_handler(server_get, cnf, token_class_map[cls]) + args[attr] = init_token_handler(upstream_get, cnf, token_class_map[cls]) if id_token is not None: - args["id_token"] = init_token_handler(server_get, id_token, token_class="") + args["id_token"] = init_token_handler(upstream_get, id_token, token_class="") return TokenHandler(**args) diff --git a/src/idpyoidc/server/token/id_token.py b/src/idpyoidc/server/token/id_token.py index bc38850e..f7e8f652 100755 --- a/src/idpyoidc/server/token/id_token.py +++ b/src/idpyoidc/server/token/id_token.py @@ -27,15 +27,15 @@ } -def include_session_id(endpoint_context, client_id, where): +def include_session_id(context, client_id, where): """ - :param endpoint_context: + :param context: :param client_id: :param where: front or back :return: """ - _pinfo = endpoint_context.provider_info + _pinfo = context.provider_info # Am the OP supposed to support {dir}-channel log out and if so can # it pass sid in logout token and ID Token @@ -50,7 +50,7 @@ def include_session_id(endpoint_context, client_id, where): # Does the client support back-channel logout ? try: - endpoint_context.cdb[client_id]["{}channel_logout_uri".format(where)] + context.cdb[client_id]["{}channel_logout_uri".format(where)] except KeyError: return False @@ -58,7 +58,7 @@ def include_session_id(endpoint_context, client_id, where): def get_sign_and_encrypt_algorithms( - endpoint_context, client_info, payload_type, sign=False, encrypt=False + context, client_info, payload_type, sign=False, encrypt=False ): args = {"sign": sign, "encrypt": encrypt} if sign: @@ -66,10 +66,10 @@ def get_sign_and_encrypt_algorithms( args["sign_alg"] = client_info["{}_signed_response_alg".format(payload_type)] except KeyError: # Fall back to default try: - args["sign_alg"] = endpoint_context.jwx_def["signing_alg"][payload_type] + args["sign_alg"] = context.jwx_def["signing_alg"][payload_type] except KeyError: _def_sign_alg = DEF_SIGN_ALG[payload_type] - _supported = endpoint_context.provider_info.get( + _supported = context.provider_info.get( "{}_signing_alg_values_supported".format(payload_type) ) @@ -86,9 +86,9 @@ def get_sign_and_encrypt_algorithms( args["enc_alg"] = client_info["%s_encrypted_response_alg" % payload_type] except KeyError: try: - args["enc_alg"] = endpoint_context.jwx_def["encryption_alg"][payload_type] + args["enc_alg"] = context.jwx_def["encryption_alg"][payload_type] except KeyError: - _supported = endpoint_context.provider_info.get( + _supported = context.provider_info.get( "{}_encryption_alg_values_supported".format(payload_type) ) if _supported: @@ -98,9 +98,9 @@ def get_sign_and_encrypt_algorithms( args["enc_enc"] = client_info["%s_encrypted_response_enc" % payload_type] except KeyError: try: - args["enc_enc"] = endpoint_context.jwx_def["encryption_enc"][payload_type] + args["enc_enc"] = context.jwx_def["encryption_enc"][payload_type] except KeyError: - _supported = endpoint_context.provider_info.get( + _supported = context.provider_info.get( "{}_encryption_enc_values_supported".format(payload_type) ) if _supported: @@ -120,12 +120,12 @@ def __init__( self, token_class: Optional[str] = "id_token", lifetime: Optional[int] = 300, - server_get: Callable = None, + upstream_get: Callable = None, **kwargs, ): Token.__init__(self, token_class, **kwargs) self.lifetime = lifetime - self.server_get = server_get + self.upstream_get = upstream_get self.kwargs = kwargs self.scope_to_claims = None self.provider_info = construct_provider_info(self.default_capabilities, **kwargs) @@ -150,7 +150,7 @@ def payload( :return: IDToken instance """ - _context = self.server_get("context") + _context = self.upstream_get("context") _mngr = _context.session_manager session_information = _mngr.get_session_info(session_id, grant=True) grant = session_information["grant"] @@ -236,7 +236,7 @@ def sign_encrypt( :return: IDToken as a signed and/or encrypted JWT """ - _context = self.server_get("context") + _context = self.upstream_get("context") client_info = _context.cdb[client_id] alg_dict = get_sign_and_encrypt_algorithms( @@ -255,7 +255,11 @@ def sign_encrypt( if lifetime is None: lifetime = self.lifetime - _jwt = JWT(_context.keyjar, iss=_context.issuer, lifetime=lifetime, **alg_dict) + _jwt = JWT( + self.upstream_get('attribute', 'keyjar'), + iss=_context.issuer, + lifetime=lifetime, + **alg_dict) return _jwt.pack(_payload, recv=client_id) @@ -269,7 +273,7 @@ def __call__( usage_rules: Optional[dict] = None, **kwargs, ) -> str: - _context = self.server_get("context") + _context = self.upstream_get("context") user_id, client_id, grant_id = _context.session_manager.decrypt_session_id(session_id) @@ -307,7 +311,7 @@ def info(self, token): :return: tuple of token type and session id """ - _context = self.server_get("context") + _context = self.upstream_get("context") _jwt = factory(token) if not _jwt: @@ -318,7 +322,9 @@ def info(self, token): client_info = _context.cdb[client_id] alg_dict = get_sign_and_encrypt_algorithms(_context, client_info, "id_token", sign=True) - verifier = JWT(key_jar=_context.keyjar, allowed_sign_algs=alg_dict["sign_alg"]) + verifier = JWT( + key_jar=self.upstream_get('attribute', 'keyjar'), + allowed_sign_algs=alg_dict["sign_alg"]) try: _payload = verifier.unpack(token) except JWSException: diff --git a/src/idpyoidc/server/token/jwt_token.py b/src/idpyoidc/server/token/jwt_token.py index 010cc703..5ad7264b 100644 --- a/src/idpyoidc/server/token/jwt_token.py +++ b/src/idpyoidc/server/token/jwt_token.py @@ -19,29 +19,29 @@ class JWTToken(Token): def __init__( - self, - token_class, - # keyjar: KeyJar = None, - issuer: str = None, - aud: Optional[list] = None, - alg: str = "ES256", - lifetime: int = DEFAULT_TOKEN_LIFETIME, - server_get: Callable = None, - token_type: str = "Bearer", - profile: Optional[Union[Message, str]] = JWTAccessToken, - with_jti: Optional[bool] = False, - **kwargs + self, + token_class, + # keyjar: KeyJar = None, + issuer: str = None, + aud: Optional[list] = None, + alg: str = "ES256", + lifetime: int = DEFAULT_TOKEN_LIFETIME, + upstream_get: Callable = None, + token_type: str = "Bearer", + profile: Optional[Union[Message, str]] = JWTAccessToken, + with_jti: Optional[bool] = False, + **kwargs ): Token.__init__(self, token_class, **kwargs) self.token_type = token_type self.lifetime = lifetime self.kwargs = kwargs - _context = server_get("context") - # self.key_jar = keyjar or _context.keyjar + _context = upstream_get("context") + # self.key_jar = keyjar or upstream_get('attribute','keyjar') self.issuer = issuer or _context.issuer self.cdb = _context.cdb - self.server_get = server_get + self.upstream_get = upstream_get self.def_aud = aud or [] self.alg = alg @@ -85,13 +85,13 @@ def __call__( payload = self.load_custom_claims(payload) # payload.update(kwargs) - _context = self.server_get("context") + _context = self.upstream_get("context") if usage_rules and "expires_in" in usage_rules: lifetime = usage_rules.get("expires_in") else: lifetime = self.lifetime signer = JWT( - key_jar=_context.keyjar, + key_jar=self.upstream_get('attribute','keyjar'), iss=self.issuer, lifetime=lifetime, sign_alg=self.alg, @@ -112,8 +112,9 @@ def __call__( return signer.pack(payload) def get_payload(self, token): - _context = self.server_get("context") - verifier = JWT(key_jar=_context.keyjar, allowed_sign_algs=[self.alg]) + _context = self.upstream_get("context") + verifier = JWT(key_jar=self.upstream_get('attribute','keyjar'), + allowed_sign_algs=[self.alg]) try: _payload = verifier.unpack(token) except JWSException: diff --git a/src/idpyoidc/server/user_authn/authn_context.py b/src/idpyoidc/server/user_authn/authn_context.py index b52dcb82..08ab8ffe 100755 --- a/src/idpyoidc/server/user_authn/authn_context.py +++ b/src/idpyoidc/server/user_authn/authn_context.py @@ -120,7 +120,7 @@ def _acr_claim(request): return None -def pick_auth(endpoint_context, areq, pick_all=False): +def pick_auth(context, areq, pick_all=False): """ Pick authentication method @@ -128,8 +128,8 @@ def pick_auth(endpoint_context, areq, pick_all=False): :return: A dictionary with the authentication method and its authn class ref """ acrs = [] - if len(endpoint_context.authn_broker) == 1: - return endpoint_context.authn_broker.default() + if len(context.authn_broker) == 1: + return context.authn_broker.default() if "acr_values" in areq: if not isinstance(areq["acr_values"], list): @@ -144,14 +144,14 @@ def pick_auth(endpoint_context, areq, pick_all=False): if _ith.get("acr"): acrs = [_ith["acr"]] else: - if areq.get("login_hint") and endpoint_context.login_hint2acrs: - acrs = endpoint_context.login_hint2acrs(areq["login_hint"]) + if areq.get("login_hint") and context.login_hint2acrs: + acrs = context.login_hint2acrs(areq["login_hint"]) if not acrs: - return endpoint_context.authn_broker.default() + return context.authn_broker.default() for acr in acrs: - res = endpoint_context.authn_broker.pick(acr) + res = context.authn_broker.pick(acr) logger.debug(f"Picked AuthN broker for ACR {str(acr)}: {str(res)}") if res: return res if pick_all else res[0] @@ -159,7 +159,7 @@ def pick_auth(endpoint_context, areq, pick_all=False): return None -def init_method(authn_spec, server_get, template_handler=None): +def init_method(authn_spec, upstream_get, template_handler=None): try: _args = authn_spec["kwargs"] except KeyError: @@ -168,25 +168,25 @@ def init_method(authn_spec, server_get, template_handler=None): if "template" in _args: _args["template_handler"] = template_handler - _args["server_get"] = server_get + _args["upstream_get"] = upstream_get args = {"method": instantiate(authn_spec["class"], **_args)} args.update({k: v for k, v in authn_spec.items() if k not in ["class", "kwargs"]}) return args -def populate_authn_broker(methods, server_get, template_handler=None): +def populate_authn_broker(methods, upstream_get, template_handler=None): """ :param methods: Authentication method specifications - :param server_get: method that returns things from server + :param upstream_get: method that returns things from server :param template_handler: A class used to render templates :return: """ authn_broker = AuthnBroker() for id, authn_spec in methods.items(): - args = init_method(authn_spec, server_get, template_handler) + args = init_method(authn_spec, upstream_get, template_handler) authn_broker[id] = args return authn_broker diff --git a/src/idpyoidc/server/user_authn/user.py b/src/idpyoidc/server/user_authn/user.py index 0706ee98..b1dd008c 100755 --- a/src/idpyoidc/server/user_authn/user.py +++ b/src/idpyoidc/server/user_authn/user.py @@ -47,9 +47,9 @@ class UserAuthnMethod(object): url_endpoint = "/verify" FAILED_AUTHN = (None, True) - def __init__(self, server_get=None, **kwargs): + def __init__(self, upstream_get=None, **kwargs): self.query_param = "upm_answer" - self.server_get = server_get + self.upstream_get = upstream_get self.kwargs = kwargs def __call__(self, **kwargs): @@ -90,7 +90,7 @@ def verify(self, *args, **kwargs): raise NotImplementedError def unpack_token(self, token): - return verify_signed_jwt(token=token, keyjar=self.server_get("context").keyjar) + return verify_signed_jwt(token=token, keyjar=self.upstream_get("context").keyjar) def done(self, areq): """ @@ -106,7 +106,7 @@ def done(self, areq): return False def cookie_info(self, cookie: List[dict], client_id: str) -> dict: - _context = self.server_get("context") + _context = self.upstream_get("context") logger.debug("Value cookies: {}".format(cookie)) if cookie is None: @@ -157,12 +157,12 @@ def __init__( db, template_handler, template="user_pass.jinja2", - server_get=None, + upstream_get=None, verify_endpoint="", **kwargs, ): - super(UserPassJinja2, self).__init__(server_get=server_get) + super(UserPassJinja2, self).__init__(upstream_get=upstream_get) self.template_handler = template_handler self.template = template @@ -190,12 +190,13 @@ def __call__(self, **kwargs): ), OnlyForTestingWarning, ) - if not self.server_get: - raise Exception(f"{self.__class__.__name__} doesn't have a working server_get") - _context = self.server_get("context") + if not self.upstream_get: + raise Exception(f"{self.__class__.__name__} doesn't have a working upstream_get") + _context = self.upstream_get("context") + _keyjar = self.upstream_get("attribute", 'keyjar') # Stores information need afterwards in a signed JWT that then # appears as a hidden input in the form - jws = create_signed_jwt(_context.issuer, _context.keyjar, **kwargs) + jws = create_signed_jwt(_context.issuer, _keyjar, **kwargs) _kwargs = self.kwargs.copy() for attr in ["policy", "tos", "logo"]: _uri = "{}_uri".format(attr) @@ -218,8 +219,8 @@ def verify(self, *args, **kwargs): class BasicAuthn(UserAuthnMethod): - def __init__(self, pwd, ttl=5, server_get=None): - UserAuthnMethod.__init__(self, server_get=server_get) + def __init__(self, pwd, ttl=5, upstream_get=None): + UserAuthnMethod.__init__(self, upstream_get=upstream_get) self.passwd = pwd self.ttl = ttl @@ -250,8 +251,8 @@ def authenticated_as(self, client_id, cookie=None, authorization="", **kwargs): class SymKeyAuthn(UserAuthnMethod): # user authentication using a token - def __init__(self, ttl, symkey, server_get=None): - UserAuthnMethod.__init__(self, server_get=server_get) + def __init__(self, ttl, symkey, upstream_get=None): + UserAuthnMethod.__init__(self, upstream_get=upstream_get) if symkey is not None and symkey == "": msg = "SymKeyAuthn.symkey cannot be an empty value" @@ -284,8 +285,8 @@ def authenticated_as(self, client_id, cookie=None, authorization="", **kwargs): class NoAuthn(UserAuthnMethod): # Just for testing allows anyone it without authentication - def __init__(self, user, server_get=None): - UserAuthnMethod.__init__(self, server_get=server_get) + def __init__(self, user, upstream_get=None): + UserAuthnMethod.__init__(self, upstream_get=upstream_get) self.user = user self.fail = None diff --git a/src/idpyoidc/server/util.py b/src/idpyoidc/server/util.py index 3e00d43f..eea2579d 100755 --- a/src/idpyoidc/server/util.py +++ b/src/idpyoidc/server/util.py @@ -9,7 +9,7 @@ OAUTH2_NOCACHE_HEADERS = [("Pragma", "no-cache"), ("Cache-Control", "no-store")] -def build_endpoints(conf, server_get, issuer): +def build_endpoints(conf, upstream_get, issuer): """ conf typically contains:: @@ -22,7 +22,7 @@ def build_endpoints(conf, server_get, issuer): This function uses class and kwargs to instantiate a class instance with kwargs. :param conf: - :param server_get: Callback function + :param upstream_get: Callback function :param issuer: :return: """ @@ -38,9 +38,9 @@ def build_endpoints(conf, server_get, issuer): # class can be a string (class path) or a class reference if isinstance(spec["class"], str): - _instance = importer(spec["class"])(server_get=server_get, **kwargs) + _instance = importer(spec["class"])(upstream_get=upstream_get, **kwargs) else: - _instance = spec["class"](server_get=server_get, **kwargs) + _instance = spec["class"](upstream_get=upstream_get, **kwargs) try: _path = spec["path"] @@ -121,15 +121,15 @@ def get_http_params(config): return params -def allow_refresh_token(endpoint_context): +def allow_refresh_token(context): # Are there a refresh_token handler - refresh_token_handler = endpoint_context.session_manager.token_handler.handler.get( + refresh_token_handler = context.session_manager.token_handler.handler.get( "refresh_token" ) # Is refresh_token grant type supported _token_supported = False - _supported = endpoint_context.get_preference("grant_types_supported") + _supported = context.get_preference("grant_types_supported") if _supported: if "refresh_token" in _supported: # self.allow_refresh = kwargs.get("allow_refresh", True) @@ -187,7 +187,7 @@ def execute(spec, **kwargs): # return urlunsplit((scheme, hostname, "", "", "")) -# def get_logout_id(endpoint_context, user_id, client_id): +# def get_logout_id(context, user_id, client_id): # _item = NodeInfo() # _item.user_id = user_id # _item.client_id = client_id @@ -196,7 +196,7 @@ def execute(spec, **kwargs): # # It must be possible to map from one to the other. # logout_session_id = uuid.uuid4().hex # # Store the map -# _mngr = endpoint_context.session_manager +# _mngr = context.session_manager # _mngr.set([logout_session_id], _item) # # return logout_session_id diff --git a/tests/request123456.jwt b/tests/request123456.jwt index 5fead1d0..7a9db987 100644 --- a/tests/request123456.jwt +++ b/tests/request123456.jwt @@ -1 +1 @@ -eyJhbGciOiJSUzI1NiIsImtpZCI6IlNIRXlZV2N3TlZrMExUZFJPVFp6WjJGVVduZElWWGRhY2sweFdVTTVTRXB3Y1MwM2RWVXhXVTR6UlEifQ.eyJyZXNwb25zZV90eXBlIjogImNvZGUiLCAic3RhdGUiOiAic3RhdGUiLCAicmVkaXJlY3RfdXJpIjogImh0dHBzOi8vZXhhbXBsZS5jb20vY2xpL2F1dGh6X2NiIiwgInNjb3BlIjogIm9wZW5pZCIsICJub25jZSI6ICJvQzlvLUIwZ2dJZzRVeFgxQ0ZEN0hOVFpOTnplZUlSWjh2azZzZTZMR213IiwgImNsaWVudF9pZCI6ICJjbGllbnRfaWQiLCAiaXNzIjogImNsaWVudF9pZCIsICJpYXQiOiAxNjY5NzM0MDAxLCAiYXVkIjogWyJodHRwczovL2V4YW1wbGUuY29tIl19.adpuPLhsRTs5__3vLjHMPn1nlFYXHq6imhQ6ZAyF5XAwp0TCTNd7ZP6gFtiR-iGOsLFJCbrDyCuGC8opB3c3ETHXVgVnMoE_KzwwwFw20PL4zkq_B0lYgp1PEi9nb7q9a0qVQujb-hkRdq3B8ntaRxdaGnofQ1sP_DtbiqZDyNDbWnT3Wv7H-rLdotStTcZ7KGslQarZJHxN_0m1Mr7Ucon0VL267RciGf0x5pNQ_wQjj9T5uzVOKZMV7gfis_Gr1PlqDPwBDm_tH_10K49mn3exmx1dUlCN-Taw67yTR9Puqo3w6rQxImWusq00LxfHPtl4POZS7RKCQWEU2lQ_wA \ No newline at end of file +eyJhbGciOiJSUzI1NiIsImtpZCI6IlNIRXlZV2N3TlZrMExUZFJPVFp6WjJGVVduZElWWGRhY2sweFdVTTVTRXB3Y1MwM2RWVXhXVTR6UlEifQ.eyJyZXNwb25zZV90eXBlIjogImNvZGUiLCAic3RhdGUiOiAic3RhdGUiLCAicmVkaXJlY3RfdXJpIjogImh0dHBzOi8vZXhhbXBsZS5jb20vY2xpL2F1dGh6X2NiIiwgInNjb3BlIjogIm9wZW5pZCIsICJub25jZSI6ICJmMFNXNzRtbzFKSG1NbFAzZUVIWWhKcXZQTm1fZmxwYjBOcTJ0SXYzUXM0IiwgImNsaWVudF9pZCI6ICJjbGllbnRfaWQiLCAiaXNzIjogImNsaWVudF9pZCIsICJpYXQiOiAxNjY1ODIzODQ0LCAiYXVkIjogWyJodHRwczovL2V4YW1wbGUuY29tIl19.iJggLp-EJvEP4ARgGwCFIhLlwikTLV8EEd7D2PX-yW6H9rm261_l-NkKTKmfV_Y2-QLT1X3K0eepI_A1qVAzLzohFSw0OcPPJDRs9IugLxeZ0Ktr9pb29XcCHOU83DD3onIXTfzgihqX_aqUfPt32teD5NTTMmMGuaA700rtJiXzrPXWQmJXDVlStgtjFh4fZI59G3yPUNQqUTm0w_HHsF8IuzIPHFq5FTlixTaX3iu90dm9icXTJtLYxw5uHL7Je2_GxWTmCE9WEOzSI3AaQz-jIsG1RVVBx5WBRngHkcFPITuXCXklOKq_iFbFCcRL-Gt7SsDHqV_zrAm72LaIvg \ No newline at end of file diff --git a/tests/test_12_context.py b/tests/test_12_context.py new file mode 100644 index 00000000..6ee98240 --- /dev/null +++ b/tests/test_12_context.py @@ -0,0 +1,19 @@ +from idpyoidc.context import OidcContext + + +ENTITY_ID = 'https://example.com' + + +class TestDumpLoad(object): + def test_context_with_entity_id(self): + c = OidcContext({}, entity_id=ENTITY_ID) + mem = c.dump() + c2 = OidcContext().load(mem) + assert c2.issuer == ENTITY_ID + + def test_context_with_entity_id_and_keys(self): + c = OidcContext({"entity_id": ENTITY_ID}) + + mem = c.dump() + c2 = OidcContext().load(mem) + assert c2.entity_id == ENTITY_ID diff --git a/tests/test_client_02_entity.py b/tests/test_client_02_entity.py index b3152e55..7ea1199c 100644 --- a/tests/test_client_02_entity.py +++ b/tests/test_client_02_entity.py @@ -31,9 +31,6 @@ def test_get_service(self): assert _srv.service_name == "" assert _srv.request_body_type == "urlencoded" - _srv = self.entity.client_get("service", "") - assert _srv.service_name == "" - def test_get_service_unsupported(self): _srv = self.entity.get_service("foobar") assert _srv is None @@ -47,8 +44,6 @@ def test_get_service_by_endpoint_name(self): _srv.endpoint_name = "flux_endpoint" _fsrv = self.entity.get_service_by_endpoint_name("flux_endpoint") assert _srv == _fsrv - _fsrv2 = self.entity.client_get("service_by_endpoint_name", "flux_endpoint") - assert _fsrv == _fsrv2 def test_get_service_context(self): _context = self.entity.get_service_context() diff --git a/tests/test_client_04_service.py b/tests/test_client_04_service.py index 951e94dd..0660c1b1 100644 --- a/tests/test_client_04_service.py +++ b/tests/test_client_04_service.py @@ -41,8 +41,8 @@ def create_service(self): self.service_context = self.entity.get_service_context() self.service_context.map_supported_to_preferred() - def client_get(self, *args): - if args[0] == "service_context": + def upstream_get(self, *args): + if args[0] == "context": return self.service_context def test_1(self): @@ -112,7 +112,7 @@ def test_parse_response_json(self): self.service_context.issuer = "https://op.example.com/" self.service_context.client_id = "client" - _sign_key = self.service_context.keyjar.get_signing_key() + _sign_key = self.service.upstream_get('attribute','keyjar').get_signing_key() resp1 = AuthorizationResponse(code="auth_grant", state="state").to_json() arg = self.service.parse_response(resp1) assert isinstance(arg, AuthorizationResponse) @@ -124,7 +124,7 @@ def test_parse_response_jwt(self): self.service_context.issuer = "https://op.example.com/" self.service_context.client_id = "client" - _sign_key = self.service_context.keyjar.get_signing_key() + _sign_key = self.service.upstream_get('attribute','keyjar').get_signing_key() resp1 = AuthorizationResponse(code="auth_grant", state="state").to_jwt( key=_sign_key, algorithm="RS256" ) @@ -138,7 +138,7 @@ def test_parse_response_err(self): self.service_context.issuer = "https://op.example.com/" self.service_context.client_id = "client" - _sign_key = self.service_context.keyjar.get_signing_key() + _sign_key = self.service.upstream_get('attribute','keyjar').get_signing_key() resp1 = AuthorizationResponse(code="auth_grant", state="state").to_jwt( key=_sign_key, algorithm="RS256" ) @@ -184,9 +184,9 @@ def test_response(self): _info = self.service.get_request_parameters(request_args=req_args) assert set(_info.keys()) == {"url", "method", "request"} msg = Message().from_urlencoded(self.service.get_urlinfo(_info["url"])) - self.service.client_get("service_context").cstate.set(_state, msg) + self.service.upstream_get("service_context").cstate.set(_state, msg) resp1 = AuthorizationResponse(code="auth_grant", state=_state) response = self.service.parse_response(resp1.to_urlencoded(), "urlencoded", state=_state) self.service.update_service_context(response, key=_state) - assert self.service.client_get("service_context").cstate.get(_state) + assert self.service.upstream_get("service_context").cstate.get(_state) diff --git a/tests/test_client_06_client_authn.py b/tests/test_client_06_client_authn.py index 060ef48b..0b40190f 100644 --- a/tests/test_client_06_client_authn.py +++ b/tests/test_client_06_client_authn.py @@ -92,7 +92,7 @@ def test_quote(): class TestClientSecretBasic(object): def test_construct(self, entity): - _service = entity.client_get("service", "") + _service = entity.get_service("") request = _service.construct( request_args={'redirect_uri': "http://example.com", 'state': "ABCDE"}) @@ -127,7 +127,7 @@ class TestBearerHeader(object): def test_construct(self, entity): request = ResourceRequest(access_token="Sesame") bh = BearerHeader() - http_args = bh.construct(request, service=entity.client_get("service", "")) + http_args = bh.construct(request, service=entity.get_service("")) assert http_args == {"headers": {"Authorization": "Bearer Sesame"}} @@ -136,7 +136,7 @@ def test_construct_with_http_args(self, entity): bh = BearerHeader() # Any HTTP args should just be passed on http_args = bh.construct( - request, service=entity.client_get("service", ""), http_args={"foo": "bar"} + request, service=entity.get_service(""), http_args={"foo": "bar"} ) assert _eq(http_args.keys(), ["foo", "headers"]) @@ -148,7 +148,7 @@ def test_construct_with_headers_in_http_args(self, entity): bh = BearerHeader() http_args = bh.construct( request, - service=entity.client_get("service", ""), + service=entity.get_service(""), http_args={"headers": {"x-foo": "bar"}}, ) @@ -160,14 +160,14 @@ def test_construct_with_resource_request(self, entity): bh = BearerHeader() request = ResourceRequest(access_token="Sesame") - http_args = bh.construct(request, service=entity.client_get("service", "")) + http_args = bh.construct(request, service=entity.get_service("")) assert "access_token" not in request assert http_args == {"headers": {"Authorization": "Bearer Sesame"}} def test_construct_with_token(self, entity): - _service = entity.client_get("service", "") - srv_cntx = _service.client_get("service_context") + _service = entity.get_service("") + srv_cntx = _service.upstream_get("context") _state = srv_cntx.cstate.create_key() srv_cntx.cstate.set(_state, {'iss': "Issuer"}) req = AuthorizationRequest( @@ -186,7 +186,7 @@ def test_construct_with_token(self, entity): ) response = _service.parse_response(resp2.to_urlencoded(), "urlencoded") - _service.client_get("service_context").cstate.update(_state, response) + _service.upstream_get("service_context").cstate.update(_state, response) # and finally use the access token, bound to a state, to # construct the authorization header @@ -197,7 +197,7 @@ def test_construct_with_token(self, entity): class TestBearerBody(object): def test_construct(self, entity): - _token_service = entity.client_get("service", "") + _token_service = entity.get_service("") request = ResourceRequest(access_token="Sesame") http_args = BearerBody().construct(request, service=_token_service) @@ -205,8 +205,8 @@ def test_construct(self, entity): assert http_args is None def test_construct_with_state(self, entity): - _auth_service = entity.client_get("service", "") - _cntx = _auth_service.client_get("service_context") + _auth_service = entity.upstream_get("service", "") + _cntx = _auth_service.upstream_get("service_context") _key = _cntx.cstate.create_key() _cntx.cstate.set(_key, {'iss': "Issuer"}) @@ -228,8 +228,8 @@ def test_construct_with_state(self, entity): assert http_args is None def test_construct_with_request(self, entity): - authz_service = entity.client_get("service", "") - _cntx = authz_service.client_get("service_context") + authz_service = entity.get_service("") + _cntx = authz_service.upstream_get('context') _key = _cntx.cstate.create_key() _cntx.cstate.set(_key, {'iss': "Issuer"}) @@ -240,9 +240,9 @@ def test_construct_with_request(self, entity): resp2 = AccessTokenResponse( access_token="token1", token_type="Bearer", expires_in=0, state=_key ) - _service2 = entity.client_get("service", "") + _service2 = entity.get_service("") response = _service2.parse_response(resp2.to_urlencoded(), "urlencoded") - _service2.client_get("service_context").cstate.update(_key, response) + _service2.upstream_get("service_context").cstate.update(_key, response) request = ResourceRequest() BearerBody().construct(request, service=authz_service, key=_key) @@ -254,7 +254,7 @@ def test_construct_with_request(self, entity): class TestClientSecretPost(object): def test_construct(self, entity): - _token_service = entity.client_get("service", "") + _token_service = entity.upstream_get("service", "") request = _token_service.construct(request_args={'redirect_uri': "http://example.com", 'state': "ABCDE"}) csp = ClientSecretPost() @@ -271,7 +271,7 @@ def test_construct(self, entity): assert http_args is None def test_modify_1(self, entity): - token_service = entity.client_get("service", "") + token_service = entity.upstream_get("service", "") request = token_service.construct(request_args={'redirect_uri': "http://example.com", 'state': "ABCDE"}) csp = ClientSecretPost() @@ -279,11 +279,11 @@ def test_modify_1(self, entity): assert "client_secret" in request def test_modify_2(self, entity): - _service = entity.client_get("service", "") + _service = entity.upstream_get("service", "") request = _service.construct(request_args={'redirect_uri': "http://example.com", 'state': "ABCDE"}) csp = ClientSecretPost() - _service.client_get("service_context").set_usage('client_secret', "") + _service.upstream_get("context").set_usage('client_secret', "") # this will fail with pytest.raises(AuthnFailure): http_args = csp.construct(request, service=_service) @@ -292,7 +292,7 @@ def test_modify_2(self, entity): class TestPrivateKeyJWT(object): def test_construct(self, entity): - token_service = entity.client_get("service", "") + token_service = entity.get_service("") kb_rsa = KeyBundle( source="file://{}".format(os.path.join(BASE_PATH, "data/keys/rsa.key")), fileformat="der", @@ -301,8 +301,8 @@ def test_construct(self, entity): for key in kb_rsa: key.add_kid() - _context = token_service.client_get("service_context") - _context.keyjar.add_kb("", kb_rsa) + _context = token_service.upstream_get('context') + token_service.upstream_get('attribute', 'keyjar').add_kb("", kb_rsa) _context.provider_info = { "issuer": "https://example.com/", "token_endpoint": "https://example.com/token", @@ -326,7 +326,7 @@ def test_construct(self, entity): assert jso["aud"] == [_context.provider_info["token_endpoint"]] def test_construct_client_assertion(self, entity): - token_service = entity.client_get("service", "") + token_service = entity.get_service("") kb_rsa = KeyBundle( source="file://{}".format(os.path.join(BASE_PATH, "data/keys/rsa.key")), @@ -336,7 +336,7 @@ def test_construct_client_assertion(self, entity): request = AccessTokenRequest() pkj = PrivateKeyJWT() _ca = assertion_jwt( - token_service.client_get("service_context").get_client_id(), + token_service.upstream_get('context').get_client_id(), kb_rsa.get("RSA"), "https://example.com/token", "RS256", @@ -350,7 +350,7 @@ def test_construct_client_assertion(self, entity): class TestClientSecretJWT_TE(object): def test_client_secret_jwt(self, entity): - _service_context = entity.client_get("service_context") + _service_context = entity.get_context() _service_context.token_endpoint = "https://example.com/token" _service_context.provider_info = { @@ -365,7 +365,7 @@ def test_client_secret_jwt(self, entity): request = AccessTokenRequest() csj.construct( - request, service=entity.client_get("service", ""), authn_endpoint="token_endpoint" + request, service=entity.get_service(""), authn_endpoint="token_endpoint" ) assert request["client_assertion_type"] == JWT_BEARER assert "client_assertion" in request @@ -385,7 +385,7 @@ def test_client_secret_jwt(self, entity): assert info["aud"] == [_service_context.provider_info["token_endpoint"]] def test_get_key_by_kid(self, entity): - _service_context = entity.client_get("service_context") + _service_context = entity.get_context() _service_context.token_endpoint = "https://example.com/token" _service_context.provider_info = { @@ -401,13 +401,14 @@ def test_get_key_by_kid(self, entity): # get a kid _keys = _service_context.keyjar.get_signing_key(key_type="oct") kid = _keys[0].kid - token_service = entity.client_get("service", "accesstoken") + # token_service = entity.get_service("") + token_service = entity.upstream_get("service", "accesstoken") csj.construct(request, service=token_service, authn_endpoint="token_endpoint", kid=kid) assert "client_assertion" in request def test_get_key_by_kid_fail(self, entity): - token_service = entity.client_get("service", "") - _service_context = token_service.client_get("service_context") + token_service = entity.get_service("") + _service_context = token_service.upstream_get('context') _service_context.token_endpoint = "https://example.com/token" _service_context.provider_info = { @@ -426,7 +427,7 @@ def test_get_key_by_kid_fail(self, entity): csj.construct(request, service=token_service, authn_endpoint="token_endpoint", kid=kid) def test_get_audience_and_algorithm_default_alg(self, entity): - _service_context = entity.client_get("service_context") + _service_context = entity.get_context() _service_context.token_endpoint = "https://example.com/token" _service_context.provider_info = { @@ -442,7 +443,7 @@ def test_get_audience_and_algorithm_default_alg(self, entity): _service_context.registration_response = {} - token_service = entity.client_get("service", "") + token_service = entity.get_service("") # Since I have an RSA key this doesn't fail csj.construct(request, service=token_service, authn_endpoint="token_endpoint") @@ -480,9 +481,9 @@ def test_get_audience_and_algorithm_default_alg(self, entity): class TestClientSecretJWT_UI(object): def test_client_secret_jwt(self, entity): - access_token_service = entity.client_get("service", "") + access_token_service = entity.get_service("") - _service_context = access_token_service.client_get("service_context") + _service_context = access_token_service.upstream_get('context') _service_context.token_endpoint = "https://example.com/token" _service_context.provider_info = { "issuer": "https://example.com/", @@ -516,7 +517,7 @@ def test_client_secret_jwt(self, entity): class TestValidClientInfo(object): def test_valid_service_context(self, entity): - _service_context = entity.client_get("service_context") + _service_context = entity.get_context() _now = 123456 # At some time # Expiration time missing or 0, client_secret never expires diff --git a/tests/test_client_12_client_auth.py b/tests/test_client_12_client_auth.py index 6d0d1c6f..f1b0c9ec 100755 --- a/tests/test_client_12_client_auth.py +++ b/tests/test_client_12_client_auth.py @@ -72,7 +72,7 @@ def test_quote(): class TestClientSecretBasic(object): def test_construct(self, entity): - _token_service = entity.client_get("service", "accesstoken") + _token_service = entity.get_service("accesstoken") request = _token_service.construct(request_args={'redirect_uri': "http://example.com", 'state': "ABCDE"}) @@ -112,7 +112,7 @@ class TestBearerHeader(object): def test_construct(self, entity): request = ResourceRequest(access_token="Sesame") bh = BearerHeader() - http_args = bh.construct(request, service=entity.client_get("service", "accesstoken")) + http_args = bh.construct(request, service=entity.get_service("accesstoken")) assert http_args == {"headers": {"Authorization": "Bearer Sesame"}} @@ -121,7 +121,7 @@ def test_construct_with_http_args(self, entity): bh = BearerHeader() # Any HTTP args should just be passed on http_args = bh.construct( - request, service=entity.client_get("service", "accesstoken"), http_args={"foo": "bar"} + request, service=entity.get_service("accesstoken"), http_args={"foo": "bar"} ) assert _eq(http_args.keys(), ["foo", "headers"]) @@ -133,7 +133,7 @@ def test_construct_with_headers_in_http_args(self, entity): bh = BearerHeader() http_args = bh.construct( request, - service=entity.client_get("service", "accesstoken"), + service=entity.get_service("accesstoken"), http_args={"headers": {"x-foo": "bar"}}, ) @@ -145,14 +145,14 @@ def test_construct_with_resource_request(self, entity): bh = BearerHeader() request = ResourceRequest(access_token="Sesame") - http_args = bh.construct(request, service=entity.client_get("service", "accesstoken")) + http_args = bh.construct(request, service=entity.get_service("accesstoken")) assert "access_token" not in request assert http_args == {"headers": {"Authorization": "Bearer Sesame"}} def test_construct_with_token(self, entity): - authz_service = entity.client_get("service", "authorization") - srv_cntx = authz_service.client_get("service_context") + authz_service = entity.get_service("authorization") + srv_cntx = authz_service.upstream_get("context") _state = srv_cntx.cstate.create_state(iss="Issuer") req = AuthorizationRequest( state=_state, response_type="code", redirect_uri="https://example.com", scope=["openid"] @@ -168,7 +168,7 @@ def test_construct_with_token(self, entity): resp2 = AccessTokenResponse( access_token="token1", token_type="Bearer", expires_in=0, state=_state ) - _token_service = entity.client_get("service", "accesstoken") + _token_service = entity.get_service("accesstoken") response = _token_service.parse_response(resp2.to_urlencoded(), "urlencoded") _token_service.update_service_context(response, key=_state) @@ -182,7 +182,7 @@ def test_construct_with_token(self, entity): class TestBearerBody(object): def test_construct(self, entity): - _token_service = entity.client_get("service", "accesstoken") + _token_service = entity.get_service("accesstoken") request = ResourceRequest(access_token="Sesame") http_args = BearerBody().construct(request, service=_token_service) @@ -190,8 +190,8 @@ def test_construct(self, entity): assert http_args is None def test_construct_with_state(self, entity): - _auth_service = entity.client_get("service", "authorization") - _cntx = _auth_service.client_get("service_context") + _auth_service = entity.get_service("authorization") + _cntx = _auth_service.upstream_get("context") _key = _cntx.cstate.create_state(iss="Issuer") resp = AuthorizationResponse(code="code", state=_key) @@ -212,8 +212,8 @@ def test_construct_with_state(self, entity): assert http_args is None def test_construct_with_request(self, entity): - authz_service = entity.client_get("service", "authorization") - _cntx = authz_service.client_get("service_context") + authz_service = entity.get_service("authorization") + _cntx = authz_service.upstream_get("context") _key = _cntx.cstate.create_state(iss="Issuer") resp1 = AuthorizationResponse(code="auth_grant", state=_key) @@ -223,7 +223,7 @@ def test_construct_with_request(self, entity): resp2 = AccessTokenResponse( access_token="token1", token_type="Bearer", expires_in=0, state=_key ) - _token_service = entity.client_get("service", "accesstoken") + _token_service = entity.get_service("accesstoken") response = _token_service.parse_response(resp2.to_urlencoded(), "urlencoded") _token_service.update_service_context(response, key=_key) @@ -237,7 +237,7 @@ def test_construct_with_request(self, entity): class TestClientSecretPost(object): def test_construct(self, entity): - _token_service = entity.client_get("service", "accesstoken") + _token_service = entity.get_service("accesstoken") request = _token_service.construct(redirect_uri="http://example.com", state="ABCDE") csp = ClientSecretPost() http_args = csp.construct(request, service=_token_service) @@ -253,7 +253,7 @@ def test_construct(self, entity): assert http_args is None def test_modify_1(self, entity): - token_service = entity.client_get("service", "accesstoken") + token_service = entity.get_service("accesstoken") request = token_service.construct(redirect_uri="http://example.com", state="ABCDE") csp = ClientSecretPost() # client secret not in request or kwargs @@ -262,12 +262,12 @@ def test_modify_1(self, entity): assert "client_secret" in request def test_modify_2(self, entity): - token_service = entity.client_get("service", "accesstoken") + token_service = entity.get_service("accesstoken") request = token_service.construct(redirect_uri="http://example.com", state="ABCDE") csp = ClientSecretPost() # client secret not in request or kwargs del request["client_secret"] - token_service.client_get("service_context").set_usage('client_secret', "") + token_service.upstream_get("context").set_usage('client_secret', "") # this will fail with pytest.raises(AuthnFailure): csp.construct(request, service=token_service) @@ -276,7 +276,7 @@ def test_modify_2(self, entity): class TestPrivateKeyJWT(object): def test_construct(self, entity): - token_service = entity.client_get("service", "accesstoken") + token_service = entity.get_service("accesstoken") kb_rsa = KeyBundle( source="file://{}".format(os.path.join(BASE_PATH, "data/keys/rsa.key")), fileformat="der", @@ -285,12 +285,13 @@ def test_construct(self, entity): for key in kb_rsa: key.add_kid() - _context = token_service.client_get("service_context") - _context.keyjar.add_kb("", kb_rsa) - _context.provider_info = { + _keyjar = token_service.upstream_get("attribute", "keyjar") + _keyjar.add_kb("", kb_rsa) + _keyjar.provider_info = { "issuer": "https://example.com/", "token_endpoint": "https://example.com/token", } + _context = token_service.upstream_get("context") _context.registration_response = {"token_endpoint_auth_signing_alg": "RS256"} token_service.endpoint = "https://example.com/token" @@ -308,7 +309,7 @@ def test_construct(self, entity): assert jso["aud"] == [_context.provider_info["token_endpoint"]] def test_construct_client_assertion(self, entity): - token_service = entity.client_get("service", "accesstoken") + token_service = entity.get_service("accesstoken") kb_rsa = KeyBundle( source="file://{}".format(os.path.join(BASE_PATH, "data/keys/rsa.key")), @@ -318,7 +319,7 @@ def test_construct_client_assertion(self, entity): request = AccessTokenRequest() pkj = PrivateKeyJWT() _ca = assertion_jwt( - token_service.client_get("service_context").get_client_id(), + token_service.upstream_get("context").get_client_id(), kb_rsa.get("RSA"), "https://example.com/token", "RS256", @@ -332,7 +333,7 @@ def test_construct_client_assertion(self, entity): class TestClientSecretJWT_TE(object): def test_client_secret_jwt(self, entity): - _service_context = entity.client_get("service_context") + _service_context = entity.get_context() _service_context.token_endpoint = "https://example.com/token" _service_context.provider_info = { @@ -347,7 +348,7 @@ def test_client_secret_jwt(self, entity): csj.construct( request, - service=entity.client_get("service", "accesstoken"), + service=entity.get_service("accesstoken"), authn_endpoint="token_endpoint", ) assert request["client_assertion_type"] == JWT_BEARER @@ -368,7 +369,7 @@ def test_client_secret_jwt(self, entity): assert info["aud"] == [_service_context.provider_info["token_endpoint"]] def test_get_key_by_kid(self, entity): - _service_context = entity.client_get("service_context") + _service_context = entity.get_context() _service_context.token_endpoint = "https://example.com/token" _service_context.provider_info = { @@ -384,13 +385,13 @@ def test_get_key_by_kid(self, entity): # get a kid _keys = _service_context.keyjar.get_issuer_keys("") kid = _keys[0].kid - token_service = entity.client_get("service", "accesstoken") + token_service = entity.get_service("accesstoken") csj.construct(request, service=token_service, authn_endpoint="token_endpoint", kid=kid) assert "client_assertion" in request def test_get_key_by_kid_fail(self, entity): - token_service = entity.client_get("service", "accesstoken") - _service_context = token_service.client_get("service_context") + token_service = entity.get_service("accesstoken") + _service_context = token_service.upstream_get("context") _service_context.token_endpoint = "https://example.com/token" _service_context.provider_info = { @@ -409,7 +410,7 @@ def test_get_key_by_kid_fail(self, entity): csj.construct(request, service=token_service, authn_endpoint="token_endpoint", kid=kid) def test_get_audience_and_algorithm_default_alg(self, entity): - _service_context = entity.client_get("service_context") + _service_context = entity.get_context() _service_context.token_endpoint = "https://example.com/token" _service_context.provider_info = { @@ -424,7 +425,7 @@ def test_get_audience_and_algorithm_default_alg(self, entity): _service_context.registration_response = {} - token_service = entity.client_get("service", "accesstoken") + token_service = entity.get_service("accesstoken") # Add a RSA key to be able to handle default _kb = KeyBundle() @@ -466,9 +467,9 @@ def test_get_audience_and_algorithm_default_alg(self, entity): class TestClientSecretJWT_UI(object): def test_client_secret_jwt(self, entity): - access_token_service = entity.client_get("service", "accesstoken") + access_token_service = entity.get_service("accesstoken") - _service_context = access_token_service.client_get("service_context") + _service_context = access_token_service.upstream_get("context") _service_context.token_endpoint = "https://example.com/token" _service_context.provider_info = { "issuer": "https://example.com/", @@ -502,7 +503,7 @@ def test_client_secret_jwt(self, entity): class TestValidClientInfo(object): def test_valid_service_context(self, entity): - _service_context = entity.client_get("service_context") + _service_context = entity.get_context() _now = 123456 # At some time # Expiration time missing or 0, client_secret never expires diff --git a/tests/test_client_13_service_context.py b/tests/test_client_13_service_context.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_client_14_service_context_impexp.py b/tests/test_client_14_service_context_impexp.py index 106d829c..51647cb7 100644 --- a/tests/test_client_14_service_context_impexp.py +++ b/tests/test_client_14_service_context_impexp.py @@ -122,7 +122,7 @@ def test_registration_userinfo_sign_enc_algs(self): } srvcntx = ServiceContext(base_url=BASE_URL).load( - self.service_context.dump(exclude_attributes=["service_context"]) + self.service_context.dump(exclude_attributes=["context"]) ) assert srvcntx.get_sign_alg("userinfo") is None assert srvcntx.get_enc_alg_enc("userinfo") == {"alg": "RSA1_5", "enc": "A128CBC-HS256"} @@ -142,7 +142,7 @@ def test_registration_request_object_sign_enc_algs(self): } srvcntx = ServiceContext(base_url=BASE_URL).load( - self.service_context.dump(exclude_attributes=["service_context"]) + self.service_context.dump(exclude_attributes=["context"]) ) res = srvcntx.get_enc_alg_enc("userinfo") # 'sign':'RS256' is an added default @@ -167,7 +167,7 @@ def test_registration_id_token_sign_enc_algs(self): } srvcntx = ServiceContext(base_url=BASE_URL).load( - self.service_context.dump(exclude_attributes=["service_context"]) + self.service_context.dump(exclude_attributes=["context"]) ) # 'sign':'RS256' is an added default @@ -235,7 +235,7 @@ def test_verify_alg_support(self): } srvcntx = ServiceContext(base_url=BASE_URL).load( - self.service_context.dump(exclude_attributes=["service_context"]) + self.service_context.dump(exclude_attributes=["context"]) ) assert verify_alg_support(srvcntx, "RS256", "id_token", "signing_alg") @@ -256,7 +256,7 @@ def test_import_keys_file(self): self.service_context.import_keys(keyspec) srvcntx = ServiceContext(base_url=BASE_URL).load( - self.service_context.dump(exclude_attributes=["service_context"]) + self.service_context.dump(exclude_attributes=["context"]) ) # Now there should be 2, the second a RSA key for signing @@ -272,7 +272,7 @@ def test_import_keys_file_json(self): keyspec = {"file": {"rsa": [file_path]}} self.service_context.import_keys(keyspec) - _sc_state = self.service_context.dump(exclude_attributes=["service_context"]) + _sc_state = self.service_context.dump(exclude_attributes=["context"]) _jsc_state = json.dumps(_sc_state) _o_state = json.loads(_jsc_state) srvcntx = ServiceContext(base_url=BASE_URL).load(_o_state) @@ -302,7 +302,7 @@ def test_import_keys_url(self): self.service_context.keyjar.update() srvcntx = ServiceContext(base_url=BASE_URL).load( - self.service_context.dump(exclude_attributes=["service_context"]) + self.service_context.dump(exclude_attributes=["context"]) ) # Now there should be one belonging to https://example.com diff --git a/tests/test_client_18_service.py b/tests/test_client_18_service.py index d9ee21c0..ea44e815 100644 --- a/tests/test_client_18_service.py +++ b/tests/test_client_18_service.py @@ -37,10 +37,11 @@ def create_service(self): "redirect_uris": ["https://example.com/cli/authz_cb"], "preference": {"response_types": ["code"]}, } + service = {"dummy": {"class": DummyService}} entity = Entity(config=config, services=service) - self.service = DummyService(client_get=entity.client_get, conf={}) + self.service = DummyService(upstream_get=entity.unit_get, conf={}) def test_construct(self): req_args = {"foo": "bar"} diff --git a/tests/test_client_19_webfinger.py b/tests/test_client_19_webfinger.py index 0edc919b..1953d251 100644 --- a/tests/test_client_19_webfinger.py +++ b/tests/test_client_19_webfinger.py @@ -41,7 +41,7 @@ def test_query(): "acct:joe@example.com": ("example.com", rel, "acct%3Ajoe%40example.com"), } - wf = WebFinger(ENTITY.client_get) + wf = WebFinger(ENTITY.upstream_get) for key, args in example_oidc.items(): _q = wf.query(key) p = urlsplit(_q) @@ -99,7 +99,7 @@ def test_query_2(): ), } - wf = WebFinger(ENTITY.client_get) + wf = WebFinger(ENTITY.upstream_get) for key, args in example_oidc.items(): _q = wf.query(key) p = urlsplit(_q) @@ -217,7 +217,7 @@ def test_extra_member_response(): class TestWebFinger(object): def test_query_device(self): - wf = WebFinger(ENTITY.client_get) + wf = WebFinger(ENTITY.upstream_get) request_args = {"resource": "p1.example.com"} _info = wf.get_request_parameters(request_args) p = urlsplit(_info["url"]) @@ -227,7 +227,7 @@ def test_query_device(self): assert qs["rel"][0] == "http://openid.net/specs/connect/1.0/issuer" def test_query_rel(self): - wf = WebFinger(ENTITY.client_get) + wf = WebFinger(ENTITY.upstream_get) request_args = {"resource": "acct:bob@example.com"} _info = wf.get_request_parameters(request_args) p = urlsplit(_info["url"]) @@ -237,7 +237,7 @@ def test_query_rel(self): assert qs["rel"][0] == "http://openid.net/specs/connect/1.0/issuer" def test_query_acct(self): - wf = WebFinger(ENTITY.client_get, rel=OIC_ISSUER) + wf = WebFinger(ENTITY.upstream_get, rel=OIC_ISSUER) request_args = {"resource": "acct:carol@example.com"} _info = wf.get_request_parameters(request_args=request_args) @@ -248,7 +248,7 @@ def test_query_acct(self): assert qs["rel"][0] == "http://openid.net/specs/connect/1.0/issuer" def test_query_acct_resource_kwargs(self): - wf = WebFinger(ENTITY.client_get, rel=OIC_ISSUER) + wf = WebFinger(ENTITY.upstream_get, rel=OIC_ISSUER) request_args = {} _info = wf.get_request_parameters( request_args=request_args, resource="acct:carol@example.com" @@ -261,8 +261,8 @@ def test_query_acct_resource_kwargs(self): assert qs["rel"][0] == "http://openid.net/specs/connect/1.0/issuer" def test_query_acct_resource_config(self): - wf = WebFinger(ENTITY.client_get, rel=OIC_ISSUER) - wf.client_get("service_context").config["resource"] = "acct:carol@example.com" + wf = WebFinger(ENTITY.entity_get, rel=OIC_ISSUER) + wf.upstream_get("context").config["resource"] = "acct:carol@example.com" request_args = {} _info = wf.get_request_parameters(request_args=request_args) @@ -273,9 +273,9 @@ def test_query_acct_resource_config(self): assert qs["rel"][0] == "http://openid.net/specs/connect/1.0/issuer" def test_query_acct_no_resource(self): - wf = WebFinger(ENTITY.client_get, rel=OIC_ISSUER) + wf = WebFinger(ENTITY.entity_get, rel=OIC_ISSUER) try: - del wf.client_get("service_context").config["resource"] + del wf.upstream_get("context").config["resource"] except KeyError: pass request_args = {} diff --git a/tests/test_client_20_oauth2.py b/tests/test_client_20_oauth2.py index fca7ba43..2aedc61a 100644 --- a/tests/test_client_20_oauth2.py +++ b/tests/test_client_20_oauth2.py @@ -65,8 +65,8 @@ def test_construct_authorization_request(self): "response_type": ["code"], } - self.client.client_get("service_context").cstate.set("ABCDE", {"iss": 'issuer'}) - msg = self.client.client_get("service", "authorization").construct(request_args=req_args) + self.client.get_context.cstate.set("ABCDE", {"iss": 'issuer'}) + msg = self.client.get_service("authorization").construct(request_args=req_args) assert isinstance(msg, AuthorizationRequest) assert msg["client_id"] == "client_1" assert msg["redirect_uri"] == "https://example.com/auth_cb" @@ -74,7 +74,7 @@ def test_construct_authorization_request(self): def test_construct_accesstoken_request(self): # Bind access code to state req_args = {} - _context = self.client.client_get("service_context") + _context = self.client.get_context() _context.cstate.set("ABCDE", {"issuer": "issuer"}) auth_request = AuthorizationRequest( @@ -85,9 +85,9 @@ def test_construct_accesstoken_request(self): auth_response = AuthorizationResponse(code="access_code") - self.client.client_get("service_context").cstate.update("ABCDE", auth_response) + self.client.get_context().cstate.update("ABCDE", auth_response) - msg = self.client.client_get("service", "accesstoken").construct( + msg = self.client.get_service("accesstoken").construct( request_args=req_args, state="ABCDE" ) @@ -102,7 +102,7 @@ def test_construct_accesstoken_request(self): } def test_construct_refresh_token_request(self): - _context = self.client.client_get("service_context") + _context = self.client.get_context() _state = "ABCDE" _context.cstate.set(_state, {'iss': "issuer"}) @@ -121,7 +121,7 @@ def test_construct_refresh_token_request(self): _context.cstate.update(_state, token_response) req_args = {} - msg = self.client.client_get("service", "refresh_token").construct( + msg = self.client.get_service("refresh_token").construct( request_args=req_args, state="ABCDE" ) assert isinstance(msg, RefreshAccessTokenRequest) @@ -136,7 +136,7 @@ def test_error_response(self): err = ResponseMessage(error="Illegal") http_resp = MockResponse(400, err.to_urlencoded()) resp = self.client.parse_request_response( - self.client.client_get("service", "authorization"), http_resp + self.client.get_service("authorization"), http_resp ) assert resp["error"] == "Illegal" @@ -147,7 +147,7 @@ def test_error_response_500(self): http_resp = MockResponse(500, err.to_urlencoded()) with pytest.raises(ParseError): self.client.parse_request_response( - self.client.client_get("service", "authorization"), http_resp + self.client.get_service("authorization"), http_resp ) def test_error_response_2(self): @@ -158,7 +158,7 @@ def test_error_response_2(self): with pytest.raises(OidcServiceError): self.client.parse_request_response( - self.client.client_get("service", "authorization"), http_resp + self.client.get_service("authorization"), http_resp ) @@ -201,7 +201,7 @@ def test_keyjar(self): "response_type": ["code"], } - _context = self.client.client_get("service_context") - assert len(_context.keyjar) == 2 # one issuer - assert len(_context.keyjar[""]) == 2 - assert len(_context.keyjar.get("sig")) == 2 + _keyjar = self.client.get_attribute('keyjar') + assert len(_keyjar) == 1 # one issuer + assert len(_keyjar[""]) == 2 + assert len(_keyjar.get("sig")) == 2 diff --git a/tests/test_client_21_oidc_service.py b/tests/test_client_21_oidc_service.py index 53a4010c..595e0d45 100644 --- a/tests/test_client_21_oidc_service.py +++ b/tests/test_client_21_oidc_service.py @@ -88,12 +88,12 @@ def create_request(self): } entity = Entity(services=DEFAULT_OIDC_SERVICES, keyjar=make_keyjar(), config=client_config, client_type='oidc') - _context = entity.client_get("service_context") + _context = entity.get_context() _context.issuer = "https://example.com" _context.map_supported_to_preferred() _context.map_preferred_to_registered() self.context = _context - self.service = entity.client_get("service", "authorization") + self.service = entity.get_service("authorization") def test_construct(self): req_args = {"foo": "bar", "response_type": "code", "state": "state"} @@ -220,7 +220,7 @@ def test_request_param(self): assert os.path.isfile(os.path.join(_dirname, "request123456.jwt")) - _context = self.service.client_get("service_context") + _context = self.service.upstream_get("context") _context.set_usage("redirect_uris", ["https://example.com/cb"]) _context.set_usage("request_uris", ["https://example.com/request123456.jwt"]) _context.base_url = "https://example.com/" @@ -298,7 +298,7 @@ def test_allow_unsigned_idtoken(self, allow_sign_alg_none): idt = JWT(ISS_KEY, iss=ISS, lifetime=3600, sign_alg="none") payload = {"sub": "123456789", "aud": ["client_id"], "nonce": req_args["nonce"]} _idt = idt.pack(payload) - self.service.client_get("service_context").work_environment.set_usage("verify_args", { + self.service.upstream_get("context").work_environment.set_usage("verify_args", { "allow_sign_alg_none": allow_sign_alg_none }) resp = AuthorizationResponse(state="state", code="code", id_token=_idt) @@ -326,12 +326,12 @@ def create_request(self): } entity = Entity(keyjar=make_keyjar(), config=client_config, services=DEFAULT_OIDC_SERVICES, client_type='oidc') - _context = entity.client_get("service_context") + _context = entity.get_context() _context.issuer = "https://example.com" _context.map_supported_to_preferred() _context.map_preferred_to_registered() - self.service = entity.client_get("service", "authorization") + self.service = entity.get_service("authorization") def test_construct_code(self): req_args = {"foo": "bar", "response_type": "code", "state": "state"} @@ -401,15 +401,15 @@ def create_request(self): "redirect_uris": ["https://example.com/cli/authz_cb"], } entity = Entity(keyjar=make_keyjar(), config=client_config, services=DEFAULT_OIDC_SERVICES) - entity.client_get("service_context").issuer = "https://example.com" - self.service = entity.client_get("service", "accesstoken") + entity.get_context().issuer = "https://example.com" + self.service = entity.get_service("accesstoken") # add some history auth_request = AuthorizationRequest( redirect_uri="https://example.com/cli/authz_cb", state="state", response_type="code" ) - _current = entity.client_get("service_context").cstate + _current = entity.get_context().cstate _current.update("state", auth_request) auth_response = AuthorizationResponse(code="access_code") @@ -464,7 +464,7 @@ def test_request_init(self): } def test_id_token_nonce_match(self): - _cstate = self.service.client_get("service_context").cstate + _cstate = self.service.get_context().cstate _cstate.bind_key("nonce", "state") resp = AccessTokenResponse() resp[verified_claim_name("id_token")] = {"nonce": "nonce"} @@ -526,8 +526,8 @@ def create_service(self): } } entity = Entity(keyjar=make_keyjar(), config=client_config, client_type='oidc') - entity.client_get("service_context").issuer = "https://example.com" - self.service = entity.client_get("service", "provider_info") + entity.get_context().issuer = "https://example.com" + self.service = entity.get_service("provider_info") def test_construct(self): _req = self.service.construct() @@ -729,7 +729,7 @@ def test_post_parse(self): "registration_endpoint": "{}/registration".format(OP_BASEURL), "end_session_endpoint": "{}/end_session".format(OP_BASEURL), } - _context = self.service.client_get("service_context") + _context = self.service.get_context() assert _context.work_environment.use == {} resp = self.service.post_parse_response(provider_info_response) @@ -742,7 +742,7 @@ def test_post_parse(self): # static client registration _context.map_preferred_to_registered() - use_copy = self.service.client_get("service_context").work_environment.use.copy() + use_copy = self.service.upstream_get("context").work_environment.use.copy() # jwks content will change dynamically between runs assert 'jwks' in use_copy del use_copy['jwks'] @@ -789,7 +789,7 @@ def test_post_parse_2(self): "registration_endpoint": "{}/registration".format(OP_BASEURL), "end_session_endpoint": "{}/end_session".format(OP_BASEURL), } - _context = self.service.client_get("service_context") + _context = self.service.upstream_get("context") assert _context.work_environment.use == {} resp = self.service.post_parse_response(provider_info_response) @@ -802,7 +802,7 @@ def test_post_parse_2(self): # static client registration _context.map_preferred_to_registered() - use_copy = self.service.client_get("service_context").work_environment.use.copy() + use_copy = self.service.upstream_get("context").work_environment.use.copy() # jwks content will change dynamically between runs assert 'jwks' in use_copy del use_copy['jwks'] @@ -863,8 +863,8 @@ def create_request(self): } entity = Entity(keyjar=make_keyjar(), config=client_config, services=DEFAULT_OIDC_SERVICES, client_type='oidc') - entity.client_get("service_context").issuer = "https://example.com" - self.service = entity.client_get("service", "registration") + entity.get_context().issuer = "https://example.com" + self.service = entity.get_service("registration") def test_construct(self): _req = self.service.construct() @@ -882,7 +882,7 @@ def test_construct(self): 'userinfo_signed_response_alg'} def test_config_with_post_logout(self): - self.service.client_get("service_context").work_environment.set_preference( + self.service.upstream_get("context").work_environment.set_preference( "post_logout_redirect_uri", "https://example.com/post_logout") _req = self.service.construct() @@ -915,7 +915,7 @@ def test_config_with_required_request_uri(): } entity = Entity(keyjar=make_keyjar(), config=client_config, services=DEFAULT_OIDC_SERVICES, client_type='oidc') - entity.client_get("service_context").issuer = "https://example.com" + entity.get_context().issuer = "https://example.com" pi_service = entity.client_get("service", "provider_info") pi_service.match_preferences({"require_request_uri_registration": True}) @@ -953,7 +953,7 @@ def test_config_logout_uri(): } entity = Entity(keyjar=make_keyjar(), config=client_config, services=DEFAULT_OIDC_SERVICES, client_type='oidc') - _context = entity.client_get("service_context") + _context = entity.get_context() _context.issuer = "https://example.com" pi_service = entity.client_get("service", "provider_info") @@ -993,16 +993,16 @@ def create_request(self): } entity = Entity(keyjar=make_keyjar(), config=client_config, services=DEFAULT_OIDC_SERVICES, client_type='oidc') - entity.client_get("service_context").issuer = "https://example.com" - self.service = entity.client_get("service", "userinfo") + entity.get_context().issuer = "https://example.com" + self.service = entity.get_service("userinfo") - entity.client_get("service_context").work_environment.use = { + entity.get_context().work_environment.use = { "userinfo_signed_response_alg": "RS256", "userinfo_encrypted_response_alg": "RSA-OAEP", "userinfo_encrypted_response_enc": "A256GCM", } - _cstate = self.service.client_get("service_context").cstate + _cstate = self.service.get_context().cstate # Add history auth_response = AuthorizationResponse(code="access_code") _cstate.update("abcde", auth_response) @@ -1094,7 +1094,7 @@ def test_unpack_aggregated_response_missing_keys(self): def test_unpack_signed_response(self): resp = OpenIDSchema(sub="diana", given_name="Diana", family_name="krall", iss=ISS) sk = ISS_KEY.get_signing_key("rsa", issuer_id=ISS) - alg = self.service.client_get("service_context").get_sign_alg("userinfo") + alg = self.service.upstream_get("context").get_sign_alg("userinfo") _resp = self.service.parse_response( resp.to_jwt(sk, algorithm=alg), state="abcde", sformat="jwt" ) @@ -1104,7 +1104,7 @@ def test_unpack_encrypted_response(self): # Add encryption key _kj = build_keyjar([{"type": "RSA", "use": ["enc"]}], issuer_id="") # Own key jar gets the private key - self.service.client_get("service_context").keyjar.import_jwks( + self.service.upstream_get("service_context").keyjar.import_jwks( _kj.export_jwks(private=True), issuer_id="" ) # opponent gets the public key @@ -1114,7 +1114,7 @@ def test_unpack_encrypted_response(self): sub="diana", given_name="Diana", family_name="krall", iss=ISS, aud="client_id" ) enckey = ISS_KEY.get_encrypt_key("rsa", issuer_id="client_id") - algspec = self.service.client_get("service_context").get_enc_alg_enc( + algspec = self.service.upstream_get("context").get_enc_alg_enc( self.service.service_name ) @@ -1138,11 +1138,11 @@ def create_request(self): } services = {"checksession": {"class": "idpyoidc.client.oidc.check_session.CheckSession"}} entity = Entity(keyjar=make_keyjar(), config=client_config, services=services) - entity.client_get("service_context").issuer = "https://example.com" - self.service = entity.client_get("service", "check_session") + entity.get_context().issuer = "https://example.com" + self.service = entity.get_service("check_session") def test_construct(self): - _cstate = self.service.client_get("service_context").cstate + _cstate = self.service.upstream_get("service_context").cstate _cstate.update("abcde", {"id_token": "a.signed.jwt"}) _req = self.service.construct(state="abcde") assert isinstance(_req, CheckSessionRequest) @@ -1166,11 +1166,11 @@ def create_request(self): } services = {"checksession": {"class": "idpyoidc.client.oidc.check_id.CheckID"}} entity = Entity(keyjar=make_keyjar(), config=client_config, services=services) - entity.client_get("service_context").issuer = "https://example.com" - self.service = entity.client_get("service", "check_id") + entity.get_context().issuer = "https://example.com" + self.service = entity.get_service("check_id") def test_construct(self): - _cstate = self.service.client_get("service_context").cstate + _cstate = self.service.upstream_get("service_context").cstate _cstate.set("abcde", {"id_token": "a.signed.jwt"}) _req = self.service.construct(state="abcde") assert isinstance(_req, CheckIDRequest) @@ -1195,14 +1195,14 @@ def create_request(self): } services = {"checksession": {"class": "idpyoidc.client.oidc.end_session.EndSession"}} entity = Entity(keyjar=make_keyjar(), config=client_config, services=services) - _context = entity.client_get("service_context") + _context = entity.get_context() _context.issuer = "https://example.com" _context.map_supported_to_preferred() _context.map_preferred_to_registered() - self.service = entity.client_get("service", "end_session") + self.service = entity.get_service("end_session") def test_construct(self): - self.service.client_get("service_context").cstate.update( + self.service.upstream_get("service_context").cstate.update( "abcde", {"id_token": "a.signed.jwt"}) _req = self.service.construct(state="abcde") assert isinstance(_req, EndSessionRequest) @@ -1235,11 +1235,11 @@ def test_authz_service_conf(): } entity = Entity(keyjar=make_keyjar(), config=client_config, services=services, client_type='oidc') - _context = entity.client_get("service_context") + _context = entity.get_context() _context.issuer = "https://example.com" _context.map_supported_to_preferred() _context.map_preferred_to_registered() - service = entity.client_get("service", "authorization") + service = entity.get_service("authorization") req = service.construct() assert "claims" in req @@ -1258,7 +1258,7 @@ def test_jwks_uri_conf(): } entity = Entity(keyjar=make_keyjar(), config=client_config, services=DEFAULT_OIDC_SERVICES, client_type='oidc') - _context = entity.client_get("service_context") + _context = entity.get_context() _context.issuer = "https://example.com" _context.map_supported_to_preferred() _context.map_preferred_to_registered() @@ -1284,7 +1284,7 @@ def test_jwks_uri_arg(): services=DEFAULT_OIDC_SERVICES, client_type='oidc' ) - _context = entity.client_get("service_context") + _context = entity.get_context() _context.issuer = "https://example.com" _context.map_supported_to_preferred() _context.map_preferred_to_registered() diff --git a/tests/test_client_22_oidc.py b/tests/test_client_22_oidc.py index f6ac3f9b..eca2f7e6 100755 --- a/tests/test_client_22_oidc.py +++ b/tests/test_client_22_oidc.py @@ -61,14 +61,14 @@ def test_construct_authorization_request(self): "nonce": "nonce", } - self.client.client_get("service_context").cstate.set("ABCDE", {'iss': "issuer"}) + self.client.get_context().cstate.set("ABCDE", {'iss': "issuer"}) - msg = self.client.client_get("service", "authorization").construct(request_args=req_args) + msg = self.client.get_service("authorization").construct(request_args=req_args) assert isinstance(msg, AuthorizationRequest) assert msg["redirect_uri"] == "https://example.com/auth_cb" def test_construct_accesstoken_request(self): - _context = self.client.client_get("service_context") + _context = self.client.get_context() auth_request = AuthorizationRequest(redirect_uri="https://example.com/cli/authz_cb") _state = _context.cstate.create_key() @@ -83,7 +83,7 @@ def test_construct_accesstoken_request(self): # Bind access code to state req_args = {} - msg = self.client.client_get("service", "accesstoken").construct( + msg = self.client.get_service("accesstoken").construct( request_args=req_args, state=_state ) assert isinstance(msg, AccessTokenRequest) @@ -97,7 +97,7 @@ def test_construct_accesstoken_request(self): } def test_construct_refresh_token_request(self): - _context = self.client.client_get("service_context") + _context = self.client.get_context() _context.cstate.set("ABCDE", {'iss':"issuer"}) auth_request = AuthorizationRequest( @@ -113,7 +113,7 @@ def test_construct_refresh_token_request(self): _context.cstate.update("ABCDE", token_response) req_args = {} - msg = self.client.client_get("service", "refresh_token").construct( + msg = self.client.get_service("refresh_token").construct( request_args=req_args, state="ABCDE" ) assert isinstance(msg, RefreshAccessTokenRequest) @@ -125,7 +125,7 @@ def test_construct_refresh_token_request(self): } def test_do_userinfo_request_init(self): - _context = self.client.client_get("service_context") + _context = self.client.get_context() _state = _context.cstate.create_key() _context.cstate.set(_state, {'iss': "issuer"}) @@ -141,7 +141,7 @@ def test_do_userinfo_request_init(self): token_response = AccessTokenResponse(refresh_token="refresh_with_me", access_token="access") _context.cstate.update(_state, token_response) - _srv = self.client.client_get("service", "userinfo") + _srv = self.client.get_service("userinfo") _srv.endpoint = "https://example.com/userinfo" _info = _srv.get_request_parameters(state=_state) assert _info diff --git a/tests/test_client_23_pkce.py b/tests/test_client_23_pkce.py index 22ca27c9..55c189d1 100644 --- a/tests/test_client_23_pkce.py +++ b/tests/test_client_23_pkce.py @@ -67,14 +67,14 @@ def create_client(self): client_type='oauth2') if "add_ons" in config: - do_add_ons(config["add_ons"], self.entity.client_get("services")) - _context = self.entity.get_service_context() + do_add_ons(config["add_ons"], self.entity.get_services()) + _context = self.entity.get_context() _context.map_supported_to_preferred() _context.map_preferred_to_registered() def test_add_code_challenge_default_values(self): - auth_serv = self.entity.client_get("service", "authorization") - _state_key = self.entity.client_get("service_context").cstate.create_state(iss="Issuer") + auth_serv = self.entity.get_service("authorization") + _state_key = self.entity.get_context().cstate.create_state(iss="Issuer") request_args, _ = add_code_challenge({"state": _state_key}, auth_serv) # default values are length:64 method:S256 @@ -86,7 +86,7 @@ def test_add_code_challenge_default_values(self): def test_authorization_and_pkce(self): auth_serv = self.entity.client_get("service", "authorization") - _state = self.entity.client_get("service_context").cstate.create_state(iss="Issuer") + _state = self.entity.get_context().state.create_state(iss="Issuer") request = auth_serv.construct_request({"state": _state, "response_type": "code"}) assert set(request.keys()) == { @@ -99,13 +99,16 @@ def test_authorization_and_pkce(self): } def test_access_token_and_pkce(self): - authz_service = self.entity.client_get("service", "authorization") + authz_service = self.entity.get_service("authorization") request = authz_service.construct_request({"state": "state", "response_type": "code"}) _state = request["state"] auth_response = AuthorizationResponse(code="access code") - self.entity.client_get("service_context").cstate.update(_state, auth_response) + _context = self.entity.get_context() + _context.cstate.update(_state, auth_response) + auth_serv = self.entity.get_service("authorization") + _state = _context.cstate.create_state(iss="Issuer") - token_service = self.entity.client_get("service", "accesstoken") + token_service = self.entity.get_service("accesstoken") request = token_service.construct_request(state=_state) assert set(request.keys()) == { "client_id", @@ -134,10 +137,10 @@ def create_client(self): } self.entity = Entity(keyjar=CLI_KEY, config=config, services=DEFAULT_OAUTH2_SERVICES) if "add_ons" in config: - do_add_ons(config["add_ons"], self.entity.client_get("services")) + do_add_ons(config["add_ons"], self.entity.get_services()) def test_add_code_challenge_spec_values(self): - auth_serv = self.entity.client_get("service", "authorization") + auth_serv = self.entity.get_service("authorization") request_args, _ = add_code_challenge({"state": "state"}, auth_serv) assert set(request_args.keys()) == {"code_challenge", "code_challenge_method", "state"} assert request_args["code_challenge_method"] == "S384" diff --git a/tests/test_client_25_cc_oauth2_service.py b/tests/test_client_25_cc_oauth2_service.py index 282b066e..c7c6ae62 100644 --- a/tests/test_client_25_cc_oauth2_service.py +++ b/tests/test_client_25_cc_oauth2_service.py @@ -29,12 +29,12 @@ def create_service(self): self.entity = Entity(config=client_config, services=services) - self.entity.client_get("service", "accesstoken").endpoint = "https://example.com/token" - self.entity.client_get("service", "refresh_token").endpoint = "https://example.com/token" + self.entity.get_service("accesstoken").endpoint = "https://example.com/token" + self.entity.get_service("refresh_token").endpoint = "https://example.com/token" def test_token_get_request(self): request_args = {"grant_type": "client_credentials"} - _srv = self.entity.client_get("service", "accesstoken") + _srv = self.entity.get_service("accesstoken") _info = _srv.get_request_parameters(request_args=request_args) assert _info["method"] == "POST" assert _info["url"] == "https://example.com/token" @@ -46,7 +46,7 @@ def test_token_get_request(self): def test_token_parse_response(self): request_args = {"grant_type": "client_credentials"} - _srv = self.entity.client_get("service", "accesstoken") + _srv = self.entity.get_service("accesstoken") _request_info = _srv.get_request_parameters(request_args=request_args) response = AccessTokenResponse( @@ -63,11 +63,11 @@ def test_token_parse_response(self): # since no state attribute is involved, a key is minted _key = rndstr(16) _srv.update_service_context(_response, key=_key) - info = _srv.client_get("service_context").cstate.get(_key) + info = _srv.upstream_get("context").cstate.get(_key) assert "__expires_at" in info def test_refresh_token_get_request(self): - _srv = self.entity.client_get("service", "accesstoken") + _srv = self.entity.get_service("accesstoken") _srv.update_service_context( { "access_token": "2YotnFZFEjr1zCsicMWpAA", @@ -77,7 +77,7 @@ def test_refresh_token_get_request(self): "example_parameter": "example_value", } ) - _srv = self.entity.client_get("service", "refresh_token") + _srv = self.entity.get_service("refresh_token") _info = _srv.get_request_parameters(state='') assert _info["method"] == "POST" assert _info["url"] == "https://example.com/token" @@ -89,7 +89,7 @@ def test_refresh_token_get_request(self): def test_refresh_token_parse_response(self): request_args = {"grant_type": "client_credentials"} - _srv = self.entity.client_get("service", "accesstoken") + _srv = self.entity.get_service("accesstoken") _request_info = _srv.get_request_parameters(request_args=request_args) response = AccessTokenResponse( @@ -106,12 +106,12 @@ def test_refresh_token_parse_response(self): # since no state attribute is involved, a key is minted _key = rndstr(16) _srv.update_service_context(_response, key=_key) - info = _srv.client_get("service_context").cstate.get(_key) + info = _srv.upstream_get("context").cstate.get(_key) assert "__expires_at" in info # Move from token to refresh token service - _srv = self.entity.client_get("service", "refresh_token") + _srv = self.entity.get_service("refresh_token") _request_info = _srv.get_request_parameters(request_args=request_args, state=_key) refresh_response = AccessTokenResponse( @@ -125,12 +125,12 @@ def test_refresh_token_parse_response(self): _response = _srv.parse_response(refresh_response.to_json(), sformat="json") _srv.update_service_context(_response, key=_key) - info = _srv.client_get("service_context").cstate.get(_key) + info = _srv.upstream_get("context").cstate.get(_key) assert "__expires_at" in info def test_2nd_refresh_token_parse_response(self): request_args = {"grant_type": "client_credentials"} - _srv = self.entity.client_get("service", "accesstoken") + _srv = self.entity.get_service("accesstoken") _request_info = _srv.get_request_parameters(request_args=request_args) response = AccessTokenResponse( @@ -147,12 +147,12 @@ def test_2nd_refresh_token_parse_response(self): # since no state attribute is involved, a key is minted _key = rndstr(16) _srv.update_service_context(_response, key=_key) - info = _srv.client_get("service_context").cstate.get(_key) + info = _srv.upstream_get("context").cstate.get(_key) assert "__expires_at" in info # Move from token to refresh token service - _srv = self.entity.client_get("service", "refresh_token") + _srv = self.entity.get_service("refresh_token") _request_info = _srv.get_request_parameters(request_args=request_args, state=_key) refresh_response = AccessTokenResponse( @@ -166,7 +166,7 @@ def test_2nd_refresh_token_parse_response(self): _response = _srv.parse_response(refresh_response.to_json(), sformat="json") _srv.update_service_context(_response, key=_key) - info = _srv.client_get("service_context").cstate.get(_key) + info = _srv.upstream_get("context").cstate.get(_key) assert "__expires_at" in info _request_info = _srv.get_request_parameters(request_args=request_args, state=_key) diff --git a/tests/test_client_26_read_registration.py b/tests/test_client_26_read_registration.py index b295c500..dc53189e 100644 --- a/tests/test_client_26_read_registration.py +++ b/tests/test_client_26_read_registration.py @@ -41,8 +41,8 @@ def create_request(self): _context.map_supported_to_preferred() _context.map_preferred_to_registered() - self.reg_service = self.entity.client_get("service", "registration") - self.read_service = self.entity.client_get("service", "registration_read") + self.reg_service = self.entity.get_service("registration") + self.read_service = self.entity.get_service("registration_read") def test_construct(self): self.reg_service.endpoint = "{}/registration".format(ISS) diff --git a/tests/test_client_27_conversation.py b/tests/test_client_27_conversation.py index 7c99920b..4e960ba5 100644 --- a/tests/test_client_27_conversation.py +++ b/tests/test_client_27_conversation.py @@ -155,7 +155,7 @@ def test_conversation(): entity = Entity(config=config, keyjar=RP_KEYJAR, client_type='oidc') - assert set(entity.client_get("services").keys()) == { + assert set(entity.get_services().keys()) == { "accesstoken", "authorization", "webfinger", @@ -165,11 +165,11 @@ def test_conversation(): "provider_info", 'end_session', } - service_context = entity.client_get("service_context") + service_context = entity.get_context() # ======================== WebFinger ======================== - webfinger_service = entity.client_get("service", "webfinger") + webfinger_service = entity.get_service("webfinger") info = webfinger_service.get_request_parameters(request_args={"resource": "foobar@example.org"}) assert ( @@ -202,10 +202,10 @@ def test_conversation(): ] webfinger_service.update_service_context(resp=response) - entity.client_get("service_context").issuer = OP_BASEURL + entity.get_context().issuer = OP_BASEURL # =================== Provider info discovery ==================== - provider_info_service = entity.client_get("service", "provider_info") + provider_info_service = entity.get_service("provider_info") info = provider_info_service.get_request_parameters() assert info["url"] == "https://example.org/op/.well-known/openid" "-configuration" @@ -406,13 +406,13 @@ def test_conversation(): assert isinstance(resp, ProviderConfigurationResponse) provider_info_service.update_service_context(resp, '') - _pi = entity.client_get("service_context").provider_info + _pi = entity.get_context().provider_info assert _pi["issuer"] == OP_BASEURL assert _pi["authorization_endpoint"] == "https://example.org/op/authorization" assert _pi["registration_endpoint"] == "https://example.org/op/registration" # =================== Client registration ==================== - registration_service = entity.client_get("service", "registration") + registration_service = entity.get_service("registration") info = registration_service.get_request_parameters() assert info["url"] == "https://example.org/op/registration" @@ -480,7 +480,7 @@ def test_conversation(): STATE = "Oh3w3gKlvoM2ehFqlxI3HIK5" NONCE = "UvudLKz287YByZdsY3AJoPAlEXQkJ0dK" - auth_service = entity.client_get("service", "authorization") + auth_service = entity.get_service("authorization") _cstate = service_context.cstate info = auth_service.get_request_parameters(request_args={"state": STATE, "nonce": NONCE}) @@ -516,7 +516,7 @@ def test_conversation(): # =================== Access token ==================== - token_service = entity.client_get("service", "accesstoken") + token_service = entity.get_service("accesstoken") request_args = {"state": STATE, "redirect_uri": service_context.get_usage("redirect_uris")[0]} info = token_service.get_request_parameters(request_args=request_args) @@ -584,7 +584,7 @@ def test_conversation(): # =================== User info ==================== - userinfo_service = entity.client_get("service", "userinfo") + userinfo_service = entity.get_service("userinfo") info = userinfo_service.get_request_parameters(state=STATE) assert info["url"] == "https://example.org/op/userinfo" diff --git a/tests/test_client_28_rp_handler_oidc.py b/tests/test_client_28_rp_handler_oidc.py index 566430a8..d59ae361 100644 --- a/tests/test_client_28_rp_handler_oidc.py +++ b/tests/test_client_28_rp_handler_oidc.py @@ -228,14 +228,14 @@ def test_pick_config(self): def test_init_client(self): client = self.rph.init_client("github") - assert set(client.client_get("services").keys()) == { + assert set(client.get_services().keys()) == { "authorization", "accesstoken", "userinfo", "refresh_token", } - _context = client.client_get("service_context") + _context = client.get_context() # Neither provider info discovery not client registration has been done # So only preferences so far. @@ -275,8 +275,8 @@ def test_do_provider_info(self): # Make sure the service endpoints are set for service_type in ["authorization", "accesstoken", "userinfo"]: - _srv = client.client_get("service", service_type) - _endp = client.client_get("service_context").get("provider_info")[_srv.endpoint_name] + _srv = client.get_service(service_type) + _endp = client.get_context().get("provider_info")[_srv.endpoint_name] assert _srv.endpoint == _endp def test_do_client_registration(self): @@ -288,14 +288,14 @@ def test_do_client_registration(self): assert self.rph.hash2issuer["github"] == issuer assert ( - client.client_get("service_context").get_preference('callback_uris').get( + client.get_context().get_preference('callback_uris').get( "post_logout_redirect_uris") is None ) def test_do_client_setup(self): client = self.rph.client_setup("github") _github_id = iss_id("github") - _context = client.client_get("service_context") + _context = client.get_context() # Neither provider info discovery not client registration has been done # So only preferences so far. @@ -310,18 +310,18 @@ def test_do_client_setup(self): assert len(keys) == 2 for service_type in ["authorization", "accesstoken", "userinfo"]: - _srv = client.client_get("service", service_type) - _endp = _srv.client_get("service_context").get("provider_info")[_srv.endpoint_name] + _srv = client.get_service(service_type) + _endp = _srv.upstream_get("context").get("provider_info")[_srv.endpoint_name] assert _srv.endpoint == _endp assert self.rph.hash2issuer["github"] == _context.get("issuer") def test_create_callbacks(self): client = self.rph.init_client("https://op.example.com/") - _srv = client.client_get("service", "registration") - _context = _srv.client_get("service_context") + _srv = client.get_service("registration") + _context = _srv.upstream_get("context") - cb = _srv.client_get("service_context").get_preference('callback_uris') + cb = _context.get_preference('callback_uris') assert set(cb.keys()) == {"request_uris", "redirect_uris"} assert set(cb['redirect_uris'].keys()) == {'code'} @@ -340,7 +340,7 @@ def test_begin(self): client = self.rph.issuer2rp[_github_id] - assert client.client_get("service_context").issuer == _github_id + assert client.get_context().issuer == _github_id part = urlsplit(res["url"]) assert part.scheme == "https" @@ -378,7 +378,7 @@ def test_get_client_from_session_key(self): # redo self.rph.do_provider_info(state=res["state"]) # get new redirect_uris - cli2.client_get("service_context").set_preference("redirect_uris", []) + cli2.get_context().set_preference("redirect_uris", []) self.rph.do_client_registration(state=res["state"]) def test_finalize_auth(self): @@ -389,7 +389,7 @@ def test_finalize_auth(self): auth_response = AuthorizationResponse(code="access_code", state=res["state"]) resp = self.rph.finalize_auth(client, _session['iss'], auth_response.to_dict()) assert set(resp.keys()) == {"state", "code"} - _state = client.client_get("service_context").cstate.get(res["state"]) + _state = client.get_context().cstate.get(res["state"]) assert set(_state.keys()) == {'client_id', 'code', 'iss', @@ -418,7 +418,7 @@ def test_get_tokens(self): client = self.rph.issuer2rp[_session['iss']] _github_id = iss_id("github") - _context = client.client_get("service_context") + _context = client.get_context() _context.keyjar.import_jwks(GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) _nonce = _session["nonce"] @@ -448,7 +448,7 @@ def test_get_tokens(self): adding_headers={"Content-Type": "application/json"}, status=200, ) - client.client_get("service", "accesstoken").endpoint = _url + client.get_service("accesstoken").endpoint = _url auth_response = AuthorizationResponse(code="access_code", state=res["state"]) resp = self.rph.finalize_auth(client, _session['iss'], auth_response.to_dict()) @@ -463,7 +463,7 @@ def test_get_tokens(self): "__expires_at", } - _curr = client.client_get("service_context").cstate.get(res["state"]) + _curr = client.get_context().cstate.get(res["state"]) assert set(_curr.keys()) == {'__expires_at', '__verified_id_token', 'access_token', @@ -483,7 +483,7 @@ def test_access_and_id_token(self): res = self.rph.begin(issuer_id="github") _session = self.rph.get_session_information(res["state"]) client = self.rph.issuer2rp[_session['iss']] - _context = client.client_get("service_context") + _context = client.get_context() _nonce = _session["nonce"] _iss = _session['iss'] _aud = _context.get_client_id() @@ -516,7 +516,7 @@ def test_access_and_id_token(self): adding_headers={"Content-Type": "application/json"}, status=200, ) - client.client_get("service", "accesstoken").endpoint = _url + client.get_service("accesstoken").endpoint = _url _response = AuthorizationResponse(code="access_code", state=res["state"]) auth_response = self.rph.finalize_auth(client, _session['iss'], _response.to_dict()) @@ -528,7 +528,7 @@ def test_access_and_id_token_by_reference(self): res = self.rph.begin(issuer_id="github") _session = self.rph.get_session_information(res["state"]) client = self.rph.issuer2rp[_session['iss']] - _context = client.client_get("service_context") + _context = client.get_context() _nonce = _session["nonce"] _iss = _session['iss'] _aud = _context.get_client_id() @@ -561,7 +561,7 @@ def test_access_and_id_token_by_reference(self): adding_headers={"Content-Type": "application/json"}, status=200, ) - client.client_get("service", "accesstoken").endpoint = _url + client.get_service("accesstoken").endpoint = _url _response = AuthorizationResponse(code="access_code", state=res["state"]) _ = self.rph.finalize_auth(client, _session['iss'], _response.to_dict()) @@ -573,7 +573,7 @@ def test_get_user_info(self): res = self.rph.begin(issuer_id="github") _session = self.rph.get_session_information(res["state"]) client = self.rph.issuer2rp[_session['iss']] - _context = client.client_get("service_context") + _context = client.get_context() _nonce = _session["nonce"] _iss = _session['iss'] _aud = _context.get_client_id() @@ -606,7 +606,7 @@ def test_get_user_info(self): adding_headers={"Content-Type": "application/json"}, status=200, ) - client.client_get("service", "accesstoken").endpoint = _url + client.get_service("accesstoken").endpoint = _url _response = AuthorizationResponse(code="access_code", state=res["state"]) auth_response = self.rph.finalize_auth(client, _session['iss'], _response.to_dict()) @@ -622,7 +622,7 @@ def test_get_user_info(self): adding_headers={"Content-Type": "application/json"}, status=200, ) - client.client_get("service", "userinfo").endpoint = _url + client.get_service("userinfo").endpoint = _url userinfo_resp = self.rph.get_user_info(res["state"], client, token_resp["access_token"]) assert userinfo_resp @@ -631,7 +631,7 @@ def test_userinfo_in_id_token(self): res = self.rph.begin(issuer_id="github") _session = self.rph.get_session_information(res["state"]) client = self.rph.issuer2rp[_session['iss']] - _context = client.client_get("service_context") + _context = client.get_context() _nonce = _session["nonce"] _iss = _session['iss'] _aud = _context.get_client_id() @@ -654,7 +654,7 @@ def test_userinfo_in_id_token(self): def test_get_provider_specific_service(): srv_desc = {"access_token": {"class": "idpyoidc.client.provider.github.AccessToken"}} entity = Entity(services=srv_desc, config={}) - assert entity.client_get("service", "accesstoken").response_body_type == "urlencoded" + assert entity.get_service("accesstoken").response_body_type == "urlencoded" class TestRPHandlerTier2(object): @@ -664,7 +664,7 @@ def rphandler_setup(self): res = self.rph.begin(issuer_id="github") _session = self.rph.get_session_information(res["state"]) client = self.rph.issuer2rp[_session['iss']] - _context = client.client_get("service_context") + _context = client.get_context() _nonce = _session["nonce"] _iss = _session['iss'] _aud = _context.get_client_id() @@ -699,7 +699,7 @@ def rphandler_setup(self): status=200, ) - client.client_get("service", "accesstoken").endpoint = _url + client.get_service("accesstoken").endpoint = _url _response = AuthorizationResponse(code="access_code", state=res["state"]) auth_response = self.rph.finalize_auth(client, _session['iss'], _response.to_dict()) @@ -716,7 +716,7 @@ def rphandler_setup(self): status=200, ) - client.client_get("service", "userinfo").endpoint = _url + client.get_service("userinfo").endpoint = _url self.rph.get_user_info(res["state"], client, token_resp["access_token"]) self.state = res["state"] @@ -744,7 +744,7 @@ def test_refresh_access_token(self): status=200, ) - client.client_get("service", "refresh_token").endpoint = _url + client.get_service("refresh_token").endpoint = _url res = self.rph.refresh_access_token(self.state, client, "openid email") assert res["access_token"] == "2nd_accessTok" @@ -761,7 +761,7 @@ def test_get_user_info(self): adding_headers={"Content-Type": "application/json"}, status=200, ) - client.client_get("service", "userinfo").endpoint = _url + client.get_service("userinfo").endpoint = _url resp = self.rph.get_user_info(self.state, client) assert set(resp.keys()) == {"sub", "mail"} @@ -881,7 +881,7 @@ def rphandler_setup(self): self.issuer = "https://github.com/login/oauth/authorize" self.mock_op = MockOP(issuer=self.issuer) self.rph = RPHandler( - BASE_URL, client_configs=CLIENT_CONFIG, http_lib=self.mock_op, keyjar=CLI_KEY + BASE_URL, client_configs=CLIENT_CONFIG, httpc=self.mock_op, keyjar=CLI_KEY ) def test_finalize(self): @@ -892,7 +892,7 @@ def test_finalize(self): p = urlparse(CLIENT_CONFIG["github"]["provider_info"]["authorization_endpoint"]) self.mock_op.register_get_response(p.path, "Redirect", 302) - _ = client.http(auth_query["url"]) + _ = client.httpc(auth_query["url"]) # the user is redirected back to the RP with a positive response auth_response = AuthorizationResponse(code="access_code", state=auth_query["state"]) @@ -923,7 +923,7 @@ def test_finalize(self): ) _github_id = iss_id("github") - client.client_get("service_context").keyjar.import_jwks( + client.get_context().keyjar.import_jwks( GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id ) diff --git a/tests/test_client_29_pushed_auth.py b/tests/test_client_29_pushed_auth.py index 47a0de64..1f8901a3 100644 --- a/tests/test_client_29_pushed_auth.py +++ b/tests/test_client_29_pushed_auth.py @@ -47,18 +47,18 @@ def create_client(self): } self.entity = Client(keyjar=CLI_KEY, config=config, services=DEFAULT_OAUTH2_SERVICES) - self.entity.client_get("service_context").provider_info = { + self.entity.get_context().provider_info = { "pushed_authorization_request_endpoint": "https://as.example.com/push" } def test_authorization(self): - auth_service = self.entity.client_get("service", "authorization") + auth_service = self.entity.get_service("authorization") req_args = {"foo": "bar", "response_type": "code"} with responses.RequestsMock() as rsps: _resp = {"request_uri": "urn:example:bwc4JK-ESC0w8acc191e-Y1LTC2", "expires_in": 3600} rsps.add( "GET", - auth_service.client_get("service_context").provider_info[ + auth_service.upstream_get("context").provider_info[ "pushed_authorization_request_endpoint" ], body=json.dumps(_resp), diff --git a/tests/test_client_30_rph_defaults.py b/tests/test_client_30_rph_defaults.py index bff069ee..a7dbfbf7 100644 --- a/tests/test_client_30_rph_defaults.py +++ b/tests/test_client_30_rph_defaults.py @@ -24,7 +24,7 @@ def test_pick_config(self): def test_init_client(self): client = self.rph.init_client("") - assert set(client.client_get("services").keys()) == { + assert set(client.get_services().keys()) == { "registration", "provider_info", "authorization", @@ -33,7 +33,7 @@ def test_init_client(self): "refresh_token", } - _context = client.client_get("service_context") + _context = client.get_context() assert set(_context.work_environment.prefer.keys()) == { 'application_type', @@ -49,8 +49,9 @@ def test_init_client(self): 'userinfo_encryption_alg_values_supported', 'userinfo_encryption_enc_values_supported'} - assert list(_context.keyjar.owners()) == ["", BASE_URL] - keys = _context.keyjar.get_issuer_keys("") + _keyjar = client.get_attribute('keyjar') + assert list(_keyjar.owners()) == ["", BASE_URL] + keys = _keyjar.get_issuer_keys("") assert len(keys) == 2 assert _context.base_url == BASE_URL @@ -80,10 +81,10 @@ def test_begin(self): issuer = self.rph.do_provider_info(client) - _context = client.client_get("service_context") + _context = client.get_context() # Calculating request so I can build a reasonable response - _req = client.client_get("service", "registration").construct_request() + _req = client.get_service("registration").construct_request() with responses.RequestsMock() as rsps: request_uri = _context.get("provider_info")["registration_endpoint"] @@ -152,13 +153,13 @@ def test_begin_2(self): issuer = self.rph.do_provider_info(client) - _context = client.client_get("service_context") + _context = client.get_context() # Calculating request so I can build a reasonable response # Publishing a JWKS instead of a JWKS_URI _context.jwks_uri = "" _context.jwks = _context.keyjar.export_jwks() - _req = client.client_get("service", "registration").construct_request() + _req = client.get_service("registration").construct_request() with responses.RequestsMock() as rsps: request_uri = _context.get("provider_info")["registration_endpoint"] diff --git a/tests/test_client_31_oauth2_persistent.py b/tests/test_client_31_oauth2_persistent.py index 6c6e3dd1..16b275bf 100644 --- a/tests/test_client_31_oauth2_persistent.py +++ b/tests/test_client_31_oauth2_persistent.py @@ -52,7 +52,7 @@ class TestClient(object): def test_construct_accesstoken_request(self): # Client 1 starts the chain of event client_1 = Client(config=CONF) - _context_1 = client_1.client_get("service_context") + _context_1 = client_1.get_context() _state = _context_1.cstate.create_state(iss="issuer") auth_request = AuthorizationRequest( @@ -65,13 +65,13 @@ def test_construct_accesstoken_request(self): client_2 = Client(config=CONF) _state_dump = _context_1.dump() - _context2 = client_2.client_get("service_context") + _context2 = client_2.get_context() _context2.load(_state_dump) auth_response = AuthorizationResponse(code="access_code") _context2.cstate.update(_state, auth_response) - msg = client_2.client_get("service", "accesstoken").construct(request_args={}, state=_state) + msg = client_2.get_service("accesstoken").construct(request_args={}, state=_state) assert isinstance(msg, AccessTokenRequest) assert msg.to_dict() == { @@ -86,32 +86,32 @@ def test_construct_accesstoken_request(self): def test_construct_refresh_token_request(self): # Client 1 starts the chain event client_1 = Client(config=CONF) - _state = client_1.client_get("service_context").cstate.create_state(iss="issuer") + _state = client_1.get_context().cstate.create_state(iss="issuer") auth_request = AuthorizationRequest( redirect_uri="https://example.com/cli/authz_cb", state=_state ) - client_1.client_get("service_context").cstate.update(_state, auth_request) + client_1.get_context().cstate.update(_state, auth_request) # Client 2 carries on client_2 = Client(config=CONF) - _state_dump = client_1.client_get("service_context").dump() - client_2.client_get("service_context").load(_state_dump) + _state_dump = client_1.get_context().dump() + client_2.get_context().load(_state_dump) auth_response = AuthorizationResponse(code="access_code") - client_2.client_get("service_context").cstate.update(_state, auth_response) + client_2.get_context().cstate.update(_state, auth_response) token_response = AccessTokenResponse(refresh_token="refresh_with_me", access_token="access") - client_2.client_get("service_context").cstate.update(_state, token_response) + client_2.get_context().cstate.update(_state, token_response) # Next up is Client 1 - _state_dump = client_2.client_get("service_context").dump() - client_1.client_get("service_context").load(_state_dump) + _state_dump = client_2.get_context().dump() + client_1.get_context().load(_state_dump) req_args = {} - msg = client_1.client_get("service", "refresh_token").construct( + msg = client_1.get_service("refresh_token").construct( request_args=req_args, state=_state ) assert isinstance(msg, RefreshAccessTokenRequest) diff --git a/tests/test_client_32_oidc_persistent.py b/tests/test_client_32_oidc_persistent.py index cd8d75fc..0f5c34ae 100755 --- a/tests/test_client_32_oidc_persistent.py +++ b/tests/test_client_32_oidc_persistent.py @@ -51,23 +51,23 @@ class TestClient(object): def test_construct_accesstoken_request(self): # Client 1 starts client_1 = RP(config=CONF) - _state = client_1.client_get("service_context").cstate.create_state(iss=ISSUER) + _state = client_1.get_context().cstate.create_state(iss=ISSUER) auth_request = AuthorizationRequest( redirect_uri="https://example.com/cli/authz_cb", state=_state ) - client_1.client_get("service_context").cstate.update(_state, auth_request) + client_1.get_context().cstate.update(_state, auth_request) # Client 2 carries on client_2 = RP(config=CONF) - _state_dump = client_1.client_get("service_context").dump() - client_2.client_get("service_context").load(_state_dump) + _state_dump = client_1.get_context().dump() + client_2.get_context().load(_state_dump) auth_response = AuthorizationResponse(code="access_code") - client_2.client_get("service_context").cstate.update(_state, auth_response) + client_2.get_context().cstate.update(_state, auth_response) # Bind access code to state req_args = {} - msg = client_2.client_get("service", "accesstoken").construct( + msg = client_2.get_service("accesstoken").construct( request_args=req_args, state=_state ) assert isinstance(msg, AccessTokenRequest) @@ -83,31 +83,31 @@ def test_construct_accesstoken_request(self): def test_construct_refresh_token_request(self): # Client 1 starts client_1 = RP(config=CONF) - _state = client_1.client_get("service_context").cstate.create_state(iss=ISSUER) + _state = client_1.get_context().cstate.create_state(iss=ISSUER) auth_request = AuthorizationRequest( redirect_uri="https://example.com/cli/authz_cb", state=_state ) - client_1.client_get("service_context").cstate.update(_state,auth_request) + client_1.get_context().cstate.update(_state,auth_request) # Client 2 carries on client_2 = RP(config=CONF) - _state_dump = client_1.client_get("service_context").dump() - client_2.client_get("service_context").load(_state_dump) + _state_dump = client_1.get_context().dump() + client_2.get_context().load(_state_dump) auth_response = AuthorizationResponse(code="access_code") - client_2.client_get("service_context").cstate.update(_state, auth_response) + client_2.get_context().cstate.update(_state, auth_response) token_response = AccessTokenResponse(refresh_token="refresh_with_me", access_token="access") - client_2.client_get("service_context").cstate.update(_state,token_response ) + client_2.get_context().cstate.update(_state,token_response ) # Back to Client 1 - _state_dump = client_2.client_get("service_context").dump() - client_1.client_get("service_context").load(_state_dump) + _state_dump = client_2.get_context().dump() + client_1.get_context().load(_state_dump) req_args = {} - msg = client_1.client_get("service", "refresh_token").construct( + msg = client_1.get_service("refresh_token").construct( request_args=req_args, state=_state ) assert isinstance(msg, RefreshAccessTokenRequest) @@ -121,7 +121,7 @@ def test_construct_refresh_token_request(self): def test_do_userinfo_request_init(self): # Client 1 starts client_1 = RP(config=CONF) - _state = client_1.client_get("service_context").cstate.create_state(iss=ISSUER) + _state = client_1.get_context().cstate.create_state(iss=ISSUER) auth_request = AuthorizationRequest( redirect_uri="https://example.com/cli/authz_cb", state="state" @@ -129,20 +129,20 @@ def test_do_userinfo_request_init(self): # Client 2 carries on client_2 = RP(config=CONF) - _state_dump = client_1.client_get("service_context").dump() - client_2.client_get("service_context").load(_state_dump) + _state_dump = client_1.get_context().dump() + client_2.get_context().load(_state_dump) auth_response = AuthorizationResponse(code="access_code") - client_2.client_get("service_context").cstate.update(_state,auth_response) + client_2.get_context().cstate.update(_state,auth_response) token_response = AccessTokenResponse(refresh_token="refresh_with_me", access_token="access") - client_2.client_get("service_context").cstate.update(_state,token_response) + client_2.get_context().cstate.update(_state,token_response) # Back to Client 1 - _state_dump = client_2.client_get("service_context").dump() - client_1.client_get("service_context").load(_state_dump) + _state_dump = client_2.get_context().dump() + client_1.get_context().load(_state_dump) - _srv = client_1.client_get("service", "userinfo") + _srv = client_1.get_service("userinfo") _srv.endpoint = "https://example.com/userinfo" _info = _srv.get_request_parameters(state=_state) assert _info diff --git a/tests/test_client_40_dpop.py b/tests/test_client_40_dpop.py index 80a9a964..6d6849b4 100644 --- a/tests/test_client_40_dpop.py +++ b/tests/test_client_40_dpop.py @@ -40,14 +40,14 @@ def create_client(self): self.client = Client(keyjar=CLI_KEY, config=config, services=DEFAULT_OAUTH2_SERVICES) - self.client.client_get("service_context").provider_info = { + self.client.get_context().provider_info = { "authorization_endpoint": "https://example.com/auth", "token_endpoint": "https://example.com/token", "dpop_signing_alg_values_supported": ["RS256", "ES256"], } def test_add_header(self): - token_serv = self.client.client_get("service", "accesstoken") + token_serv = self.client.get_service("accesstoken") req_args = { "grant_type": "authorization_code", "code": "SplxlOBeZQQYbYS6WxSbIA", @@ -99,7 +99,7 @@ def create_client(self): } self.client = Client(keyjar=CLI_KEY, config=config, services=services) - self.client.client_get("service_context").provider_info = { + self.client.get_context().provider_info = { "authorization_endpoint": "https://example.com/auth", "token_endpoint": "https://example.com/token", "dpop_signing_alg_values_supported": ["RS256", "ES256"], @@ -107,7 +107,7 @@ def create_client(self): } def test_add_header_token(self): - token_serv = self.client.client_get("service", "accesstoken") + token_serv = self.client.get_service("accesstoken") req_args = { "grant_type": "authorization_code", "code": "SplxlOBeZQQYbYS6WxSbIA", @@ -130,7 +130,7 @@ def test_add_header_token(self): assert _header["jwk"]["crv"] == "P-256" def test_add_header_userinfo(self): - userinfo_serv = self.client.client_get("service", "userinfo") + userinfo_serv = self.client.get_service("userinfo") req_args = {} access_token = "access.token.sign" headers = userinfo_serv.get_headers( diff --git a/tests/test_client_41_rp_handler_persistent.py b/tests/test_client_41_rp_handler_persistent.py index 2558938c..fef02546 100644 --- a/tests/test_client_41_rp_handler_persistent.py +++ b/tests/test_client_41_rp_handler_persistent.py @@ -201,16 +201,16 @@ def test_do_provider_info(self): client_2 = rph_2.init_client("github") - _context_dump = client_1.client_get("service_context").dump() - client_2.client_get("service_context").load(_context_dump) - _service_dump = client_1.client_get("services").dump() - client_2.client_get("services").load( - _service_dump, init_args={"client_get": client_2.client_get} + _context_dump = client_1.get_context().dump() + client_2.get_context().load(_context_dump) + _service_dump = client_1.get_services().dump() + client_2.get_services().load( + _service_dump, init_args={"upstream_get": client_2.upstream_get} ) for service_type in ["authorization", "accesstoken", "userinfo"]: - _srv = client_2.client_get("service", service_type) - _endp = client_2.client_get("service_context").provider_info[_srv.endpoint_name] + _srv = client_2.get_service(service_type) + _endp = client_2.get_context().provider_info[_srv.endpoint_name] assert _srv.endpoint == _endp def test_do_client_registration(self): @@ -225,7 +225,7 @@ def test_do_client_registration(self): # only 2 things should have happened assert rph_1.hash2issuer["github"] == issuer - assert not client.client_get("service_context").get_usage("post_logout_redirect_uris") + assert not client.get_context().get_usage("post_logout_redirect_uris") def test_do_client_setup(self): rph_1 = RPHandler( @@ -234,7 +234,7 @@ def test_do_client_setup(self): client = rph_1.client_setup("github") _github_id = iss_id("github") - _context = client.client_get("service_context") + _context = client.get_context() assert _context.get_client_id() == "eeeeeeeee" assert _context.get_usage("client_secret") == "aaaaaaaaaaaaaaaaaaaa" @@ -247,8 +247,8 @@ def test_do_client_setup(self): assert len(keys) == 2 for service_type in ["authorization", "accesstoken", "userinfo"]: - _srv = client.client_get("service", service_type) - _endp = client.client_get("service_context").get("provider_info")[_srv.endpoint_name] + _srv = client.get_service(service_type) + _endp = client.get_context().get("provider_info")[_srv.endpoint_name] assert _srv.endpoint == _endp assert rph_1.hash2issuer["github"] == _context.get("issuer") @@ -264,7 +264,7 @@ def test_begin(self): client = rph_1.issuer2rp[_github_id] - assert client.client_get("service_context").get("issuer") == _github_id + assert client.get_context().get("issuer") == _github_id part = urlsplit(res["url"]) assert part.scheme == "https" @@ -309,7 +309,7 @@ def test_get_client_from_session_key(self): # redo rph_1.do_provider_info(state=res["state"]) # get new redirect_uris - cli2.client_get("service_context").set_usage("redirect_uris", []) + cli2.get_context().set_usage("redirect_uris", []) rph_1.do_client_registration(state=res["state"]) def test_finalize_auth(self): @@ -325,8 +325,8 @@ def test_finalize_auth(self): resp = rph_1.finalize_auth(client, _session["iss"], auth_response.to_dict()) assert set(resp.keys()) == {"state", "code"} aresp = ( - client.client_get("service", "authorization") - .client_get("service_context").cstate.get(res["state"]) + client.get_service("authorization") + .upstream("service_context").cstate.get(res["state"]) ) assert set(aresp.keys()) == { "state", "code", 'iss', 'client_id', @@ -359,7 +359,7 @@ def test_get_tokens(self): client = rph_1.issuer2rp[_session["iss"]] _github_id = iss_id("github") - _context = client.client_get("service_context") + _context = client.get_context() _context.keyjar.import_jwks(GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) _nonce = _session["nonce"] @@ -389,7 +389,7 @@ def test_get_tokens(self): adding_headers={"Content-Type": "application/json"}, status=200, ) - client.client_get("service", "accesstoken").endpoint = _url + client.get_service("accesstoken").endpoint = _url auth_response = AuthorizationResponse(code="access_code", state=res["state"]) resp = rph_1.finalize_auth(client, _session["iss"], auth_response.to_dict()) @@ -405,8 +405,8 @@ def test_get_tokens(self): } atresp = ( - client.client_get("service", "accesstoken") - .client_get("service_context") + client.get_service("accesstoken") + .upstream_get("service_context") .cstate.get(res["state"]) ) assert set(atresp.keys()) == { @@ -434,7 +434,7 @@ def test_access_and_id_token(self): res = rph_1.begin(issuer_id="github") _session = rph_1.get_session_information(res["state"]) client = rph_1.issuer2rp[_session["iss"]] - _context = client.client_get("service_context") + _context = client.get_context() _nonce = _session["nonce"] _iss = _session["iss"] _aud = _context.get_client_id() @@ -467,7 +467,7 @@ def test_access_and_id_token(self): adding_headers={"Content-Type": "application/json"}, status=200, ) - client.client_get("service", "accesstoken").endpoint = _url + client.get_service("accesstoken").endpoint = _url _response = AuthorizationResponse(code="access_code", state=res["state"]) auth_response = rph_1.finalize_auth(client, _session["iss"], _response.to_dict()) @@ -483,7 +483,7 @@ def test_access_and_id_token_by_reference(self): res = rph_1.begin(issuer_id="github") _session = rph_1.get_session_information(res["state"]) client = rph_1.issuer2rp[_session["iss"]] - _context = client.client_get("service_context") + _context = client.get_context() _nonce = _session["nonce"] _iss = _session["iss"] _aud = _context.get_client_id() @@ -516,7 +516,7 @@ def test_access_and_id_token_by_reference(self): adding_headers={"Content-Type": "application/json"}, status=200, ) - client.client_get("service", "accesstoken").endpoint = _url + client.get_service("accesstoken").endpoint = _url _response = AuthorizationResponse(code="access_code", state=res["state"]) _ = rph_1.finalize_auth(client, _session["iss"], _response.to_dict()) @@ -532,7 +532,7 @@ def test_get_user_info(self): res = rph_1.begin(issuer_id="github") _session = rph_1.get_session_information(res["state"]) client = rph_1.issuer2rp[_session["iss"]] - _context = client.client_get("service_context") + _context = client.get_context() _nonce = _session["nonce"] _iss = _session["iss"] _aud = _context.get_client_id() @@ -565,7 +565,7 @@ def test_get_user_info(self): adding_headers={"Content-Type": "application/json"}, status=200, ) - client.client_get("service", "accesstoken").endpoint = _url + client.get_service("accesstoken").endpoint = _url _response = AuthorizationResponse(code="access_code", state=res["state"]) auth_response = rph_1.finalize_auth(client, _session["iss"], _response.to_dict()) @@ -581,7 +581,7 @@ def test_get_user_info(self): adding_headers={"Content-Type": "application/json"}, status=200, ) - client.client_get("service", "userinfo").endpoint = _url + client.get_service("userinfo").endpoint = _url userinfo_resp = rph_1.get_user_info(res["state"], client, token_resp["access_token"]) assert userinfo_resp diff --git a/tests/test_client_50_ciba.py b/tests/test_client_50_ciba.py index 283808c5..61d11220 100644 --- a/tests/test_client_50_ciba.py +++ b/tests/test_client_50_ciba.py @@ -39,7 +39,7 @@ def create_client(self): self.client = Client(keyjar=CLI_KEY, config=config, services=DEFAULT_OAUTH2_SERVICES) - self.client.client_get("service_context").provider_info = { + self.client.upstream_get("context").provider_info = { "authorization_endpoint": "https://example.com/auth", "token_endpoint": "https://example.com/token", "dpop_signing_alg_values_supported": ["RS256", "ES256"], diff --git a/tests/test_client_51_identity_assurance.py b/tests/test_client_51_identity_assurance.py index d64e906a..54d1bc99 100644 --- a/tests/test_client_51_identity_assurance.py +++ b/tests/test_client_51_identity_assurance.py @@ -33,17 +33,17 @@ def create_request(self): KEYS = init_key_jar(key_defs=KEYSPEC) entity = Entity(config=client_config, services=DEFAULT_OIDC_SERVICES, keyjar=KEYS) - entity.client_get("service_context").issuer = "https://server.otherop.com" - self.service = entity.client_get("service", "userinfo") + entity.get_context().issuer = "https://server.otherop.com" + self.service = entity.get_service("userinfo") - entity.client_get("service_context").work_environment.use = { + entity.get_context().work_environment.use = { "userinfo_signed_response_alg": "RS256", "userinfo_encrypted_response_alg": "RSA-OAEP", "userinfo_encrypted_response_enc": "A256GCM", } def test_unpack_aggregated_response(self): - _cstate = self.service.client_get("service_context").cstate + _cstate = self.service.upstream_get("context").cstate # Add history auth_request = AuthorizationRequest( redirect_uri="https://example.com/cli/authz_cb", @@ -72,7 +72,7 @@ def test_unpack_aggregated_response(self): }, } - _jwt = JWT(key_jar=self.service.client_get("service_context").keyjar) + _jwt = JWT(key_jar=self.service.upstream_get("context").keyjar) _jws = _jwt.pack(payload=_distributed_respone) resp = { diff --git a/tests/test_server_00a_client_configure.py b/tests/test_server_00a_client_configure.py index 665d936f..e618e1e9 100644 --- a/tests/test_server_00a_client_configure.py +++ b/tests/test_server_00a_client_configure.py @@ -108,7 +108,7 @@ def test_verify_oidc_client_information_complext(): client_conf["client1"].update(EXTRA) - res = verify_oidc_client_information(client_conf, server_get=server.server_get) + res = verify_oidc_client_information(client_conf, upstream_get=server.upstream_get) assert res for cli, _cli_conf in res.items(): print(_cli_conf.extra()) @@ -135,5 +135,5 @@ def test_verify_oidc_client_information_2(): } } - res = verify_oidc_client_information(client_conf, server_get=server.server_get) + res = verify_oidc_client_information(client_conf, upstream_get=server.upstream_get) assert res diff --git a/tests/test_server_01_claims.py b/tests/test_server_01_claims.py index 582b3975..41d8d0bf 100644 --- a/tests/test_server_01_claims.py +++ b/tests/test_server_01_claims.py @@ -128,7 +128,7 @@ class TestEndpoint(object): @pytest.fixture(autouse=True) def create_idtoken(self): self.server = Server(conf) - # self.endpoint_context = EndpointContext(conf=conf, server_get=self.server_get) + # self.endpoint_context = EndpointContext(conf=conf, upstream_get=self.upstream_get) self.endpoint_context = self.server.endpoint_context self.endpoint_context.cdb["client_1"] = { "client_secret": "hemligtochintekort", @@ -141,7 +141,8 @@ def create_idtoken(self): }, "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"] } - self.endpoint_context.keyjar.add_symmetric("client_1", "hemligtochintekort", ["sig", "enc"]) + self.server.get_attribute('keyjar').add_symmetric("client_1", "hemligtochintekort", + ["sig", "enc"]) self.claims_interface = self.endpoint_context.claims_interface self.user_id = USER_ID diff --git a/tests/test_server_03_authz_handling.py b/tests/test_server_03_authz_handling.py index 5dfc7d81..f50e2c90 100644 --- a/tests/test_server_03_authz_handling.py +++ b/tests/test_server_03_authz_handling.py @@ -134,10 +134,10 @@ def create_idtoken(self): "response_types": ["code", "token", "code id_token", "id_token"], "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"] } - server.endpoint_context.keyjar.add_symmetric( + server.get_attribute('attribute', 'keyjar').add_symmetric( "client_1", "hemligtochintekort", ["sig", "enc"] ) - server.endpoint = do_endpoints(conf, server.server_get) + server.endpoint = do_endpoints(conf, server.upstream_get) self.session_manager = server.endpoint_context.session_manager self.user_id = USER_ID self.server = server @@ -193,7 +193,7 @@ def test_usage_rules_client(self): assert _usage_rules["refresh_token"] == {} def test_factory(self): - _mod = factory("Implicit", server_get=self.server.server_get) + _mod = factory("Implicit", upstream_get=self.server.upstream_get) assert isinstance(_mod, Implicit) def test_call(self): diff --git a/tests/test_server_06_grant.py b/tests/test_server_06_grant.py index c7002b3b..287c8b42 100644 --- a/tests/test_server_06_grant.py +++ b/tests/test_server_06_grant.py @@ -110,7 +110,7 @@ class TestGrant: @pytest.fixture(autouse=True) def create_session_manager(self): self.server = Server(conf=conf) - self.endpoint_context = self.server.server_get("endpoint_context") + self.endpoint_context = self.server.get_context() def _create_session(self, auth_req, sub_type="public", sector_identifier=""): if sector_identifier: @@ -134,14 +134,14 @@ def test_mint_token(self): code = grant.mint_token( session_id, - endpoint_context=self.endpoint_context, + context=self.endpoint_context, token_class="authorization_code", token_handler=TOKEN_HANDLER["authorization_code"], ) access_token = grant.mint_token( session_id, - endpoint_context=self.endpoint_context, + context=self.endpoint_context, token_class="access_token", token_handler=TOKEN_HANDLER["access_token"], based_on=code, @@ -158,14 +158,14 @@ def test_grant(self): grant = session_info["grant"] code = grant.mint_token( session_id, - endpoint_context=self.endpoint_context, + context=self.endpoint_context, token_class="authorization_code", token_handler=TOKEN_HANDLER["authorization_code"], ) access_token = grant.mint_token( session_id, - endpoint_context=self.endpoint_context, + context=self.endpoint_context, token_class="access_token", token_handler=TOKEN_HANDLER["access_token"], based_on=code, @@ -173,7 +173,7 @@ def test_grant(self): refresh_token = grant.mint_token( session_id, - endpoint_context=self.endpoint_context, + context=self.endpoint_context, token_class="refresh_token", token_handler=TOKEN_HANDLER["refresh_token"], based_on=code, @@ -190,14 +190,14 @@ def test_get_token(self): grant = session_info["grant"] code = grant.mint_token( session_id, - endpoint_context=self.endpoint_context, + context=self.endpoint_context, token_class="authorization_code", token_handler=TOKEN_HANDLER["authorization_code"], ) access_token = grant.mint_token( session_id, - endpoint_context=self.endpoint_context, + context=self.endpoint_context, token_class="access_token", token_handler=TOKEN_HANDLER["access_token"], based_on=code, @@ -219,14 +219,14 @@ def test_grant_revoked_based_on(self): grant = session_info["grant"] code = grant.mint_token( session_id, - endpoint_context=self.endpoint_context, + context=self.endpoint_context, token_class="authorization_code", token_handler=TOKEN_HANDLER["authorization_code"], ) access_token = grant.mint_token( session_id, - endpoint_context=self.endpoint_context, + context=self.endpoint_context, token_class="access_token", token_handler=TOKEN_HANDLER["access_token"], based_on=code, @@ -234,7 +234,7 @@ def test_grant_revoked_based_on(self): refresh_token = grant.mint_token( session_id, - endpoint_context=self.endpoint_context, + context=self.endpoint_context, token_class="refresh_token", token_handler=TOKEN_HANDLER["refresh_token"], based_on=code, @@ -256,14 +256,14 @@ def test_revoke(self): grant = session_info["grant"] code = grant.mint_token( session_id, - endpoint_context=self.endpoint_context, + context=self.endpoint_context, token_class="authorization_code", token_handler=TOKEN_HANDLER["authorization_code"], ) access_token = grant.mint_token( session_id, - endpoint_context=self.endpoint_context, + context=self.endpoint_context, token_class="access_token", token_handler=TOKEN_HANDLER["access_token"], based_on=code, @@ -276,7 +276,7 @@ def test_revoke(self): access_token_2 = grant.mint_token( session_id, - endpoint_context=self.endpoint_context, + context=self.endpoint_context, token_class="access_token", token_handler=TOKEN_HANDLER["access_token"], based_on=code, @@ -295,14 +295,14 @@ def test_json_conversion(self): grant = session_info["grant"] code = grant.mint_token( session_id, - endpoint_context=self.endpoint_context, + context=self.endpoint_context, token_class="authorization_code", token_handler=TOKEN_HANDLER["authorization_code"], ) grant.mint_token( session_id, - endpoint_context=self.endpoint_context, + context=self.endpoint_context, token_class="access_token", token_handler=TOKEN_HANDLER["access_token"], based_on=code, @@ -333,7 +333,7 @@ def test_json_no_token_map(self): with pytest.raises(ValueError): grant.mint_token( session_id, - endpoint_context=self.endpoint_context, + context=self.endpoint_context, token_class="authorization_code", token_handler=TOKEN_HANDLER["authorization_code"], ) @@ -350,14 +350,14 @@ def test_json_custom_token_map(self): grant.token_map = token_map code = grant.mint_token( session_id, - endpoint_context=self.endpoint_context, + context=self.endpoint_context, token_class="authorization_code", token_handler=TOKEN_HANDLER["authorization_code"], ) grant.mint_token( session_id, - endpoint_context=self.endpoint_context, + context=self.endpoint_context, token_class="access_token", token_handler=TOKEN_HANDLER["access_token"], based_on=code, @@ -365,7 +365,7 @@ def test_json_custom_token_map(self): grant.mint_token( session_id, - endpoint_context=self.endpoint_context, + context=self.endpoint_context, token_class="my_token", token_handler=DefaultToken("my_token", typ="M"), ) @@ -404,14 +404,14 @@ def test_get_spec(self): code = grant.mint_token( session_id, - endpoint_context=self.endpoint_context, + context=self.endpoint_context, token_class="authorization_code", token_handler=TOKEN_HANDLER["authorization_code"], ) access_token = grant.mint_token( session_id, - endpoint_context=self.endpoint_context, + context=self.endpoint_context, token_class="access_token", token_handler=TOKEN_HANDLER["access_token"], based_on=code, @@ -452,7 +452,7 @@ def test_assigned_scope(self): grant = session_info["grant"] code = grant.mint_token( session_id, - endpoint_context=self.endpoint_context, + context=self.endpoint_context, token_class="authorization_code", token_handler=TOKEN_HANDLER["authorization_code"], ) @@ -461,7 +461,7 @@ def test_assigned_scope(self): access_token = grant.mint_token( session_id, - endpoint_context=self.endpoint_context, + context=self.endpoint_context, token_class="access_token", token_handler=TOKEN_HANDLER["access_token"], based_on=code, @@ -477,7 +477,7 @@ def test_assigned_scope_2nd(self): grant = session_info["grant"] code = grant.mint_token( session_id, - endpoint_context=self.endpoint_context, + context=self.endpoint_context, token_class="authorization_code", token_handler=TOKEN_HANDLER["authorization_code"], ) @@ -486,7 +486,7 @@ def test_assigned_scope_2nd(self): refresh_token = grant.mint_token( session_id, - endpoint_context=self.endpoint_context, + context=self.endpoint_context, token_class="refresh_token", token_handler=TOKEN_HANDLER["refresh_token"], based_on=code, @@ -494,7 +494,7 @@ def test_assigned_scope_2nd(self): access_token = grant.mint_token( session_id, - endpoint_context=self.endpoint_context, + context=self.endpoint_context, token_class="access_token", token_handler=TOKEN_HANDLER["access_token"], based_on=refresh_token, @@ -506,7 +506,7 @@ def test_assigned_scope_2nd(self): access_token = grant.mint_token( session_id, - endpoint_context=self.endpoint_context, + context=self.endpoint_context, token_class="access_token", token_handler=TOKEN_HANDLER["access_token"], based_on=refresh_token, @@ -522,14 +522,14 @@ def test_grant_remove_based_on_code(self): grant = session_info["grant"] code = grant.mint_token( session_id, - endpoint_context=self.endpoint_context, + context=self.endpoint_context, token_class="authorization_code", token_handler=TOKEN_HANDLER["authorization_code"], ) access_token = grant.mint_token( session_id, - endpoint_context=self.endpoint_context, + context=self.endpoint_context, token_class="access_token", token_handler=TOKEN_HANDLER["access_token"], based_on=code, @@ -537,7 +537,7 @@ def test_grant_remove_based_on_code(self): refresh_token = grant.mint_token( session_id, - endpoint_context=self.endpoint_context, + context=self.endpoint_context, token_class="refresh_token", token_handler=TOKEN_HANDLER["refresh_token"], based_on=code, @@ -554,14 +554,14 @@ def test_grant_remove_one_by_one(self): grant = session_info["grant"] code = grant.mint_token( session_id, - endpoint_context=self.endpoint_context, + context=self.endpoint_context, token_class="authorization_code", token_handler=TOKEN_HANDLER["authorization_code"], ) access_token = grant.mint_token( session_id, - endpoint_context=self.endpoint_context, + context=self.endpoint_context, token_class="access_token", token_handler=TOKEN_HANDLER["access_token"], based_on=code, @@ -569,7 +569,7 @@ def test_grant_remove_one_by_one(self): refresh_token = grant.mint_token( session_id, - endpoint_context=self.endpoint_context, + context=self.endpoint_context, token_class="refresh_token", token_handler=TOKEN_HANDLER["refresh_token"], based_on=code, diff --git a/tests/test_server_08_id_token.py b/tests/test_server_08_id_token.py index b1dac732..5bfd019f 100644 --- a/tests/test_server_08_id_token.py +++ b/tests/test_server_08_id_token.py @@ -196,7 +196,7 @@ def _mint_code(self, grant, session_id): # Constructing an authorization code is now done return grant.mint_token( session_id=session_id, - endpoint_context=self.endpoint_context, + context=self.endpoint_context, token_class="authorization_code", token_handler=self.session_manager.token_handler["authorization_code"], expires_at=utc_time_sans_frac() + 300, # 5 minutes from now @@ -205,7 +205,7 @@ def _mint_code(self, grant, session_id): def _mint_access_token(self, grant, session_id, token_ref): access_token = grant.mint_token( session_id=session_id, - endpoint_context=self.endpoint_context, + context=self.endpoint_context, token_class="access_token", token_handler=self.session_manager.token_handler["access_token"], expires_at=utc_time_sans_frac() + 900, # 15 minutes from now @@ -216,7 +216,7 @@ def _mint_access_token(self, grant, session_id, token_ref): def _mint_id_token(self, grant, session_id, token_ref=None, code=None, access_token=None): return grant.mint_token( session_id=session_id, - endpoint_context=self.endpoint_context, + context=self.endpoint_context, token_class="id_token", token_handler=self.session_manager.token_handler["id_token"], expires_at=utc_time_sans_frac() + 900, # 15 minutes from now diff --git a/tests/test_server_09_authn_context.py b/tests/test_server_09_authn_context.py index cedd170c..b4b3e94a 100644 --- a/tests/test_server_09_authn_context.py +++ b/tests/test_server_09_authn_context.py @@ -172,12 +172,12 @@ def create_authn_broker(self): def test_pick_authn_one(self): request = {"acr_values": INTERNETPROTOCOLPASSWORD} - res = pick_auth(self.server.server_get("endpoint_context"), request) + res = pick_auth(self.server.get_context(), request) assert res["acr"] == INTERNETPROTOCOLPASSWORD def test_pick_authn_all(self): request = {"acr_values": INTERNETPROTOCOLPASSWORD} - res = pick_auth(self.server.server_get("endpoint_context"), request, pick_all=True) + res = pick_auth(self.server.get_context(), request, pick_all=True) assert len(res) == 2 diff --git a/tests/test_server_10_session_manager.py b/tests/test_server_10_session_manager.py index 381ee0bc..5dca75a7 100644 --- a/tests/test_server_10_session_manager.py +++ b/tests/test_server_10_session_manager.py @@ -220,7 +220,7 @@ def _mint_token(self, token_class, grant, session_id, based_on=None): # Constructing an authorization code is now done return grant.mint_token( session_id=session_id, - endpoint_context=self.endpoint_context, + context=self.endpoint_context, token_class=token_class, token_handler=self.session_manager.token_handler.handler[token_class], expires_at=utc_time_sans_frac() + 300, # 5 minutes from now diff --git a/tests/test_server_12_session_life.py b/tests/test_server_12_session_life.py index a4ffa8d7..9a86d74b 100644 --- a/tests/test_server_12_session_life.py +++ b/tests/test_server_12_session_life.py @@ -97,7 +97,7 @@ def auth(self): code = grant.mint_token( session_id=session_id, - endpoint_context=self.endpoint_context, + context=self.endpoint_context, token_class="authorization_code", token_handler=self.session_manager.token_handler["authorization_code"], expires_at=utc_time_sans_frac() + 300, # 5 minutes from now @@ -145,7 +145,7 @@ def test_code_flow(self): grant.mint_token( session_id=session_id, - endpoint_context=self.endpoint_context, + context=self.endpoint_context, token_class="access_token", token_handler=self.session_manager.token_handler["access_token"], expires_at=utc_time_sans_frac() + 900, # 15 minutes from now @@ -156,7 +156,7 @@ def test_code_flow(self): refresh_token = grant.mint_token( session_id=session_id, - endpoint_context=self.endpoint_context, + context=self.endpoint_context, token_class="refresh_token", token_handler=self.session_manager.token_handler["refresh_token"], based_on=tok, @@ -182,7 +182,7 @@ def test_code_flow(self): access_token_2 = grant.mint_token( session_id=session_id, - endpoint_context=self.endpoint_context, + context=self.endpoint_context, token_class="access_token", token_handler=self.session_manager.token_handler["access_token"], expires_at=utc_time_sans_frac() + 900, # 15 minutes from now @@ -330,7 +330,7 @@ def auth(self): # Constructing an authorization code is now done by code = grant.mint_token( session_id=session_id, - endpoint_context=self.endpoint_context, + context=self.endpoint_context, token_class="authorization_code", token_handler=self.session_manager.token_handler["authorization_code"], expires_at=utc_time_sans_frac() + 300, # 5 minutes from now @@ -378,7 +378,7 @@ def test_code_flow(self): grant.mint_token( session_id=session_id, - endpoint_context=self.endpoint_context, + context=self.endpoint_context, token_class="access_token", token_handler=self.session_manager.token_handler["access_token"], expires_at=utc_time_sans_frac() + 900, # 15 minutes from now @@ -390,7 +390,7 @@ def test_code_flow(self): refresh_token = grant.mint_token( session_id=session_id, - endpoint_context=self.endpoint_context, + context=self.endpoint_context, token_class="refresh_token", token_handler=self.session_manager.token_handler["refresh_token"], based_on=tok, @@ -424,7 +424,7 @@ def test_code_flow(self): access_token_2 = grant.mint_token( session_id=session_id, - endpoint_context=self.endpoint_context, + context=self.endpoint_context, token_class="access_token", token_handler=self.session_manager.token_handler["access_token"], expires_at=utc_time_sans_frac() + 900, # 15 minutes from now diff --git a/tests/test_server_13_user_authn.py b/tests/test_server_13_user_authn.py index 80e8ec13..b9279dce 100644 --- a/tests/test_server_13_user_authn.py +++ b/tests/test_server_13_user_authn.py @@ -140,13 +140,13 @@ def test_userpassjinja2(self): "kwargs": {"filename": full_path("passwd.json")}, } template_handler = self.endpoint_context.template_handler - res = UserPassJinja2(db, template_handler, server_get=self.server.server_get) + res = UserPassJinja2(db, template_handler, upstream_get=self.server.server_get) res() assert "page_header" in res.kwargs def test_basic_auth(self): basic_auth = base64.b64encode(b"diana:krall").decode() - ba = BasicAuthn(pwd={"diana": "krall"}, server_get=self.server.server_get) + ba = BasicAuthn(pwd={"diana": "krall"}, upstream_get=self.server.server_get) _info, _time_stamp = ba.authenticated_as(client_id="", authorization=f"Basic {basic_auth}") assert _info @@ -154,6 +154,6 @@ def test_no_auth(self): basic_auth = base64.b64encode( b"D\xfd\x8a\x85\xa6\xd1\x16\xe4\\6\x1e\x9ds~\xc3\t\x95\x99\x83\x91\x1f\xfb:iviviviv" ) - ba = SymKeyAuthn(symkey=b"0" * 32, ttl=600, server_get=self.server.server_get) + ba = SymKeyAuthn(symkey=b"0" * 32, ttl=600, upstream_get=self.server.server_get) _info, _time_stamp = ba.authenticated_as(client_id="", authorization=basic_auth) assert _info diff --git a/tests/test_server_16_endpoint.py b/tests/test_server_16_endpoint.py index 74002c0f..c00da5ae 100755 --- a/tests/test_server_16_endpoint.py +++ b/tests/test_server_16_endpoint.py @@ -31,12 +31,12 @@ } -def pre(args, request, endpoint_context): +def pre(args, request, context): args.update({"name": "{}, {}".format(args["family_name"], args["given_name"])}) return args -def post(cis, request, endpoint_context): +def post(cis, request, context): cis["request"] = request return cis @@ -108,7 +108,7 @@ def test_parse_dict(self): def test_parse_jwt(self): self.endpoint.request_format = "jwt" - kj = self.endpoint_context.keyjar + kj = self.endpoint.upstream_get('attribute','keyjar') request = REQ.to_jwt(kj.get_signing_key("RSA"), "RS256") req = self.endpoint.parse_request(request) assert req == REQ diff --git a/tests/test_server_17_client_authn.py b/tests/test_server_17_client_authn.py index e22a46c7..af80c675 100644 --- a/tests/test_server_17_client_authn.py +++ b/tests/test_server_17_client_authn.py @@ -48,10 +48,10 @@ class Endpoint_2(Endpoint): class Endpoint_3(Endpoint): name = "endpoint_3" - def __init__(self, server_get: Callable, add_claims_by_scope: Optional[bool] = True, **kwargs): + def __init__(self, upstream_get: Callable, add_claims_by_scope: Optional[bool] = True, **kwargs): Endpoint.__init__( self, - server_get, + upstream_get, add_claims_by_scope=add_claims_by_scope, **kwargs, ) @@ -252,7 +252,7 @@ def test_private_key_jwt_reusage_other_endpoint(self): _jwt = JWT(client_keyjar, iss=client_id, sign_alg="RS256") _jwt.with_jti = True _assertion = _jwt.pack( - {"aud": [self.server.server_get("endpoint", "endpoint_1").full_path]} + {"aud": [self.server.get_endpoint("endpoint_1").full_path]} ) request = {"client_assertion": _assertion, "client_assertion_type": JWT_BEARER} @@ -260,19 +260,19 @@ def test_private_key_jwt_reusage_other_endpoint(self): # This should be OK assert self.method.is_usable(request=request) self.method.verify( - request=request, endpoint=self.server.server_get("endpoint", "endpoint_1") + request=request, endpoint=self.server.get_endpoint("endpoint_1") ) # This should NOT be OK with pytest.raises(InvalidToken): self.method.verify( - request=request, endpoint=self.server.server_get("endpoint", "authorization") + request=request, endpoint=self.server.get_endpoint("authorization") ) # This should NOT be OK because this is the second time the token appears with pytest.raises(InvalidToken): self.method.verify( - request=request, endpoint=self.server.server_get("endpoint", "endpoint_1") + request=request, endpoint=self.server.get_endpoint("endpoint_1") ) def test_private_key_jwt_auth_endpoint(self): @@ -287,7 +287,7 @@ def test_private_key_jwt_auth_endpoint(self): _jwt = JWT(client_keyjar, iss=client_id, sign_alg="RS256") _jwt.with_jti = True _assertion = _jwt.pack( - {"aud": [self.server.server_get("endpoint", "endpoint_2").full_path]} + {"aud": [self.server.get_endpoint("endpoint_2").full_path]} ) request = {"client_assertion": _assertion, "client_assertion_type": JWT_BEARER} @@ -295,7 +295,7 @@ def test_private_key_jwt_auth_endpoint(self): assert self.method.is_usable(request=request) authn_info = self.method.verify( request=request, - endpoint=self.server.server_get("endpoint", "endpoint_2"), + endpoint=self.server.get_endpoint("endpoint_2"), ) assert authn_info["client_id"] == client_id @@ -400,7 +400,7 @@ def test_jws_authn_method_aud_token_endpoint(self): assert self.method.verify( request=request, - endpoint=self.server.server_get("endpoint", "endpoint_1"), + endpoint=self.server.get_endpoint("endpoint_1"), key_type="client_secret", ) @@ -437,7 +437,7 @@ def test_jws_authn_method_aud_userinfo_endpoint(self): assert self.method.verify( request=request, - endpoint=self.server.server_get("endpoint", "endpoint_3"), + endpoint=self.server.get_endpoint("endpoint_3"), key_type="client_secret", ) @@ -475,7 +475,7 @@ def create_method(self): self.server = Server(conf=CONF, keyjar=KEYJAR) self.server.endpoint_context.cdb[client_id] = {"client_secret": client_secret} self.server.endpoint = do_endpoints(CONF, self.server.server_get) - self.endpoint_context = self.server.server_get("context") + self.endpoint_context = self.server.get_context() def test_verify_per_client(self): self.server.endpoint_context.cdb[client_id]["client_authn_method"] = ["public"] @@ -484,7 +484,7 @@ def test_verify_per_client(self): res = verify_client( self.endpoint_context, request, - endpoint=self.server.server_get("endpoint", "endpoint_4"), + endpoint=self.server.get_endpoint("endpoint_4"), ) assert res == {"method": "public", "client_id": client_id} @@ -500,22 +500,23 @@ def test_verify_per_client_per_endpoint(self): res = verify_client( self.endpoint_context, request, - endpoint=self.server.server_get("endpoint", "endpoint_4"), + endpoint=self.server.get_endpoint("endpoint_4"), ) assert res == {"method": "public", "client_id": client_id} - res = verify_client( - self.endpoint_context, - request, - endpoint=self.server.server_get("endpoint", "endpoint_1"), - ) - assert res == {} + with pytest.raises(ClientAuthenticationError) as e: + verify_client( + self.endpoint_context, + request, + endpoint=self.server.get_endpoint("endpoint_1"), + ) + assert e.value.args[0] == "Failed to verify client" request = {"client_id": client_id, "client_secret": client_secret} res = verify_client( self.endpoint_context, request, - endpoint=self.server.server_get("endpoint", "endpoint_1"), + endpoint=self.server.get_endpoint("endpoint_1"), ) assert set(res.keys()) == {"method", "client_id"} assert res["method"] == "client_secret_post" @@ -525,7 +526,7 @@ def test_verify_client_client_secret_post(self): res = verify_client( self.endpoint_context, request, - endpoint=self.server.server_get("endpoint", "endpoint_1"), + endpoint=self.server.get_endpoint("endpoint_1"), ) assert set(res.keys()) == {"method", "client_id"} assert res["method"] == "client_secret_post" @@ -547,7 +548,7 @@ def test_verify_client_jws_authn_method(self): self.endpoint_context, request, http_info=http_info, - endpoint=self.server.server_get("endpoint", "endpoint_1"), + endpoint=self.server.get_endpoint("endpoint_1"), ) assert res["method"] == "client_secret_jwt" assert res["client_id"] == "client_id" @@ -559,7 +560,7 @@ def test_verify_client_bearer_body(self): self.endpoint_context, request, get_client_id_from_token=get_client_id_from_token, - endpoint=self.server.server_get("endpoint", "endpoint_3"), + endpoint=self.server.get_endpoint("endpoint_3"), ) assert set(res.keys()) == {"token", "method", "client_id"} assert res["method"] == "bearer_body" @@ -567,7 +568,7 @@ def test_verify_client_bearer_body(self): # def test_verify_client_client_secret_post(self): # request = {"client_id": client_id, "client_secret": client_secret} # res = verify_client( - # self.endpoint_context, request, endpoint=self.server.server_get("endpoint", + # self.endpoint_context, request, endpoint=self.server.upstream_get("endpoint", # "endpoint_1"), # ) # assert set(res.keys()) == {"method", "client_id"} @@ -583,7 +584,7 @@ def test_verify_client_client_secret_basic(self): self.endpoint_context, request={}, http_info=http_info, - endpoint=self.server.server_get("endpoint", "endpoint_1"), + endpoint=self.server.get_endpoint("endpoint_1"), ) assert set(res.keys()) == {"method", "client_id"} assert res["method"] == "client_secret_basic" @@ -600,7 +601,7 @@ def test_verify_client_bearer_header(self): request, http_info=http_info, get_client_id_from_token=get_client_id_from_token, - endpoint=self.server.server_get("endpoint", "endpoint_2"), + endpoint=self.server.get_endpoint("endpoint_2"), ) assert set(res.keys()) == {"token", "method", "client_id"} assert res["method"] == "bearer_header" @@ -612,7 +613,7 @@ def create_method(self): self.server = Server(conf=CONF, keyjar=KEYJAR) self.server.endpoint_context.cdb[client_id] = {"client_secret": client_secret} self.server.endpoint = do_endpoints(CONF, self.server.server_get) - self.endpoint_context = self.server.server_get("context") + self.endpoint_context = self.server.get_context() def test_verify_client_jws_authn_method(self): client_keyjar = KeyJar() @@ -630,7 +631,7 @@ def test_verify_client_jws_authn_method(self): res = verify_client( self.endpoint_context, request, - endpoint=self.server.server_get("endpoint", "endpoint_1"), + endpoint=self.server.get_endpoint("endpoint_1"), ) assert res["method"] == "client_secret_jwt" assert res["client_id"] == "client_id" @@ -642,7 +643,7 @@ def test_verify_client_bearer_body(self): self.endpoint_context, request, get_client_id_from_token=get_client_id_from_token, - endpoint=self.server.server_get("endpoint", "endpoint_3"), + endpoint=self.server.get_endpoint("endpoint_3"), ) assert set(res.keys()) == {"token", "method", "client_id"} assert res["method"] == "bearer_body" @@ -652,7 +653,7 @@ def test_verify_client_client_secret_post(self): res = verify_client( self.endpoint_context, request, - endpoint=self.server.server_get("endpoint", "endpoint_1"), + endpoint=self.server.get_endpoint("endpoint_1"), ) assert set(res.keys()) == {"method", "client_id"} assert res["method"] == "client_secret_post" @@ -667,7 +668,7 @@ def test_verify_client_client_secret_basic(self): self.endpoint_context, {}, http_info=http_info, - endpoint=self.server.server_get("endpoint", "endpoint_1"), + endpoint=self.server.get_endpoint("endpoint_1"), ) assert set(res.keys()) == {"method", "client_id"} assert res["method"] == "client_secret_basic" @@ -684,7 +685,7 @@ def test_verify_client_bearer_header(self): request, http_info=http_info, get_client_id_from_token=get_client_id_from_token, - endpoint=self.server.server_get("endpoint", "endpoint_2"), + endpoint=self.server.get_endpoint("endpoint_2"), ) assert set(res.keys()) == {"token", "method", "client_id"} assert res["method"] == "bearer_header" @@ -695,7 +696,7 @@ def test_verify_client_authorization_none(self): res = verify_client( self.endpoint_context, request, - endpoint=self.server.server_get("endpoint", "endpoint_2"), + endpoint=self.server.get_endpoint("endpoint_2"), ) assert res["method"] == "none" assert res["client_id"] == "client_id" @@ -706,7 +707,7 @@ def test_verify_client_registration_public(self): res = verify_client( self.endpoint_context, request, - endpoint=self.server.server_get("endpoint", "endpoint_4"), + endpoint=self.server.get_endpoint("endpoint_4"), ) assert res == {"client_id": "client_id", "method": "public"} @@ -716,7 +717,7 @@ def test_verify_client_registration_none(self): res = verify_client( self.endpoint_context, request, - endpoint=self.server.server_get("endpoint", "endpoint_4"), + endpoint=self.server.get_endpoint("endpoint_4"), ) assert res == {"client_id": None, "method": "none"} @@ -737,7 +738,7 @@ class Mock: request = {"redirect_uris": ["https://example.com/cb"]} res = verify_client( - server.endpoint_context, request, endpoint=server.server_get("endpoint", "endpoint_4") + server.endpoint_context, request, endpoint=server.get_endpoint("endpoint_4") ) assert res == {"client_id": "client_id", "method": "custom"} diff --git a/tests/test_server_20b_claims.py b/tests/test_server_20b_claims.py index d39ac8cd..3b2759ec 100644 --- a/tests/test_server_20b_claims.py +++ b/tests/test_server_20b_claims.py @@ -156,7 +156,7 @@ def test_get_claims(self, usage): assert claims == {} def test_get_claims_userinfo_3(self): - _module = self.server.server_get("endpoint", "userinfo") + _module = self.server.get_endpoint("userinfo") session_id = self._create_session(AREQ) _module.kwargs = { "base_claims": {"email": None, "email_verified": None}, @@ -178,7 +178,7 @@ def test_get_claims_userinfo_3(self): } def test_get_claims_introspection_3(self): - _module = self.server.server_get("endpoint", "introspection") + _module = self.server.get_endpoint("introspection") _module.kwargs = { "base_claims": {"email": None, "email_verified": None}, "enable_claims_per_client": True, @@ -206,8 +206,8 @@ def test_get_claims_all_usage(self): self.session_manager.token_handler["id_token"].kwargs = {} self.session_manager.token_handler["access_token"].kwargs = {} - self.server.server_get("endpoint", "userinfo").kwargs = {} - self.server.server_get("endpoint", "introspection").kwargs = {} + self.server.get_endpoint("userinfo").kwargs = {} + self.server.get_endpoint("introspection").kwargs = {} session_id = self._create_session(AREQ) claims = self.claims_interface.get_claims_all_usage(session_id, ["openid", "address"]) @@ -226,7 +226,7 @@ def test_get_claims_all_usage_2(self): "base_claims": {"email": None, "email_verified": None} } - self.server.server_get("endpoint", "userinfo").kwargs = { + self.server.get_endpoint("userinfo").kwargs = { "enable_claims_per_client": True, } self.endpoint_context.cdb["client_1"]["add_claims"]["always"]["userinfo"] = [ @@ -234,7 +234,7 @@ def test_get_claims_all_usage_2(self): "email", ] - self.server.server_get("endpoint", "introspection").kwargs = {"add_claims_by_scope": True} + self.server.get_endpoint("introspection").kwargs = {"add_claims_by_scope": True} self.endpoint_context.session_manager.token_handler["access_token"].kwargs = {} @@ -258,7 +258,7 @@ def test_get_user_claims(self): "base_claims": {"email": None, "email_verified": None} } - self.server.server_get("endpoint", "userinfo").kwargs = { + self.server.get_endpoint("userinfo").kwargs = { "enable_claims_per_client": True, } self.endpoint_context.cdb["client_1"]["add_claims"]["always"]["userinfo"] = [ @@ -266,7 +266,7 @@ def test_get_user_claims(self): "email", ] - self.server.server_get("endpoint", "introspection").kwargs = {"add_claims_by_scope": True} + self.server.get_endpoint("introspection").kwargs = {"add_claims_by_scope": True} self.endpoint_context.session_manager.token_handler["access_token"].kwargs = {} diff --git a/tests/test_server_20d_client_authn.py b/tests/test_server_20d_client_authn.py index badd4842..cda6f5fc 100755 --- a/tests/test_server_20d_client_authn.py +++ b/tests/test_server_20d_client_authn.py @@ -212,24 +212,24 @@ def test_private_key_jwt_reusage_other_endpoint(self): _jwt = JWT(client_keyjar, iss=client_id, sign_alg="RS256") _jwt.with_jti = True - _assertion = _jwt.pack({"aud": [self.server.server_get("endpoint", "token").full_path]}) + _assertion = _jwt.pack({"aud": [self.server.get_endpoint("token").full_path]}) request = {"client_assertion": _assertion, "client_assertion_type": JWT_BEARER} # This should be OK assert self.method.is_usable(request=request) - self.method.verify(request=request, endpoint=self.server.server_get("endpoint", "token")) + self.method.verify(request=request, endpoint=self.server.get_endpoint("token")) # This should NOT be OK with pytest.raises(InvalidToken): self.method.verify( - request=request, endpoint=self.server.server_get("endpoint", "authorization") + request=request, endpoint=self.server.get_endpoint("authorization") ) # This should NOT be OK because this is the second time the token appears with pytest.raises(InvalidToken): self.method.verify( - request=request, endpoint=self.server.server_get("endpoint", "token") + request=request, endpoint=self.server.get_endpoint("token") ) def test_private_key_jwt_auth_endpoint(self): @@ -244,7 +244,7 @@ def test_private_key_jwt_auth_endpoint(self): _jwt = JWT(client_keyjar, iss=client_id, sign_alg="RS256") _jwt.with_jti = True _assertion = _jwt.pack( - {"aud": [self.server.server_get("endpoint", "authorization").full_path]} + {"aud": [self.server.get_endpoint("authorization").full_path]} ) request = {"client_assertion": _assertion, "client_assertion_type": JWT_BEARER} @@ -252,7 +252,7 @@ def test_private_key_jwt_auth_endpoint(self): assert self.method.is_usable(request=request) authn_info = self.method.verify( request=request, - endpoint=self.server.server_get("endpoint", "authorization"), + endpoint=self.server.get_endpoint("authorization"), ) assert authn_info["client_id"] == client_id @@ -354,7 +354,7 @@ def test_jws_authn_method_aud_token_endpoint(self): assert self.method.verify( request=request, - endpoint=self.server.server_get("endpoint", "token"), + endpoint=self.server.get_endpoint("token"), key_type="client_secret", ) @@ -391,7 +391,7 @@ def test_jws_authn_method_aud_userinfo_endpoint(self): assert self.method.verify( request=request, - endpoint=self.server.server_get("endpoint", "userinfo"), + endpoint=self.server.get_endpoint("userinfo"), key_type="client_secret", ) @@ -428,7 +428,7 @@ class TestVerify: def create_method(self): self.server = Server(conf=CONF, keyjar=KEYJAR) self.server.endpoint_context.cdb[client_id] = {"client_secret": client_secret} - self.endpoint_context = self.server.server_get("context") + self.endpoint_context = self.server.get_context() def test_verify_per_client(self): self.server.endpoint_context.cdb[client_id]["client_authn_method"] = ["public"] @@ -437,7 +437,7 @@ def test_verify_per_client(self): res = verify_client( self.endpoint_context, request, - endpoint=self.server.server_get("endpoint", "registration"), + endpoint=self.server.get_endpoint("registration"), ) assert res == {"method": "public", "client_id": client_id} @@ -453,14 +453,14 @@ def test_verify_per_client_per_endpoint(self): res = verify_client( self.endpoint_context, request, - endpoint=self.server.server_get("endpoint", "registration"), + endpoint=self.server.get_endpoint("registration"), ) assert res == {"method": "public", "client_id": client_id} res = verify_client( self.endpoint_context, request, - endpoint=self.server.server_get("endpoint", "token"), + endpoint=self.server.get_endpoint("token"), ) assert res == {} @@ -468,7 +468,7 @@ def test_verify_per_client_per_endpoint(self): res = verify_client( self.endpoint_context, request, - endpoint=self.server.server_get("endpoint", "token"), + endpoint=self.server.get_endpoint("token"), ) assert set(res.keys()) == {"method", "client_id"} assert res["method"] == "client_secret_post" @@ -478,7 +478,7 @@ def test_verify_client_client_secret_post(self): res = verify_client( self.endpoint_context, request, - endpoint=self.server.server_get("endpoint", "token"), + endpoint=self.server.get_endpoint("token"), ) assert set(res.keys()) == {"method", "client_id"} assert res["method"] == "client_secret_post" @@ -500,7 +500,7 @@ def test_verify_client_jws_authn_method(self): self.endpoint_context, request, http_info=http_info, - endpoint=self.server.server_get("endpoint", "token"), + endpoint=self.server.get_endpoint("token"), ) assert res["method"] == "client_secret_jwt" assert res["client_id"] == "client_id" @@ -512,7 +512,7 @@ def test_verify_client_bearer_body(self): self.endpoint_context, request, get_client_id_from_token=get_client_id_from_token, - endpoint=self.server.server_get("endpoint", "userinfo"), + endpoint=self.server.get_endpoint("userinfo"), ) assert set(res.keys()) == {"token", "method", "client_id"} assert res["method"] == "bearer_body" @@ -522,7 +522,7 @@ def test_verify_client_client_secret_post(self): res = verify_client( self.endpoint_context, request, - endpoint=self.server.server_get("endpoint", "token"), + endpoint=self.server.get_endpoint("token"), ) assert set(res.keys()) == {"method", "client_id"} assert res["method"] == "client_secret_post" @@ -537,7 +537,7 @@ def test_verify_client_client_secret_basic(self): self.endpoint_context, request={}, http_info=http_info, - endpoint=self.server.server_get("endpoint", "token"), + endpoint=self.server.get_endpoint("token"), ) assert set(res.keys()) == {"method", "client_id"} assert res["method"] == "client_secret_basic" @@ -554,7 +554,7 @@ def test_verify_client_bearer_header(self): request, http_info=http_info, get_client_id_from_token=get_client_id_from_token, - endpoint=self.server.server_get("endpoint", "authorization"), + endpoint=self.server.get_endpoint("authorization"), ) assert set(res.keys()) == {"token", "method", "client_id"} assert res["method"] == "bearer_header" @@ -565,7 +565,7 @@ class TestVerify2: def create_method(self): self.server = Server(conf=CONF, keyjar=KEYJAR) self.server.endpoint_context.cdb[client_id] = {"client_secret": client_secret} - self.endpoint_context = self.server.server_get("context") + self.endpoint_context = self.server.get_context() def test_verify_client_jws_authn_method(self): client_keyjar = KeyJar() @@ -583,7 +583,7 @@ def test_verify_client_jws_authn_method(self): res = verify_client( self.endpoint_context, request, - endpoint=self.server.server_get("endpoint", "token"), + endpoint=self.server.get_endpoint("token"), ) assert res["method"] == "client_secret_jwt" assert res["client_id"] == "client_id" @@ -595,7 +595,7 @@ def test_verify_client_bearer_body(self): self.endpoint_context, request, get_client_id_from_token=get_client_id_from_token, - endpoint=self.server.server_get("endpoint", "userinfo"), + endpoint=self.server.get_endpoint("userinfo"), ) assert set(res.keys()) == {"token", "method", "client_id"} assert res["method"] == "bearer_body" @@ -605,7 +605,7 @@ def test_verify_client_client_secret_post(self): res = verify_client( self.endpoint_context, request, - endpoint=self.server.server_get("endpoint", "token"), + endpoint=self.server.get_endpoint("token"), ) assert set(res.keys()) == {"method", "client_id"} assert res["method"] == "client_secret_post" @@ -620,7 +620,7 @@ def test_verify_client_client_secret_basic(self): self.endpoint_context, {}, http_info=http_info, - endpoint=self.server.server_get("endpoint", "token"), + endpoint=self.server.get_endpoint("token"), ) assert set(res.keys()) == {"method", "client_id"} assert res["method"] == "client_secret_basic" @@ -637,7 +637,7 @@ def test_verify_client_bearer_header(self): request, http_info=http_info, get_client_id_from_token=get_client_id_from_token, - endpoint=self.server.server_get("endpoint", "authorization"), + endpoint=self.server.get_endpoint("authorization"), ) assert set(res.keys()) == {"token", "method", "client_id"} assert res["method"] == "bearer_header" @@ -648,7 +648,7 @@ def test_verify_client_authorization_none(self): res = verify_client( self.endpoint_context, request, - endpoint=self.server.server_get("endpoint", "authorization"), + endpoint=self.server.get_endpoint("authorization"), ) assert res["method"] == "none" assert res["client_id"] == "client_id" @@ -659,7 +659,7 @@ def test_verify_client_registration_public(self): res = verify_client( self.endpoint_context, request, - endpoint=self.server.server_get("endpoint", "registration"), + endpoint=self.server.get_endpoint("registration"), ) assert res == {"client_id": "client_id", "method": "public"} @@ -669,7 +669,7 @@ def test_verify_client_registration_none(self): res = verify_client( self.endpoint_context, request, - endpoint=self.server.server_get("endpoint", "registration"), + endpoint=self.server.get_endpoint("registration"), ) assert res == {"client_id": None, "method": "none"} @@ -689,7 +689,7 @@ class Mock: request = {"redirect_uris": ["https://example.com/cb"]} res = verify_client( - server.endpoint_context, request, endpoint=server.server_get("endpoint", "registration") + server.endpoint_context, request, endpoint=server.get_endpoint("registration") ) assert res == {"client_id": "client_id", "method": "custom"} diff --git a/tests/test_server_20e_jwt_token.py b/tests/test_server_20e_jwt_token.py index b363bac2..da99488a 100644 --- a/tests/test_server_20e_jwt_token.py +++ b/tests/test_server_20e_jwt_token.py @@ -211,7 +211,7 @@ def create_endpoint(self): } self.session_manager = self.endpoint_context.session_manager self.user_id = "diana" - self.endpoint = server.server_get("endpoint", "session") + self.endpoint = server.get_endpoint("session") def _create_session(self, auth_req, sub_type="public", sector_identifier=""): if sector_identifier: @@ -229,7 +229,7 @@ def _mint_token(self, token_class, grant, session_id, based_on=None, **kwargs): # Constructing an authorization code is now done return grant.mint_token( session_id=session_id, - endpoint_context=self.endpoint_context, + context=self.endpoint_context, token_class=token_class, token_handler=self.session_manager.token_handler.handler[token_class], expires_at=utc_time_sans_frac() + 300, # 5 minutes from now @@ -415,7 +415,7 @@ def create_endpoint(self): } self.session_manager = self.endpoint_context.session_manager self.user_id = "diana" - self.endpoint = server.server_get("endpoint", "session") + self.endpoint = server.get_endpoint("session") def _create_session(self, auth_req, sub_type="public", sector_identifier=""): if sector_identifier: @@ -433,7 +433,7 @@ def _mint_token(self, token_class, grant, session_id, based_on=None, **kwargs): # Constructing an authorization code is now done return grant.mint_token( session_id=session_id, - endpoint_context=self.endpoint_context, + context=self.endpoint_context, token_class=token_class, token_handler=self.session_manager.token_handler.handler[token_class], expires_at=utc_time_sans_frac() + 300, # 5 minutes from now diff --git a/tests/test_server_20f_userinfo.py b/tests/test_server_20f_userinfo.py index 70272ef9..57007a95 100644 --- a/tests/test_server_20f_userinfo.py +++ b/tests/test_server_20f_userinfo.py @@ -290,7 +290,7 @@ def test_collect_user_info_scope_not_supported_no_base_claims(self): session_id = self._create_session(_req) _uid, _cid, _gid = self.session_manager.decrypt_session_id(session_id) - _userinfo_endpoint = self.server.server_get("endpoint", "userinfo") + _userinfo_endpoint = self.server.get_endpoint("userinfo") _userinfo_endpoint.kwargs["add_claims_by_scope"] = False _userinfo_endpoint.kwargs["enable_claims_per_client"] = False del _userinfo_endpoint.kwargs["base_claims"] @@ -311,7 +311,7 @@ def test_collect_user_info_enable_claims_per_client(self): session_id = self._create_session(_req) _uid, _cid, _gid = self.session_manager.decrypt_session_id(session_id) - _userinfo_endpoint = self.server.server_get("endpoint", "userinfo") + _userinfo_endpoint = self.server.get_endpoint("userinfo") _userinfo_endpoint.kwargs["add_claims_by_scope"] = False _userinfo_endpoint.kwargs["enable_claims_per_client"] = True del _userinfo_endpoint.kwargs["base_claims"] diff --git a/tests/test_server_21_oidc_discovery_endpoint.py b/tests/test_server_21_oidc_discovery_endpoint.py index cde6da83..9e93f1cf 100755 --- a/tests/test_server_21_oidc_discovery_endpoint.py +++ b/tests/test_server_21_oidc_discovery_endpoint.py @@ -51,7 +51,7 @@ def create_endpoint(self): }, } server = Server(OPConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) - self.endpoint = server.server_get("endpoint", "discovery") + self.endpoint = server.get_endpoint("discovery") def test_do_response(self): args = self.endpoint.process_request({"resource": "acct:foo@example.com"}) diff --git a/tests/test_server_22_oidc_provider_config_endpoint.py b/tests/test_server_22_oidc_provider_config_endpoint.py index 3cbe05e2..b532d4aa 100755 --- a/tests/test_server_22_oidc_provider_config_endpoint.py +++ b/tests/test_server_22_oidc_provider_config_endpoint.py @@ -81,7 +81,7 @@ def create_endpoint(self, conf): server = Server(OPConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) self.endpoint_context = server.endpoint_context - self.endpoint = server.server_get("endpoint", "provider_config") + self.endpoint = server.get_endpoint("provider_config") def test_do_response(self): args = self.endpoint.process_request() @@ -99,7 +99,7 @@ def test_scopes_supported(self, conf): conf["scopes_supported"] = scopes_supported server = Server(OPConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) - endpoint = server.server_get("endpoint", "provider_config") + endpoint = server.get_endpoint("provider_config") args = endpoint.process_request() msg = endpoint.do_response(args["response_args"]) assert isinstance(msg, dict) diff --git a/tests/test_server_23_oidc_registration_endpoint.py b/tests/test_server_23_oidc_registration_endpoint.py index 64cb2a1b..9e6efcf0 100755 --- a/tests/test_server_23_oidc_registration_endpoint.py +++ b/tests/test_server_23_oidc_registration_endpoint.py @@ -164,7 +164,7 @@ def create_endpoint(self): } server = Server(OPConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) server.endpoint_context.cdb["client_id"] = {} - self.endpoint = server.server_get("endpoint", "registration") + self.endpoint = server.get_endpoint("registration") def test_parse(self): _req = self.endpoint.parse_request(CLI_REQ.to_json()) diff --git a/tests/test_server_24_oauth2_authorization_endpoint.py b/tests/test_server_24_oauth2_authorization_endpoint.py index ecd0fdf9..fd12ca3c 100755 --- a/tests/test_server_24_oauth2_authorization_endpoint.py +++ b/tests/test_server_24_oauth2_authorization_endpoint.py @@ -266,13 +266,13 @@ def create_endpoint(self): endpoint_context.keyjar.export_jwks(True, ""), conf["issuer"] ) self.endpoint_context = endpoint_context - self.endpoint = server.server_get("endpoint", "authorization") + self.endpoint = server.get_endpoint("authorization") self.session_manager = endpoint_context.session_manager self.user_id = "diana" self.rp_keyjar = KeyJar() self.rp_keyjar.add_symmetric("client_1", "hemligtkodord1234567890") - self.endpoint.server_get("context").keyjar.add_symmetric( + self.endpoint.upstream_get("context").keyjar.add_symmetric( "client_1", "hemligtkodord1234567890" ) @@ -334,24 +334,24 @@ def test_do_response_code_token(self): def test_verify_uri_unknown_client(self): request = {"redirect_uri": "https://rp.example.com/cb"} with pytest.raises(UnknownClient): - verify_uri(self.endpoint.server_get("context"), request, "redirect_uri") + verify_uri(self.endpoint.upstream_get("context"), request, "redirect_uri") def test_verify_uri_fragment(self): - _context = self.endpoint.server_get("context") + _context = self.endpoint.upstream_get("context") _context.cdb["client_id"] = {"redirect_uri": ["https://rp.example.com/auth_cb"]} request = {"redirect_uri": "https://rp.example.com/cb#foobar"} with pytest.raises(URIError): verify_uri(_context, request, "redirect_uri", "client_id") def test_verify_uri_noregistered(self): - _context = self.endpoint.server_get("context") + _context = self.endpoint.upstream_get("context") request = {"redirect_uri": "https://rp.example.com/cb"} with pytest.raises(KeyError): verify_uri(_context, request, "redirect_uri", "client_id") def test_verify_uri_unregistered(self): - _context = self.endpoint.server_get("context") + _context = self.endpoint.upstream_get("context") _context.cdb["client_id"] = {"redirect_uris": [("https://rp.example.com/auth_cb", {})]} request = {"redirect_uri": "https://rp.example.com/cb"} @@ -360,7 +360,7 @@ def test_verify_uri_unregistered(self): verify_uri(_context, request, "redirect_uri", "client_id") def test_verify_uri_qp_match(self): - _context = self.endpoint.server_get("context") + _context = self.endpoint.upstream_get("context") _context.cdb["client_id"] = { "redirect_uris": [("https://rp.example.com/cb", {"foo": ["bar"]})] } @@ -370,7 +370,7 @@ def test_verify_uri_qp_match(self): verify_uri(_context, request, "redirect_uri", "client_id") def test_verify_uri_qp_mismatch(self): - _context = self.endpoint.server_get("context") + _context = self.endpoint.upstream_get("context") _context.cdb["client_id"] = { "redirect_uris": [("https://rp.example.com/cb", {"foo": ["bar"]})] } @@ -392,7 +392,7 @@ def test_verify_uri_qp_mismatch(self): verify_uri(_context, request, "redirect_uri", "client_id") def test_verify_uri_qp_missing(self): - _context = self.endpoint.server_get("context") + _context = self.endpoint.upstream_get("context") _context.cdb["client_id"] = { "redirect_uris": [("https://rp.example.com/cb", {"foo": ["bar"], "state": ["low"]})] } @@ -402,7 +402,7 @@ def test_verify_uri_qp_missing(self): verify_uri(_context, request, "redirect_uri", "client_id") def test_verify_uri_qp_missing_val(self): - _context = self.endpoint.server_get("context") + _context = self.endpoint.upstream_get("context") _context.cdb["client_id"] = { "redirect_uris": [("https://rp.example.com/cb", {"foo": ["bar", "low"]})] } @@ -412,7 +412,7 @@ def test_verify_uri_qp_missing_val(self): verify_uri(_context, request, "redirect_uri", "client_id") def test_verify_uri_no_registered_qp(self): - _context = self.endpoint.server_get("context") + _context = self.endpoint.upstream_get("context") _context.cdb["client_id"] = {"redirect_uris": [("https://rp.example.com/cb", {})]} request = {"redirect_uri": "https://rp.example.com/cb?foo=bob"} @@ -420,7 +420,7 @@ def test_verify_uri_no_registered_qp(self): verify_uri(_context, request, "redirect_uri", "client_id") def test_verify_uri_wrong_uri_type(self): - _context = self.endpoint.server_get("context") + _context = self.endpoint.upstream_get("context") _context.cdb["client_id"] = {"redirect_uris": [("https://rp.example.com/cb", {})]} request = {"redirect_uri": "https://rp.example.com/cb?foo=bob"} @@ -428,7 +428,7 @@ def test_verify_uri_wrong_uri_type(self): verify_uri(_context, request, "post_logout_redirect_uri", "client_id") def test_verify_uri_none_registered(self): - _context = self.endpoint.server_get("context") + _context = self.endpoint.upstream_get("context") _context.cdb["client_id"] = { "post_logout_redirect_uri": [("https://rp.example.com/plrc", {})] } @@ -438,7 +438,7 @@ def test_verify_uri_none_registered(self): verify_uri(_context, request, "redirect_uri", "client_id") def test_get_uri(self): - _context = self.endpoint.server_get("context") + _context = self.endpoint.upstream_get("context") _context.cdb["client_id"] = {"redirect_uris": [("https://rp.example.com/cb", {})]} request = { @@ -449,7 +449,7 @@ def test_get_uri(self): assert get_uri(_context, request, "redirect_uri") == "https://rp.example.com/cb" def test_get_uri_no_redirect_uri(self): - _context = self.endpoint.server_get("context") + _context = self.endpoint.upstream_get("context") _context.cdb["client_id"] = {"redirect_uris": [("https://rp.example.com/cb", {})]} request = {"client_id": "client_id"} @@ -457,7 +457,7 @@ def test_get_uri_no_redirect_uri(self): assert get_uri(_context, request, "redirect_uri") == "https://rp.example.com/cb" def test_get_uri_no_registered(self): - _context = self.endpoint.server_get("context") + _context = self.endpoint.upstream_get("context") _context.cdb["client_id"] = {"redirect_uris": [("https://rp.example.com/cb", {})]} request = {"client_id": "client_id"} @@ -466,7 +466,7 @@ def test_get_uri_no_registered(self): get_uri(_context, request, "post_logout_redirect_uri") def test_get_uri_more_then_one_registered(self): - _context = self.endpoint.server_get("context") + _context = self.endpoint.upstream_get("context") _context.cdb["client_id"] = { "redirect_uris": [ ("https://rp.example.com/cb", {}), @@ -489,7 +489,7 @@ def test_create_authn_response(self): scope="openid", ) - self.endpoint.server_get("context").cdb["client_id"] = { + self.endpoint.upstream_get("context").cdb["client_id"] = { "client_id": "client_id", "redirect_uris": [("https://rp.example.com/cb", {})], "id_token_signed_response_alg": "ES256", @@ -517,7 +517,7 @@ def test_setup_auth(self): "id_token_signed_response_alg": "RS256", } - kaka = self.endpoint.server_get("context").cookie_handler.make_cookie_content( + kaka = self.endpoint.upstream_get("context").cookie_handler.make_cookie_content( "value", "sso" ) @@ -545,7 +545,7 @@ def test_setup_auth_error(self): "id_token_signed_response_alg": "RS256", } - item = self.endpoint.server_get("context").authn_broker.db["anon"] + item = self.endpoint.upstream_get("context").authn_broker.db["anon"] item["method"].fail = NoSuchAuthentication res = self.endpoint.setup_auth(request, redirect_uri, cinfo, None) @@ -575,7 +575,7 @@ def test_setup_auth_invalid_scope(self): "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"] } - _context = self.endpoint.server_get("context") + _context = self.endpoint.upstream_get("context") _context.cdb["client_id"] = cinfo kaka = _context.cookie_handler.make_cookie_content("value", "sso") @@ -608,7 +608,7 @@ def test_setup_auth_user(self): session_id = self._create_session(request) - item = self.endpoint.server_get("context").authn_broker.db["anon"] + item = self.endpoint.upstream_get("context").authn_broker.db["anon"] item["method"].user = b64e(as_bytes(json.dumps({"uid": "krall", "sid": session_id}))) res = self.endpoint.setup_auth(request, redirect_uri, cinfo, None) @@ -633,7 +633,7 @@ def test_setup_auth_session_revoked(self): session_id = self._create_session(request) - _context = self.endpoint.server_get("context") + _context = self.endpoint.upstream_get("context") _mngr = _context.session_manager _csi = _mngr[session_id] _csi.revoked = True diff --git a/tests/test_server_24_oauth2_authorization_endpoint_jar.py b/tests/test_server_24_oauth2_authorization_endpoint_jar.py index 885e8a0e..ee275e42 100755 --- a/tests/test_server_24_oauth2_authorization_endpoint_jar.py +++ b/tests/test_server_24_oauth2_authorization_endpoint_jar.py @@ -193,7 +193,7 @@ def create_endpoint(self): endpoint_context.keyjar.import_jwks( endpoint_context.keyjar.export_jwks(True, ""), conf["issuer"] ) - self.endpoint = server.server_get("endpoint", "authorization") + self.endpoint = server.get_endpoint("authorization") self.session_manager = endpoint_context.session_manager self.user_id = "diana" @@ -205,7 +205,7 @@ def test_parse_request_parameter(self): _jwt = JWT(key_jar=self.rp_keyjar, iss="client_1", sign_alg="HS256") _jws = _jwt.pack( AUTH_REQ_DICT, - aud=self.endpoint.server_get("context").provider_info["issuer"], + aud=self.endpoint.upstream_get("context").provider_info["issuer"], ) # ----------------- _req = self.endpoint.parse_request( @@ -223,7 +223,7 @@ def test_parse_request_uri(self): _jwt = JWT(key_jar=self.rp_keyjar, iss="client_1", sign_alg="HS256") _jws = _jwt.pack( AUTH_REQ_DICT, - aud=self.endpoint.server_get("context").provider_info["issuer"], + aud=self.endpoint.upstream_get("context").provider_info["issuer"], ) request_uri = "https://client.example.com/req" diff --git a/tests/test_server_24_oauth2_token_endpoint.py b/tests/test_server_24_oauth2_token_endpoint.py index 68b63345..94939131 100644 --- a/tests/test_server_24_oauth2_token_endpoint.py +++ b/tests/test_server_24_oauth2_token_endpoint.py @@ -188,7 +188,7 @@ def create_endpoint(self, conf): } endpoint_context.keyjar.import_jwks(CLIENT_KEYJAR.export_jwks(), "client_1") self.session_manager = endpoint_context.session_manager - self.token_endpoint = server.server_get("endpoint", "token") + self.token_endpoint = server.get_endpoint("token") self.user_id = "diana" self.endpoint_context = endpoint_context @@ -215,7 +215,7 @@ def _mint_code(self, grant, client_id): # Constructing an authorization code is now done _code = grant.mint_token( session_id=session_id, - endpoint_context=self.endpoint_context, + context=self.endpoint_context, token_class="authorization_code", token_handler=self.session_manager.token_handler["authorization_code"], usage_rules=usage_rules, @@ -235,7 +235,7 @@ def _mint_access_token(self, grant, session_id, token_ref=None): _token = grant.mint_token( _session_info, - endpoint_context=self.endpoint_context, + context=self.endpoint_context, token_class="access_token", token_handler=self.session_manager.token_handler["access_token"], based_on=token_ref, # Means the token (tok) was used to mint this token @@ -692,7 +692,7 @@ def test_do_refresh_access_token_not_allowed(self): grant = self.endpoint_context.authz(session_id, areq) code = self._mint_code(grant, areq["client_id"]) - _cntx = self.token_endpoint.server_get("endpoint_context") + _cntx = self.token_endpoint.upstream_get("endpoint_context") _token_request = TOKEN_REQ_DICT.copy() _token_request["code"] = code.value @@ -717,7 +717,7 @@ def test_do_refresh_access_token_revoked(self): grant = self.endpoint_context.authz(session_id, areq) code = self._mint_code(grant, areq["client_id"]) - _cntx = self.token_endpoint.server_get("endpoint_context") + _cntx = self.token_endpoint.upstream_get("endpoint_context") _token_request = TOKEN_REQ_DICT.copy() _token_request["code"] = code.value diff --git a/tests/test_server_24_oidc_authorization_endpoint.py b/tests/test_server_24_oidc_authorization_endpoint.py index 3e16ec65..aa2dfb47 100755 --- a/tests/test_server_24_oidc_authorization_endpoint.py +++ b/tests/test_server_24_oidc_authorization_endpoint.py @@ -298,7 +298,7 @@ def create_endpoint(self): endpoint_context.keyjar.export_jwks(True, ""), conf["issuer"] ) self.endpoint_context = endpoint_context - self.endpoint = server.server_get("endpoint", "authorization") + self.endpoint = server.get_endpoint("authorization") self.session_manager = endpoint_context.session_manager self.user_id = "diana" @@ -434,7 +434,7 @@ def test_id_token_claims(self): _resp = self.endpoint.process_request(_pr_resp) idt = verify_id_token( _resp["response_args"], - keyjar=self.endpoint.server_get("context").keyjar, + keyjar=self.endpoint.upstream_get("context").keyjar, ) assert idt # from config @@ -445,7 +445,7 @@ def test_id_token_claims(self): def test_re_authenticate(self): request = {"prompt": "login"} - authn = UserAuthnMethod(self.endpoint.server_get("context")) + authn = UserAuthnMethod(self.endpoint.upstream_get("context")) assert re_authenticate(request, authn) def test_id_token_acr(self): @@ -459,7 +459,7 @@ def test_id_token_acr(self): _resp = self.endpoint.process_request(_pr_resp) res = verify_id_token( _resp["response_args"], - keyjar=self.endpoint.server_get("context").keyjar, + keyjar=self.endpoint.upstream_get("context").keyjar, ) assert res res = _resp["response_args"][verified_claim_name("id_token")] @@ -468,24 +468,24 @@ def test_id_token_acr(self): def test_verify_uri_unknown_client(self): request = {"redirect_uri": "https://rp.example.com/cb"} with pytest.raises(UnknownClient): - verify_uri(self.endpoint.server_get("context"), request, "redirect_uri") + verify_uri(self.endpoint.upstream_get("context"), request, "redirect_uri") def test_verify_uri_fragment(self): - _ec = self.endpoint.server_get("context") + _ec = self.endpoint.upstream_get("context") _ec.cdb["client_id"] = {"redirect_uri": ["https://rp.example.com/auth_cb"]} request = {"redirect_uri": "https://rp.example.com/cb#foobar"} with pytest.raises(URIError): verify_uri(_ec, request, "redirect_uri", "client_id") def test_verify_uri_noregistered(self): - _ec = self.endpoint.server_get("context") + _ec = self.endpoint.upstream_get("context") request = {"redirect_uri": "https://rp.example.com/cb"} with pytest.raises(KeyError): verify_uri(_ec, request, "redirect_uri", "client_id") def test_verify_uri_unregistered(self): - _ec = self.endpoint.server_get("context") + _ec = self.endpoint.upstream_get("context") _ec.cdb["client_id"] = {"redirect_uris": [("https://rp.example.com/auth_cb", {})]} request = {"redirect_uri": "https://rp.example.com/cb"} @@ -494,7 +494,7 @@ def test_verify_uri_unregistered(self): verify_uri(_ec, request, "redirect_uri", "client_id") def test_verify_uri_qp_match(self): - _ec = self.endpoint.server_get("context") + _ec = self.endpoint.upstream_get("context") _ec.cdb["client_id"] = {"redirect_uris": [("https://rp.example.com/cb", {"foo": ["bar"]})]} request = {"redirect_uri": "https://rp.example.com/cb?foo=bar"} @@ -502,7 +502,7 @@ def test_verify_uri_qp_match(self): verify_uri(_ec, request, "redirect_uri", "client_id") def test_verify_uri_qp_mismatch(self): - _ec = self.endpoint.server_get("context") + _ec = self.endpoint.upstream_get("context") _ec.cdb["client_id"] = {"redirect_uris": [("https://rp.example.com/cb", {"foo": ["bar"]})]} request = {"redirect_uri": "https://rp.example.com/cb?foo=bob"} @@ -522,7 +522,7 @@ def test_verify_uri_qp_mismatch(self): verify_uri(_ec, request, "redirect_uri", "client_id") def test_verify_uri_qp_missing(self): - _ec = self.endpoint.server_get("context") + _ec = self.endpoint.upstream_get("context") _ec.cdb["client_id"] = { "redirect_uris": [("https://rp.example.com/cb", {"foo": ["bar"], "state": ["low"]})] } @@ -532,7 +532,7 @@ def test_verify_uri_qp_missing(self): verify_uri(_ec, request, "redirect_uri", "client_id") def test_verify_uri_qp_missing_val(self): - _ec = self.endpoint.server_get("context") + _ec = self.endpoint.upstream_get("context") _ec.cdb["client_id"] = { "redirect_uris": [("https://rp.example.com/cb", {"foo": ["bar", "low"]})] } @@ -542,7 +542,7 @@ def test_verify_uri_qp_missing_val(self): verify_uri(_ec, request, "redirect_uri", "client_id") def test_verify_uri_no_registered_qp(self): - _ec = self.endpoint.server_get("context") + _ec = self.endpoint.upstream_get("context") _ec.cdb["client_id"] = {"redirect_uris": [("https://rp.example.com/cb", {})]} request = {"redirect_uri": "https://rp.example.com/cb?foo=bob"} @@ -550,7 +550,7 @@ def test_verify_uri_no_registered_qp(self): verify_uri(_ec, request, "redirect_uri", "client_id") def test_get_uri(self): - _ec = self.endpoint.server_get("context") + _ec = self.endpoint.upstream_get("context") _ec.cdb["client_id"] = {"redirect_uris": [("https://rp.example.com/cb", {})]} request = { @@ -561,7 +561,7 @@ def test_get_uri(self): assert get_uri(_ec, request, "redirect_uri") == "https://rp.example.com/cb" def test_get_uri_no_redirect_uri(self): - _ec = self.endpoint.server_get("context") + _ec = self.endpoint.upstream_get("context") _ec.cdb["client_id"] = {"redirect_uris": [("https://rp.example.com/cb", {})]} request = {"client_id": "client_id"} @@ -569,7 +569,7 @@ def test_get_uri_no_redirect_uri(self): assert get_uri(_ec, request, "redirect_uri") == "https://rp.example.com/cb" def test_get_uri_no_registered(self): - _ec = self.endpoint.server_get("context") + _ec = self.endpoint.upstream_get("context") _ec.cdb["client_id"] = {"redirect_uris": [("https://rp.example.com/cb", {})]} request = {"client_id": "client_id"} @@ -578,7 +578,7 @@ def test_get_uri_no_registered(self): get_uri(_ec, request, "post_logout_redirect_uri") def test_get_uri_more_then_one_registered(self): - _ec = self.endpoint.server_get("context") + _ec = self.endpoint.upstream_get("context") _ec.cdb["client_id"] = { "redirect_uris": [ ("https://rp.example.com/cb", {}), @@ -601,7 +601,7 @@ def test_create_authn_response_id_token(self): scope=["openid", "profile"], ) - _ec = self.endpoint.server_get("context") + _ec = self.endpoint.upstream_get("context") _ec.cdb["client_id"] = { "client_id": "client_id", "redirect_uris": [("https://rp.example.com/cb", {})], @@ -629,7 +629,7 @@ def test_create_authn_response_id_token_request_claims(self): scope=["openid"], ) - _ec = self.endpoint.server_get("context") + _ec = self.endpoint.upstream_get("context") _ec.cdb["client_id"] = { "client_id": "client_id", "redirect_uris": [("https://rp.example.com/cb", {})], @@ -663,7 +663,7 @@ def test_setup_auth(self): } session_id = self._create_session(request) - kaka = self.endpoint.server_get("endpoint_context").cookie_handler.make_cookie_content( + kaka = self.endpoint.upstream_get("endpoint_context").cookie_handler.make_cookie_content( value=json.dumps({"sid": session_id, "state": request.get("state")}), name=self.endpoint_context.cookie_handler.name["session"], ) @@ -692,7 +692,7 @@ def test_setup_auth_error(self): "id_token_signed_response_alg": "RS256", } - item = self.endpoint.server_get("endpoint_context").authn_broker.db["anon"] + item = self.endpoint.upstream_get("endpoint_context").authn_broker.db["anon"] item["method"].fail = NoSuchAuthentication res = self.endpoint.setup_auth(request, redirect_uri, cinfo, None) @@ -715,7 +715,7 @@ def test_setup_auth_user_form_post(self): nonce="nonce", scope="openid", ) - _ec = self.endpoint.server_get("endpoint_context") + _ec = self.endpoint.upstream_get("endpoint_context") session_id = self._create_session(request) @@ -743,7 +743,7 @@ def test_setup_auth_error_form_post(self): scope=["openid"], ) - item = self.endpoint.server_get("endpoint_context").authn_broker.db["anon"] + item = self.endpoint.upstream_get("endpoint_context").authn_broker.db["anon"] item["method"].fail = NoSuchAuthentication res = self.endpoint.process_request(request) @@ -767,7 +767,7 @@ def test_setup_auth_session_revoked(self): "redirect_uris": [("https://rp.example.com/cb", {})], "id_token_signed_response_alg": "RS256", } - _ec = self.endpoint.server_get("endpoint_context") + _ec = self.endpoint.upstream_get("endpoint_context") session_id = self._create_session(request) @@ -781,7 +781,7 @@ def test_setup_auth_session_revoked(self): assert set(res.keys()) == {"args", "function"} def test_check_session_iframe(self): - self.endpoint.server_get("endpoint_context").provider_info[ + self.endpoint.upstream_get("endpoint_context").provider_info[ "check_session_iframe" ] = "https://example.com/csi" _pr_resp = self.endpoint.parse_request(AUTH_REQ_DICT) @@ -805,7 +805,7 @@ def test_setup_auth_login_hint(self): "id_token_signed_response_alg": "RS256", } - item = self.endpoint.server_get("endpoint_context").authn_broker.db["anon"] + item = self.endpoint.upstream_get("endpoint_context").authn_broker.db["anon"] item["method"].fail = NoSuchAuthentication res = self.endpoint.setup_auth(request, redirect_uri, cinfo, None) @@ -829,13 +829,13 @@ def test_setup_auth_login_hint2acrs(self): "kwargs": {"user": "knoll"}, "class": NoAuthn, } - self.endpoint.server_get("endpoint_context").authn_broker["foo"] = init_method( + self.endpoint.upstream_get("endpoint_context").authn_broker["foo"] = init_method( method_spec, None ) - item = self.endpoint.server_get("endpoint_context").authn_broker.db["anon"] + item = self.endpoint.upstream_get("endpoint_context").authn_broker.db["anon"] item["method"].fail = NoSuchAuthentication - item = self.endpoint.server_get("endpoint_context").authn_broker.db["foo"] + item = self.endpoint.upstream_get("endpoint_context").authn_broker.db["foo"] item["method"].fail = NoSuchAuthentication res = self.endpoint.pick_authn_method(request, redirect_uri) @@ -851,7 +851,7 @@ def test_parse_request(self): _jwt = JWT(key_jar=self.rp_keyjar, iss="client_1", sign_alg="HS256") _jws = _jwt.pack( AUTH_REQ_DICT, - aud=self.endpoint.server_get("endpoint_context").provider_info["issuer"], + aud=self.endpoint.upstream_get("endpoint_context").provider_info["issuer"], ) # ----------------- _req = self.endpoint.parse_request( @@ -877,7 +877,7 @@ def test_parse_request_uri(self): _jwt = JWT(key_jar=self.rp_keyjar, iss="client_1", sign_alg="HS256") _jws = _jwt.pack( AUTH_REQ_DICT, - aud=self.endpoint.server_get("endpoint_context").provider_info["issuer"], + aud=self.endpoint.upstream_get("endpoint_context").provider_info["issuer"], ) request_uri = "https://client.example.com/req" @@ -957,10 +957,10 @@ def test_do_request_uri(self): _jwt = JWT(key_jar=self.rp_keyjar, iss="client_1", sign_alg="HS256") _jws = _jwt.pack( orig_request.to_dict(), - aud=self.endpoint.server_get("endpoint_context").provider_info["issuer"], + aud=self.endpoint.upstream_get("endpoint_context").provider_info["issuer"], ) - endpoint_context = self.endpoint.server_get("endpoint_context") + endpoint_context = self.endpoint.upstream_get("endpoint_context") endpoint_context.cdb["client_1"]["request_uris"] = [("https://example.com/request", {})] with responses.RequestsMock() as rsps: @@ -996,7 +996,7 @@ def test_do_request_uri(self): self.endpoint._do_request_uri(request, "client_1", endpoint_context) def test_post_parse_request(self): - endpoint_context = self.endpoint.server_get("endpoint_context") + endpoint_context = self.endpoint.upstream_get("endpoint_context") msg = self.endpoint._post_parse_request({}, "client_1", endpoint_context) assert "error" in msg @@ -1080,7 +1080,7 @@ def test_do_request_user(self): request["login_hint"] = "mail:diana@example.org" assert self.endpoint.do_request_user(request) == {} - endpoint_context = self.endpoint.server_get("endpoint_context") + endpoint_context = self.endpoint.upstream_get("endpoint_context") # userinfo _userinfo = init_user_info( { @@ -1246,7 +1246,7 @@ def create_endpoint(self): endpoint_context.keyjar.import_jwks( endpoint_context.keyjar.export_jwks(True, ""), conf["issuer"] ) - self.endpoint = server.server_get("endpoint", "authorization") + self.endpoint = server.get_endpoint("authorization") self.session_manager = endpoint_context.session_manager self.user_id = "diana" @@ -1266,7 +1266,7 @@ def test_setup_acr_claim(self): ) redirect_uri = request["redirect_uri"] - _context = self.endpoint.server_get("endpoint_context") + _context = self.endpoint.upstream_get("endpoint_context") cinfo = _context.cdb["client_1"] res = self.endpoint.setup_auth(request, redirect_uri, cinfo, None) diff --git a/tests/test_server_26_oidc_userinfo_endpoint.py b/tests/test_server_26_oidc_userinfo_endpoint.py index 16b50117..76477c6b 100755 --- a/tests/test_server_26_oidc_userinfo_endpoint.py +++ b/tests/test_server_26_oidc_userinfo_endpoint.py @@ -215,7 +215,7 @@ def create_endpoint(self): "response_types": ["code", "token", "code id_token", "id_token"], "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access", "research_and_scholarship"] } - self.endpoint = self.server.server_get("endpoint", "userinfo") + self.endpoint = self.server.get_endpoint("userinfo") self.session_manager = self.endpoint_context.session_manager self.user_id = "diana" @@ -235,7 +235,7 @@ def _mint_code(self, grant, session_id): # Constructing an authorization code is now done return grant.mint_token( session_id=session_id, - endpoint_context=self.endpoint.server_get("endpoint_context"), + context=self.endpoint.upstream_get("context"), token_class="authorization_code", token_handler=self.session_manager.token_handler["authorization_code"], expires_at=utc_time_sans_frac() + 300, # 5 minutes from now @@ -245,7 +245,7 @@ def _mint_token(self, token_class, grant, session_id, token_ref=None): _session_info = self.session_manager.get_session_info(session_id, grant=True) return grant.mint_token( session_id=session_id, - endpoint_context=self.endpoint.server_get("endpoint_context"), + context=self.endpoint.upstream_get("context"), token_class=token_class, token_handler=self.session_manager.token_handler[token_class], expires_at=utc_time_sans_frac() + 900, # 15 minutes from now @@ -255,7 +255,7 @@ def _mint_token(self, token_class, grant, session_id, token_ref=None): def test_init(self): assert self.endpoint assert set( - self.endpoint.server_get("endpoint_context").provider_info["claims_supported"] + self.endpoint.upstream_get("context").provider_info["claims_supported"] ) == { "address", "birthdate", @@ -342,7 +342,7 @@ def test_do_response(self): assert res def test_do_signed_response(self): - self.endpoint.server_get("endpoint_context").cdb["client_1"][ + self.endpoint.upstream_get("context").cdb["client_1"][ "userinfo_signed_response_alg" ] = "ES256" @@ -369,9 +369,9 @@ def test_scopes_to_claims(self): access_token = self._mint_token("access_token", grant, session_id) self.endpoint.kwargs["add_claims_by_scope"] = True - self.endpoint.server_get("endpoint_context").claims_interface.add_claims_by_scope = True + self.endpoint.upstream_get("context").claims_interface.add_claims_by_scope = True grant.claims = { - "userinfo": self.endpoint.server_get("endpoint_context").claims_interface.get_claims( + "userinfo": self.endpoint.upstream_get("context").claims_interface.get_claims( session_id=session_id, scopes=_auth_req["scope"], claims_release_point="userinfo" ) } @@ -416,9 +416,9 @@ def test_scopes_to_claims_per_client(self): access_token = self._mint_token("access_token", grant, session_id) self.endpoint.kwargs["add_claims_by_scope"] = True - self.endpoint.server_get("endpoint_context").claims_interface.add_claims_by_scope = True + self.endpoint.upstream_get("context").claims_interface.add_claims_by_scope = True grant.claims = { - "userinfo": self.endpoint.server_get("endpoint_context").claims_interface.get_claims( + "userinfo": self.endpoint.upstream_get("context").claims_interface.get_claims( session_id=session_id, scopes=_auth_req["scope"], claims_release_point="userinfo" ) } @@ -438,6 +438,8 @@ def test_scopes_to_claims_per_client(self): } def test_allowed_scopes(self): + _context = self.endpoint.upstream_get("context") + _context.scopes_handler.allowed_scopes = list(SCOPE2CLAIMS.keys()) _auth_req = AUTH_REQ.copy() _auth_req["scope"] = ["openid", "research_and_scholarship"] @@ -446,9 +448,9 @@ def test_allowed_scopes(self): access_token = self._mint_token("access_token", grant, session_id) self.endpoint.kwargs["add_claims_by_scope"] = True - self.endpoint.server_get("endpoint_context").claims_interface.add_claims_by_scope = True + _context.claims_interface.add_claims_by_scope = True grant.claims = { - "userinfo": self.endpoint.server_get("endpoint_context").claims_interface.get_claims( + "userinfo": _context.claims_interface.get_claims( session_id=session_id, scopes=_auth_req["scope"], claims_release_point="userinfo" ) } @@ -467,6 +469,42 @@ def test_allowed_scopes(self): "sub" } + def test_allowed_scopes_per_client(self): + self.endpoint_context.cdb["client_1"]["scopes_to_claims"] = { + **SCOPE2CLAIMS, + "research_and_scholarship_2": [ + "name", + "given_name", + "family_name", + "email", + "email_verified", + "sub", + "eduperson_scoped_affiliation", + ], + } + self.endpoint_context.cdb["client_1"]["allowed_scopes"] = list(SCOPE2CLAIMS.keys()) + + _auth_req = AUTH_REQ.copy() + _auth_req["scope"] = ["openid", "research_and_scholarship_2"] + + session_id = self._create_session(_auth_req) + grant = self.session_manager[session_id] + access_token = self._mint_token("access_token", grant, session_id) + + self.endpoint.kwargs["add_claims_by_scope"] = True + self.endpoint.upstream_get("context").claims_interface.add_claims_by_scope = True + grant.claims = { + "userinfo": self.endpoint.upstream_get("context").claims_interface.get_claims( + session_id=session_id, scopes=_auth_req["scope"], claims_release_point="userinfo" + ) + } + + http_info = {"headers": {"authorization": "Bearer {}".format(access_token.value)}} + _req = self.endpoint.parse_request({}, http_info=http_info) + args = self.endpoint.process_request(_req, http_info=http_info) + + assert set(args["response_args"].keys()) == {"sub"} + def test_wrong_type_of_token(self): _auth_req = AUTH_REQ.copy() _auth_req["scope"] = ["openid", "research_and_scholarship"] @@ -591,7 +629,7 @@ def test_userinfo_claims_post(self): def test_process_request_absent_userinfo_conf(self): # consider to have a configuration without userinfo defined in - ec = self.endpoint.server_get("endpoint_context") + ec = self.endpoint.upstream_get("context") ec.userinfo = None session_id = self._create_session(AUTH_REQ) diff --git a/tests/test_server_30_oidc_end_session.py b/tests/test_server_30_oidc_end_session.py index eea70c34..7d9cc772 100644 --- a/tests/test_server_30_oidc_end_session.py +++ b/tests/test_server_30_oidc_end_session.py @@ -223,9 +223,9 @@ def create_endpoint(self): } self.endpoint_context = endpoint_context self.session_manager = endpoint_context.session_manager - self.authn_endpoint = server.server_get("endpoint", "authorization") - self.session_endpoint = server.server_get("endpoint", "session") - self.token_endpoint = server.server_get("endpoint", "token") + self.authn_endpoint = server.get_endpoint("authorization") + self.session_endpoint = server.get_endpoint("session") + self.token_endpoint = server.get_endpoint("token") self.user_id = "diana" def test_end_session_endpoint(self): @@ -236,7 +236,7 @@ def test_end_session_endpoint(self): _ = self.session_endpoint.process_request("", http_info=http_info) def _create_cookie(self, session_id): - ec = self.session_endpoint.server_get("context") + ec = self.session_endpoint.upstream_get("context") return ec.new_cookie( name=ec.cookie_handler.name["session"], sid=session_id, @@ -276,7 +276,7 @@ def _auth_with_id_token(self, state): _pr_resp = self.authn_endpoint.parse_request(req.to_dict()) _resp = self.authn_endpoint.process_request(_pr_resp) - _info = self.session_endpoint.server_get("context").cookie_handler.parse_cookie( + _info = self.session_endpoint.upstream_get("context").cookie_handler.parse_cookie( "oidc_op", _resp["cookie"] ) # value is a JSON document @@ -287,7 +287,7 @@ def _auth_with_id_token(self, state): def _mint_token(self, token_class, grant, session_id, token_ref=None): return grant.mint_token( session_id=session_id, - endpoint_context=self.endpoint_context, + context=self.endpoint_context, token_class=token_class, token_handler=self.session_manager.token_handler[token_class], expires_at=utc_time_sans_frac() + 900, # 15 minutes from now @@ -336,7 +336,7 @@ def test_end_session_endpoint_with_cookie_id_token_and_unknown_sid(self): http_info = {"cookie": [cookie]} msg = Message(id_token=id_token) - verify_id_token(msg, keyjar=self.session_endpoint.server_get("context").keyjar) + verify_id_token(msg, keyjar=self.session_endpoint.upstream_get("context").keyjar) msg2 = Message(id_token_hint=id_token) msg2[verified_claim_name("id_token_hint")] = msg[verified_claim_name("id_token")] @@ -376,7 +376,7 @@ def test_end_session_endpoint_with_post_logout_redirect_uri(self): http_info = {"cookie": [cookie]} post_logout_redirect_uri = join_query( - *self.session_endpoint.server_get("context").cdb["client_1"][ + *self.session_endpoint.upstream_get("context").cdb["client_1"][ "post_logout_redirect_uri" ] ) @@ -403,7 +403,7 @@ def test_end_session_endpoint_with_wrong_post_logout_redirect_uri(self): post_logout_redirect_uri = "https://demo.example.com/log_out" msg = Message(id_token=id_token) - verify_id_token(msg, keyjar=self.session_endpoint.server_get("context").keyjar) + verify_id_token(msg, keyjar=self.session_endpoint.upstream_get("context").keyjar) with pytest.raises(RedirectURIError): self.session_endpoint.process_request( @@ -420,14 +420,14 @@ def test_back_channel_logout_no_backchannel_logout_uri(self): info = self._code_auth("1234567") res = self.session_endpoint.do_back_channel_logout( - self.session_endpoint.server_get("context").cdb["client_1"], info["session_id"] + self.session_endpoint.upstream_get("context").cdb["client_1"], info["session_id"] ) assert res is None def test_back_channel_logout(self): info = self._code_auth("1234567") - _cdb = copy.copy(self.session_endpoint.server_get("context").cdb["client_1"]) + _cdb = copy.copy(self.session_endpoint.upstream_get("context").cdb["client_1"]) _cdb["backchannel_logout_uri"] = "https://example.com/bc_logout" _cdb["client_id"] = "client_1" res = self.session_endpoint.do_back_channel_logout(_cdb, info["session_id"]) @@ -442,7 +442,7 @@ def test_back_channel_logout(self): def test_front_channel_logout(self): self._code_auth("1234567") - _cdb = copy.copy(self.session_endpoint.server_get("context").cdb["client_1"]) + _cdb = copy.copy(self.session_endpoint.upstream_get("context").cdb["client_1"]) _cdb["frontchannel_logout_uri"] = "https://example.com/fc_logout" _cdb["client_id"] = "client_1" res = do_front_channel_logout_iframe(_cdb, ISS, "_sid_") @@ -451,7 +451,7 @@ def test_front_channel_logout(self): def test_front_channel_logout_session_required(self): self._code_auth("1234567") - _cdb = copy.copy(self.session_endpoint.server_get("context").cdb["client_1"]) + _cdb = copy.copy(self.session_endpoint.upstream_get("context").cdb["client_1"]) _cdb["frontchannel_logout_uri"] = "https://example.com/fc_logout" _cdb["frontchannel_logout_session_required"] = True _cdb["client_id"] = "client_1" @@ -467,7 +467,7 @@ def test_front_channel_logout_session_required(self): def test_front_channel_logout_with_query(self): self._code_auth("1234567") - _cdb = copy.copy(self.session_endpoint.server_get("context").cdb["client_1"]) + _cdb = copy.copy(self.session_endpoint.upstream_get("context").cdb["client_1"]) _cdb["frontchannel_logout_uri"] = "https://example.com/fc_logout?entity_id=foo" _cdb["frontchannel_logout_session_required"] = True _cdb["client_id"] = "client_1" @@ -489,10 +489,10 @@ def test_logout_from_client_bc(self): _code, client_session_info=True, handler_key="authorization_code" ) - self.session_endpoint.server_get("context").cdb["client_1"][ + self.session_endpoint.upstream_get("context").cdb["client_1"][ "backchannel_logout_uri" ] = "https://example.com/bc_logout" - self.session_endpoint.server_get("context").cdb["client_1"][ + self.session_endpoint.upstream_get("context").cdb["client_1"][ "client_id" ] = "client_1" @@ -517,12 +517,12 @@ def test_logout_from_client_fc(self): _code, client_session_info=True, handler_key="authorization_code" ) - # del self.session_endpoint.server_get("context").cdb['client_1'][ + # del self.session_endpoint.upstream_get("context").cdb['client_1'][ # 'backchannel_logout_uri'] - self.session_endpoint.server_get("context").cdb["client_1"][ + self.session_endpoint.upstream_get("context").cdb["client_1"][ "frontchannel_logout_uri" ] = "https://example.com/fc_logout" - self.session_endpoint.server_get("context").cdb["client_1"][ + self.session_endpoint.upstream_get("context").cdb["client_1"][ "client_id" ] = "client_1" @@ -554,14 +554,18 @@ def test_logout_from_client(self): ) # client0 - self.session_endpoint.server_get("context").cdb["client_1"][ - "backchannel_logout_uri"] = "https://example.com/bc_logout" - self.session_endpoint.server_get("context").cdb["client_1"][ - "client_id"] = "client_1" - self.session_endpoint.server_get("context").cdb["client_2"][ - "frontchannel_logout_uri"] = "https://example.com/fc_logout" - self.session_endpoint.server_get("context").cdb["client_2"][ - "client_id"] = "client_2" + self.session_endpoint.upstream_get("context").cdb["client_1"][ + "backchannel_logout_uri" + ] = "https://example.com/bc_logout" + self.session_endpoint.upstream_get("context").cdb["client_1"][ + "client_id" + ] = "client_1" + self.session_endpoint.upstream_get("context").cdb["client_2"][ + "frontchannel_logout_uri" + ] = "https://example.com/fc_logout" + self.session_endpoint.upstream_get("context").cdb["client_2"][ + "client_id" + ] = "client_2" res = self.session_endpoint.logout_all_clients(_session_info["branch_id"]) @@ -599,7 +603,7 @@ def test_do_verified_logout(self): _session_info = self.session_manager.get_session_info_by_token( _code, handler_key="authorization_code" ) - _cdb = self.session_endpoint.server_get("context").cdb + _cdb = self.session_endpoint.upstream_get("context").cdb _cdb["client_1"]["backchannel_logout_uri"] = "https://example.com/bc_logout" _cdb["client_1"]["client_id"] = "client_1" @@ -628,21 +632,21 @@ def test_logout_from_client_no_session(self): self._code_auth2("abcdefg") # client0 - self.session_endpoint.server_get("context").cdb["client_1"][ + self.session_endpoint.upstream_get("context").cdb["client_1"][ "backchannel_logout_uri" ] = "https://example.com/bc_logout" - self.session_endpoint.server_get("context").cdb["client_1"][ + self.session_endpoint.upstream_get("context").cdb["client_1"][ "client_id" ] = "client_1" - self.session_endpoint.server_get("context").cdb["client_2"][ + self.session_endpoint.upstream_get("context").cdb["client_2"][ "frontchannel_logout_uri" ] = "https://example.com/fc_logout" - self.session_endpoint.server_get("context").cdb["client_2"][ + self.session_endpoint.upstream_get("context").cdb["client_2"][ "client_id" ] = "client_2" _uid, _cid, _gid = self.session_manager.decrypt_session_id(_session_info["branch_id"]) - self.session_endpoint.server_get("context").session_manager.delete([_uid, _cid]) + self.session_endpoint.upstream_get("context").session_manager.delete([_uid, _cid]) with pytest.raises(InvalidBranchID): self.session_endpoint.logout_all_clients(_session_info["branch_id"]) diff --git a/tests/test_server_31_oauth2_introspection.py b/tests/test_server_31_oauth2_introspection.py index 773ef1d0..f14ec928 100644 --- a/tests/test_server_31_oauth2_introspection.py +++ b/tests/test_server_31_oauth2_introspection.py @@ -210,8 +210,8 @@ def create_endpoint(self, jwt_token): endpoint_context.keyjar.export_jwks_as_json(private=True), endpoint_context.issuer, ) - self.introspection_endpoint = server.server_get("endpoint", "introspection") - self.token_endpoint = server.server_get("endpoint", "token") + self.introspection_endpoint = server.get_endpoint("introspection") + self.token_endpoint = server.get_endpoint("token") self.session_manager = endpoint_context.session_manager self.user_id = "diana" @@ -231,7 +231,7 @@ def _mint_token(self, token_class, grant, session_id, based_on=None, **kwargs): # Constructing an authorization code is now done return grant.mint_token( session_id=session_id, - endpoint_context=self.token_endpoint.server_get("context"), + context=self.token_endpoint.upstream_get("context"), token_class=token_class, token_handler=self.session_manager.token_handler.handler[token_class], expires_at=utc_time_sans_frac() + 300, # 5 minutes from now @@ -242,7 +242,7 @@ def _mint_token(self, token_class, grant, session_id, based_on=None, **kwargs): def _get_access_token(self, areq): session_id = self._create_session(areq) # Consent handling - grant = self.token_endpoint.server_get("context").authz(session_id, areq) + grant = self.token_endpoint.upstream_get("context").authz(session_id, areq) self.session_manager[session_id] = grant # grant = self.session_manager[session_id] code = self._mint_token("authorization_code", grant, session_id) @@ -256,7 +256,7 @@ def test_parse_no_authn(self): def test_parse_with_client_auth_in_req(self): access_token = self._get_access_token(AUTH_REQ) - _context = self.introspection_endpoint.server_get("context") + _context = self.introspection_endpoint.upstream_get("context") _req = self.introspection_endpoint.parse_request( { "token": access_token.value, @@ -273,7 +273,7 @@ def test_parse_with_wrong_client_authn(self): _basic_token = "{}:{}".format( "client_1", - self.introspection_endpoint.server_get("context").cdb["client_1"][ + self.introspection_endpoint.upstream_get("context").cdb["client_1"][ "client_secret" ], ) @@ -293,7 +293,7 @@ def test_process_request(self): { "token": access_token.value, "client_id": "client_1", - "client_secret": self.introspection_endpoint.server_get("context").cdb[ + "client_secret": self.introspection_endpoint.upstream_get("context").cdb[ "client_1" ]["client_secret"], } @@ -317,7 +317,7 @@ def test_do_response(self): { "token": access_token.value, "client_id": "client_1", - "client_secret": self.introspection_endpoint.server_get("context").cdb[ + "client_secret": self.introspection_endpoint.upstream_get("context").cdb[ "client_1" ]["client_secret"], } @@ -348,7 +348,7 @@ def test_do_response(self): def test_do_response_no_token(self): # access_token = self._get_access_token(AUTH_REQ) - _context = self.introspection_endpoint.server_get("context") + _context = self.introspection_endpoint.upstream_get("context") _req = self.introspection_endpoint.parse_request( { "client_id": "client_1", @@ -360,7 +360,7 @@ def test_do_response_no_token(self): def test_access_token(self): access_token = self._get_access_token(AUTH_REQ) - _context = self.introspection_endpoint.server_get("context") + _context = self.introspection_endpoint.upstream_get("context") _req = self.introspection_endpoint.parse_request( { "token": access_token.value, @@ -378,12 +378,12 @@ def test_code(self): session_id = self._create_session(AUTH_REQ) # Apply consent - grant = self.token_endpoint.server_get("context").authz(session_id, AUTH_REQ) + grant = self.token_endpoint.upstream_get("context").authz(session_id, AUTH_REQ) self.session_manager[session_id] = grant code = self._mint_token("authorization_code", grant, session_id) - _context = self.introspection_endpoint.server_get("context") + _context = self.introspection_endpoint.upstream_get("context") _req = self.introspection_endpoint.parse_request( { @@ -399,7 +399,7 @@ def test_code(self): def test_introspection_claims(self): session_id = self._create_session(AUTH_REQ) # Apply consent - grant = self.token_endpoint.server_get("endpoint_context").authz(session_id, AUTH_REQ) + grant = self.token_endpoint.upstream_get("endpoint_context").authz(session_id, AUTH_REQ) self.session_manager[session_id] = grant code = self._mint_token("authorization_code", grant, session_id) @@ -407,14 +407,14 @@ def test_introspection_claims(self): self.introspection_endpoint.kwargs["enable_claims_per_client"] = True - _c_interface = self.introspection_endpoint.server_get("endpoint_context").claims_interface + _c_interface = self.introspection_endpoint.upstream_get("endpoint_context").claims_interface grant.claims = { "introspection": _c_interface.get_claims( session_id, scopes=AUTH_REQ["scope"], claims_release_point="introspection" ) } - _context = self.introspection_endpoint.server_get("endpoint_context") + _context = self.introspection_endpoint.upstream_get("endpoint_context") _req = self.introspection_endpoint.parse_request( { "token": access_token.value, @@ -435,7 +435,7 @@ def test_jwt_unknown_key(self): _jwt = JWT( _keyjar, - iss=self.introspection_endpoint.server_get("endpoint_context").issuer, + iss=self.introspection_endpoint.upstream_get("endpoint_context").issuer, lifetime=3600, ) @@ -443,7 +443,7 @@ def test_jwt_unknown_key(self): _payload = {"sub": "subject_id"} _token = _jwt.pack(_payload, aud="client_1") - _context = self.introspection_endpoint.server_get("endpoint_context") + _context = self.introspection_endpoint.upstream_get("endpoint_context") _req = self.introspection_endpoint.parse_request( { @@ -466,7 +466,7 @@ def mock(): monkeypatch.setattr("idpyoidc.server.token.utc_time_sans_frac", mock) - _context = self.introspection_endpoint.server_get("endpoint_context") + _context = self.introspection_endpoint.upstream_get("endpoint_context") _req = self.introspection_endpoint.parse_request( { @@ -482,7 +482,7 @@ def test_revoked_access_token(self): access_token = self._get_access_token(AUTH_REQ) access_token.revoked = True - _context = self.introspection_endpoint.server_get("endpoint_context") + _context = self.introspection_endpoint.upstream_get("endpoint_context") _req = self.introspection_endpoint.parse_request( { @@ -496,12 +496,12 @@ def test_revoked_access_token(self): def test_introspect_id_token(self): session_id = self._create_session(AUTH_REQ) - grant = self.token_endpoint.server_get("endpoint_context").authz(session_id, AUTH_REQ) + grant = self.token_endpoint.upstream_get("endpoint_context").authz(session_id, AUTH_REQ) self.session_manager[session_id] = grant code = self._mint_token("authorization_code", grant, session_id) id_token = self._mint_token("id_token", grant, session_id, code) - _context = self.introspection_endpoint.server_get("endpoint_context") + _context = self.introspection_endpoint.upstream_get("endpoint_context") _req = self.introspection_endpoint.parse_request( { "token": id_token.value, diff --git a/tests/test_server_32_oidc_read_registration.py b/tests/test_server_32_oidc_read_registration.py index 2e803ba7..af0c7324 100644 --- a/tests/test_server_32_oidc_read_registration.py +++ b/tests/test_server_32_oidc_read_registration.py @@ -125,8 +125,8 @@ def create_endpoint(self): "session_params": SESSION_PARAMS, } server = Server(OPConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) - self.registration_endpoint = server.server_get("endpoint", "registration") - self.registration_api_endpoint = server.server_get("endpoint", "registration_read") + self.registration_endpoint = server.get_endpoint("registration") + self.registration_api_endpoint = server.get_endpoint("registration_read") server.endpoint_context.cdb["client_1"] = {} def test_do_response(self): diff --git a/tests/test_server_33_oauth2_pkce.py b/tests/test_server_33_oauth2_pkce.py index 7b9c1ea0..a6942938 100644 --- a/tests/test_server_33_oauth2_pkce.py +++ b/tests/test_server_33_oauth2_pkce.py @@ -242,8 +242,8 @@ class TestEndpoint(object): def create_endpoint(self, conf): server = create_server(conf) self.session_manager = server.endpoint_context.session_manager - self.authn_endpoint = server.server_get("endpoint", "authorization") - self.token_endpoint = server.server_get("endpoint", "token") + self.authn_endpoint = server.get_endpoint("authorization") + self.token_endpoint = server.get_endpoint("token") def test_unsupported_code_challenge_methods(self, conf): conf["add_on"]["pkce"]["kwargs"]["code_challenge_methods"] = ["dada"] @@ -306,8 +306,8 @@ def test_no_code_challenge(self): def test_not_essential(self, conf): conf["add_on"]["pkce"]["kwargs"]["essential"] = False server = create_server(conf) - authn_endpoint = server.server_get("endpoint", "authorization") - token_endpoint = server.server_get("endpoint", "token") + authn_endpoint = server.get_endpoint("authorization") + token_endpoint = server.get_endpoint("token") _authn_req = AUTH_REQ.copy() _pr_resp = authn_endpoint.parse_request(_authn_req.to_dict()) @@ -324,10 +324,10 @@ def test_not_essential(self, conf): def test_essential_per_client(self, conf): conf["add_on"]["pkce"]["kwargs"]["essential"] = False server = create_server(conf) - authn_endpoint = server.server_get("endpoint", "authorization") - token_endpoint = server.server_get("endpoint", "token") + authn_endpoint = server.get_endpoint("authorization") + token_endpoint = server.get_endpoint("token") _authn_req = AUTH_REQ.copy() - endpoint_context = server.server_get("context") + endpoint_context = server.get_context() endpoint_context.cdb[AUTH_REQ["client_id"]]["pkce_essential"] = True _pr_resp = authn_endpoint.parse_request(_authn_req.to_dict()) @@ -339,10 +339,10 @@ def test_essential_per_client(self, conf): def test_not_essential_per_client(self, conf): conf["add_on"]["pkce"]["kwargs"]["essential"] = True server = create_server(conf) - authn_endpoint = server.server_get("endpoint", "authorization") - token_endpoint = server.server_get("endpoint", "token") + authn_endpoint = server.get_endpoint("authorization") + token_endpoint = server.get_endpoint("token") _authn_req = AUTH_REQ.copy() - endpoint_context = server.server_get("context") + endpoint_context = server.get_context() endpoint_context.cdb[AUTH_REQ["client_id"]]["pkce_essential"] = False _pr_resp = authn_endpoint.parse_request(_authn_req.to_dict()) @@ -372,7 +372,7 @@ def test_unknown_code_challenge_method(self): def test_unsupported_code_challenge_method(self, conf): conf["add_on"]["pkce"]["kwargs"]["code_challenge_methods"] = ["plain"] server = create_server(conf) - authn_endpoint = server.server_get("endpoint", "authorization") + authn_endpoint = server.get_endpoint("authorization") _cc_info = _code_challenge() _authn_req = AUTH_REQ.copy() @@ -438,9 +438,9 @@ def test_missing_authz_endpoint(): } configuration = OPConfiguration(conf, base_path=BASEDIR, domain="127.0.0.1", port=443) server = Server(configuration) - add_pkce_support(server.server_get("endpoints")) + add_pkce_support(server.get_endpoints()) - assert "pkce" not in server.server_get("context").args + assert "pkce" not in server.get_context().args def test_missing_token_endpoint(): @@ -463,6 +463,6 @@ def test_missing_token_endpoint(): } configuration = OPConfiguration(conf, base_path=BASEDIR, domain="127.0.0.1", port=443) server = Server(configuration) - add_pkce_support(server.server_get("endpoints")) + add_pkce_support(server.get_endpoints()) - assert "pkce" not in server.server_get("context").args + assert "pkce" not in server.get_context().args diff --git a/tests/test_server_34_oidc_sso.py b/tests/test_server_34_oidc_sso.py index 99b2db41..d85c5510 100755 --- a/tests/test_server_34_oidc_sso.py +++ b/tests/test_server_34_oidc_sso.py @@ -202,7 +202,7 @@ def create_endpoint_context(self): endpoint_context.keyjar.import_jwks( endpoint_context.keyjar.export_jwks(True, ""), conf["issuer"] ) - self.endpoint = server.server_get("endpoint", "authorization") + self.endpoint = server.get_endpoint("authorization") self.endpoint_context = endpoint_context self.rp_keyjar = KeyJar() self.rp_keyjar.add_symmetric("client_1", "hemligtkodord1234567890") @@ -211,7 +211,7 @@ def create_endpoint_context(self): def test_sso(self): request = self.endpoint.parse_request(AUTH_REQ_DICT) redirect_uri = request["redirect_uri"] - cinfo = self.endpoint.server_get("endpoint_context").cdb[request["client_id"]] + cinfo = self.endpoint.upstream_get("endpoint_context").cdb[request["client_id"]] info = self.endpoint.setup_auth(request, redirect_uri, cinfo, cookie=None) # info = self.endpoint.process_request(request) @@ -224,7 +224,7 @@ def test_sso(self): # second login - from 2nd client request = self.endpoint.parse_request(AUTH_REQ_2.to_dict()) redirect_uri = request["redirect_uri"] - cinfo = self.endpoint.server_get("endpoint_context").cdb[request["client_id"]] + cinfo = self.endpoint.upstream_get("endpoint_context").cdb[request["client_id"]] info = self.endpoint.setup_auth(request, redirect_uri, cinfo, cookie=None) sid2 = info["session_id"] @@ -237,7 +237,7 @@ def test_sso(self): # third login - from 3rd client request = self.endpoint.parse_request(AUTH_REQ_3.to_dict()) redirect_uri = request["redirect_uri"] - cinfo = self.endpoint.server_get("endpoint_context").cdb[request["client_id"]] + cinfo = self.endpoint.upstream_get("endpoint_context").cdb[request["client_id"]] info = self.endpoint.setup_auth(request, redirect_uri, cinfo, cookie=None) assert set(info.keys()) == {"session_id", "identity", "user"} @@ -250,7 +250,7 @@ def test_sso(self): request = self.endpoint.parse_request(AUTH_REQ_4.to_dict()) redirect_uri = request["redirect_uri"] - cinfo = self.endpoint.server_get("endpoint_context").cdb[request["client_id"]] + cinfo = self.endpoint.upstream_get("endpoint_context").cdb[request["client_id"]] # Parse cookies once before setup_auth kakor = self.endpoint_context.cookie_handler.parse_cookie( @@ -267,12 +267,12 @@ def test_sso(self): # Fifth login - from 2nd client - wrong cookie request = self.endpoint.parse_request(AUTH_REQ_2.to_dict()) redirect_uri = request["redirect_uri"] - cinfo = self.endpoint.server_get("endpoint_context").cdb[request["client_id"]] + cinfo = self.endpoint.upstream_get("endpoint_context").cdb[request["client_id"]] info = self.endpoint.setup_auth(request, redirect_uri, cinfo, cookie=kakor) # No valid login cookie so new session assert info["session_id"] != sid2 - user_session_info = self.endpoint.server_get("endpoint_context").session_manager.get( + user_session_info = self.endpoint.upstream_get("endpoint_context").session_manager.get( ["diana"] ) assert len(user_session_info.subordinate) == 3 @@ -285,13 +285,13 @@ def test_sso(self): # Should be one grant for each of client_2 and client_3 and # 2 grants for client_1 - csi1 = self.endpoint.server_get("endpoint_context").session_manager.get( + csi1 = self.endpoint.upstream_get("endpoint_context").session_manager.get( ["diana", "client_1"] ) - csi2 = self.endpoint.server_get("endpoint_context").session_manager.get( + csi2 = self.endpoint.upstream_get("endpoint_context").session_manager.get( ["diana", "client_2"] ) - csi3 = self.endpoint.server_get("endpoint_context").session_manager.get( + csi3 = self.endpoint.upstream_get("endpoint_context").session_manager.get( ["diana", "client_3"] ) diff --git a/tests/test_server_35_oidc_token_endpoint.py b/tests/test_server_35_oidc_token_endpoint.py index 2faa76d6..3e8b4dd6 100755 --- a/tests/test_server_35_oidc_token_endpoint.py +++ b/tests/test_server_35_oidc_token_endpoint.py @@ -216,7 +216,7 @@ def create_endpoint(self, conf): endpoint_context.keyjar.import_jwks(CLIENT_KEYJAR.export_jwks(), "client_1") endpoint_context.userinfo = USERINFO self.session_manager = endpoint_context.session_manager - self.token_endpoint = server.server_get("endpoint", "token") + self.token_endpoint = server.get_endpoint("token") self.user_id = "diana" self.endpoint_context = endpoint_context @@ -243,7 +243,7 @@ def _mint_code(self, grant, client_id): # Constructing an authorization code is now done _code = grant.mint_token( session_id=session_id, - endpoint_context=self.endpoint_context, + context=self.endpoint_context, token_class="authorization_code", token_handler=self.session_manager.token_handler["authorization_code"], usage_rules=usage_rules, @@ -263,7 +263,7 @@ def _mint_access_token(self, grant, session_id, token_ref=None): _token = grant.mint_token( _session_info, - endpoint_context=self.endpoint_context, + context=self.endpoint_context, token_class="access_token", token_handler=self.session_manager.token_handler["access_token"], based_on=token_ref, # Means the token (tok) was used to mint this token @@ -896,7 +896,7 @@ def test_do_refresh_access_token_not_allowed(self): grant = self.endpoint_context.authz(session_id, areq) code = self._mint_code(grant, areq["client_id"]) - _cntx = self.token_endpoint.server_get("endpoint_context") + _cntx = self.token_endpoint.upstream_get("endpoint_context") _token_request = TOKEN_REQ_DICT.copy() _token_request["code"] = code.value @@ -920,7 +920,7 @@ def test_do_refresh_access_token_revoked(self): grant = self.endpoint_context.authz(session_id, areq) code = self._mint_code(grant, areq["client_id"]) - _cntx = self.token_endpoint.server_get("endpoint_context") + _cntx = self.token_endpoint.upstream_get("endpoint_context") _token_request = TOKEN_REQ_DICT.copy() _token_request["code"] = code.value @@ -1030,7 +1030,7 @@ def create_endpoint(self, conf): } endpoint_context.keyjar.import_jwks(CLIENT_KEYJAR.export_jwks(), "client_1") self.session_manager = endpoint_context.session_manager - self.token_endpoint = server.server_get("endpoint", "token") + self.token_endpoint = server.get_endpoint("token") self.user_id = "diana" self.endpoint_context = endpoint_context @@ -1054,7 +1054,7 @@ def _mint_code(self, grant, client_id): # Constructing an authorization code is now done _code = grant.mint_token( session_id=session_id, - endpoint_context=self.endpoint_context, + context=self.endpoint_context, token_class="authorization_code", token_handler=self.session_manager.token_handler["authorization_code"], usage_rules=usage_rules, @@ -1118,7 +1118,7 @@ def test_old_jwt_token(self): payload = _handler.load_custom_claims(payload) # payload.update(kwargs) - _context = _handler.server_get("endpoint_context") + _context = _handler.upstream_get("endpoint_context") signer = JWT( key_jar=_context.keyjar, iss=_handler.issuer, diff --git a/tests/test_server_36_oauth2_token_exchange.py b/tests/test_server_36_oauth2_token_exchange.py index a18342cf..1287b817 100644 --- a/tests/test_server_36_oauth2_token_exchange.py +++ b/tests/test_server_36_oauth2_token_exchange.py @@ -204,8 +204,8 @@ def create_endpoint(self): "allowed_scopes": ["openid", "profile", "offline_access"], } self.endpoint_context.keyjar.import_jwks(CLIENT_KEYJAR.export_jwks(), "client_1") - self.endpoint = server.server_get("endpoint", "token") - self.introspection_endpoint = server.server_get("endpoint", "introspection") + self.endpoint = server.get_endpoint("token") + self.introspection_endpoint = server.get_endpoint("introspection") self.session_manager = self.endpoint_context.session_manager self.user_id = "diana" @@ -229,7 +229,7 @@ def _mint_code(self, grant, client_id): # Constructing an authorization code is now done _code = grant.mint_token( session_id=session_id, - endpoint_context=self.endpoint.server_get("context"), + context=self.endpoint.upstream_get("context"), token_class="authorization_code", token_handler=self.session_manager.token_handler["authorization_code"], usage_rules=usage_rules, diff --git a/tests/test_server_40_oauth2_pushed_authorization.py b/tests/test_server_40_oauth2_pushed_authorization.py index 8fc0dc34..4caa190f 100644 --- a/tests/test_server_40_oauth2_pushed_authorization.py +++ b/tests/test_server_40_oauth2_pushed_authorization.py @@ -177,8 +177,8 @@ def create_endpoint(self): self.rp_keyjar.export_jwks(issuer_id="s6BhdRkqt3"), "s6BhdRkqt3" ) - self.pushed_authorization_endpoint = server.server_get("endpoint", "pushed_authorization") - self.authorization_endpoint = server.server_get("endpoint", "authorization") + self.pushed_authorization_endpoint = server.get_endpoint("pushed_authorization") + self.authorization_endpoint = server.get_endpoint("authorization") def test_init(self): assert self.pushed_authorization_endpoint diff --git a/tests/test_server_50_persistence.py b/tests/test_server_50_persistence.py index 46c521e7..358f8478 100644 --- a/tests/test_server_50_persistence.py +++ b/tests/test_server_50_persistence.py @@ -222,14 +222,14 @@ def create_endpoint(self): server2.endpoint_context.load( _store, init_args={ - "server_get": server2.server_get, + "upstream_get": server2.upstream_get, "handler": server2.endpoint_context.session_manager.token_handler, }, ) self.endpoint = { - 1: server1.server_get("endpoint", "userinfo"), - 2: server2.server_get("endpoint", "userinfo"), + 1: server1.get_endpoint("userinfo"), + 2: server2.get_endpoint("userinfo"), } self.session_manager = { @@ -254,7 +254,7 @@ def _mint_code(self, grant, session_id, index=1): # Constructing an authorization code is now done _code = grant.mint_token( session_id, - endpoint_context=self.endpoint[index].server_get("context"), + context=self.endpoint[index].upstream_get("context"), token_class="authorization_code", token_handler=self.session_manager[index].token_handler["authorization_code"], ) @@ -272,7 +272,7 @@ def _mint_access_token(self, grant, session_id, token_ref=None, index=1): _token = grant.mint_token( session_id=session_id, - endpoint_context=self.endpoint[index].server_get("context"), + context=self.endpoint[index].upstream_get("context"), token_class="access_token", token_handler=self.session_manager[index].token_handler["access_token"], based_on=token_ref, # Means the token (tok) was used to mint this token @@ -285,7 +285,7 @@ def _mint_access_token(self, grant, session_id, token_ref=None, index=1): def _dump_restore(self, fro, to): _store = self.session_manager[fro].dump() self.session_manager[to].load( - _store, init_args={"server_get": self.endpoint[to].server_get} + _store, init_args={"upstream_get": self.endpoint[to].upstream_get} ) def test_init(self): @@ -294,12 +294,12 @@ def test_init(self): self.endpoint[1].server_get("endpoint_context").provider_info["scopes_supported"] ) == {"openid"} assert set( - self.endpoint[1].server_get("endpoint_context").provider_info["scopes_supported"] - ) == set(self.endpoint[2].server_get("endpoint_context").provider_info["scopes_supported"]) + self.endpoint[1].upstream_get("endpoint_context").provider_info["claims_supported"] + ) == set(self.endpoint[2].upstream_get("endpoint_context").provider_info["claims_supported"]) def test_parse(self): session_id = self._create_session(AUTH_REQ, index=1) - grant = self.endpoint[1].server_get("endpoint_context").authz(session_id, AUTH_REQ) + grant = self.endpoint[1].upstream_get("endpoint_context").authz(session_id, AUTH_REQ) # grant, session_id = self._do_grant(AUTH_REQ, index=1) code = self._mint_code(grant, session_id, index=1) access_token = self._mint_access_token(grant, session_id, code, 1) @@ -315,7 +315,7 @@ def test_parse(self): def test_process_request(self): session_id = self._create_session(AUTH_REQ, index=1) - grant = self.endpoint[1].server_get("endpoint_context").authz(session_id, AUTH_REQ) + grant = self.endpoint[1].upstream_get("endpoint_context").authz(session_id, AUTH_REQ) code = self._mint_code(grant, session_id, index=1) access_token = self._mint_access_token(grant, session_id, code, 1) @@ -328,7 +328,7 @@ def test_process_request(self): def test_process_request_not_allowed(self): session_id = self._create_session(AUTH_REQ, index=2) - grant = self.endpoint[2].server_get("endpoint_context").authz(session_id, AUTH_REQ) + grant = self.endpoint[2].upstream_get("endpoint_context").authz(session_id, AUTH_REQ) code = self._mint_code(grant, session_id, index=2) access_token = self._mint_access_token(grant, session_id, code, 2) @@ -362,7 +362,7 @@ def test_process_request_not_allowed(self): def test_do_response(self): session_id = self._create_session(AUTH_REQ, index=2) - grant = self.endpoint[2].server_get("endpoint_context").authz(session_id, AUTH_REQ) + grant = self.endpoint[2].upstream_get("endpoint_context").authz(session_id, AUTH_REQ) code = self._mint_code(grant, session_id, index=2) access_token = self._mint_access_token(grant, session_id, code, 2) @@ -380,12 +380,12 @@ def test_do_response(self): assert res def test_do_signed_response(self): - self.endpoint[2].server_get("endpoint_context").cdb["client_1"][ + self.endpoint[2].upstream_get("endpoint_context").cdb["client_1"][ "userinfo_signed_response_alg" ] = "ES256" session_id = self._create_session(AUTH_REQ, index=2) - grant = self.endpoint[2].server_get("endpoint_context").authz(session_id, AUTH_REQ) + grant = self.endpoint[2].upstream_get("endpoint_context").authz(session_id, AUTH_REQ) code = self._mint_code(grant, session_id, index=2) access_token = self._mint_access_token(grant, session_id, code, 2) @@ -404,13 +404,13 @@ def test_custom_scope(self): _auth_req["scope"] = ["openid", "research_and_scholarship"] session_id = self._create_session(_auth_req, index=2) - grant = self.endpoint[2].server_get("endpoint_context").authz(session_id, _auth_req) + grant = self.endpoint[2].upstream_get("endpoint_context").authz(session_id, _auth_req) self._dump_restore(2, 1) grant.claims = { "userinfo": self.endpoint[1] - .server_get("endpoint_context") + .upstream_get("endpoint_context") .claims_interface.get_claims( session_id, scopes=_auth_req["scope"], claims_release_point="userinfo" ) @@ -448,7 +448,7 @@ def test_sman_db_integrity(self): it show that flush and loads method will keep order, anyway. """ session_id = self._create_session(AUTH_REQ, index=1) - grant = self.endpoint[1].server_get("endpoint_context").authz(session_id, AUTH_REQ) + grant = self.endpoint[1].upstream_get("endpoint_context").authz(session_id, AUTH_REQ) sman = self.session_manager[1] session_dump = sman.dump() diff --git a/tests/test_server_60_dpop.py b/tests/test_server_60_dpop.py index 7b74e172..156162bc 100644 --- a/tests/test_server_60_dpop.py +++ b/tests/test_server_60_dpop.py @@ -196,7 +196,7 @@ def create_endpoint(self): "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"] } self.user_id = "diana" - self.token_endpoint = server.server_get("endpoint", "token") + self.token_endpoint = server.get_endpoint("token") self.session_manager = self.endpoint_context.session_manager def _create_session(self, auth_req, sub_type="public", sector_identifier=""): @@ -219,7 +219,7 @@ def _mint_code(self, grant, client_id): # Constructing an authorization code is now done _code = grant.mint_token( session_id=session_id, - endpoint_context=self.endpoint_context, + context=self.endpoint_context, token_class="authorization_code", token_handler=self.session_manager.token_handler["authorization_code"], usage_rules=usage_rules, diff --git a/tests/test_server_61_add_on.py b/tests/test_server_61_add_on.py index b83a5ea3..b226af33 100644 --- a/tests/test_server_61_add_on.py +++ b/tests/test_server_61_add_on.py @@ -145,10 +145,10 @@ def create_endpoint(self): "response_types": ["code", "token", "code id_token", "id_token"], "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"] } - self.endpoint = server.server_get("endpoint", "authorization") + self.endpoint = server.get_endpoint("authorization") def test_process_request(self): - _context = self.endpoint.server_get("context") + _context = self.endpoint.upstream_get("context") assert _context.add_on["extra_args"] == {"authorization": {"iss": "issuer"}} _pr_resp = self.endpoint.parse_request(AUTH_REQ) diff --git a/tests/test_y_actor_01.py b/tests/test_y_actor_01.py index 3ec9cbc2..e69de29b 100644 --- a/tests/test_y_actor_01.py +++ b/tests/test_y_actor_01.py @@ -1,351 +0,0 @@ - -import copy -import os - -import pytest -from cryptojwt.jwt import JWT -from cryptojwt.key_jar import KeyJar -from cryptojwt.key_jar import init_key_jar - -from idpyoidc.actor import CIBAClient -from idpyoidc.actor import CIBAServer -from idpyoidc.client.entity import Entity -from idpyoidc.message.oidc.backchannel_authentication import AuthenticationRequest -from idpyoidc.server import OPConfiguration -from idpyoidc.server import Server -from idpyoidc.server.authn_event import create_authn_event -from idpyoidc.server.client_authn import verify_client -from idpyoidc.server.oidc.backchannel_authentication import BackChannelAuthentication -from idpyoidc.server.oidc.backchannel_authentication import ClientNotification -from idpyoidc.server.oidc.token import Token -from idpyoidc.server.user_authn.authn_context import MOBILETWOFACTORCONTRACT -from idpyoidc.util import rndstr -from tests import CRYPT_CONFIG -from tests import SESSION_PARAMS - -BASEDIR = os.path.abspath(os.path.dirname(__file__)) -ISSUER_1 = "https://example.com/actor1" -ISSUER_2 = "https://example.com/actor2" - -KEYSPEC = [ - {"type": "RSA", "use": ["sig"]}, - {"type": "EC", "crv": "P-256", "use": ["sig"]}, -] - -RESPONSE_TYPES_SUPPORTED = [ - ["code"], - ["token"], - ["id_token"], - ["code", "token"], - ["code", "id_token"], - ["id_token", "token"], - ["code", "token", "id_token"], - ["none"], -] - -CAPABILITIES = { - "response_types_supported": [" ".join(x) for x in RESPONSE_TYPES_SUPPORTED], - "token_endpoint_auth_methods_supported": [ - "client_secret_post", - "client_secret_basic", - "client_secret_jwt", - "private_key_jwt", - ], - "response_modes_supported": ["query", "fragment", "form_post"], - "subject_types_supported": ["public", "pairwise", "ephemeral"], - "grant_types_supported": [ - "authorization_code", - "implicit", - "urn:ietf:params:oauth:grant-type:jwt-bearer", - ], - "claim_types_supported": ["normal", "aggregated", "distributed"], - "claims_parameter_supported": True, - "request_parameter_supported": True, - "request_uri_parameter_supported": True, -} - -SERVER_CONFIG = { - "httpc_params": {"verify": False, "timeout": 1}, - "capabilities": CAPABILITIES, - "keys": {"uri_path": "jwks.json", "key_defs": KEYSPEC}, - "token_handler_args": { - "jwks_file": "private/token_jwks.json", - "code": {"lifetime": 600, "kwargs": {"crypt_conf": CRYPT_CONFIG}}, - "token": { - "class": "idpyoidc.server.token.jwt_token.JWTToken", - "kwargs": { - "lifetime": 3600, - "base_claims": {"eduperson_scoped_affiliation": None}, - "add_claims_by_scope": True, - }, - }, - "refresh": { - "class": "idpyoidc.server.token.jwt_token.JWTToken", - "kwargs": {"lifetime": 3600}, - }, - "id_token": { - "class": "idpyoidc.server.token.id_token.IDToken", - "kwargs": { - "base_claims": { - "email": {"essential": True}, - "email_verified": {"essential": True}, - } - }, - }, - }, - "endpoint": { - "token": {"path": "token", "class": Token, "kwargs": {}}, - }, - "client_authn": verify_client, - "session_params": SESSION_PARAMS, -} - - -def _create_client(issuer, client_id, service): - client_config = { - "issuer": issuer, - "client_id": client_id, - "client_secret": rndstr(24), - "redirect_uris": [f"https://example.com/{client_id}/authz_cb"], - "behaviour": {"response_types": ["code"]}, - "client_authn_methods": { - "client_notification_authn": "idpyoidc.client.oidc.backchannel_authentication.ClientNotificationAuthn" - }, - } - _services = { - "discovery": { - "class": "idpyoidc.client.oidc.provider_info_discovery.ProviderInfoDiscovery" - }, - "registration": {"class": "idpyoidc.client.oidc.registration.Registration"}, - } - _services.update(service) - - _cli_1_key = init_key_jar(key_defs=KEYSPEC) - - return Entity(config=client_config, services=_services, keyjar=_cli_1_key) - - -def _create_server(issuer, endpoint, port, extra_conf=None): - _config = copy.deepcopy(SERVER_CONFIG) - _config["issuer"] = issuer - _config["endpoint"].update(endpoint) - if extra_conf: - _config.update(extra_conf) - - return Server(OPConfiguration(conf=_config, base_path=BASEDIR, domain="127.0.0.1", port=port)) - - -# Locally defined -def parse_login_hint_token(keyjar: KeyJar, login_hint_token: str, context=None) -> str: - _jwt = JWT(keyjar) - _info = _jwt.unpack(login_hint_token) - # here comes the special knowledge - _sub_id = _info.get("sub_id") - _sub = "" - if _sub_id: - if _sub_id["format"] == "phone": - _sub = "tel:" + _sub_id["phone"] - elif _sub_id["format"] == "mail": - _sub = "mail:" + _sub_id["mail"] - - if _sub and context and context.login_hint_lookup: - try: - _sub = context.login_hint_lookup(_sub) - except KeyError: - _sub = "" - - return _sub - - -class TestPushActor: - @pytest.fixture(autouse=True) - def create_actor(self): - # ============== ACTOR 1 ============== - # Actor 1 can use Authentication Service and provides a Client Notification Endpoint - actor_1 = CIBAClient() - actor_1.client = _create_client( - ISSUER_2, - "actor1", - { - "authentication": { - "class": "idpyoidc.client.oidc.backchannel_authentication.BackChannelAuthentication" - } - }, - ) - - endpoint = { - "client_notify": { - "path": "notify", - "class": ClientNotification, - "kwargs": {"client_authn_method": ["client_notification_authn"]}, - } - } - extra = { - "client_authn_methods": { - "client_notification_authn": "idpyoidc.server.oidc.backchannel_authentication.ClientNotificationAuthn" - } - } - - actor_1.server = _create_server(ISSUER_1, endpoint, 6000, extra_conf=extra) - - self.actor_1 = actor_1 - - # ============== ACTOR 2 ============== - # Provides Authentication endpoint and can use the Client notification service - actor_2 = CIBAServer() - actor_2.client = _create_client( - ISSUER_1, - "actor2", - { - "notification": { - "class": "idpyoidc.client.oidc.backchannel_authentication.ClientNotification" - } - }, - ) - endpoint = { - "backchannel_authentication": { - "path": "authentication", - "class": BackChannelAuthentication, - "kwargs": { - "client_authn_method": [ - "client_secret_basic", - "client_secret_post", - "client_secret_jwt", - "private_key_jwt", - ], - "parse_login_hint_token": {"func": parse_login_hint_token}, - }, - } - } - extra = { - "login_hint_lookup": {"class": "idpyoidc.server.login_hint.LoginHintLookup"}, - "userinfo": { - "class": "idpyoidc.server.user_info.UserInfo", - "kwargs": {"db_file": "users.json"}, - }, - } - actor_2.server = _create_server(ISSUER_2, endpoint, 7000, extra) - - # register clients with servers. - _server_context = actor_1.server.server_get("context") - _client_context = actor_2.client.client_get("service_context") - _server_context.cdb = { - _client_context.client_id: { - "client_secret": _client_context.client_secret, - }, - actor_2.server.server_get("context").issuer: { - "client_secret": _client_context.client_secret - }, - } - _server_context = actor_2.server.server_get("context") - _client_context = actor_1.client.client_get("service_context") - _server_context.cdb = { - _client_context.client_id: {"client_secret": _client_context.client_secret}, - actor_1.server.server_get("context").issuer: { - "client_secret": _client_context.client_secret - }, - } - - # Transfer provider metadata 1->2 and 2->1 - _client_context = actor_2.client.client_get("service_context") - _server_context = actor_1.server.server_get("context") - _client_context.provider_info = _server_context.provider_info - - _client_context = actor_1.client.client_get("service_context") - _server_context = actor_2.server.server_get("context") - _client_context.provider_info = _server_context.provider_info - - _server_context.parse_login_hint_token = parse_login_hint_token - - # keys - _client_keyjar = actor_2.client.client_get("service_context").keyjar - _server_keyjar = actor_1.server.server_get("context").keyjar - _server_keyjar.import_jwks(_client_keyjar.export_jwks(), "actor2") - _client_keyjar.import_jwks(_server_keyjar.export_jwks(), ISSUER_1) - - _client_keyjar = actor_1.client.client_get("service_context").keyjar - _server_keyjar = actor_2.server.server_get("context").keyjar - _server_keyjar.import_jwks(_client_keyjar.export_jwks(), "actor1") - _client_keyjar.import_jwks(_server_keyjar.export_jwks(), ISSUER_2) - - self.actor_1 = actor_1 - self.actor_2 = actor_2 - - def _create_session( - self, server, user_id, auth_req, sub_type="public", sector_identifier="", authn_info="" - ): - if sector_identifier: - authz_req = auth_req.copy() - authz_req["sector_identifier_uri"] = sector_identifier - else: - authz_req = auth_req - client_id = authz_req["client_id"] - ae = create_authn_event(user_id, authn_info=authn_info) - _session_manager = server.endpoint_context.session_manager - return _session_manager.create_session( - ae, authz_req, user_id, client_id=client_id, sub_type=sub_type - ) - - def test_init(self): - assert self.actor_1.client - assert self.actor_2.client - assert self.actor_1.server - assert self.actor_2.server - - def test_query(self): - _req = self.actor_1.create_authentication_request( - scope="openid email example-scope", - binding_message="W4SCT", - login_hint="mail:diana@example.org", - ) - assert _req - assert _req["method"] == "GET" - assert isinstance(_req["request"], AuthenticationRequest) - assert _req["request"]["login_hint"] == "mail:diana@example.org" - - # On the CIBA server side - _endpoint = self.actor_2.server.server_get("endpoint", "backchannel_authentication") - _request = _endpoint.parse_request(_req["request"].to_urlencoded()) - assert _request - # If ping mode - assert "client_notification_token" in _request - req_user = _endpoint.do_request_user(_request) - assert req_user == "diana" - # Construct response to the authentication request - _info = _endpoint.process_request(_request) - assert _info - - # User interaction with the authentication device returns some authentication info - - session_id_2 = self._create_session( - self.actor_2.server, req_user, _request, authn_info=MOBILETWOFACTORCONTRACT - ) - - # Create fake token response - token_request = { - "grant_type": "urn:openid:params:grant-type:ciba", - "auth_req_id": _info["response_args"]["auth_req_id"], - "client_id": "actor1", - } - _token_endpoint = self.actor_2.server.server_get("endpoint", "token") - _treq = _token_endpoint.parse_request(token_request) - # Construct response to the authentication request - _tinfo = _token_endpoint.process_request(_treq) - assert _tinfo - - # Send the response to the client notification endpoint - - _tinfo["response_args"]["client_notification_token"] = _request["client_notification_token"] - _notification_service = self.actor_2.client.client_get("service", "client_notification") - _not_req = _notification_service.get_request_parameters( - request_args=_tinfo["response_args"], authn_method="client_notification_authn" - ) - - assert _not_req - - # The receiver of the notification - - _ninfo = self.actor_1.do_client_notification( - _not_req["body"], http_info={"headers": _not_req["headers"]} - ) - assert _ninfo is None \ No newline at end of file diff --git a/tests/x_test_ciba_01_backchannel_auth.py b/tests/x_test_ciba_01_backchannel_auth.py new file mode 100644 index 00000000..8d8b9969 --- /dev/null +++ b/tests/x_test_ciba_01_backchannel_auth.py @@ -0,0 +1,617 @@ +import os + +import pytest +from cryptojwt import JWT +from cryptojwt import KeyJar +from cryptojwt.jwt import utc_time_sans_frac +from cryptojwt.key_jar import build_keyjar +from cryptojwt.key_jar import init_key_jar + +from idpyoidc.client.defaults import DEFAULT_OAUTH2_SERVICES +from idpyoidc.client.oauth2 import Client +from idpyoidc.defaults import JWT_BEARER +from idpyoidc.message.oidc.backchannel_authentication import AuthenticationRequest +from idpyoidc.message.oidc.backchannel_authentication import NotificationRequest +from idpyoidc.message.oidc.backchannel_authentication import TokenRequest +from idpyoidc.server import OPConfiguration +from idpyoidc.server import Server +from idpyoidc.server import init_service +from idpyoidc.server import init_user_info +from idpyoidc.server import user_info +from idpyoidc.server.authn_event import create_authn_event +from idpyoidc.server.client_authn import verify_client +from idpyoidc.server.oidc.backchannel_authentication import BackChannelAuthentication +from idpyoidc.server.oidc.token import Token +from idpyoidc.server.user_authn.authn_context import INTERNETPROTOCOLPASSWORD + +from . import CRYPT_CONFIG +from . import SESSION_PARAMS +from . import full_path + +QUERY_2 = ( + "request=eyJraWQiOiJsdGFjZXNidyIsImFsZyI6IkVTMjU2In0.eyJpc3MiOiJz" + "NkJoZFJrcXQzIiwiYXVkIjoiaHR0cHM6Ly9zZXJ2ZXIuZXhhbXBsZS5jb20iLCJl" + "eHAiOjE1Mzc4MjAwODYsImlhdCI6MTUzNzgxOTQ4NiwibmJmIjoxNTM3ODE4ODg2" + "LCJqdGkiOiI0TFRDcUFDQzJFU0M1QldDbk4zajU4RW5BIiwic2NvcGUiOiJvcGVu" + "aWQgZW1haWwgZXhhbXBsZS1zY29wZSIsImNsaWVudF9ub3RpZmljYXRpb25fdG9r" + "ZW4iOiI4ZDY3ZGM3OC03ZmFhLTRkNDEtYWFiZC02NzcwN2IzNzQyNTUiLCJiaW5k" + "aW5nX21lc3NhZ2UiOiJXNFNDVCIsImxvZ2luX2hpbnRfdG9rZW4iOiJleUpyYVdR" + "aU9pSnNkR0ZqWlhOaWR5SXNJbUZzWnlJNklrVlRNalUySW4wLmV5SnpkV0pmYVdR" + "aU9uc2labTl5YldGMElqb2ljR2h2Ym1VaUxDSndhRzl1WlNJNklpc3hNek13TWpn" + "eE9EQXdOQ0o5ZlEuR1NxeEpzRmJJeW9qZGZNQkR2M01PeUFwbENWaVZrd1FXenRo" + "Q1d1dTlfZ25LSXFFQ1ppbHdBTnQxSGZJaDN4M0pGamFFcS01TVpfQjNxZWIxMU5B" + "dmcifQ.ELJvZ2RfBl05bq7nx7pXhagzL9R75mUwO-yZScB1aT3mp480fCQ5KjRVD" + "womMMjiMKUI4sx8VrPgAZuTfsNSvA&" + "client_assertion_type=urn%3Aietf%3Aparams%3Aoauth%3A" + "client-assertion-type%3Ajwt-bearer&" + "client_assertion=eyJraWQiOiJsdGFjZXNidyIsImFsZyI6IkVTMjU2In0.eyJ" + "pc3MiOiJzNkJoZFJrcXQzIiwic3ViIjoiczZCaGRSa3F0MyIsImF1ZCI6Imh0dHB" + "zOi8vc2VydmVyLmV4YW1wbGUuY29tIiwianRpIjoiY2NfMVhzc3NmLTJpOG8yZ1B" + "6SUprMSIsImlhdCI6MTUzNzgxOTQ4NiwiZXhwIjoxNTM3ODE5Nzc3fQ.PWb_VMzU" + "IbD_aaO5xYpygnAlhRIjzoc6kxg4NixDuD1DVpkKVSBbBweqgbDLV-awkDtuWnyF" + "yUpHqg83AUV5TA" +) + +KEYDEFS = [ + {"type": "RSA", "key": "", "use": ["sig"]}, + {"type": "EC", "crv": "P-256", "use": ["sig"]}, +] + +ISSUER = "https://example.com/" + +RESPONSE_TYPES_SUPPORTED = [ + ["code"], + ["token"], + ["id_token"], + ["code", "token"], + ["code", "id_token"], + ["id_token", "token"], + ["code", "token", "id_token"], + ["none"], +] + +CAPABILITIES = { + "response_types_supported": [" ".join(x) for x in RESPONSE_TYPES_SUPPORTED], + "token_endpoint_auth_methods_supported": [ + "client_secret_post", + "client_secret_basic", + "client_secret_jwt", + "private_key_jwt", + ], + "response_modes_supported": ["query", "fragment", "form_post"], + "subject_types_supported": ["public", "pairwise", "ephemeral"], + "grant_types_supported": [ + "authorization_code", + "implicit", + "urn:ietf:params:oauth:grant-type:jwt-bearer", + ], + "claim_types_supported": ["normal", "aggregated", "distributed"], + "claims_parameter_supported": True, + "request_parameter_supported": True, + "request_uri_parameter_supported": True, +} + +BASEDIR = os.path.abspath(os.path.dirname(__file__)) + +CLIENT_ID = "client_id" +CLIENT_SECRET = "a_longer_client_secret" +CLI1 = "https://client1.example.com/" + + +# Locally defined +def parse_login_hint_token(keyjar: KeyJar, login_hint_token: str, context=None) -> str: + _jwt = JWT(keyjar) + _info = _jwt.unpack(login_hint_token) + # here comes the special knowledge + _sub_id = _info.get("sub_id") + _sub = "" + if _sub_id: + if _sub_id["format"] == "phone": + _sub = "tel:" + _sub_id["phone"] + elif _sub_id["format"] == "mail": + _sub = "mail:" + _sub_id["mail"] + + if _sub and context and context.login_hint_lookup: + try: + _sub = context.login_hint_lookup(_sub) + except KeyError: + _sub = "" + + return _sub + + +SERVER_CONF = { + "issuer": ISSUER, + "httpc_params": {"verify": False, "timeout": 1}, + "capabilities": CAPABILITIES, + "keys": {"uri_path": "jwks.json", "key_defs": KEYDEFS}, + "token_handler_args": { + "jwks_file": "private/token_jwks.json", + "code": {"lifetime": 600, "kwargs": {"crypt_conf": CRYPT_CONFIG}}, + "token": { + "class": "idpyoidc.server.token.jwt_token.JWTToken", + "kwargs": { + "lifetime": 3600, + "base_claims": {"eduperson_scoped_affiliation": None}, + "add_claims_by_scope": True, + "aud": ["https://example.org/appl"], + }, + }, + "refresh": { + "class": "idpyoidc.server.token.jwt_token.JWTToken", + "kwargs": { + "lifetime": 3600, + "aud": ["https://example.org/appl"], + }, + }, + "id_token": { + "class": "idpyoidc.server.token.id_token.IDToken", + "kwargs": { + "base_claims": { + "email": {"essential": True}, + "email_verified": {"essential": True}, + } + }, + }, + }, + "endpoint": { + "bc_authentication": { + "path": "backchannel_authn", + "class": BackChannelAuthentication, + "kwargs": { + "client_authn_method": [ + "client_secret_basic", + "client_secret_post", + "client_secret_jwt", + "private_key_jwt", + ], + "parse_login_hint_token": {"func": parse_login_hint_token}, + }, + }, + "token": {"path": "token", "class": Token, "kwargs": {}}, + }, + "client_authn": verify_client, + "authentication": { + "anon": { + "acr": INTERNETPROTOCOLPASSWORD, + "class": "idpyoidc.server.user_authn.user.NoAuthn", + "kwargs": {"user": "diana"}, + } + }, + "template_dir": "template", + "userinfo": { + "class": user_info.UserInfo, + "kwargs": {"db_file": "users.json"}, + }, + "session_params": SESSION_PARAMS, +} + + +class TestBCAEndpoint(object): + @pytest.fixture(autouse=True) + def create_endpoint(self): + server = Server(OPConfiguration(SERVER_CONF, base_path=BASEDIR)) + self.endpoint_context = server.endpoint_context + self.endpoint_context.cdb["client_1"] = { + "client_secret": "hemligt", + "redirect_uris": [("https://example.com/cb", None)], + "client_salt": "salted", + "token_endpoint_auth_method": "client_secret_post", + "response_types": ["code", "token", "code id_token", "id_token"], + } + self.endpoint = server.get_endpoint("backchannel_authentication") + self.token_endpoint = server.get_endpoint("token") + + self.client_keyjar = build_keyjar(KEYDEFS) + # Add servers keys + self.client_keyjar.import_jwks(server.endpoint_context.keyjar.export_jwks(), ISSUER) + # The only own key the client has a this point + self.client_keyjar.add_symmetric("", CLIENT_SECRET, ["sig"]) + # Need to add the client_secret as a symmetric key bound to the client_id + server.endpoint_context.keyjar.add_symmetric(CLIENT_ID, CLIENT_SECRET, ["sig"]) + server.endpoint_context.keyjar.import_jwks(self.client_keyjar.export_jwks(), CLIENT_ID) + + server.endpoint_context.cdb = {CLIENT_ID: {"client_secret": CLIENT_SECRET}} + # login_hint + server.endpoint_context.login_hint_lookup = init_service( + {"class": "idpyoidc.server.login_hint.LoginHintLookup"}, None + ) + # userinfo + _userinfo = init_user_info( + { + "class": "idpyoidc.server.user_info.UserInfo", + "kwargs": {"db_file": full_path("users.json")}, + }, + "", + ) + server.endpoint_context.login_hint_lookup.userinfo = _userinfo + self.session_manager = server.endpoint_context.session_manager + + def test_login_hint_token(self): + _jwt = JWT(self.client_keyjar, iss=CLIENT_ID, sign_alg="HS256") + _jwt.with_jti = True + _assertion = _jwt.pack({"aud": [ISSUER]}) + + _jwt = JWT(self.client_keyjar, iss=CLIENT_ID, sign_alg="ES256") + _payload = {"sub_id": {"format": "phone", "phone": "+46907865000"}} + _login_hint_token = _jwt.pack(_payload, aud=[ISSUER]) + + request = { + "client_assertion": _assertion, + "client_assertion_type": JWT_BEARER, + "scope": "openid email example-scope", + "client_notification_token": "8d67dc78-7faa-4d41-aabd-67707b374255", + "binding_message": "W4SCT", + "login_hint_token": _login_hint_token, + } + + req = AuthenticationRequest(**request) + req = self.endpoint.parse_request(req.to_urlencoded(), verify_args={"mode": "ping"}) + assert req + req_user = self.endpoint.do_request_user(req) + assert req_user == "diana" + + def test_login_hint_token_jwt(self): + _jwt = JWT(self.client_keyjar, iss=CLIENT_ID, sign_alg="HS256") + _jwt.with_jti = True + _assertion = _jwt.pack({"aud": [ISSUER]}) + + _jwt = JWT(self.client_keyjar, iss=CLIENT_ID, sign_alg="ES256") + _payload = {"sub_id": {"format": "phone", "phone": "+46907865000"}} + _login_hint_token = _jwt.pack(_payload, aud=[ISSUER]) + + _jwt = JWT(self.client_keyjar, iss=CLIENT_ID, sign_alg="ES256") + _jwt.with_jti = True + _request_payload = { + "scope": "openid email example-scope", + "client_notification_token": "8d67dc78-7faa-4d41-aabd-67707b374255", + "binding_message": "W4SCT", + "login_hint_token": _login_hint_token, + } + _request_object = _jwt.pack(_request_payload, aud=[ISSUER]) + + request = { + "client_assertion": _assertion, + "client_assertion_type": JWT_BEARER, + "request": _request_object, + } + + req = AuthenticationRequest(**request) + req = self.endpoint.parse_request(req.to_urlencoded()) + assert req + req_user = self.endpoint.do_request_user(req) + assert req_user == "diana" + + def test_id_token_hint(self): + _jwt = JWT(self.client_keyjar, iss=CLIENT_ID, sign_alg="HS256") + _jwt.with_jti = True + _assertion = _jwt.pack({"aud": [ISSUER]}) + + # The old ID token + _idt_payload = { + "sub": "Anna", + "iss": ISSUER, + "aud": [CLIENT_ID], + "exp": utc_time_sans_frac() + 3600, + } + + _id_token_hint = _jwt.pack(_idt_payload) + + request = { + "client_assertion": _assertion, + "client_assertion_type": JWT_BEARER, + "scope": "openid email example-scope", + "client_notification_token": "8d67dc78-7faa-4d41-aabd-67707b374255", + "binding_message": "W4SCT", + "id_token_hint": _id_token_hint, + } + + req = AuthenticationRequest(**request) + req = self.endpoint.parse_request(req.to_urlencoded()) + assert req + # If ping mode + assert "client_notification_token" in req + req_user = self.endpoint.do_request_user(req) + assert req_user == "Anna" + + def test_login_hint(self): + _jwt = JWT(self.client_keyjar, iss=CLIENT_ID, sign_alg="HS256") + _jwt.with_jti = True + _assertion = _jwt.pack({"aud": [ISSUER]}) + + request = { + "client_assertion": _assertion, + "client_assertion_type": JWT_BEARER, + "scope": "openid email example-scope", + "client_notification_token": "8d67dc78-7faa-4d41-aabd-67707b374255", + "binding_message": "W4SCT", + "login_hint": "mail:diana@example.org", + } + + req = AuthenticationRequest(**request) + req = self.endpoint.parse_request(req.to_urlencoded()) + assert req + # If ping mode + assert "client_notification_token" in req + req_user = self.endpoint.do_request_user(req) + assert req_user == "diana" + + def test_login_hint_and_id_token_hint(self): + _jwt = JWT(self.client_keyjar, iss=CLIENT_ID, sign_alg="HS256") + _jwt.with_jti = True + _assertion = _jwt.pack({"aud": [ISSUER]}) + + # The old ID token + _idt_payload = { + "sub": "Anna", + "iss": ISSUER, + "aud": [CLIENT_ID], + "exp": utc_time_sans_frac() + 3600, + } + + _id_token_hint = _jwt.pack(_idt_payload) + + request = { + "client_assertion": _assertion, + "client_assertion_type": JWT_BEARER, + "scope": "openid email example-scope", + "client_notification_token": "8d67dc78-7faa-4d41-aabd-67707b374255", + "binding_message": "W4SCT", + "login_hint": "mail:diana@example.org", + "id_token_hint": _id_token_hint, + } + + req = AuthenticationRequest(**request) + req = self.endpoint.parse_request(req.to_urlencoded()) + assert "error" in req + + def test_ping_and_no_client_notification_token(self): + _jwt = JWT(self.client_keyjar, iss=CLIENT_ID, sign_alg="HS256") + _jwt.with_jti = True + _assertion = _jwt.pack({"aud": [ISSUER]}) + + request = { + "client_assertion": _assertion, + "client_assertion_type": JWT_BEARER, + "scope": "openid email example-scope", + "binding_message": "W4SCT", + "login_hint": "mail:diana@example.org", + } + + req = AuthenticationRequest(**request) + req = self.endpoint.parse_request(req.to_urlencoded(), verify_args={"mode": "ping"}) + assert "error" in req + + def test_request_and_extra_parameter(self): + _jwt = JWT(self.client_keyjar, iss=CLIENT_ID, sign_alg="HS256") + _jwt.with_jti = True + _assertion = _jwt.pack({"aud": [ISSUER]}) + + _jwt = JWT(self.client_keyjar, iss=CLIENT_ID, sign_alg="ES256") + _payload = {"sub_id": {"format": "phone", "phone": "+13302818004"}} + _login_hint_token = _jwt.pack(_payload, aud=[ISSUER]) + + _jwt = JWT(self.client_keyjar, iss=CLIENT_ID, sign_alg="ES256") + _jwt.with_jti = True + _request_payload = { + "scope": "openid email example-scope", + "client_notification_token": "8d67dc78-7faa-4d41-aabd-67707b374255", + "binding_message": "W4SCT", + "login_hint_token": _login_hint_token, + } + _request_object = _jwt.pack(_request_payload, aud=[ISSUER]) + + request = { + "scope": "openid email example-scope", + "client_assertion": _assertion, + "client_assertion_type": JWT_BEARER, + "request": _request_object, + } + + req = AuthenticationRequest(**request) + req = self.endpoint.parse_request(req.to_urlencoded()) + assert "error" in req + + def _create_session(self, user_id, auth_req, sub_type="public", sector_identifier=""): + if sector_identifier: + authz_req = auth_req.copy() + authz_req["sector_identifier_uri"] = sector_identifier + else: + authz_req = auth_req + client_id = authz_req["client_id"] + ae = create_authn_event(user_id) + return self.session_manager.create_session( + ae, authz_req, user_id, client_id=client_id, sub_type=sub_type + ) + + def test_login_hint_response(self): + _jwt = JWT(self.client_keyjar, iss=CLIENT_ID, sign_alg="HS256") + _jwt.with_jti = True + _assertion = _jwt.pack({"aud": [ISSUER]}) + + request = { + "client_assertion": _assertion, + "client_assertion_type": JWT_BEARER, + "scope": "openid email example-scope", + "client_notification_token": "8d67dc78-7faa-4d41-aabd-67707b374255", + "binding_message": "W4SCT", + "login_hint": "mail:diana@example.org", + } + + req = AuthenticationRequest(**request) + req = self.endpoint.parse_request(req.to_urlencoded()) + _info = self.endpoint.process_request(req) + assert _info + sid = self.session_manager.auth_req_id_map[_info["response_args"]["auth_req_id"]] + _user_id, _client_id, _grant_id = self.session_manager.decrypt_session_id(sid) + # Some time passes and the client authentication is successfully performed + session_id_2 = self._create_session(_user_id, req) + + # token request comes in + _jwt = JWT(self.client_keyjar, iss=CLIENT_ID, sign_alg="HS256") + _jwt.with_jti = True + _assertion = _jwt.pack({"aud": [ISSUER + "token"]}) + + token_request = { + "client_assertion": _assertion, + "client_assertion_type": JWT_BEARER, + "auth_req_id": _info["response_args"]["auth_req_id"], + "grant_type": "urn:openid:params:grant-type:ciba", + } + _treq = TokenRequest(**token_request) + _req = self.token_endpoint.parse_request(_treq.to_urlencoded()) + assert _req + _info = self.token_endpoint.process_request(_req) + assert _info + assert set(_info["response_args"].keys()) == { + "token_type", + "scope", + "access_token", + "expires_in", + "id_token", + } + + +_dirname = os.path.dirname(os.path.abspath(__file__)) + +KEYSPEC = [ + {"type": "RSA", "use": ["sig"]}, + {"type": "EC", "crv": "P-256", "use": ["sig"]}, +] + +CLI_KEY = init_key_jar( + public_path="{}/pub_client.jwks".format(_dirname), + private_path="{}/priv_client.jwks".format(_dirname), + key_defs=KEYSPEC, + issuer_id="client_id", +) + + +class TestBCAEndpointService(object): + @pytest.fixture(autouse=True) + def create_endpoint(self): + self.ciba = {"server": self._create_server(), "client": self._create_ciba_client()} + + def _create_server(self): + server = Server(OPConfiguration(SERVER_CONF, base_path=BASEDIR)) + endpoint_context = server.endpoint_context + endpoint_context.cdb["client_1"] = { + "client_secret": "hemligt", + "redirect_uris": [("https://example.com/cb", None)], + "client_salt": "salted", + "token_endpoint_auth_method": "client_secret_post", + "response_types": ["code", "token", "code id_token", "id_token"], + } + + client_keyjar = build_keyjar(KEYDEFS) + # Add servers keys + client_keyjar.import_jwks(server.endpoint_context.keyjar.export_jwks(), ISSUER) + # The only own key the client has a this point + client_keyjar.add_symmetric("", CLIENT_SECRET, ["sig"]) + # Need to add the client_secret as a symmetric key bound to the client_id + server.endpoint_context.keyjar.add_symmetric(CLIENT_ID, CLIENT_SECRET, ["sig"]) + server.endpoint_context.keyjar.import_jwks(client_keyjar.export_jwks(), CLIENT_ID) + + server.endpoint_context.cdb = {CLIENT_ID: {"client_secret": CLIENT_SECRET}} + # login_hint + server.endpoint_context.login_hint_lookup = init_service( + {"class": "idpyoidc.server.login_hint.LoginHintLookup"}, None + ) + # userinfo + _userinfo = init_user_info( + { + "class": "idpyoidc.server.user_info.UserInfo", + "kwargs": {"db_file": full_path("users.json")}, + }, + "", + ) + server.endpoint_context.login_hint_lookup.userinfo = _userinfo + return server + + def _create_ciba_client(self): + config = { + "client_id": CLIENT_ID, + "client_secret": CLIENT_SECRET, + "redirect_uris": ["https://example.com/cb"], + "services": { + "client_notification": { + "class": "idpyoidc.client.oidc.backchannel_authentication.ClientNotification", + "kwargs": {"conf": {"default_authn_method": "client_notification_authn"}}, + }, + }, + "client_authn_methods": { + "client_notification_authn": "idpyoidc.client.oidc.backchannel_authentication.ClientNotificationAuthn" + }, + } + + client = Client(keyjar=CLI_KEY, config=config, services=DEFAULT_OAUTH2_SERVICES) + + client.upstream_get("context").provider_info = { + "client_notification_endpoint": "https://example.com/notify", + } + + return client + + def _create_session(self, user_id, auth_req, sub_type="public", sector_identifier=""): + if sector_identifier: + authz_req = auth_req.copy() + authz_req["sector_identifier_uri"] = sector_identifier + else: + authz_req = auth_req + client_id = authz_req["client_id"] + ae = create_authn_event(user_id) + _session_manager = self.ciba["server"].endpoint_context.session_manager + return _session_manager.create_session( + ae, authz_req, user_id, client_id=client_id, sub_type=sub_type + ) + + def test_client_notification(self): + _keyjar = self.ciba["server"].endpoint_context.keyjar + _jwt = JWT(_keyjar, iss=CLIENT_ID, sign_alg="HS256") + _jwt.with_jti = True + _assertion = _jwt.pack({"aud": [ISSUER]}) + + request = { + "client_assertion": _assertion, + "client_assertion_type": JWT_BEARER, + "scope": "openid email example-scope", + "client_notification_token": "8d67dc78-7faa-4d41-aabd-67707b374255", + "binding_message": "W4SCT", + "login_hint": "mail:diana@example.org", + } + + _authn_endpoint = self.ciba["server"].upstream_get("endpoint", "backchannel_authentication") + + req = AuthenticationRequest(**request) + req = _authn_endpoint.parse_request(req.to_urlencoded()) + _info = _authn_endpoint.process_request(req) + assert _info + + _session_manager = self.ciba["server"].endpoint_context.session_manager + sid = _session_manager.auth_req_id_map[_info["response_args"]["auth_req_id"]] + _user_id, _client_id, _grant_id = _session_manager.decrypt_session_id(sid) + + # Some time passes and the client authentication is successfully performed + # The interaction with the authentication device is not shown + session_id_2 = self._create_session(_user_id, req) + + # Now it's time to send a client notification + req_args = { + "auth_req_id": _info["response_args"]["auth_req_id"], + "client_notification_token": request["client_notification_token"], + } + + _service = self.ciba["client"].upstream_get("service", "client_notification") + _req_param = _service.get_request_parameters(request_args=req_args) + assert _req_param + assert isinstance(_req_param["request"], NotificationRequest) + assert set(_req_param.keys()) == {"method", "request", "url", "body", "headers"} + assert _req_param["method"] == "POST" + # This is the client's notification endpoint + assert ( + _req_param["url"] + == self.ciba["client"] + .upstream_get("context") + .provider_info["client_notification_endpoint"] + ) + assert set(_req_param["request"].keys()) == {"auth_req_id", "client_notification_token"} From b52c82073ac0aceaf54c52e06e6f70b28b3c7d40 Mon Sep 17 00:00:00 2001 From: Roland Hedberg Date: Sun, 4 Dec 2022 20:04:17 +0100 Subject: [PATCH 43/76] Rebased onto improved --- src/idpyoidc/client/client_auth.py | 6 +- src/idpyoidc/client/entity.py | 56 +++- src/idpyoidc/client/http.py | 4 +- src/idpyoidc/client/oauth2/__init__.py | 4 +- .../oauth2/add_on/pushed_authorization.py | 5 +- src/idpyoidc/client/oidc/__init__.py | 6 +- src/idpyoidc/client/oidc/authorization.py | 9 +- src/idpyoidc/client/oidc/utils.py | 7 +- src/idpyoidc/node.py | 11 +- src/idpyoidc/server/__init__.py | 9 - src/idpyoidc/server/endpoint.py | 1 + src/idpyoidc/server/endpoint_context.py | 5 +- src/idpyoidc/server/oauth2/authorization.py | 2 +- src/idpyoidc/server/oidc/registration.py | 4 +- src/idpyoidc/server/oidc/session.py | 3 +- src/idpyoidc/server/token/jwt_token.py | 1 - tests/test_client_04_service.py | 3 + tests/test_client_06_client_authn.py | 2 +- tests/test_client_10_entity.py | 71 +++++ tests/test_client_12_client_auth.py | 9 +- tests/test_client_13_service_context.py | 254 ++++++++++++++++++ .../test_client_14_service_context_impexp.py | 4 +- tests/test_client_19_webfinger.py | 4 +- tests/test_client_21_oidc_service.py | 4 +- tests/test_client_24_oic_utils.py | 2 +- tests/test_client_28_rp_handler_oidc.py | 2 +- tests/test_client_30_rph_defaults.py | 2 +- tests/test_client_51_identity_assurance.py | 2 +- tests/test_server_03_authz_handling.py | 2 +- tests/test_server_08_id_token.py | 28 +- tests/test_server_09_authn_context.py | 4 +- tests/test_server_17_client_authn.py | 69 +++-- tests/test_server_20b_claims.py | 2 +- tests/test_server_20c_authz_handling.py | 2 +- tests/test_server_20d_client_authn.py | 60 +++-- tests/test_server_20e_jwt_token.py | 24 +- ...st_server_23_oidc_registration_endpoint.py | 15 +- ...server_24_oauth2_authorization_endpoint.py | 6 +- ...er_24_oauth2_authorization_endpoint_jar.py | 6 +- tests/test_server_24_oauth2_token_endpoint.py | 2 +- ...t_server_24_oidc_authorization_endpoint.py | 17 +- tests/test_server_30_oidc_end_session.py | 4 +- tests/test_server_31_oauth2_introspection.py | 5 +- tests/test_server_33_oauth2_pkce.py | 4 +- tests/test_server_34_oidc_sso.py | 6 +- tests/test_server_35_oidc_token_endpoint.py | 28 +- tests/test_server_36_oauth2_token_exchange.py | 2 +- ...t_server_40_oauth2_pushed_authorization.py | 6 +- tests/test_server_50_persistence.py | 25 +- tests/x_test_ciba_01_backchannel_auth.py | 92 +++---- 50 files changed, 646 insertions(+), 255 deletions(-) create mode 100644 tests/test_client_10_entity.py diff --git a/src/idpyoidc/client/client_auth.py b/src/idpyoidc/client/client_auth.py index 760812f3..1ab0b100 100755 --- a/src/idpyoidc/client/client_auth.py +++ b/src/idpyoidc/client/client_auth.py @@ -452,7 +452,7 @@ def _get_signing_key(self, algorithm, keyjar, key_types, kid=None): return signing_key - def _get_audience_and_algorithm(self, context, **kwargs): + def _get_audience_and_algorithm(self, context, keyjar, **kwargs): algorithm = None # audience for the signed JWT depends on which endpoint @@ -468,7 +468,7 @@ def _get_audience_and_algorithm(self, context, **kwargs): else: for alg in algs: # pick the first one I support and have keys for if alg in SIGNER_ALGS and self.get_signing_key_from_keyjar( - alg, context + alg, keyjar ): algorithm = alg break @@ -485,7 +485,7 @@ def _construct_client_assertion(self, service, **kwargs): _context = service.upstream_get("context") _entity = service.upstream_get("entity") _keyjar = service.upstream_get('attribute', 'keyjar') - audience, algorithm = self._get_audience_and_algorithm(_context, **kwargs) + audience, algorithm = self._get_audience_and_algorithm(_context, _keyjar, **kwargs) if "kid" in kwargs: signing_key = self._get_signing_key(algorithm, _keyjar, _context.kid["sig"], diff --git a/src/idpyoidc/client/entity.py b/src/idpyoidc/client/entity.py index 7f1739aa..5db28cc7 100644 --- a/src/idpyoidc/client/entity.py +++ b/src/idpyoidc/client/entity.py @@ -5,8 +5,8 @@ from cryptojwt import KeyJar from cryptojwt.key_jar import init_key_jar -from idpyoidc.client.client_auth import CLIENT_AUTHN_METHOD from idpyoidc.client.client_auth import client_auth_setup +from idpyoidc.client.client_auth import CLIENT_AUTHN_METHOD from idpyoidc.client.configure import Configuration from idpyoidc.client.configure import get_configuration from idpyoidc.client.defaults import DEFAULT_OAUTH2_SERVICES @@ -74,6 +74,15 @@ def redirect_uris_from_callback_uris(callback_uris): class Entity(Unit): + parameter = { + 'entity_id': None, + 'jwks_uri': None, + 'httpc_params': None, + 'key_conf': None, + 'keyjar': KeyJar, + 'context': None + } + def __init__( self, keyjar: Optional[KeyJar] = None, @@ -93,9 +102,9 @@ def __init__( entity_id=entity_id) if context: - self._service_context = context + self.context = context else: - self._service_context = ServiceContext(config=config, jwks_uri=jwks_uri, + self.context = ServiceContext(config=config, jwks_uri=jwks_uri, upstream_get=self.unit_get) if services: @@ -119,10 +128,10 @@ def get_services(self, *arg): return self._service def get_service_context(self, *arg): # Want to get rid of this - return self._service_context + return self.context def get_context(self, *arg): - return self._service_context + return self.context def get_service(self, service_name, *arg): try: @@ -141,15 +150,15 @@ def get_entity(self): return self def get_client_id(self): - _val = self._service_context.work_environment.get_usage('client_id') + _val = self.context.work_environment.get_usage('client_id') if _val: return _val else: - return self._service_context.work_environment.get_preference('client_id') + return self.context.work_environment.get_preference('client_id') def setup_client_authn_methods(self, config): if config and "client_authn_methods" in config: - self._service_context.client_authn_method = client_auth_setup( + self.context.client_authn_method = client_auth_setup( config.get("client_authn_methods") ) else: @@ -158,4 +167,33 @@ def setup_client_authn_methods(self, config): s.default_authn_method]) _methods = {m: CLIENT_AUTHN_METHOD[m] for m in _default_methods if m in CLIENT_AUTHN_METHOD} - self._service_context.client_authn_method = client_auth_setup(_methods) + self.context.client_authn_method = client_auth_setup(_methods) + + def import_keys(self, keyspec): + """ + The client needs it's own set of keys. It can either dynamically + create them or load them from local storage. + This method can also fetch other entities keys provided the + URL points to a JWKS. + + :param keyspec: + """ + _keyjar = self.get_attribute('keyjar') + if _keyjar is None: + _keyjar = KeyJar() + + for where, spec in keyspec.items(): + if where == "file": + for typ, files in spec.items(): + if typ == "rsa": + for fil in files: + _key = RSAKey(priv_key=import_private_rsa_key_from_file(fil), + use="sig") + _bundle = KeyBundle() + _bundle.append(_key) + _keyjar.add_kb("", _bundle) + elif where == "url": + for iss, url in spec.items(): + _bundle = KeyBundle(source=url) + _keyjar.add_kb(iss, _bundle) + return _keyjar diff --git a/src/idpyoidc/client/http.py b/src/idpyoidc/client/http.py index e846ceed..7a7f58b3 100644 --- a/src/idpyoidc/client/http.py +++ b/src/idpyoidc/client/http.py @@ -4,7 +4,7 @@ from http.cookies import CookieError from http.cookies import SimpleCookie -import requests +from requests import request from idpyoidc.client.exception import NonFatalException from idpyoidc.client.util import sanitize from idpyoidc.client.util import set_cookie @@ -94,7 +94,7 @@ def __call__(self, url, method="GET", **kwargs): try: # Do the request - r = requests.request(method, url, **_kwargs) + r = request(method, url, **_kwargs) except Exception as err: logger.error( "http_request failed: %s, url: %s, htargs: %s, method: %s" diff --git a/src/idpyoidc/client/oauth2/__init__.py b/src/idpyoidc/client/oauth2/__init__.py index 9c6239c0..dc17c031 100755 --- a/src/idpyoidc/client/oauth2/__init__.py +++ b/src/idpyoidc/client/oauth2/__init__.py @@ -129,7 +129,7 @@ def do_request( ) def set_client_id(self, client_id): - self._service_context.set("client_id", client_id) + self.get_context().set("client_id", client_id) def get_response( self, @@ -152,7 +152,7 @@ def get_response( :return: """ try: - resp = self.httpc(url, method, data=body, headers=headers) + resp = self.httpc(method, url, data=body, headers=headers) except Exception as err: logger.error("Exception on request: {}".format(err)) raise diff --git a/src/idpyoidc/client/oauth2/add_on/pushed_authorization.py b/src/idpyoidc/client/oauth2/add_on/pushed_authorization.py index c8790c3d..13c7706f 100644 --- a/src/idpyoidc/client/oauth2/add_on/pushed_authorization.py +++ b/src/idpyoidc/client/oauth2/add_on/pushed_authorization.py @@ -2,7 +2,8 @@ from cryptojwt import JWT -import requests +from requests import request + from idpyoidc.message import Message from idpyoidc.message.oauth2 import JWTSecuredAuthorizationRequest @@ -64,7 +65,7 @@ def add_support( """ if http_client is None: - http_client = requests + http_client = request _service = services["authorization"] _service.upstream_get("context").add_on["pushed_authorization"] = { diff --git a/src/idpyoidc/client/oidc/__init__.py b/src/idpyoidc/client/oidc/__init__.py index e70df309..e7336f6a 100755 --- a/src/idpyoidc/client/oidc/__init__.py +++ b/src/idpyoidc/client/oidc/__init__.py @@ -136,7 +136,7 @@ def fetch_distributed_claims(self, userinfo, callback=None): service=self.get_service("userinfo"), access_token=spec["access_token"], ) - _resp = self.httpc.send(spec["endpoint"], "GET", **httpc_params) + _resp = self.httpc("GET", spec["endpoint"], **httpc_params) else: if callback: token = callback(spec["endpoint"]) @@ -144,9 +144,9 @@ def fetch_distributed_claims(self, userinfo, callback=None): httpc_params = cauth.construct( service=self.get_service("userinfo"), access_token=token ) - _resp = self.httpc.send(spec["endpoint"], "GET", **httpc_params) + _resp = self.httpc("GET", spec["endpoint"], **httpc_params) else: - _resp = self.httpc.send(spec["endpoint"], "GET") + _resp = self.httpc("GET", spec["endpoint"]) if _resp.status_code == 200: _uinfo = json.loads(_resp.text) diff --git a/src/idpyoidc/client/oidc/authorization.py b/src/idpyoidc/client/oidc/authorization.py index 16554ace..c3865544 100644 --- a/src/idpyoidc/client/oidc/authorization.py +++ b/src/idpyoidc/client/oidc/authorization.py @@ -258,7 +258,14 @@ def construct_request_parameter( _req = make_openid_request(req, **_mor_args) # Should the request be encrypted - return request_object_encryption(_req, _context, **kwargs) + _req = request_object_encryption(_req, _context, + self.upstream_get('attribute', 'keyjar'), + **kwargs) + + if request_param == "request": + req["request"] = _req + else: # MUST be request_uri + req["request_uri"] = self.store_request_on_file(_req, **kwargs) def oidc_post_construct(self, req, **kwargs): """ diff --git a/src/idpyoidc/client/oidc/utils.py b/src/idpyoidc/client/oidc/utils.py index 5240eeb4..ced4d6f4 100644 --- a/src/idpyoidc/client/oidc/utils.py +++ b/src/idpyoidc/client/oidc/utils.py @@ -7,7 +7,7 @@ from idpyoidc.util import rndstr -def request_object_encryption(msg, service_context, **kwargs): +def request_object_encryption(msg, service_context, keyjar, **kwargs): """ Created an encrypted JSON Web token with *msg* as body. @@ -49,12 +49,11 @@ def request_object_encryption(msg, service_context, **kwargs): if "target" not in kwargs: raise MissingRequiredAttribute("No target specified") - _keyjar = service_context.upstream_get('attribute', 'keyjar') if _kid: - _keys = _keyjar.get_encrypt_key(_kty, issuer_id=kwargs["target"], kid=_kid) + _keys = keyjar.get_encrypt_key(_kty, issuer_id=kwargs["target"], kid=_kid) _jwe["kid"] = _kid else: - _keys = _keyjar.get_encrypt_key(_kty, issuer_id=kwargs["target"]) + _keys = keyjar.get_encrypt_key(_kty, issuer_id=kwargs["target"]) return _jwe.encrypt(_keys) diff --git a/src/idpyoidc/node.py b/src/idpyoidc/node.py index f5622247..71501a10 100644 --- a/src/idpyoidc/node.py +++ b/src/idpyoidc/node.py @@ -29,12 +29,14 @@ def __init__(self, config = {} self.entity_id = entity_id or config.get('entity_id', "") + if not self.entity_id: + self.entity_id = config.get('issuer', "") - if keyjar or key_conf or config.get('key_conf') or config.get('jwks'): + if keyjar or key_conf or config.get('key_conf') or config.get('jwks') or config.get('keys'): self.keyjar = self._keyjar(keyjar, conf=config, entity_id=self.entity_id, key_conf=key_conf) else: - self.keyjar = None + self.keyjar = KeyJar() self.httpc_params = httpc_params or config.get("httpc_params", {}) @@ -62,6 +64,9 @@ def get_attribute(self, attr, *args): else: return val + def set_attribute(self, attr, val): + setattr(self, attr, val) + def get_unit(self, *args): return self @@ -122,7 +127,7 @@ def __init__(self, httpc_params=httpc_params, config=config, entity_id=entity_id, key_conf=key_conf) - self._service_context = context or None + self.context = context or None class Collection(Unit): diff --git a/src/idpyoidc/server/__init__.py b/src/idpyoidc/server/__init__.py index f84df605..cc31d6a9 100644 --- a/src/idpyoidc/server/__init__.py +++ b/src/idpyoidc/server/__init__.py @@ -107,15 +107,6 @@ def server_get(self, what, *arg): return _func(*arg) return None - def get_attribute(self, attribute, *args): - try: - getattr(self, attribute) - except AttributeError: - if self.upstream_get: - return self.upstream_get(attribute) - else: - return None - def get_endpoints(self, *arg): return self.endpoint diff --git a/src/idpyoidc/server/endpoint.py b/src/idpyoidc/server/endpoint.py index a35da3da..a5a311f7 100755 --- a/src/idpyoidc/server/endpoint.py +++ b/src/idpyoidc/server/endpoint.py @@ -237,6 +237,7 @@ def client_authentication(self, request: Message, http_info: Optional[dict] = No authn_info = verify_client( context=self.upstream_get("context"), + keyjar=self.upstream_get('attribute','keyjar'), request=request, http_info=http_info, **kwargs diff --git a/src/idpyoidc/server/endpoint_context.py b/src/idpyoidc/server/endpoint_context.py index 34512093..67f21c30 100755 --- a/src/idpyoidc/server/endpoint_context.py +++ b/src/idpyoidc/server/endpoint_context.py @@ -9,7 +9,8 @@ from jinja2 import Environment from jinja2 import FileSystemLoader -import requests +from requests import request + from idpyoidc.context import OidcContext from idpyoidc.message.oidc import ProviderConfigurationResponse from idpyoidc.server.configure import OPConfiguration @@ -160,7 +161,7 @@ def __init__( self.cookie_handler = cookie_handler self.claims_interface = None self.endpoint_to_authn_method = {} - self.httpc = httpc or requests + self.httpc = httpc or request self.idtoken = None self.issuer = "" # self.jwks_uri = None diff --git a/src/idpyoidc/server/oauth2/authorization.py b/src/idpyoidc/server/oauth2/authorization.py index 60d26c0f..0a74922a 100755 --- a/src/idpyoidc/server/oauth2/authorization.py +++ b/src/idpyoidc/server/oauth2/authorization.py @@ -435,7 +435,7 @@ def _do_request_uri(self, request, client_id, context, **kwargs): raise ValueError("A request_uri outside the registered") # Fetch the request - _resp = context.httpc.get(_request_uri, **context.httpc_params) + _resp = context.httpc('GET', _request_uri, **context.httpc_params) if _resp.status_code == 200: args = { "keyjar": self.upstream_get('attribute', 'keyjar'), diff --git a/src/idpyoidc/server/oidc/registration.py b/src/idpyoidc/server/oidc/registration.py index e213f904..5c55fb01 100755 --- a/src/idpyoidc/server/oidc/registration.py +++ b/src/idpyoidc/server/oidc/registration.py @@ -317,8 +317,8 @@ def _verify_sector_identifier(self, request): """ si_url = request["sector_identifier_uri"] try: - res = self.upstream_get("context").httpc.get( - si_url, **self.upstream_get("context").httpc_params + res = self.upstream_get("context").httpc( + "GET", si_url, **self.upstream_get("context").httpc_params ) logger.debug("sector_identifier_uri => %s", sanitize(res.text)) except Exception as err: diff --git a/src/idpyoidc/server/oidc/session.py b/src/idpyoidc/server/oidc/session.py index 97182772..e9d2509e 100644 --- a/src/idpyoidc/server/oidc/session.py +++ b/src/idpyoidc/server/oidc/session.py @@ -407,7 +407,8 @@ def do_verified_logout(self, sid, alla=False, **kwargs): _url, sjwt = spec logger.info("logging out from {} at {}".format(_cid, _url)) - res = _context.httpc.post( + res = _context.httpc( + "POST", _url, data="logout_token={}".format(sjwt), headers={"Content-Type": "application/x-www-form-urlencoded"}, diff --git a/src/idpyoidc/server/token/jwt_token.py b/src/idpyoidc/server/token/jwt_token.py index 5ad7264b..ce03bb0c 100644 --- a/src/idpyoidc/server/token/jwt_token.py +++ b/src/idpyoidc/server/token/jwt_token.py @@ -112,7 +112,6 @@ def __call__( return signer.pack(payload) def get_payload(self, token): - _context = self.upstream_get("context") verifier = JWT(key_jar=self.upstream_get('attribute','keyjar'), allowed_sign_algs=[self.alg]) try: diff --git a/tests/test_client_04_service.py b/tests/test_client_04_service.py index 0660c1b1..d2778a7a 100644 --- a/tests/test_client_04_service.py +++ b/tests/test_client_04_service.py @@ -1,4 +1,5 @@ import pytest +from cryptojwt.key_jar import init_key_jar from idpyoidc.client.entity import Entity from idpyoidc.message.oauth2 import AuthorizationResponse @@ -44,6 +45,8 @@ def create_service(self): def upstream_get(self, *args): if args[0] == "context": return self.service_context + elif args[0] == 'attribute' and args[1] == 'keyjar': + return self.upstream_get('attribute','keyjar') def test_1(self): assert self.service diff --git a/tests/test_client_06_client_authn.py b/tests/test_client_06_client_authn.py index 0b40190f..e62529c0 100644 --- a/tests/test_client_06_client_authn.py +++ b/tests/test_client_06_client_authn.py @@ -399,7 +399,7 @@ def test_get_key_by_kid(self, entity): request = AccessTokenRequest() # get a kid - _keys = _service_context.keyjar.get_signing_key(key_type="oct") + _keys = entity.get_attribute('keyjar').get_issuer_keys("") kid = _keys[0].kid # token_service = entity.get_service("") token_service = entity.upstream_get("service", "accesstoken") diff --git a/tests/test_client_10_entity.py b/tests/test_client_10_entity.py new file mode 100644 index 00000000..6daeca49 --- /dev/null +++ b/tests/test_client_10_entity.py @@ -0,0 +1,71 @@ +import json +import os + +import pytest +import responses + +from idpyoidc.client.entity import Entity + + +class TestClientInfo(object): + @pytest.fixture(autouse=True) + def create_client_info_instance(self): + config = { + "client_id": "client_id", + "issuer": "issuer", + "client_secret": "longenoughsupersecret", + "base_url": "https://example.com", + "requests_dir": "requests", + } + self.entity = Entity(config=config) + + def test_import_keys_file(self): + # Should only be one and that a symmetric key (client_secret) usable + # for signing and encryption + assert len(self.entity.keyjar.get_issuer_keys("")) == 1 + + file_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "salesforce.key")) + + keyspec = {"file": {"rsa": [file_path]}} + self.entity.import_keys(keyspec) + + # Now there should be 2, the second a RSA key for signing + assert len(self.entity.keyjar.get_issuer_keys("")) == 2 + + def test_import_keys_url(self): + # Uses 2 variants of getting hold of the keyjar + assert len(self.entity.keyjar.get_issuer_keys("")) == 1 + + with responses.RequestsMock() as rsps: + _jwks_url = "https://foobar.com/jwks.json" + rsps.add( + "GET", + _jwks_url, + body=self.entity.get_attribute('keyjar').export_jwks_as_json(), + status=200, + adding_headers={"Content-Type": "application/json"}, + ) + keyspec = {"url": {"https://foobar.com": _jwks_url}} + self.entity.import_keys(keyspec) + + # Now there should be one belonging to https://example.com + assert len(self.entity.get_attribute('keyjar').get_issuer_keys( + "https://foobar.com")) == 1 + + def test_import_keys_file_json(self): + # Should only be one and that a symmetric key (client_secret) usable + # for signing and encryption + assert len(self.entity.keyjar.get_issuer_keys("")) == 1 + + file_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "salesforce.key")) + + keyspec = {"file": {"rsa": [file_path]}} + self.entity.import_keys(keyspec) + + _entity_state = self.entity.dump(exclude_attributes=["context"]) + _jsc_state = json.dumps(_entity_state) + _o_state = json.loads(_jsc_state) + _entity = Entity().load(_o_state) + + # Now there should be 2, the second a RSA key for signing + assert len(_entity.keyjar.get_issuer_keys("")) == 2 diff --git a/tests/test_client_12_client_auth.py b/tests/test_client_12_client_auth.py index f1b0c9ec..9149ffd0 100755 --- a/tests/test_client_12_client_auth.py +++ b/tests/test_client_12_client_auth.py @@ -287,11 +287,12 @@ def test_construct(self, entity): _keyjar = token_service.upstream_get("attribute", "keyjar") _keyjar.add_kb("", kb_rsa) - _keyjar.provider_info = { + + _context = token_service.upstream_get("context") + _context.provider_info = { "issuer": "https://example.com/", "token_endpoint": "https://example.com/token", } - _context = token_service.upstream_get("context") _context.registration_response = {"token_endpoint_auth_signing_alg": "RS256"} token_service.endpoint = "https://example.com/token" @@ -383,7 +384,7 @@ def test_get_key_by_kid(self, entity): request = AccessTokenRequest() # get a kid - _keys = _service_context.keyjar.get_issuer_keys("") + _keys = entity.keyjar.get_issuer_keys("") kid = _keys[0].kid token_service = entity.get_service("accesstoken") csj.construct(request, service=token_service, authn_endpoint="token_endpoint", kid=kid) @@ -431,7 +432,7 @@ def test_get_audience_and_algorithm_default_alg(self, entity): _kb = KeyBundle() _rsa_key = new_rsa_key() _kb.append(_rsa_key) - _service_context.keyjar.add_kb("", _kb) + entity.keyjar.add_kb("", _kb) # Since I have a RSA key this doesn't fail csj.construct(request, service=token_service, authn_endpoint="token_endpoint") diff --git a/tests/test_client_13_service_context.py b/tests/test_client_13_service_context.py index e69de29b..5b3ceef8 100644 --- a/tests/test_client_13_service_context.py +++ b/tests/test_client_13_service_context.py @@ -0,0 +1,254 @@ +import os +from urllib.parse import urlsplit + +import pytest +import responses +from cryptojwt.key_jar import build_keyjar + +from idpyoidc.client.entity import Entity +from idpyoidc.client.service_context import ServiceContext + +BASE_URL = "https://entity.example.org" + + +def test_client_info_init(): + config = { + "client_id": "client_id", + "issuer": "issuer", + "client_secret": "client_secret_wordplay", + "base_url": "https://example.com", + "requests_dir": "requests", + } + entity = Entity(entity_id=BASE_URL, config=config) + entity_copy = Entity().load(entity.dump()) + + srvcnx = entity_copy.get_context() + + for attr in config.keys(): + try: + val = getattr(srvcnx, attr) + except AttributeError: + val = srvcnx.get(attr) + + assert val == config[attr] + + +def test_set_and_get_client_secret(): + service_context = ServiceContext() + service_context.client_secret = "longenoughsupersecret" + assert service_context.client_secret == "longenoughsupersecret" + + +def test_set_and_get_client_id(): + ci = ServiceContext() + ci.client_id = "myself" + assert ci.client_id == "myself" + + +def test_client_filename(): + config = { + "client_id": "client_id", + "issuer": "issuer", + "client_secret": "longenoughsupersecret", + "base_url": "https://example.com", + "requests_dir": "requests", + } + entity = Entity(config=config) + fname = entity.get_context().filename_from_webname("https://example.com/rq12345") + assert fname == "rq12345" + + +def verify_alg_support(service_context, alg, usage, typ): + """ + Verifies that the algorithm to be used are supported by the other side. + This will look at provider information either statically configured or + obtained through dynamic provider info discovery. + + :param alg: The algorithm specification + :param usage: In which context the 'alg' will be used. + The following contexts are supported: + - userinfo + - id_token + - request_object + - token_endpoint_auth + :param typ: Type of algorithm + - signing_alg + - encryption_alg + - encryption_enc + :return: True or False + """ + + supported = service_context.provider_info["{}_{}_values_supported".format(usage, typ)] + + if alg in supported: + return True + else: + return False + + +class TestClientInfo(object): + @pytest.fixture(autouse=True) + def create_client_info_instance(self): + config = { + "client_id": "client_id", + "issuer": "issuer", + "client_secret": "longenoughsupersecret", + "base_url": "https://example.com", + "requests_dir": "requests", + } + self.entity = Entity(config=config) + self.service_context = self.entity.get_context() + + def test_registration_userinfo_sign_enc_algs(self): + self.service_context.behaviour = { + "application_type": "web", + "redirect_uris": [ + "https://client.example.org/callback", + "https://client.example.org/callback2", + ], + "token_endpoint_auth_method": "client_secret_basic", + "jwks_uri": "https://client.example.org/my_public_keys.jwks", + "userinfo_encrypted_response_alg": "RSA1_5", + "userinfo_encrypted_response_enc": "A128CBC-HS256", + } + + assert self.service_context.get_sign_alg("userinfo") is None + assert self.service_context.get_enc_alg_enc("userinfo") == { + "alg": "RSA1_5", + "enc": "A128CBC-HS256", + } + + def test_registration_request_object_sign_enc_algs(self): + self.service_context.behaviour = { + "application_type": "web", + "redirect_uris": [ + "https://client.example.org/callback", + "https://client.example.org/callback2", + ], + "token_endpoint_auth_method": "client_secret_basic", + "jwks_uri": "https://client.example.org/my_public_keys.jwks", + "userinfo_encrypted_response_alg": "RSA1_5", + "userinfo_encrypted_response_enc": "A128CBC-HS256", + "request_object_signing_alg": "RS384", + } + + res = self.service_context.get_enc_alg_enc("userinfo") + # 'sign':'RS256' is an added default + assert res == {"alg": "RSA1_5", "enc": "A128CBC-HS256"} + res = self.service_context.get_sign_alg("request_object") + assert res == "RS384" + + def test_registration_id_token_sign_enc_algs(self): + self.service_context.behaviour = { + "application_type": "web", + "redirect_uris": [ + "https://client.example.org/callback", + "https://client.example.org/callback2", + ], + "token_endpoint_auth_method": "client_secret_basic", + "jwks_uri": "https://client.example.org/my_public_keys.jwks", + "userinfo_encrypted_response_alg": "RSA1_5", + "userinfo_encrypted_response_enc": "A128CBC-HS256", + "request_object_signing_alg": "RS384", + "id_token_encrypted_response_alg": "ECDH-ES", + "id_token_encrypted_response_enc": "A128GCM", + "id_token_signed_response_alg": "ES384", + } + + res = self.service_context.get_enc_alg_enc("userinfo") + # 'sign':'RS256' is an added default + assert res == {"alg": "RSA1_5", "enc": "A128CBC-HS256"} + res = self.service_context.get_sign_alg("request_object") + assert res == "RS384" + res = self.service_context.get_enc_alg_enc("id_token") + assert res == {"alg": "ECDH-ES", "enc": "A128GCM"} + + def test_verify_alg_support(self): + self.service_context.provider_info = { + "version": "3.0", + "issuer": "https://server.example.com", + "authorization_endpoint": "https://server.example.com/connect/authorize", + "token_endpoint": "https://server.example.com/connect/token", + "token_endpoint_auth_methods_supported": ["client_secret_basic", "private_key_jwt"], + "token_endpoint_auth_signing_alg_values_supported": ["RS256", "ES256"], + "userinfo_endpoint": "https://server.example.com/connect/userinfo", + "check_session_iframe": "https://server.example.com/connect/check_session", + "end_session_endpoint": "https://server.example.com/connect/end_session", + "jwks_uri": "https://server.example.com/jwks.json", + "registration_endpoint": "https://server.example.com/connect/register", + "scopes_supported": [ + "openid", + "profile", + "email", + "address", + "phone", + "offline_access", + ], + "response_types_supported": ["code", "code id_token", "id_token", "token id_token"], + "acr_values_supported": [ + "urn:mace:incommon:iap:silver", + "urn:mace:incommon:iap:bronze", + ], + "subject_types_supported": ["public", "pairwise"], + "userinfo_signing_alg_values_supported": ["RS256", "ES256", "HS256"], + "userinfo_encryption_alg_values_supported": ["RSA1_5", "A128KW"], + "userinfo_encryption_enc_values_supported": ["A128CBC+HS256", "A128GCM"], + "id_token_signing_alg_values_supported": ["RS256", "ES256", "HS256"], + "id_token_encryption_alg_values_supported": ["RSA1_5", "A128KW"], + "id_token_encryption_enc_values_supported": ["A128CBC+HS256", "A128GCM"], + "request_object_signing_alg_values_supported": ["none", "RS256", "ES256"], + "display_values_supported": ["page", "popup"], + "claim_types_supported": ["normal", "distributed"], + "claims_supported": [ + "sub", + "iss", + "auth_time", + "acr", + "name", + "given_name", + "family_name", + "nickname", + "profile", + "picture", + "website", + "email", + "email_verified", + "locale", + "zoneinfo", + "http://example.info/claims/groups", + ], + "claims_parameter_supported": True, + "service_documentation": "http://server.example.com/connect/service_documentation.html", + "ui_locales_supported": ["en-US", "en-GB", "en-CA", "fr-FR", "fr-CA"], + } + + assert verify_alg_support(self.service_context, "RS256", "id_token", "signing_alg") + assert verify_alg_support(self.service_context, "RS512", "id_token", "signing_alg") is False + + assert verify_alg_support(self.service_context, "RSA1_5", "userinfo", "encryption_alg") + + # token_endpoint_auth_signing_alg_values_supported + assert verify_alg_support( + self.service_context, "ES256", "token_endpoint_auth", "signing_alg" + ) + + def test_verify_requests_uri(self): + self.service_context.provider_info = {"issuer": "https://example.com/"} + url_list = self.service_context.generate_redirect_uris("/leading") + sp = urlsplit(url_list[0]) + p = sp.path.split("/") + assert p[0] == "" + assert p[1] == "leading" + assert len(p) == 3 + + # different for different OPs + self.service_context.provider_info = {"issuer": "https://op.example.org/"} + url_list = self.service_context.generate_redirect_uris("/leading") + sp = urlsplit(url_list[0]) + np = sp.path.split("/") + assert np[0] == "" + assert np[1] == "leading" + assert len(np) == 3 + + assert np[2] != p[2] + diff --git a/tests/test_client_14_service_context_impexp.py b/tests/test_client_14_service_context_impexp.py index 51647cb7..976f475d 100644 --- a/tests/test_client_14_service_context_impexp.py +++ b/tests/test_client_14_service_context_impexp.py @@ -5,6 +5,7 @@ import responses from cryptojwt.key_jar import build_keyjar +from idpyoidc.client.entity import Entity from idpyoidc.client.service_context import ServiceContext BASE_URL = "https://example.com" @@ -106,7 +107,8 @@ def create_client_info_instance(self): "base_url": "https://example.com", "requests_dir": "requests", } - self.service_context = ServiceContext(config=config) + self.entity = Entity(config=config) + self.service_context = self.entity.get_context() def test_registration_userinfo_sign_enc_algs(self): self.service_context.work_environment.use = { diff --git a/tests/test_client_19_webfinger.py b/tests/test_client_19_webfinger.py index 1953d251..5e1e4ddc 100644 --- a/tests/test_client_19_webfinger.py +++ b/tests/test_client_19_webfinger.py @@ -261,7 +261,7 @@ def test_query_acct_resource_kwargs(self): assert qs["rel"][0] == "http://openid.net/specs/connect/1.0/issuer" def test_query_acct_resource_config(self): - wf = WebFinger(ENTITY.entity_get, rel=OIC_ISSUER) + wf = WebFinger(ENTITY.unit_get, rel=OIC_ISSUER) wf.upstream_get("context").config["resource"] = "acct:carol@example.com" request_args = {} _info = wf.get_request_parameters(request_args=request_args) @@ -273,7 +273,7 @@ def test_query_acct_resource_config(self): assert qs["rel"][0] == "http://openid.net/specs/connect/1.0/issuer" def test_query_acct_no_resource(self): - wf = WebFinger(ENTITY.entity_get, rel=OIC_ISSUER) + wf = WebFinger(ENTITY.unit_get, rel=OIC_ISSUER) try: del wf.upstream_get("context").config["resource"] except KeyError: diff --git a/tests/test_client_21_oidc_service.py b/tests/test_client_21_oidc_service.py index 595e0d45..e9ae9c1e 100644 --- a/tests/test_client_21_oidc_service.py +++ b/tests/test_client_21_oidc_service.py @@ -1104,8 +1104,8 @@ def test_unpack_encrypted_response(self): # Add encryption key _kj = build_keyjar([{"type": "RSA", "use": ["enc"]}], issuer_id="") # Own key jar gets the private key - self.service.upstream_get("service_context").keyjar.import_jwks( - _kj.export_jwks(private=True), issuer_id="" + self.service.upstream_get("attribute",'keyjar').import_jwks( + _kj.export_jwks(private=True), issuer_id="client_id" ) # opponent gets the public key ISS_KEY.import_jwks(_kj.export_jwks(), issuer_id="client_id") diff --git a/tests/test_client_24_oic_utils.py b/tests/test_client_24_oic_utils.py index f903d992..1e1b42f9 100644 --- a/tests/test_client_24_oic_utils.py +++ b/tests/test_client_24_oic_utils.py @@ -31,7 +31,7 @@ def test_request_object_encryption(): _condition.set_usage("request_object_encryption_alg", "RSA1_5") _condition.set_usage("request_object_encryption_enc", "A128CBC-HS256") - _jwe = request_object_encryption(msg.to_json(), service_context, target=RECEIVER) + _jwe = request_object_encryption(msg.to_json(), service_context, KEYJAR, target=RECEIVER) assert _jwe _decryptor = factory(_jwe) diff --git a/tests/test_client_28_rp_handler_oidc.py b/tests/test_client_28_rp_handler_oidc.py index d59ae361..bec7c541 100644 --- a/tests/test_client_28_rp_handler_oidc.py +++ b/tests/test_client_28_rp_handler_oidc.py @@ -892,7 +892,7 @@ def test_finalize(self): p = urlparse(CLIENT_CONFIG["github"]["provider_info"]["authorization_endpoint"]) self.mock_op.register_get_response(p.path, "Redirect", 302) - _ = client.httpc(auth_query["url"]) + _ = client.httpc("GET", auth_query["url"]) # the user is redirected back to the RP with a positive response auth_response = AuthorizationResponse(code="access_code", state=auth_query["state"]) diff --git a/tests/test_client_30_rph_defaults.py b/tests/test_client_30_rph_defaults.py index a7dbfbf7..19916f46 100644 --- a/tests/test_client_30_rph_defaults.py +++ b/tests/test_client_30_rph_defaults.py @@ -157,7 +157,7 @@ def test_begin_2(self): # Calculating request so I can build a reasonable response # Publishing a JWKS instead of a JWKS_URI _context.jwks_uri = "" - _context.jwks = _context.keyjar.export_jwks() + _context.jwks = client.keyjar.export_jwks() _req = client.get_service("registration").construct_request() diff --git a/tests/test_client_51_identity_assurance.py b/tests/test_client_51_identity_assurance.py index 54d1bc99..6758e6b1 100644 --- a/tests/test_client_51_identity_assurance.py +++ b/tests/test_client_51_identity_assurance.py @@ -72,7 +72,7 @@ def test_unpack_aggregated_response(self): }, } - _jwt = JWT(key_jar=self.service.upstream_get("context").keyjar) + _jwt = JWT(key_jar=self.service.upstream_get("attribute",'keyjar')) _jws = _jwt.pack(payload=_distributed_respone) resp = { diff --git a/tests/test_server_03_authz_handling.py b/tests/test_server_03_authz_handling.py index f50e2c90..af152e81 100644 --- a/tests/test_server_03_authz_handling.py +++ b/tests/test_server_03_authz_handling.py @@ -134,7 +134,7 @@ def create_idtoken(self): "response_types": ["code", "token", "code id_token", "id_token"], "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"] } - server.get_attribute('attribute', 'keyjar').add_symmetric( + server.get_attribute('keyjar').add_symmetric( "client_1", "hemligtochintekort", ["sig", "enc"] ) server.endpoint = do_endpoints(conf, server.upstream_get) diff --git a/tests/test_server_08_id_token.py b/tests/test_server_08_id_token.py index 5bfd019f..2bfbc56a 100644 --- a/tests/test_server_08_id_token.py +++ b/tests/test_server_08_id_token.py @@ -161,8 +161,8 @@ def full_path(local_file): class TestEndpoint(object): @pytest.fixture(autouse=True) def create_session_manager(self): - server = Server(conf) - self.endpoint_context = server.endpoint_context + self.server = Server(conf) + self.endpoint_context = self.server.endpoint_context self.endpoint_context.cdb["client_1"] = { "client_secret": "hemligtochintekort", "redirect_uris": [("https://example.com/cb", None)], @@ -175,7 +175,7 @@ def create_session_manager(self): }, "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"] } - self.endpoint_context.keyjar.add_symmetric("client_1", "hemligtochintekort", ["sig", "enc"]) + self.server.keyjar.add_symmetric("client_1", "hemligtochintekort", ["sig", "enc"]) self.session_manager = self.endpoint_context.session_manager self.user_id = USER_ID @@ -412,7 +412,7 @@ def test_sign_encrypt_id_token(self): assert _jws.jwt.headers["alg"] == "RS256" client_keyjar = KeyJar() - _jwks = self.endpoint_context.keyjar.export_jwks() + _jwks = self.server.keyjar.export_jwks() client_keyjar.import_jwks(_jwks, self.endpoint_context.issuer) _jwt = JWT(key_jar=client_keyjar, iss="client_1") @@ -452,7 +452,7 @@ def test_available_claims(self): id_token = self._mint_id_token(grant, session_id) client_keyjar = KeyJar() - _jwks = self.endpoint_context.keyjar.export_jwks() + _jwks = self.server.keyjar.export_jwks() client_keyjar.import_jwks(_jwks, self.endpoint_context.issuer) _jwt = JWT(key_jar=client_keyjar, iss="client_1") res = _jwt.unpack(id_token.value) @@ -465,7 +465,7 @@ def test_lifetime_default(self): id_token = self._mint_id_token(grant, session_id) client_keyjar = KeyJar() - _jwks = self.endpoint_context.keyjar.export_jwks() + _jwks = self.server.keyjar.export_jwks() client_keyjar.import_jwks(_jwks, self.endpoint_context.issuer) _jwt = JWT(key_jar=client_keyjar, iss="client_1") res = _jwt.unpack(id_token.value) @@ -482,7 +482,7 @@ def test_lifetime(self): id_token = self._mint_id_token(grant, session_id) client_keyjar = KeyJar() - _jwks = self.endpoint_context.keyjar.export_jwks() + _jwks = self.server.keyjar.export_jwks() client_keyjar.import_jwks(_jwks, self.endpoint_context.issuer) _jwt = JWT(key_jar=client_keyjar, iss="client_1") res = _jwt.unpack(id_token.value) @@ -497,7 +497,7 @@ def test_no_available_claims(self): id_token = self._mint_id_token(grant, session_id) client_keyjar = KeyJar() - _jwks = self.endpoint_context.keyjar.export_jwks() + _jwks = self.server.keyjar.export_jwks() client_keyjar.import_jwks(_jwks, self.endpoint_context.issuer) _jwt = JWT(key_jar=client_keyjar, iss="client_1") res = _jwt.unpack(id_token.value) @@ -520,7 +520,7 @@ def test_client_claims(self): id_token = self._mint_id_token(grant, session_id) client_keyjar = KeyJar() - _jwks = self.endpoint_context.keyjar.export_jwks() + _jwks = self.server.keyjar.export_jwks() client_keyjar.import_jwks(_jwks, self.endpoint_context.issuer) _jwt = JWT(key_jar=client_keyjar, iss="client_1") res = _jwt.unpack(id_token.value) @@ -539,7 +539,7 @@ def test_client_claims_with_default(self): id_token = self._mint_id_token(grant, session_id) client_keyjar = KeyJar() - _jwks = self.endpoint_context.keyjar.export_jwks() + _jwks = self.server.keyjar.export_jwks() client_keyjar.import_jwks(_jwks, self.endpoint_context.issuer) _jwt = JWT(key_jar=client_keyjar, iss="client_1") res = _jwt.unpack(id_token.value) @@ -558,7 +558,7 @@ def test_client_claims_scopes(self): id_token = self._mint_id_token(grant, session_id) client_keyjar = KeyJar() - _jwks = self.endpoint_context.keyjar.export_jwks() + _jwks = self.server.keyjar.export_jwks() client_keyjar.import_jwks(_jwks, self.endpoint_context.issuer) _jwt = JWT(key_jar=client_keyjar, iss="client_1") res = _jwt.unpack(id_token.value) @@ -580,7 +580,7 @@ def test_client_claims_scopes_per_client(self): id_token = self._mint_id_token(grant, session_id) client_keyjar = KeyJar() - _jwks = self.endpoint_context.keyjar.export_jwks() + _jwks = self.server.keyjar.export_jwks() client_keyjar.import_jwks(_jwks, self.endpoint_context.issuer) _jwt = JWT(key_jar=client_keyjar, iss="client_1") res = _jwt.unpack(id_token.value) @@ -598,7 +598,7 @@ def test_client_claims_scopes_and_request_claims_no_match(self): id_token = self._mint_id_token(grant, session_id) client_keyjar = KeyJar() - _jwks = self.endpoint_context.keyjar.export_jwks() + _jwks = self.server.keyjar.export_jwks() client_keyjar.import_jwks(_jwks, self.endpoint_context.issuer) _jwt = JWT(key_jar=client_keyjar, iss="client_1") res = _jwt.unpack(id_token.value) @@ -621,7 +621,7 @@ def test_client_claims_scopes_and_request_claims_one_match(self): id_token = self._mint_id_token(grant, session_id) client_keyjar = KeyJar() - _jwks = self.endpoint_context.keyjar.export_jwks() + _jwks = self.server.keyjar.export_jwks() client_keyjar.import_jwks(_jwks, self.endpoint_context.issuer) _jwt = JWT(key_jar=client_keyjar, iss="client_1") res = _jwt.unpack(id_token.value) diff --git a/tests/test_server_09_authn_context.py b/tests/test_server_09_authn_context.py index b4b3e94a..07c5dcf9 100644 --- a/tests/test_server_09_authn_context.py +++ b/tests/test_server_09_authn_context.py @@ -164,9 +164,7 @@ def create_authn_broker(self): "code id_token token", ], } - endpoint_context.keyjar.import_jwks( - endpoint_context.keyjar.export_jwks(True, ""), conf["issuer"] - ) + server.keyjar.import_jwks(server.keyjar.export_jwks(True, ""), conf["issuer"]) self.server = server diff --git a/tests/test_server_17_client_authn.py b/tests/test_server_17_client_authn.py index af80c675..e8276d84 100644 --- a/tests/test_server_17_client_authn.py +++ b/tests/test_server_17_client_authn.py @@ -226,7 +226,7 @@ def test_private_key_jwt(self): client_keyjar.import_jwks(KEYJAR.export_jwks(private=True), CONF["issuer"]) _jwks = client_keyjar.export_jwks() - self.endpoint_context.keyjar.import_jwks(_jwks, client_id) + self.server.keyjar.import_jwks(_jwks, client_id) _jwt = JWT(client_keyjar, iss=client_id, sign_alg="RS256") _jwt.with_jti = True @@ -247,7 +247,7 @@ def test_private_key_jwt_reusage_other_endpoint(self): client_keyjar.import_jwks(KEYJAR.export_jwks(private=True), CONF["issuer"]) _jwks = client_keyjar.export_jwks() - self.endpoint_context.keyjar.import_jwks(_jwks, client_id) + self.server.keyjar.import_jwks(_jwks, client_id) _jwt = JWT(client_keyjar, iss=client_id, sign_alg="RS256") _jwt.with_jti = True @@ -282,7 +282,7 @@ def test_private_key_jwt_auth_endpoint(self): client_keyjar.import_jwks(KEYJAR.export_jwks(private=True), CONF["issuer"]) _jwks = client_keyjar.export_jwks() - self.endpoint_context.keyjar.import_jwks(_jwks, client_id) + self.server.keyjar.import_jwks(_jwks, client_id) _jwt = JWT(client_keyjar, iss=client_id, sign_alg="RS256") _jwt.with_jti = True @@ -483,7 +483,8 @@ def test_verify_per_client(self): request = {"client_id": client_id} res = verify_client( self.endpoint_context, - request, + keyjar=self.server.get_attribute('keyjar'), + request=request, endpoint=self.server.get_endpoint("endpoint_4"), ) assert res == {"method": "public", "client_id": client_id} @@ -499,7 +500,8 @@ def test_verify_per_client_per_endpoint(self): request = {"client_id": client_id} res = verify_client( self.endpoint_context, - request, + keyjar=self.server.get_attribute('keyjar'), + request=request, endpoint=self.server.get_endpoint("endpoint_4"), ) assert res == {"method": "public", "client_id": client_id} @@ -507,7 +509,8 @@ def test_verify_per_client_per_endpoint(self): with pytest.raises(ClientAuthenticationError) as e: verify_client( self.endpoint_context, - request, + keyjar=self.server.get_attribute('keyjar'), + request=request, endpoint=self.server.get_endpoint("endpoint_1"), ) assert e.value.args[0] == "Failed to verify client" @@ -515,7 +518,8 @@ def test_verify_per_client_per_endpoint(self): request = {"client_id": client_id, "client_secret": client_secret} res = verify_client( self.endpoint_context, - request, + keyjar=self.server.get_attribute('keyjar'), + request=request, endpoint=self.server.get_endpoint("endpoint_1"), ) assert set(res.keys()) == {"method", "client_id"} @@ -525,7 +529,8 @@ def test_verify_client_client_secret_post(self): request = {"client_id": client_id, "client_secret": client_secret} res = verify_client( self.endpoint_context, - request, + keyjar=self.server.get_attribute('keyjar'), + request=request, endpoint=self.server.get_endpoint("endpoint_1"), ) assert set(res.keys()) == {"method", "client_id"} @@ -546,7 +551,8 @@ def test_verify_client_jws_authn_method(self): http_info = {"headers": {}} res = verify_client( self.endpoint_context, - request, + keyjar=self.server.get_attribute('keyjar'), + request=request, http_info=http_info, endpoint=self.server.get_endpoint("endpoint_1"), ) @@ -558,22 +564,14 @@ def test_verify_client_bearer_body(self): self.endpoint_context.registration_access_token["1234567890"] = client_id res = verify_client( self.endpoint_context, - request, + keyjar=self.server.get_attribute('keyjar'), + request=request, get_client_id_from_token=get_client_id_from_token, endpoint=self.server.get_endpoint("endpoint_3"), ) assert set(res.keys()) == {"token", "method", "client_id"} assert res["method"] == "bearer_body" - # def test_verify_client_client_secret_post(self): - # request = {"client_id": client_id, "client_secret": client_secret} - # res = verify_client( - # self.endpoint_context, request, endpoint=self.server.upstream_get("endpoint", - # "endpoint_1"), - # ) - # assert set(res.keys()) == {"method", "client_id"} - # assert res["method"] == "client_secret_post" - def test_verify_client_client_secret_basic(self): _token = "{}:{}".format(client_id, client_secret) token = as_unicode(base64.b64encode(as_bytes(_token))) @@ -582,6 +580,7 @@ def test_verify_client_client_secret_basic(self): res = verify_client( self.endpoint_context, + keyjar=self.server.get_attribute('keyjar'), request={}, http_info=http_info, endpoint=self.server.get_endpoint("endpoint_1"), @@ -598,7 +597,8 @@ def test_verify_client_bearer_header(self): request = {"client_id": client_id} res = verify_client( self.endpoint_context, - request, + keyjar=self.server.get_attribute('keyjar'), + request=request, http_info=http_info, get_client_id_from_token=get_client_id_from_token, endpoint=self.server.get_endpoint("endpoint_2"), @@ -630,7 +630,8 @@ def test_verify_client_jws_authn_method(self): res = verify_client( self.endpoint_context, - request, + keyjar=self.server.get_attribute('keyjar'), + request=request, endpoint=self.server.get_endpoint("endpoint_1"), ) assert res["method"] == "client_secret_jwt" @@ -641,7 +642,8 @@ def test_verify_client_bearer_body(self): self.endpoint_context.registration_access_token["1234567890"] = client_id res = verify_client( self.endpoint_context, - request, + keyjar=self.server.get_attribute('keyjar'), + request=request, get_client_id_from_token=get_client_id_from_token, endpoint=self.server.get_endpoint("endpoint_3"), ) @@ -652,7 +654,8 @@ def test_verify_client_client_secret_post(self): request = {"client_id": client_id, "client_secret": client_secret} res = verify_client( self.endpoint_context, - request, + keyjar=self.server.get_attribute('keyjar'), + request=request, endpoint=self.server.get_endpoint("endpoint_1"), ) assert set(res.keys()) == {"method", "client_id"} @@ -666,7 +669,8 @@ def test_verify_client_client_secret_basic(self): res = verify_client( self.endpoint_context, - {}, + keyjar=self.server.get_attribute('keyjar'), + request={}, http_info=http_info, endpoint=self.server.get_endpoint("endpoint_1"), ) @@ -682,7 +686,8 @@ def test_verify_client_bearer_header(self): request = {"client_id": client_id} res = verify_client( self.endpoint_context, - request, + keyjar=self.server.get_attribute('keyjar'), + request=request, http_info=http_info, get_client_id_from_token=get_client_id_from_token, endpoint=self.server.get_endpoint("endpoint_2"), @@ -695,7 +700,8 @@ def test_verify_client_authorization_none(self): request = {"client_id": client_id} res = verify_client( self.endpoint_context, - request, + keyjar=self.server.get_attribute('keyjar'), + request=request, endpoint=self.server.get_endpoint("endpoint_2"), ) assert res["method"] == "none" @@ -706,7 +712,8 @@ def test_verify_client_registration_public(self): request = {"redirect_uris": ["https://example.com/cb"], "client_id": "client_id"} res = verify_client( self.endpoint_context, - request, + keyjar=self.server.get_attribute('keyjar'), + request=request, endpoint=self.server.get_endpoint("endpoint_4"), ) assert res == {"client_id": "client_id", "method": "public"} @@ -716,7 +723,8 @@ def test_verify_client_registration_none(self): request = {"redirect_uris": ["https://example.com/cb"]} res = verify_client( self.endpoint_context, - request, + keyjar=self.server.get_attribute('keyjar'), + request=request, endpoint=self.server.get_endpoint("endpoint_4"), ) assert res == {"client_id": None, "method": "none"} @@ -738,7 +746,10 @@ class Mock: request = {"redirect_uris": ["https://example.com/cb"]} res = verify_client( - server.endpoint_context, request, endpoint=server.get_endpoint("endpoint_4") + server.endpoint_context, + keyjar=server.get_attribute('keyjar'), + request=request, + endpoint=server.get_endpoint("endpoint_4") ) assert res == {"client_id": "client_id", "method": "custom"} diff --git a/tests/test_server_20b_claims.py b/tests/test_server_20b_claims.py index 3b2759ec..81d290b4 100644 --- a/tests/test_server_20b_claims.py +++ b/tests/test_server_20b_claims.py @@ -127,7 +127,7 @@ def create_idtoken(self): }, "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"] } - server.endpoint_context.keyjar.add_symmetric( + server.keyjar.add_symmetric( "client_1", "hemligtochintekort", ["sig", "enc"] ) self.claims_interface = server.endpoint_context.claims_interface diff --git a/tests/test_server_20c_authz_handling.py b/tests/test_server_20c_authz_handling.py index 8a7c1aa1..e2ea920d 100644 --- a/tests/test_server_20c_authz_handling.py +++ b/tests/test_server_20c_authz_handling.py @@ -110,7 +110,7 @@ def create_idtoken(self): "response_types": ["code", "token", "code id_token", "id_token"], "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"] } - server.endpoint_context.keyjar.add_symmetric( + server.keyjar.add_symmetric( "client_1", "hemligtochintekort", ["sig", "enc"] ) self.session_manager = server.endpoint_context.session_manager diff --git a/tests/test_server_20d_client_authn.py b/tests/test_server_20d_client_authn.py index cda6f5fc..f2552f80 100755 --- a/tests/test_server_20d_client_authn.py +++ b/tests/test_server_20d_client_authn.py @@ -187,7 +187,7 @@ def test_private_key_jwt(self): client_keyjar.import_jwks(KEYJAR.export_jwks(private=True), CONF["issuer"]) _jwks = client_keyjar.export_jwks() - self.endpoint_context.keyjar.import_jwks(_jwks, client_id) + self.server.keyjar.import_jwks(_jwks, client_id) _jwt = JWT(client_keyjar, iss=client_id, sign_alg="RS256") _jwt.with_jti = True @@ -208,7 +208,7 @@ def test_private_key_jwt_reusage_other_endpoint(self): client_keyjar.import_jwks(KEYJAR.export_jwks(private=True), CONF["issuer"]) _jwks = client_keyjar.export_jwks() - self.endpoint_context.keyjar.import_jwks(_jwks, client_id) + self.server.keyjar.import_jwks(_jwks, client_id) _jwt = JWT(client_keyjar, iss=client_id, sign_alg="RS256") _jwt.with_jti = True @@ -239,7 +239,7 @@ def test_private_key_jwt_auth_endpoint(self): client_keyjar.import_jwks(KEYJAR.export_jwks(private=True), CONF["issuer"]) _jwks = client_keyjar.export_jwks() - self.endpoint_context.keyjar.import_jwks(_jwks, client_id) + self.server.keyjar.import_jwks(_jwks, client_id) _jwt = JWT(client_keyjar, iss=client_id, sign_alg="RS256") _jwt.with_jti = True @@ -436,7 +436,8 @@ def test_verify_per_client(self): request = {"client_id": client_id} res = verify_client( self.endpoint_context, - request, + keyjar=self.server.get_attribute('keyjar'), + request=request, endpoint=self.server.get_endpoint("registration"), ) assert res == {"method": "public", "client_id": client_id} @@ -452,7 +453,8 @@ def test_verify_per_client_per_endpoint(self): request = {"client_id": client_id} res = verify_client( self.endpoint_context, - request, + keyjar=self.server.get_attribute('keyjar'), + request=request, endpoint=self.server.get_endpoint("registration"), ) assert res == {"method": "public", "client_id": client_id} @@ -467,7 +469,8 @@ def test_verify_per_client_per_endpoint(self): request = {"client_id": client_id, "client_secret": client_secret} res = verify_client( self.endpoint_context, - request, + keyjar=self.server.get_attribute('keyjar'), + request=request, endpoint=self.server.get_endpoint("token"), ) assert set(res.keys()) == {"method", "client_id"} @@ -477,7 +480,8 @@ def test_verify_client_client_secret_post(self): request = {"client_id": client_id, "client_secret": client_secret} res = verify_client( self.endpoint_context, - request, + keyjar=self.server.get_attribute('keyjar'), + request=request, endpoint=self.server.get_endpoint("token"), ) assert set(res.keys()) == {"method", "client_id"} @@ -498,7 +502,8 @@ def test_verify_client_jws_authn_method(self): http_info = {"headers": {}} res = verify_client( self.endpoint_context, - request, + keyjar=self.server.get_attribute('keyjar'), + request=request, http_info=http_info, endpoint=self.server.get_endpoint("token"), ) @@ -510,7 +515,8 @@ def test_verify_client_bearer_body(self): self.endpoint_context.registration_access_token["1234567890"] = client_id res = verify_client( self.endpoint_context, - request, + keyjar=self.server.get_attribute('keyjar'), + request=request, get_client_id_from_token=get_client_id_from_token, endpoint=self.server.get_endpoint("userinfo"), ) @@ -521,7 +527,8 @@ def test_verify_client_client_secret_post(self): request = {"client_id": client_id, "client_secret": client_secret} res = verify_client( self.endpoint_context, - request, + keyjar=self.server.get_attribute('keyjar'), + request=request, endpoint=self.server.get_endpoint("token"), ) assert set(res.keys()) == {"method", "client_id"} @@ -535,6 +542,7 @@ def test_verify_client_client_secret_basic(self): res = verify_client( self.endpoint_context, + keyjar=self.server.get_attribute('keyjar'), request={}, http_info=http_info, endpoint=self.server.get_endpoint("token"), @@ -551,7 +559,8 @@ def test_verify_client_bearer_header(self): request = {"client_id": client_id} res = verify_client( self.endpoint_context, - request, + keyjar=self.server.get_attribute('keyjar'), + request=request, http_info=http_info, get_client_id_from_token=get_client_id_from_token, endpoint=self.server.get_endpoint("authorization"), @@ -582,7 +591,8 @@ def test_verify_client_jws_authn_method(self): res = verify_client( self.endpoint_context, - request, + keyjar=self.server.get_attribute('keyjar'), + request=request, endpoint=self.server.get_endpoint("token"), ) assert res["method"] == "client_secret_jwt" @@ -593,7 +603,8 @@ def test_verify_client_bearer_body(self): self.endpoint_context.registration_access_token["1234567890"] = client_id res = verify_client( self.endpoint_context, - request, + keyjar=self.server.get_attribute('keyjar'), + request=request, get_client_id_from_token=get_client_id_from_token, endpoint=self.server.get_endpoint("userinfo"), ) @@ -604,7 +615,8 @@ def test_verify_client_client_secret_post(self): request = {"client_id": client_id, "client_secret": client_secret} res = verify_client( self.endpoint_context, - request, + keyjar=self.server.get_attribute('keyjar'), + request=request, endpoint=self.server.get_endpoint("token"), ) assert set(res.keys()) == {"method", "client_id"} @@ -618,7 +630,8 @@ def test_verify_client_client_secret_basic(self): res = verify_client( self.endpoint_context, - {}, + keyjar=self.server.get_attribute('keyjar'), + request={}, http_info=http_info, endpoint=self.server.get_endpoint("token"), ) @@ -634,7 +647,8 @@ def test_verify_client_bearer_header(self): request = {"client_id": client_id} res = verify_client( self.endpoint_context, - request, + keyjar=self.server.get_attribute('keyjar'), + request=request, http_info=http_info, get_client_id_from_token=get_client_id_from_token, endpoint=self.server.get_endpoint("authorization"), @@ -647,7 +661,8 @@ def test_verify_client_authorization_none(self): request = {"client_id": client_id} res = verify_client( self.endpoint_context, - request, + keyjar=self.server.get_attribute('keyjar'), + request=request, endpoint=self.server.get_endpoint("authorization"), ) assert res["method"] == "none" @@ -658,7 +673,8 @@ def test_verify_client_registration_public(self): request = {"redirect_uris": ["https://example.com/cb"], "client_id": "client_id"} res = verify_client( self.endpoint_context, - request, + keyjar=self.server.get_attribute('keyjar'), + request=request, endpoint=self.server.get_endpoint("registration"), ) assert res == {"client_id": "client_id", "method": "public"} @@ -668,7 +684,8 @@ def test_verify_client_registration_none(self): request = {"redirect_uris": ["https://example.com/cb"]} res = verify_client( self.endpoint_context, - request, + keyjar=self.server.get_attribute('keyjar'), + request=request, endpoint=self.server.get_endpoint("registration"), ) assert res == {"client_id": None, "method": "none"} @@ -689,7 +706,10 @@ class Mock: request = {"redirect_uris": ["https://example.com/cb"]} res = verify_client( - server.endpoint_context, request, endpoint=server.get_endpoint("registration") + server.endpoint_context, + keyjar=server.get_attribute('keyjar'), + request=request, + endpoint=server.get_endpoint("registration") ) assert res == {"client_id": "client_id", "method": "custom"} diff --git a/tests/test_server_20e_jwt_token.py b/tests/test_server_20e_jwt_token.py index da99488a..d7fd2687 100644 --- a/tests/test_server_20e_jwt_token.py +++ b/tests/test_server_20e_jwt_token.py @@ -195,8 +195,8 @@ def create_endpoint(self): }, "session_params": {"encrypter": SESSION_PARAMS}, } - server = Server(conf, keyjar=KEYJAR) - self.endpoint_context = server.endpoint_context + self.server = Server(conf, keyjar=KEYJAR) + self.endpoint_context = self.server.endpoint_context self.endpoint_context.cdb["client_1"] = { "client_secret": "hemligt", "redirect_uris": [("https://example.com/cb", None)], @@ -211,7 +211,7 @@ def create_endpoint(self): } self.session_manager = self.endpoint_context.session_manager self.user_id = "diana" - self.endpoint = server.get_endpoint("session") + self.endpoint = self.server.get_endpoint("session") def _create_session(self, auth_req, sub_type="public", sector_identifier=""): if sector_identifier: @@ -247,7 +247,7 @@ def test_parse(self): "access_token", grant, session_id, code, resources=[AUTH_REQ["client_id"]] ) - _verifier = JWT(self.endpoint_context.keyjar) + _verifier = JWT(self.server.keyjar) _info = _verifier.unpack(access_token.value) assert _info["token_class"] == "access_token" @@ -399,8 +399,8 @@ def create_endpoint(self): "scopes_to_claims": _scope2claims, "session_params": SESSION_PARAMS, } - server = Server(conf, keyjar=KEYJAR) - self.endpoint_context = server.endpoint_context + self.server = Server(conf, keyjar=KEYJAR) + self.endpoint_context = self.server.endpoint_context self.endpoint_context.cdb["client_1"] = { "client_secret": "hemligt", "redirect_uris": [("https://example.com/cb", None)], @@ -415,7 +415,7 @@ def create_endpoint(self): } self.session_manager = self.endpoint_context.session_manager self.user_id = "diana" - self.endpoint = server.get_endpoint("session") + self.endpoint = self.server.get_endpoint("session") def _create_session(self, auth_req, sub_type="public", sector_identifier=""): if sector_identifier: @@ -459,7 +459,7 @@ def test_parse(self): "access_token", grant, session_id, code, resources=[_auth_req["client_id"]] ) - _verifier = JWT(self.endpoint_context.keyjar) + _verifier = JWT(self.server.keyjar) _info = _verifier.unpack(access_token.value) assert _info["token_class"] == "access_token" @@ -490,7 +490,7 @@ def test_mint_with_aud(self): aud=["https://audience.example.com"], ) - _verifier = JWT(self.endpoint_context.keyjar) + _verifier = JWT(self.server.keyjar) _info = _verifier.unpack(access_token.value) assert _info["token_class"] == "access_token" @@ -521,7 +521,7 @@ def test_mint_with_scope(self): aud=["https://audience.example.com"], ) - _verifier = JWT(self.endpoint_context.keyjar) + _verifier = JWT(self.server.keyjar) _info = _verifier.unpack(access_token.value) assert _info["token_class"] == "access_token" @@ -551,7 +551,7 @@ def test_mint_with_extra(self): claims=["name", "family_name"], ) - _verifier = JWT(self.endpoint_context.keyjar) + _verifier = JWT(self.server.keyjar) _info = _verifier.unpack(access_token.value) assert "name" in _info assert "family_name" in _info @@ -561,6 +561,6 @@ def test_token_handler(self): _handler = master_handler["access_token"] assert _handler _jwt = _handler(aud="https://example.org") - _verifier = JWT(self.endpoint_context.keyjar) + _verifier = JWT(self.server.keyjar) _info = _verifier.unpack(_jwt) assert _info diff --git a/tests/test_server_23_oidc_registration_endpoint.py b/tests/test_server_23_oidc_registration_endpoint.py index 9e6efcf0..bff468a0 100755 --- a/tests/test_server_23_oidc_registration_endpoint.py +++ b/tests/test_server_23_oidc_registration_endpoint.py @@ -288,18 +288,9 @@ def test_sector_uri_missing_redirect_uri(self): _msg["application_type"] = "native" _msg["sector_identifier_uri"] = _url - with responses.RequestsMock() as rsps: - rsps.add( - "GET", - _url, - body=json.dumps(["https://example.com", "https://example.org"]), - adding_headers={"Content-Type": "application/json"}, - status=200, - ) - - _req = self.endpoint.parse_request(RegistrationRequest(**_msg).to_json()) - _resp = self.endpoint.process_request(request=_req) - assert "error" in _resp + _req = self.endpoint.parse_request(RegistrationRequest(**_msg).to_json()) + _resp = self.endpoint.process_request(request=_req) + assert "error" in _resp def test_incorrect_request(self): _msg = MSG.copy() diff --git a/tests/test_server_24_oauth2_authorization_endpoint.py b/tests/test_server_24_oauth2_authorization_endpoint.py index fd12ca3c..39e7af37 100755 --- a/tests/test_server_24_oauth2_authorization_endpoint.py +++ b/tests/test_server_24_oauth2_authorization_endpoint.py @@ -262,9 +262,7 @@ def create_endpoint(self): endpoint_context = server.endpoint_context _clients = yaml.safe_load(io.StringIO(client_yaml)) endpoint_context.cdb = _clients["clients"] - endpoint_context.keyjar.import_jwks( - endpoint_context.keyjar.export_jwks(True, ""), conf["issuer"] - ) + server.keyjar.import_jwks(server.keyjar.export_jwks(True, ""), conf["issuer"]) self.endpoint_context = endpoint_context self.endpoint = server.get_endpoint("authorization") self.session_manager = endpoint_context.session_manager @@ -272,7 +270,7 @@ def create_endpoint(self): self.rp_keyjar = KeyJar() self.rp_keyjar.add_symmetric("client_1", "hemligtkodord1234567890") - self.endpoint.upstream_get("context").keyjar.add_symmetric( + self.endpoint.upstream_get("attribute",'keyjar').add_symmetric( "client_1", "hemligtkodord1234567890" ) diff --git a/tests/test_server_24_oauth2_authorization_endpoint_jar.py b/tests/test_server_24_oauth2_authorization_endpoint_jar.py index ee275e42..922a65ed 100755 --- a/tests/test_server_24_oauth2_authorization_endpoint_jar.py +++ b/tests/test_server_24_oauth2_authorization_endpoint_jar.py @@ -190,16 +190,14 @@ def create_endpoint(self): endpoint_context = server.endpoint_context _clients = yaml.safe_load(io.StringIO(client_yaml)) endpoint_context.cdb = _clients["clients"] - endpoint_context.keyjar.import_jwks( - endpoint_context.keyjar.export_jwks(True, ""), conf["issuer"] - ) + server.keyjar.import_jwks(server.keyjar.export_jwks(True, ""), conf["issuer"]) self.endpoint = server.get_endpoint("authorization") self.session_manager = endpoint_context.session_manager self.user_id = "diana" self.rp_keyjar = KeyJar() self.rp_keyjar.add_symmetric("client_1", "hemligtkodord1234567890") - endpoint_context.keyjar.add_symmetric("client_1", "hemligtkodord1234567890") + server.keyjar.add_symmetric("client_1", "hemligtkodord1234567890") def test_parse_request_parameter(self): _jwt = JWT(key_jar=self.rp_keyjar, iss="client_1", sign_alg="HS256") diff --git a/tests/test_server_24_oauth2_token_endpoint.py b/tests/test_server_24_oauth2_token_endpoint.py index 94939131..a03e7034 100644 --- a/tests/test_server_24_oauth2_token_endpoint.py +++ b/tests/test_server_24_oauth2_token_endpoint.py @@ -186,7 +186,7 @@ def create_endpoint(self, conf): "response_types": ["code", "token", "code id_token", "id_token"], "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"] } - endpoint_context.keyjar.import_jwks(CLIENT_KEYJAR.export_jwks(), "client_1") + server.keyjar.import_jwks(CLIENT_KEYJAR.export_jwks(), "client_1") self.session_manager = endpoint_context.session_manager self.token_endpoint = server.get_endpoint("token") self.user_id = "diana" diff --git a/tests/test_server_24_oidc_authorization_endpoint.py b/tests/test_server_24_oidc_authorization_endpoint.py index aa2dfb47..e98326ad 100755 --- a/tests/test_server_24_oidc_authorization_endpoint.py +++ b/tests/test_server_24_oidc_authorization_endpoint.py @@ -294,8 +294,8 @@ def create_endpoint(self): _clients = yaml.safe_load(io.StringIO(client_yaml)) endpoint_context.cdb = _clients["oidc_clients"] - endpoint_context.keyjar.import_jwks( - endpoint_context.keyjar.export_jwks(True, ""), conf["issuer"] + server.keyjar.import_jwks( + server.keyjar.export_jwks(True, ""), conf["issuer"] ) self.endpoint_context = endpoint_context self.endpoint = server.get_endpoint("authorization") @@ -304,7 +304,8 @@ def create_endpoint(self): self.rp_keyjar = KeyJar() self.rp_keyjar.add_symmetric("client_1", "hemligtkodord1234567890") - endpoint_context.keyjar.add_symmetric("client_1", "hemligtkodord1234567890") + server.keyjar.add_symmetric("client_1", "hemligtkodord1234567890") + self.server = server def test_init(self): assert self.endpoint @@ -434,7 +435,7 @@ def test_id_token_claims(self): _resp = self.endpoint.process_request(_pr_resp) idt = verify_id_token( _resp["response_args"], - keyjar=self.endpoint.upstream_get("context").keyjar, + keyjar=self.endpoint.upstream_get("attribute","keyjar") ) assert idt # from config @@ -459,7 +460,7 @@ def test_id_token_acr(self): _resp = self.endpoint.process_request(_pr_resp) res = verify_id_token( _resp["response_args"], - keyjar=self.endpoint.upstream_get("context").keyjar, + keyjar=self.endpoint.upstream_get("attribute","keyjar"), ) assert res res = _resp["response_args"][verified_claim_name("id_token")] @@ -1243,8 +1244,8 @@ def create_endpoint(self): _clients = yaml.safe_load(io.StringIO(client_yaml)) endpoint_context.cdb = _clients["oidc_clients"] - endpoint_context.keyjar.import_jwks( - endpoint_context.keyjar.export_jwks(True, ""), conf["issuer"] + server.keyjar.import_jwks( + server.keyjar.export_jwks(True, ""), conf["issuer"] ) self.endpoint = server.get_endpoint("authorization") self.session_manager = endpoint_context.session_manager @@ -1252,7 +1253,7 @@ def create_endpoint(self): self.rp_keyjar = KeyJar() self.rp_keyjar.add_symmetric("client_1", "hemligtkodord1234567890") - endpoint_context.keyjar.add_symmetric("client_1", "hemligtkodord1234567890") + server.keyjar.add_symmetric("client_1", "hemligtkodord1234567890") def test_setup_acr_claim(self): request = AuthorizationRequest( diff --git a/tests/test_server_30_oidc_end_session.py b/tests/test_server_30_oidc_end_session.py index 7d9cc772..20ca42c8 100644 --- a/tests/test_server_30_oidc_end_session.py +++ b/tests/test_server_30_oidc_end_session.py @@ -336,7 +336,7 @@ def test_end_session_endpoint_with_cookie_id_token_and_unknown_sid(self): http_info = {"cookie": [cookie]} msg = Message(id_token=id_token) - verify_id_token(msg, keyjar=self.session_endpoint.upstream_get("context").keyjar) + verify_id_token(msg, keyjar=self.session_endpoint.upstream_get("attribute",'keyjar')) msg2 = Message(id_token_hint=id_token) msg2[verified_claim_name("id_token_hint")] = msg[verified_claim_name("id_token")] @@ -403,7 +403,7 @@ def test_end_session_endpoint_with_wrong_post_logout_redirect_uri(self): post_logout_redirect_uri = "https://demo.example.com/log_out" msg = Message(id_token=id_token) - verify_id_token(msg, keyjar=self.session_endpoint.upstream_get("context").keyjar) + verify_id_token(msg, keyjar=self.session_endpoint.upstream_get("attribute",'keyjar')) with pytest.raises(RedirectURIError): self.session_endpoint.process_request( diff --git a/tests/test_server_31_oauth2_introspection.py b/tests/test_server_31_oauth2_introspection.py index f14ec928..47664844 100644 --- a/tests/test_server_31_oauth2_introspection.py +++ b/tests/test_server_31_oauth2_introspection.py @@ -206,9 +206,8 @@ def create_endpoint(self, jwt_token): }, "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"] } - endpoint_context.keyjar.import_jwks_as_json( - endpoint_context.keyjar.export_jwks_as_json(private=True), - endpoint_context.issuer, + server.keyjar.import_jwks_as_json( + server.keyjar.export_jwks_as_json(private=True),endpoint_context.issuer ) self.introspection_endpoint = server.get_endpoint("introspection") self.token_endpoint = server.get_endpoint("token") diff --git a/tests/test_server_33_oauth2_pkce.py b/tests/test_server_33_oauth2_pkce.py index a6942938..8cd3bd15 100644 --- a/tests/test_server_33_oauth2_pkce.py +++ b/tests/test_server_33_oauth2_pkce.py @@ -231,9 +231,7 @@ def create_server(config): endpoint_context = server.endpoint_context _clients = yaml.safe_load(io.StringIO(client_yaml)) endpoint_context.cdb = _clients["oidc_clients"] - endpoint_context.keyjar.import_jwks( - endpoint_context.keyjar.export_jwks(True, ""), config["issuer"] - ) + server.keyjar.import_jwks(server.keyjar.export_jwks(True, ""), config["issuer"]) return server diff --git a/tests/test_server_34_oidc_sso.py b/tests/test_server_34_oidc_sso.py index d85c5510..b66ddf76 100755 --- a/tests/test_server_34_oidc_sso.py +++ b/tests/test_server_34_oidc_sso.py @@ -199,14 +199,14 @@ def create_endpoint_context(self): endpoint_context = server.endpoint_context _clients = yaml.safe_load(io.StringIO(client_yaml)) endpoint_context.cdb = _clients["oidc_clients"] - endpoint_context.keyjar.import_jwks( - endpoint_context.keyjar.export_jwks(True, ""), conf["issuer"] + server.keyjar.import_jwks( + server.keyjar.export_jwks(True, ""), conf["issuer"] ) self.endpoint = server.get_endpoint("authorization") self.endpoint_context = endpoint_context self.rp_keyjar = KeyJar() self.rp_keyjar.add_symmetric("client_1", "hemligtkodord1234567890") - endpoint_context.keyjar.add_symmetric("client_1", "hemligtkodord1234567890") + server.keyjar.add_symmetric("client_1", "hemligtkodord1234567890") def test_sso(self): request = self.endpoint.parse_request(AUTH_REQ_DICT) diff --git a/tests/test_server_35_oidc_token_endpoint.py b/tests/test_server_35_oidc_token_endpoint.py index 3e8b4dd6..34f3ca24 100755 --- a/tests/test_server_35_oidc_token_endpoint.py +++ b/tests/test_server_35_oidc_token_endpoint.py @@ -202,9 +202,9 @@ def conf(): class TestEndpoint(_TestEndpoint): @pytest.fixture(autouse=True) def create_endpoint(self, conf): - server = Server(OPConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) + self.server = Server(OPConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) - endpoint_context = server.endpoint_context + endpoint_context = self.server.endpoint_context endpoint_context.cdb["client_1"] = { "client_secret": "hemligt", "redirect_uris": [("https://example.com/cb", None)], @@ -213,10 +213,10 @@ def create_endpoint(self, conf): "response_types": ["code", "token", "code id_token", "id_token"], "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"] } - endpoint_context.keyjar.import_jwks(CLIENT_KEYJAR.export_jwks(), "client_1") + self.server.keyjar.import_jwks(CLIENT_KEYJAR.export_jwks(), "client_1") endpoint_context.userinfo = USERINFO self.session_manager = endpoint_context.session_manager - self.token_endpoint = server.get_endpoint("token") + self.token_endpoint = self.server.get_endpoint("token") self.user_id = "diana" self.endpoint_context = endpoint_context @@ -395,7 +395,7 @@ def test_do_refresh_access_token(self): "scope", } AuthorizationResponse().from_jwt( - _resp["response_args"]["id_token"], _cntx.keyjar, sender="" + _resp["response_args"]["id_token"], self.server.get_attribute('keyjar'), sender="" ) msg = self.token_endpoint.do_response(request=_req, **_resp) @@ -449,7 +449,7 @@ def test_do_2nd_refresh_access_token(self): "scope", } AuthorizationResponse().from_jwt( - _2nd_resp["response_args"]["id_token"], _cntx.keyjar, sender="" + _2nd_resp["response_args"]["id_token"], self.server.keyjar, sender="" ) msg = self.token_endpoint.do_response(request=_req, **_resp) @@ -508,7 +508,7 @@ def test_refresh_scopes(self): } AuthorizationResponse().from_jwt( _resp["response_args"]["id_token"], - self.endpoint_context.keyjar, + self.server.keyjar, sender="", ) @@ -619,7 +619,7 @@ def test_refresh_more_scopes_2(self): } AuthorizationResponse().from_jwt( _resp["response_args"]["id_token"], - self.endpoint_context.keyjar, + self.server.keyjar, sender="", ) @@ -649,7 +649,7 @@ def test_refresh_less_scopes(self): _resp = self.token_endpoint.process_request(request=_req) idtoken = AuthorizationResponse().from_jwt( _resp["response_args"]["id_token"], - self.endpoint_context.keyjar, + self.server.keyjar, sender="", ) @@ -674,7 +674,7 @@ def test_refresh_less_scopes(self): _resp = self.token_endpoint.process_request(request=_req) idtoken = AuthorizationResponse().from_jwt( _resp["response_args"]["id_token"], - self.endpoint_context.keyjar, + self.server.keyjar, sender="", ) @@ -763,7 +763,7 @@ def test_refresh_no_offline_access_scope(self): } AuthorizationResponse().from_jwt( _resp["response_args"]["id_token"], - self.endpoint_context.keyjar, + self.server.keyjar, sender="", ) assert _resp["response_args"]["scope"] == ["openid"] @@ -959,7 +959,7 @@ def test_access_token_lifetime(self): access_token = AccessTokenRequest().from_jwt( _resp["response_args"]["access_token"], - self.endpoint_context.keyjar, + self.server.keyjar, sender="", ) @@ -1028,7 +1028,7 @@ def create_endpoint(self, conf): "response_types": ["code", "token", "code id_token", "id_token"], "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"] } - endpoint_context.keyjar.import_jwks(CLIENT_KEYJAR.export_jwks(), "client_1") + server.keyjar.import_jwks(CLIENT_KEYJAR.export_jwks(), "client_1") self.session_manager = endpoint_context.session_manager self.token_endpoint = server.get_endpoint("token") self.user_id = "diana" @@ -1120,7 +1120,7 @@ def test_old_jwt_token(self): # payload.update(kwargs) _context = _handler.upstream_get("endpoint_context") signer = JWT( - key_jar=_context.keyjar, + key_jar=_handler.upstream_get('attribute', 'keyjar'), iss=_handler.issuer, lifetime=300, sign_alg=_handler.alg, diff --git a/tests/test_server_36_oauth2_token_exchange.py b/tests/test_server_36_oauth2_token_exchange.py index 1287b817..d6729313 100644 --- a/tests/test_server_36_oauth2_token_exchange.py +++ b/tests/test_server_36_oauth2_token_exchange.py @@ -203,7 +203,7 @@ def create_endpoint(self): "response_types": ["code", "token", "code id_token", "id_token"], "allowed_scopes": ["openid", "profile", "offline_access"], } - self.endpoint_context.keyjar.import_jwks(CLIENT_KEYJAR.export_jwks(), "client_1") + server.keyjar.import_jwks(CLIENT_KEYJAR.export_jwks(), "client_1") self.endpoint = server.get_endpoint("token") self.introspection_endpoint = server.get_endpoint("introspection") self.session_manager = self.endpoint_context.session_manager diff --git a/tests/test_server_40_oauth2_pushed_authorization.py b/tests/test_server_40_oauth2_pushed_authorization.py index 4caa190f..6a0aaffe 100644 --- a/tests/test_server_40_oauth2_pushed_authorization.py +++ b/tests/test_server_40_oauth2_pushed_authorization.py @@ -167,13 +167,13 @@ def create_endpoint(self): endpoint_context = server.endpoint_context _clients = yaml.safe_load(io.StringIO(client_yaml)) endpoint_context.cdb = verify_oidc_client_information(_clients["oidc_clients"]) - endpoint_context.keyjar.import_jwks( - endpoint_context.keyjar.export_jwks(True, ""), conf["issuer"] + server.keyjar.import_jwks( + server.keyjar.export_jwks(True, ""), conf["issuer"] ) self.rp_keyjar = init_key_jar(key_defs=KEYDEFS, issuer_id="s6BhdRkqt3") # Add RP's keys to the OP's keyjar - endpoint_context.keyjar.import_jwks( + server.keyjar.import_jwks( self.rp_keyjar.export_jwks(issuer_id="s6BhdRkqt3"), "s6BhdRkqt3" ) diff --git a/tests/test_server_50_persistence.py b/tests/test_server_50_persistence.py index 358f8478..12697787 100644 --- a/tests/test_server_50_persistence.py +++ b/tests/test_server_50_persistence.py @@ -205,9 +205,11 @@ def create_endpoint(self): server1 = Server( OPConfiguration(conf=ENDPOINT_CONTEXT_CONFIG, base_path=BASEDIR), cwd=BASEDIR ) + server2 = Server( OPConfiguration(conf=ENDPOINT_CONTEXT_CONFIG, base_path=BASEDIR), cwd=BASEDIR ) + # The top most part (Server class instance) is not server1.endpoint_context.cdb["client_1"] = { "client_secret": "hemligt", @@ -218,6 +220,7 @@ def create_endpoint(self): "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access", "research_and_scholarship"] } + # make server2 endpoint context a copy of server 1 endpoint context _store = server1.endpoint_context.dump() server2.endpoint_context.load( _store, @@ -294,12 +297,12 @@ def test_init(self): self.endpoint[1].server_get("endpoint_context").provider_info["scopes_supported"] ) == {"openid"} assert set( - self.endpoint[1].upstream_get("endpoint_context").provider_info["claims_supported"] - ) == set(self.endpoint[2].upstream_get("endpoint_context").provider_info["claims_supported"]) + self.endpoint[1].upstream_get("context").provider_info["claims_supported"] + ) == set(self.endpoint[2].upstream_get("context").provider_info["claims_supported"]) def test_parse(self): session_id = self._create_session(AUTH_REQ, index=1) - grant = self.endpoint[1].upstream_get("endpoint_context").authz(session_id, AUTH_REQ) + grant = self.endpoint[1].upstream_get("context").authz(session_id, AUTH_REQ) # grant, session_id = self._do_grant(AUTH_REQ, index=1) code = self._mint_code(grant, session_id, index=1) access_token = self._mint_access_token(grant, session_id, code, 1) @@ -315,7 +318,7 @@ def test_parse(self): def test_process_request(self): session_id = self._create_session(AUTH_REQ, index=1) - grant = self.endpoint[1].upstream_get("endpoint_context").authz(session_id, AUTH_REQ) + grant = self.endpoint[1].upstream_get("context").authz(session_id, AUTH_REQ) code = self._mint_code(grant, session_id, index=1) access_token = self._mint_access_token(grant, session_id, code, 1) @@ -328,7 +331,7 @@ def test_process_request(self): def test_process_request_not_allowed(self): session_id = self._create_session(AUTH_REQ, index=2) - grant = self.endpoint[2].upstream_get("endpoint_context").authz(session_id, AUTH_REQ) + grant = self.endpoint[2].upstream_get("context").authz(session_id, AUTH_REQ) code = self._mint_code(grant, session_id, index=2) access_token = self._mint_access_token(grant, session_id, code, 2) @@ -362,7 +365,7 @@ def test_process_request_not_allowed(self): def test_do_response(self): session_id = self._create_session(AUTH_REQ, index=2) - grant = self.endpoint[2].upstream_get("endpoint_context").authz(session_id, AUTH_REQ) + grant = self.endpoint[2].upstream_get("context").authz(session_id, AUTH_REQ) code = self._mint_code(grant, session_id, index=2) access_token = self._mint_access_token(grant, session_id, code, 2) @@ -380,12 +383,12 @@ def test_do_response(self): assert res def test_do_signed_response(self): - self.endpoint[2].upstream_get("endpoint_context").cdb["client_1"][ + self.endpoint[2].upstream_get("context").cdb["client_1"][ "userinfo_signed_response_alg" ] = "ES256" session_id = self._create_session(AUTH_REQ, index=2) - grant = self.endpoint[2].upstream_get("endpoint_context").authz(session_id, AUTH_REQ) + grant = self.endpoint[2].upstream_get("context").authz(session_id, AUTH_REQ) code = self._mint_code(grant, session_id, index=2) access_token = self._mint_access_token(grant, session_id, code, 2) @@ -404,13 +407,13 @@ def test_custom_scope(self): _auth_req["scope"] = ["openid", "research_and_scholarship"] session_id = self._create_session(_auth_req, index=2) - grant = self.endpoint[2].upstream_get("endpoint_context").authz(session_id, _auth_req) + grant = self.endpoint[2].upstream_get("context").authz(session_id, _auth_req) self._dump_restore(2, 1) grant.claims = { "userinfo": self.endpoint[1] - .upstream_get("endpoint_context") + .upstream_get("context") .claims_interface.get_claims( session_id, scopes=_auth_req["scope"], claims_release_point="userinfo" ) @@ -448,7 +451,7 @@ def test_sman_db_integrity(self): it show that flush and loads method will keep order, anyway. """ session_id = self._create_session(AUTH_REQ, index=1) - grant = self.endpoint[1].upstream_get("endpoint_context").authz(session_id, AUTH_REQ) + grant = self.endpoint[1].upstream_get("context").authz(session_id, AUTH_REQ) sman = self.session_manager[1] session_dump = sman.dump() diff --git a/tests/x_test_ciba_01_backchannel_auth.py b/tests/x_test_ciba_01_backchannel_auth.py index 8d8b9969..d55d77e2 100644 --- a/tests/x_test_ciba_01_backchannel_auth.py +++ b/tests/x_test_ciba_01_backchannel_auth.py @@ -13,16 +13,16 @@ from idpyoidc.message.oidc.backchannel_authentication import AuthenticationRequest from idpyoidc.message.oidc.backchannel_authentication import NotificationRequest from idpyoidc.message.oidc.backchannel_authentication import TokenRequest -from idpyoidc.server import OPConfiguration -from idpyoidc.server import Server -from idpyoidc.server import init_service -from idpyoidc.server import init_user_info -from idpyoidc.server import user_info -from idpyoidc.server.authn_event import create_authn_event -from idpyoidc.server.client_authn import verify_client -from idpyoidc.server.oidc.backchannel_authentication import BackChannelAuthentication -from idpyoidc.server.oidc.token import Token -from idpyoidc.server.user_authn.authn_context import INTERNETPROTOCOLPASSWORD +from idpyoidc.self.server import OPConfiguration +from idpyoidc.self.server import Server +from idpyoidc.self.server import init_service +from idpyoidc.self.server import init_user_info +from idpyoidc.self.server import user_info +from idpyoidc.self.server.authn_event import create_authn_event +from idpyoidc.self.server.client_authn import verify_client +from idpyoidc.self.server.oidc.backchannel_authentication import BackChannelAuthentication +from idpyoidc.self.server.oidc.token import Token +from idpyoidc.self.server.user_authn.authn_context import INTERNETPROTOCOLPASSWORD from . import CRYPT_CONFIG from . import SESSION_PARAMS @@ -129,7 +129,7 @@ def parse_login_hint_token(keyjar: KeyJar, login_hint_token: str, context=None) "jwks_file": "private/token_jwks.json", "code": {"lifetime": 600, "kwargs": {"crypt_conf": CRYPT_CONFIG}}, "token": { - "class": "idpyoidc.server.token.jwt_token.JWTToken", + "class": "idpyoidc.self.server.token.jwt_token.JWTToken", "kwargs": { "lifetime": 3600, "base_claims": {"eduperson_scoped_affiliation": None}, @@ -138,14 +138,14 @@ def parse_login_hint_token(keyjar: KeyJar, login_hint_token: str, context=None) }, }, "refresh": { - "class": "idpyoidc.server.token.jwt_token.JWTToken", + "class": "idpyoidc.self.server.token.jwt_token.JWTToken", "kwargs": { "lifetime": 3600, "aud": ["https://example.org/appl"], }, }, "id_token": { - "class": "idpyoidc.server.token.id_token.IDToken", + "class": "idpyoidc.self.server.token.id_token.IDToken", "kwargs": { "base_claims": { "email": {"essential": True}, @@ -174,7 +174,7 @@ def parse_login_hint_token(keyjar: KeyJar, login_hint_token: str, context=None) "authentication": { "anon": { "acr": INTERNETPROTOCOLPASSWORD, - "class": "idpyoidc.server.user_authn.user.NoAuthn", + "class": "idpyoidc.self.server.user_authn.user.NoAuthn", "kwargs": {"user": "diana"}, } }, @@ -190,8 +190,8 @@ def parse_login_hint_token(keyjar: KeyJar, login_hint_token: str, context=None) class TestBCAEndpoint(object): @pytest.fixture(autouse=True) def create_endpoint(self): - server = Server(OPConfiguration(SERVER_CONF, base_path=BASEDIR)) - self.endpoint_context = server.endpoint_context + self.server = Server(OPConfiguration(SERVER_CONF, base_path=BASEDIR)) + self.endpoint_context = self.server.endpoint_context self.endpoint_context.cdb["client_1"] = { "client_secret": "hemligt", "redirect_uris": [("https://example.com/cb", None)], @@ -199,33 +199,33 @@ def create_endpoint(self): "token_endpoint_auth_method": "client_secret_post", "response_types": ["code", "token", "code id_token", "id_token"], } - self.endpoint = server.get_endpoint("backchannel_authentication") - self.token_endpoint = server.get_endpoint("token") + self.endpoint = self.server.get_endpoint("backchannel_authentication") + self.token_endpoint = self.server.get_endpoint("token") self.client_keyjar = build_keyjar(KEYDEFS) - # Add servers keys - self.client_keyjar.import_jwks(server.endpoint_context.keyjar.export_jwks(), ISSUER) + # Add self.servers keys + self.client_keyjar.import_jwks(self.server.keyjar.export_jwks(), ISSUER) # The only own key the client has a this point self.client_keyjar.add_symmetric("", CLIENT_SECRET, ["sig"]) # Need to add the client_secret as a symmetric key bound to the client_id - server.endpoint_context.keyjar.add_symmetric(CLIENT_ID, CLIENT_SECRET, ["sig"]) - server.endpoint_context.keyjar.import_jwks(self.client_keyjar.export_jwks(), CLIENT_ID) + self.server.keyjar.add_symmetric(CLIENT_ID, CLIENT_SECRET, ["sig"]) + self.server.keyjar.import_jwks(self.client_keyjar.export_jwks(), CLIENT_ID) - server.endpoint_context.cdb = {CLIENT_ID: {"client_secret": CLIENT_SECRET}} + self.server.endpoint_context.cdb = {CLIENT_ID: {"client_secret": CLIENT_SECRET}} # login_hint - server.endpoint_context.login_hint_lookup = init_service( - {"class": "idpyoidc.server.login_hint.LoginHintLookup"}, None + self.server.endpoint_context.login_hint_lookup = init_service( + {"class": "idpyoidc.self.server.login_hint.LoginHintLookup"}, None ) # userinfo _userinfo = init_user_info( { - "class": "idpyoidc.server.user_info.UserInfo", + "class": "idpyoidc.self.server.user_info.UserInfo", "kwargs": {"db_file": full_path("users.json")}, }, "", ) - server.endpoint_context.login_hint_lookup.userinfo = _userinfo - self.session_manager = server.endpoint_context.session_manager + self.server.endpoint_context.login_hint_lookup.userinfo = _userinfo + self.session_manager = self.server.endpoint_context.session_manager def test_login_hint_token(self): _jwt = JWT(self.client_keyjar, iss=CLIENT_ID, sign_alg="HS256") @@ -490,11 +490,11 @@ def test_login_hint_response(self): class TestBCAEndpointService(object): @pytest.fixture(autouse=True) def create_endpoint(self): - self.ciba = {"server": self._create_server(), "client": self._create_ciba_client()} + self.ciba = {"self.server": self._create_self.server(), "client": self._create_ciba_client()} - def _create_server(self): - server = Server(OPConfiguration(SERVER_CONF, base_path=BASEDIR)) - endpoint_context = server.endpoint_context + def _create_self.server(self): + self.server = Server(OPConfiguration(SERVER_CONF, base_path=BASEDIR)) + endpoint_context = self.server.endpoint_context endpoint_context.cdb["client_1"] = { "client_secret": "hemligt", "redirect_uris": [("https://example.com/cb", None)], @@ -504,29 +504,29 @@ def _create_server(self): } client_keyjar = build_keyjar(KEYDEFS) - # Add servers keys - client_keyjar.import_jwks(server.endpoint_context.keyjar.export_jwks(), ISSUER) + # Add self.servers keys + client_keyjar.import_jwks(self.server.keyjar.export_jwks(), ISSUER) # The only own key the client has a this point client_keyjar.add_symmetric("", CLIENT_SECRET, ["sig"]) # Need to add the client_secret as a symmetric key bound to the client_id - server.endpoint_context.keyjar.add_symmetric(CLIENT_ID, CLIENT_SECRET, ["sig"]) - server.endpoint_context.keyjar.import_jwks(client_keyjar.export_jwks(), CLIENT_ID) + self.server.keyjar.add_symmetric(CLIENT_ID, CLIENT_SECRET, ["sig"]) + self.server.keyjar.import_jwks(client_keyjar.export_jwks(), CLIENT_ID) - server.endpoint_context.cdb = {CLIENT_ID: {"client_secret": CLIENT_SECRET}} + self.server.endpoint_context.cdb = {CLIENT_ID: {"client_secret": CLIENT_SECRET}} # login_hint - server.endpoint_context.login_hint_lookup = init_service( - {"class": "idpyoidc.server.login_hint.LoginHintLookup"}, None + self.server.endpoint_context.login_hint_lookup = init_service( + {"class": "idpyoidc.self.server.login_hint.LoginHintLookup"}, None ) # userinfo _userinfo = init_user_info( { - "class": "idpyoidc.server.user_info.UserInfo", + "class": "idpyoidc.self.server.user_info.UserInfo", "kwargs": {"db_file": full_path("users.json")}, }, "", ) - server.endpoint_context.login_hint_lookup.userinfo = _userinfo - return server + self.server.endpoint_context.login_hint_lookup.userinfo = _userinfo + return self.server def _create_ciba_client(self): config = { @@ -560,13 +560,13 @@ def _create_session(self, user_id, auth_req, sub_type="public", sector_identifie authz_req = auth_req client_id = authz_req["client_id"] ae = create_authn_event(user_id) - _session_manager = self.ciba["server"].endpoint_context.session_manager + _session_manager = self.ciba["self.server"].endpoint_context.session_manager return _session_manager.create_session( ae, authz_req, user_id, client_id=client_id, sub_type=sub_type ) def test_client_notification(self): - _keyjar = self.ciba["server"].endpoint_context.keyjar + _keyjar = self.ciba["self.server"].endpoint_context.keyjar _jwt = JWT(_keyjar, iss=CLIENT_ID, sign_alg="HS256") _jwt.with_jti = True _assertion = _jwt.pack({"aud": [ISSUER]}) @@ -580,14 +580,14 @@ def test_client_notification(self): "login_hint": "mail:diana@example.org", } - _authn_endpoint = self.ciba["server"].upstream_get("endpoint", "backchannel_authentication") + _authn_endpoint = self.ciba["self.server"].upstream_get("endpoint", "backchannel_authentication") req = AuthenticationRequest(**request) req = _authn_endpoint.parse_request(req.to_urlencoded()) _info = _authn_endpoint.process_request(req) assert _info - _session_manager = self.ciba["server"].endpoint_context.session_manager + _session_manager = self.ciba["self.server"].endpoint_context.session_manager sid = _session_manager.auth_req_id_map[_info["response_args"]["auth_req_id"]] _user_id, _client_id, _grant_id = _session_manager.decrypt_session_id(sid) From 44009156562d2821c78ca1d136913f4bbc3e4c7f Mon Sep 17 00:00:00 2001 From: Roland Hedberg Date: Sun, 4 Dec 2022 20:05:55 +0100 Subject: [PATCH 44/76] Rebased onto improved --- .../oauth2/add_on/pushed_authorization.py | 9 +- src/idpyoidc/node.py | 13 ++- src/idpyoidc/server/endpoint.py | 109 +++++++++++------- 3 files changed, 81 insertions(+), 50 deletions(-) diff --git a/src/idpyoidc/client/oauth2/add_on/pushed_authorization.py b/src/idpyoidc/client/oauth2/add_on/pushed_authorization.py index 13c7706f..4c67bdcd 100644 --- a/src/idpyoidc/client/oauth2/add_on/pushed_authorization.py +++ b/src/idpyoidc/client/oauth2/add_on/pushed_authorization.py @@ -19,6 +19,8 @@ def push_authorization(request_args, service, **kwargs): _context = service.upstream_get("context") method_args = _context.add_on["pushed_authorization"] + if method_args['apply'] is False: + return request_args # construct the message body if method_args["body_format"] == "urlencoded": @@ -36,8 +38,10 @@ def push_authorization(request_args, service, **kwargs): _body = _msg.to_urlencoded() # Send it to the Pushed Authorization Request Endpoint - resp = method_args["http_client"].get( - _context.provider_info["pushed_authorization_request_endpoint"], data=_body + resp = method_args["http_client"]( + method="GET", + url=_context.provider_info["pushed_authorization_request_endpoint"], + data=_body ) if resp.status_code == 200: @@ -73,6 +77,7 @@ def add_support( "signing_algorithm": signing_algorithm, "http_client": http_client, "merge_rule": merge_rule, + 'apply': True } _service.post_construct.append(push_authorization) diff --git a/src/idpyoidc/node.py b/src/idpyoidc/node.py index 71501a10..0f6b28fa 100644 --- a/src/idpyoidc/node.py +++ b/src/idpyoidc/node.py @@ -36,7 +36,7 @@ def __init__(self, self.keyjar = self._keyjar(keyjar, conf=config, entity_id=self.entity_id, key_conf=key_conf) else: - self.keyjar = KeyJar() + self.keyjar = None self.httpc_params = httpc_params or config.get("httpc_params", {}) @@ -45,7 +45,7 @@ def __init__(self, self.keyjar.httpc_params = self.httpc_params def unit_get(self, what, *arg): - _func = getattr(self, "get_{}".format(what), None) + _func = getattr(self, f"get_{what}", None) if _func: return _func(*arg) return None @@ -102,9 +102,12 @@ def _keyjar(self, return keyjar -def find_topmost_unit(unit): - while hasattr(unit, 'upstream_get'): - unit = unit.upstream_get('unit') +def topmost_unit(unit): + if hasattr(unit, 'upstream_get'): + if unit.upstream_get: + next_unit = unit.upstream_get('unit') + if next_unit: + unit = topmost_unit(next_unit) return unit diff --git a/src/idpyoidc/server/endpoint.py b/src/idpyoidc/server/endpoint.py index a5a311f7..5de31940 100755 --- a/src/idpyoidc/server/endpoint.py +++ b/src/idpyoidc/server/endpoint.py @@ -5,6 +5,8 @@ from typing import Union from urllib.parse import urlparse +from cryptojwt.exception import IssuerNotFound + from idpyoidc.exception import MissingRequiredAttribute from idpyoidc.exception import MissingRequiredValue from idpyoidc.exception import ParameterError @@ -142,12 +144,42 @@ def process_verify_error(self, exception): _error = "invalid_request" return self.error_cls(error=_error, error_description="%s" % exception) + def find_client_keys(self, iss): + return False + + def verify_request(self, request, keyjar, client_id, verify_args, lap=0): + # verify that the request message is correct, may have to do it twice + try: + if verify_args is None: + request.verify(keyjar=keyjar, opponent_id=client_id) + else: + request.verify(keyjar=keyjar, opponent_id=client_id, **verify_args) + except (MissingRequiredAttribute, ValueError, MissingRequiredValue, ParameterError) as err: + _error = "invalid_request" + if isinstance(err, ValueError) and self.request_cls == RegistrationRequest: + if len(err.args) > 1: + if err.args[1] == "initiate_login_uri": + _error = "invalid_client_metadata" + + return self.error_cls(error=_error, error_description="%s" % err) + except IssuerNotFound as err: + if lap: + return self.error_cls(error=err) + client_id =self.find_client_keys(err.args[0]) + if not client_id: + return self.error_cls(error=err) + else: + # Fund a client ID I believe will work + self.verify_request(request=request, keyjar=keyjar, client_id=client_id, + verify_args=verify_args, lap=1) + return None + def parse_request( - self, - request: Union[Message, dict, str], - http_info: Optional[dict] = None, - verify_args: Optional[dict] = None, - **kwargs + self, + request: Union[Message, dict, str], + http_info: Optional[dict] = None, + verify_args: Optional[dict] = None, + **kwargs ): """ @@ -197,20 +229,11 @@ def parse_request( else: _client_id = req.get("client_id") - # verify that the request message is correct - try: - if verify_args is None: - req.verify(keyjar=_keyjar, opponent_id=_client_id) - else: - req.verify(keyjar=_keyjar, opponent_id=_client_id, **verify_args) - except (MissingRequiredAttribute, ValueError, MissingRequiredValue, ParameterError) as err: - _error = "invalid_request" - if isinstance(err, ValueError) and self.request_cls == RegistrationRequest: - if len(err.args) > 1: - if err.args[1] == "initiate_login_uri": - _error = "invalid_client_metadata" - - return self.error_cls(error=_error, error_description="%s" % err) + # verify that the request message is correct, may have to do it twice + err_response = self.verify_request(request=req, keyjar=_keyjar, client_id=_client_id, + verify_args=verify_args) + if err_response: + return err_response LOGGER.info("Parsed and verified request: %s" % sanitize(req)) @@ -237,7 +260,7 @@ def client_authentication(self, request: Message, http_info: Optional[dict] = No authn_info = verify_client( context=self.upstream_get("context"), - keyjar=self.upstream_get('attribute','keyjar'), + keyjar=self.upstream_get('attribute', 'keyjar'), request=request, http_info=http_info, **kwargs @@ -252,7 +275,7 @@ def client_authentication(self, request: Message, http_info: Optional[dict] = No return authn_info def do_post_parse_request( - self, request: Message, client_id: Optional[str] = "", **kwargs + self, request: Message, client_id: Optional[str] = "", **kwargs ) -> Message: _context = self.upstream_get("context") for meth in self.post_parse_request: @@ -262,7 +285,7 @@ def do_post_parse_request( return request def do_pre_construct( - self, response_args: dict, request: Optional[Union[Message, dict]] = None, **kwargs + self, response_args: dict, request: Optional[Union[Message, dict]] = None, **kwargs ) -> dict: _context = self.upstream_get("context") for meth in self.pre_construct: @@ -271,10 +294,10 @@ def do_pre_construct( return response_args def do_post_construct( - self, - response_args: Union[Message, dict], - request: Optional[Union[Message, dict]] = None, - **kwargs + self, + response_args: Union[Message, dict], + request: Optional[Union[Message, dict]] = None, + **kwargs ) -> dict: _context = self.upstream_get("context") for meth in self.post_construct: @@ -283,10 +306,10 @@ def do_post_construct( return response_args def process_request( - self, - request: Optional[Union[Message, dict]] = None, - http_info: Optional[dict] = None, - **kwargs + self, + request: Optional[Union[Message, dict]] = None, + http_info: Optional[dict] = None, + **kwargs ) -> Union[Message, dict]: """ @@ -297,10 +320,10 @@ def process_request( return {} def construct( - self, - response_args: Optional[dict] = None, - request: Optional[Union[Message, dict]] = None, - **kwargs + self, + response_args: Optional[dict] = None, + request: Optional[Union[Message, dict]] = None, + **kwargs ): """ Construct the response @@ -318,19 +341,19 @@ def construct( return self.do_post_construct(response, request, **kwargs) def response_info( - self, - response_args: Optional[dict] = None, - request: Optional[Union[Message, dict]] = None, - **kwargs + self, + response_args: Optional[dict] = None, + request: Optional[Union[Message, dict]] = None, + **kwargs ) -> dict: return self.construct(response_args, request, **kwargs) def do_response( - self, - response_args: Optional[dict] = None, - request: Optional[Union[Message, dict]] = None, - error: Optional[str] = "", - **kwargs + self, + response_args: Optional[dict] = None, + request: Optional[Union[Message, dict]] = None, + error: Optional[str] = "", + **kwargs ) -> dict: """ :param response_args: Information to use when constructing the response From ffbf880159a034c6dc2977be525f8334303f4c04 Mon Sep 17 00:00:00 2001 From: Roland Hedberg Date: Sun, 4 Dec 2022 20:07:59 +0100 Subject: [PATCH 45/76] Rebased onto improved --- src/idpyoidc/node.py | 11 +++++++++-- tests/test_client_06_client_authn.py | 9 +++++++++ tests/test_client_10_entity.py | 4 +++- 3 files changed, 21 insertions(+), 3 deletions(-) diff --git a/src/idpyoidc/node.py b/src/idpyoidc/node.py index 0f6b28fa..d4febc6c 100644 --- a/src/idpyoidc/node.py +++ b/src/idpyoidc/node.py @@ -28,15 +28,22 @@ def __init__(self, if config is None: config = {} + _client_id = config.get('client_id', "") self.entity_id = entity_id or config.get('entity_id', "") if not self.entity_id: - self.entity_id = config.get('issuer', "") + self.entity_id = _client_id if keyjar or key_conf or config.get('key_conf') or config.get('jwks') or config.get('keys'): self.keyjar = self._keyjar(keyjar, conf=config, entity_id=self.entity_id, key_conf=key_conf) + if _client_id: + self.keyjar.add_symmetric('', _client_id) else: - self.keyjar = None + if _client_id: + self.keyjar = KeyJar() + self.keyjar.add_symmetric('', _client_id) + else: + self.keyjar = None self.httpc_params = httpc_params or config.get("httpc_params", {}) diff --git a/tests/test_client_06_client_authn.py b/tests/test_client_06_client_authn.py index e62529c0..321d0203 100644 --- a/tests/test_client_06_client_authn.py +++ b/tests/test_client_06_client_authn.py @@ -34,11 +34,17 @@ BASE_PATH = os.path.abspath(os.path.dirname(__file__)) CLIENT_ID = "A" +KEYSPEC = [ + {"type": "RSA", "use": ["sig"]}, + {"type": "EC", "crv": "P-256", "use": ["sig"]}, +] + CLIENT_CONF = { "issuer": "https://example.com/as", # "redirect_uris": ["https://example.com/cli/authz_cb"], "client_secret": "white boarding pass", "client_id": CLIENT_ID, + "key_conf": {'key_defs': KEYSPEC} } KEY_CONF = { @@ -441,6 +447,7 @@ def test_get_audience_and_algorithm_default_alg(self, entity): csj = ClientSecretJWT() request = AccessTokenRequest() + # No preference -> default == RS256 _service_context.registration_response = {} token_service = entity.get_service("") @@ -448,6 +455,7 @@ def test_get_audience_and_algorithm_default_alg(self, entity): # Since I have an RSA key this doesn't fail csj.construct(request, service=token_service, authn_endpoint="token_endpoint") + _rsa_key = entity.keyjar.get(key_use='sig', key_type='rsa', issuer_id='')[0] _jws = factory(request["client_assertion"]) assert _jws.jwt.headers["alg"] == "RS256" _rsa_key = _service_context.keyjar.get_signing_key(key_type="RSA")[0] @@ -471,6 +479,7 @@ def test_get_audience_and_algorithm_default_alg(self, entity): ] csj.construct(request, service=token_service, authn_endpoint="token_endpoint") + _ec_key = entity.keyjar.get(key_use='sig', key_type='ec', issuer_id='')[0] _jws = factory(request["client_assertion"]) # Should be ES256 since I have a key for ES256 assert _jws.jwt.headers["alg"] == "ES256" diff --git a/tests/test_client_10_entity.py b/tests/test_client_10_entity.py index 6daeca49..095e2a5f 100644 --- a/tests/test_client_10_entity.py +++ b/tests/test_client_10_entity.py @@ -6,6 +6,8 @@ from idpyoidc.client.entity import Entity +KEYSPEC = [{"type": "RSA", "use": ["sig"]}] + class TestClientInfo(object): @pytest.fixture(autouse=True) @@ -29,7 +31,7 @@ def test_import_keys_file(self): keyspec = {"file": {"rsa": [file_path]}} self.entity.import_keys(keyspec) - # Now there should be 2, the second a RSA key for signing + # Now there should be 3, 2 RSA keys assert len(self.entity.keyjar.get_issuer_keys("")) == 2 def test_import_keys_url(self): From 4c473084555311b48dbaa2295508d76a5dd3f36d Mon Sep 17 00:00:00 2001 From: Roland Hedberg Date: Sun, 4 Dec 2022 20:11:01 +0100 Subject: [PATCH 46/76] Rebased onto improved --- src/idpyoidc/client/oidc/registration.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/idpyoidc/client/oidc/registration.py b/src/idpyoidc/client/oidc/registration.py index 1e29935c..3c0b019f 100644 --- a/src/idpyoidc/client/oidc/registration.py +++ b/src/idpyoidc/client/oidc/registration.py @@ -66,15 +66,22 @@ def update_service_context(self, resp, key="", **kwargs): _context = self.upstream_get("context") _context.map_preferred_to_registered(resp) - _keyjar = _context.keyjar _context.registration_response = resp _client_id = _context.get_usage("client_id") if _client_id: - if _client_id not in _keyjar: - _keyjar.import_jwks(_keyjar.export_jwks(True, ""), issuer_id=_client_id) + _context.client_id = _client_id + _keyjar = self.upstream_get('attribute', 'keyjar') + if _keyjar: + if _client_id not in _keyjar: + _keyjar.import_jwks(_keyjar.export_jwks(True, ""), issuer_id=_client_id) _client_secret = _context.get_usage("client_secret") if _client_secret: + if not _keyjar: + _entity = self.upstream_get('unit') + _keyjar = _entity.keyjar = KeyJar() + + _context.client_secret = _client_secret _keyjar.add_symmetric("", _client_secret) _keyjar.add_symmetric(_client_id, _client_secret) try: From 18212474ea0221ac9b42cdd99623a96692dfc6d5 Mon Sep 17 00:00:00 2001 From: Roland Hedberg Date: Sun, 4 Dec 2022 20:13:53 +0100 Subject: [PATCH 47/76] Rebased onto improved --- src/idpyoidc/client/entity.py | 10 ++++-- src/idpyoidc/client/rp_handler.py | 7 ++-- src/idpyoidc/node.py | 50 ++++++++++++++++----------- src/idpyoidc/server/__init__.py | 30 ++++++++-------- tests/test_server_08_id_token.py | 2 +- tests/test_server_13_user_authn.py | 12 +++---- tests/test_server_16_endpoint.py | 2 +- tests/test_server_17_client_authn.py | 30 ++++++++-------- tests/test_server_20d_client_authn.py | 14 ++++---- tests/test_server_20f_userinfo.py | 4 +-- 10 files changed, 88 insertions(+), 73 deletions(-) diff --git a/src/idpyoidc/client/entity.py b/src/idpyoidc/client/entity.py index 5db28cc7..7455ddf6 100644 --- a/src/idpyoidc/client/entity.py +++ b/src/idpyoidc/client/entity.py @@ -73,7 +73,7 @@ def redirect_uris_from_callback_uris(callback_uris): return res -class Entity(Unit): +class Entity(Unit): # This is a Client parameter = { 'entity_id': None, 'jwks_uri': None, @@ -97,9 +97,15 @@ def __init__( key_conf: Optional[dict] = None, entity_id: Optional[str] = '' ): + if config is None: + config = {} + + self.entity_id = entity_id or config.get('entity_id') + self.client_id = config.get('client_id', entity_id) + Unit.__init__(self, upstream_get=upstream_get, keyjar=keyjar, httpc=httpc, httpc_params=httpc_params, config=config, key_conf=key_conf, - entity_id=entity_id) + client_id=self.client_id) if context: self.context = context diff --git a/src/idpyoidc/client/rp_handler.py b/src/idpyoidc/client/rp_handler.py index b444e107..ab847249 100644 --- a/src/idpyoidc/client/rp_handler.py +++ b/src/idpyoidc/client/rp_handler.py @@ -4,6 +4,7 @@ from typing import Optional from cryptojwt import as_unicode +from cryptojwt import KeyJar from cryptojwt.key_bundle import keybundle_from_local_file from cryptojwt.key_jar import init_key_jar from cryptojwt.utils import as_bytes @@ -17,7 +18,6 @@ from idpyoidc.exception import MessageException from idpyoidc.exception import MissingRequiredAttribute from idpyoidc.exception import NotForMe -from idpyoidc.message.oauth2 import ResponseMessage from idpyoidc.message.oauth2 import is_error_message from idpyoidc.message.oidc import AuthorizationRequest from idpyoidc.message.oidc import AuthorizationResponse @@ -37,6 +37,7 @@ class RPHandler(object): + def __init__( self, base_url: Optional[str] = "", @@ -201,7 +202,9 @@ def init_client(self, issuer): if _context.iss_hash: self.hash2issuer[_context.iss_hash] = issuer # If non persistent - _keyjar = client.get_attribute('keyjar') + _keyjar = client.keyjar + if not _keyjar: + _keyjar = client.keyjar = KeyJar() _keyjar.load(self.keyjar.dump()) # If persistent nothings has to be copied diff --git a/src/idpyoidc/node.py b/src/idpyoidc/node.py index d4febc6c..e70fc1f1 100644 --- a/src/idpyoidc/node.py +++ b/src/idpyoidc/node.py @@ -18,8 +18,9 @@ def __init__(self, httpc: Optional[object] = None, httpc_params: Optional[dict] = None, config: Optional[Union[Configuration, dict]] = None, - entity_id: Optional[str] = "", - key_conf: Optional[dict] = None + key_conf: Optional[dict] = None, + issuer_id: Optional[str] = '', + client_id: Optional[str] = '' ): ImpExp.__init__(self) self.upstream_get = upstream_get @@ -28,20 +29,16 @@ def __init__(self, if config is None: config = {} - _client_id = config.get('client_id', "") - self.entity_id = entity_id or config.get('entity_id', "") - if not self.entity_id: - self.entity_id = _client_id - if keyjar or key_conf or config.get('key_conf') or config.get('jwks') or config.get('keys'): - self.keyjar = self._keyjar(keyjar, conf=config, entity_id=self.entity_id, - key_conf=key_conf) - if _client_id: - self.keyjar.add_symmetric('', _client_id) + # Should be either one + id = issuer_id or client_id + self.keyjar = self._keyjar(keyjar, conf=config, key_conf=key_conf, id=id) + if client_id: + self.keyjar.add_symmetric('', client_id) else: - if _client_id: + if client_id: self.keyjar = KeyJar() - self.keyjar.add_symmetric('', _client_id) + self.keyjar.add_symmetric('', client_id) else: self.keyjar = None @@ -80,8 +77,9 @@ def get_unit(self, *args): def _keyjar(self, keyjar: Optional[KeyJar] = None, conf: Optional[Union[dict, Configuration]] = None, - entity_id: Optional[str] = "", - key_conf: Optional[dict] = None): + key_conf: Optional[dict] = None, + id: Optional[str] = "", + ): if keyjar is None: if key_conf: keys_args = {k: v for k, v in key_conf.items() if k != "uri_path"} @@ -100,9 +98,9 @@ def _keyjar(self, else: _keyjar = None - if _keyjar and "" in _keyjar and entity_id: + if _keyjar and "" in _keyjar and id: # make sure I have the keys under my own name too (if I know it) - _keyjar.import_jwks_as_json(_keyjar.export_jwks_as_json(True, ""), entity_id) + _keyjar.import_jwks_as_json(_keyjar.export_jwks_as_json(True, ""), id) return _keyjar else: @@ -129,17 +127,24 @@ def __init__(self, keyjar: Optional[KeyJar] = None, context: Optional[ImpExp] = None, config: Optional[Union[Configuration, dict]] = None, - jwks_uri: Optional[str] = "", + # jwks_uri: Optional[str] = "", entity_id: Optional[str] = "", key_conf: Optional[dict] = None ): + if config is None: + config = {} + + self.entity_id = entity_id or config.get('entity_id') + self.client_id = config.get('client_id', entity_id) + Unit.__init__(self, upstream_get=upstream_get, keyjar=keyjar, httpc=httpc, - httpc_params=httpc_params, config=config, entity_id=entity_id, + httpc_params=httpc_params, config=config, client_id=self.client_id, key_conf=key_conf) self.context = context or None +# Neither client nor Server class Collection(Unit): def __init__(self, @@ -153,8 +158,13 @@ def __init__(self, functions: Optional[dict] = None, metadata: Optional[dict] = None ): + if config is None: + config = {} + + self.entity_id = entity_id or config.get('entity_id') - Unit.__init__(self, upstream_get, keyjar, httpc, httpc_params, config, entity_id, key_conf) + Unit.__init__(self, upstream_get, keyjar, httpc, httpc_params, config, + issuer_id=self.entity_id, key_conf=key_conf) _args = { 'upstream_get': self.unit_get diff --git a/src/idpyoidc/server/__init__.py b/src/idpyoidc/server/__init__.py index cc31d6a9..4877e07e 100644 --- a/src/idpyoidc/server/__init__.py +++ b/src/idpyoidc/server/__init__.py @@ -50,14 +50,18 @@ def __init__( entity_id: Optional[str] = "", key_conf: Optional[dict] = None ): + self.entity_id = entity_id or conf.get('entity_id') + self.issuer = conf.get('issuer', self.entity_id) + Unit.__init__(self, config=conf, keyjar=keyjar, httpc=httpc, upstream_get=upstream_get, - httpc_params=httpc_params, entity_id=entity_id, key_conf=key_conf) + httpc_params=httpc_params, key_conf=key_conf, + issuer_id=self.issuer) self.upstream_get = upstream_get self.conf = conf self.endpoint_context = EndpointContext( conf=conf, - upstream_get=self.server_get, # points to me + upstream_get=self.unit_get, # points to me cwd=cwd, cookie_handler=cookie_handler ) @@ -66,14 +70,14 @@ def __init__( self.setup_authentication(self.endpoint_context) - self.endpoint = do_endpoints(conf, self.server_get) + self.endpoint = do_endpoints(conf, self.unit_get) _cap = get_provider_capabilities(conf, self.endpoint) self.endpoint_context.provider_info = self.endpoint_context.create_providerinfo(_cap) self.endpoint_context.do_add_on(endpoints=self.endpoint) self.endpoint_context.session_manager = create_session_manager( - self.server_get, + self.unit_get, self.endpoint_context.th_args, sub_func=self.endpoint_context._sub_func, conf=self.conf, @@ -85,14 +89,14 @@ def __init__( self.setup_client_authn_methods() for endpoint_name, _ in self.endpoint.items(): - self.endpoint[endpoint_name].upstream_get = self.server_get + self.endpoint[endpoint_name].upstream_get = self.unit_get _token_endp = self.endpoint.get("token") if _token_endp: _token_endp.allow_refresh = allow_refresh_token(self.endpoint_context) self.endpoint_context.claims_interface = init_service( - conf["claims_interface"], self.server_get + conf["claims_interface"], self.unit_get ) _id_token_handler = self.endpoint_context.session_manager.token_handler.handler.get( @@ -101,12 +105,6 @@ def __init__( if _id_token_handler: self.endpoint_context.provider_info.update(_id_token_handler.provider_info) - def server_get(self, what, *arg): - _func = getattr(self, "get_{}".format(what), None) - if _func: - return _func(*arg) - return None - def get_endpoints(self, *arg): return self.endpoint @@ -131,15 +129,15 @@ def get_entity(self, *args): def setup_authz(self): authz_spec = self.conf.get("authz") if authz_spec: - return init_service(authz_spec, self.server_get) + return init_service(authz_spec, self.unit_get) else: - return authz.Implicit(self.server_get) + return authz.Implicit(self.unit_get) def setup_authentication(self, target): _conf = self.conf.get("authentication") if _conf: target.authn_broker = populate_authn_broker( - _conf, self.server_get, target.template_handler + _conf, self.unit_get, target.template_handler ) else: target.authn_broker = {} @@ -169,5 +167,5 @@ def setup_login_hint_lookup(self): def setup_client_authn_methods(self): self.endpoint_context.client_authn_method = client_auth_setup( - self.server_get, self.conf.get("client_authn_methods") + self.unit_get, self.conf.get("client_authn_methods") ) diff --git a/tests/test_server_08_id_token.py b/tests/test_server_08_id_token.py index 2bfbc56a..85c99b1a 100644 --- a/tests/test_server_08_id_token.py +++ b/tests/test_server_08_id_token.py @@ -64,7 +64,7 @@ def full_path(local_file): conf = { "issuer": "https://example.com/", "httpc_params": {"verify": False, "timeout": 1}, - "keys": {"key_defs": KEYDEFS, "uri_path": "static/jwks.json"}, + "key_conf": {"key_defs": KEYDEFS, "uri_path": "static/jwks.json"}, "token_handler_args": { "jwks_def": { "private_path": "private/token_jwks.json", diff --git a/tests/test_server_13_user_authn.py b/tests/test_server_13_user_authn.py index b9279dce..c274855f 100644 --- a/tests/test_server_13_user_authn.py +++ b/tests/test_server_13_user_authn.py @@ -140,20 +140,18 @@ def test_userpassjinja2(self): "kwargs": {"filename": full_path("passwd.json")}, } template_handler = self.endpoint_context.template_handler - res = UserPassJinja2(db, template_handler, upstream_get=self.server.server_get) + res = UserPassJinja2(db, template_handler, upstream_get=self.server.unit_get) res() assert "page_header" in res.kwargs def test_basic_auth(self): basic_auth = base64.b64encode(b"diana:krall").decode() - ba = BasicAuthn(pwd={"diana": "krall"}, upstream_get=self.server.server_get) - _info, _time_stamp = ba.authenticated_as(client_id="", authorization=f"Basic {basic_auth}") - assert _info + ba = BasicAuthn(pwd={"diana": "krall"}, upstream_get=self.server.unit_get) + ba.authenticated_as(client_id="", authorization=f"Basic {basic_auth}") def test_no_auth(self): basic_auth = base64.b64encode( b"D\xfd\x8a\x85\xa6\xd1\x16\xe4\\6\x1e\x9ds~\xc3\t\x95\x99\x83\x91\x1f\xfb:iviviviv" ) - ba = SymKeyAuthn(symkey=b"0" * 32, ttl=600, upstream_get=self.server.server_get) - _info, _time_stamp = ba.authenticated_as(client_id="", authorization=basic_auth) - assert _info + ba = SymKeyAuthn(symkey=b"0" * 32, ttl=600, upstream_get=self.server.unit_get) + ba.authenticated_as(client_id="", authorization=basic_auth) diff --git a/tests/test_server_16_endpoint.py b/tests/test_server_16_endpoint.py index c00da5ae..25b863f6 100755 --- a/tests/test_server_16_endpoint.py +++ b/tests/test_server_16_endpoint.py @@ -78,7 +78,7 @@ def create_endpoint(self): server.endpoint_context.cdb["client_id"] = {} self.endpoint_context = server.endpoint_context - _endpoints = do_endpoints(conf, server.server_get) + _endpoints = do_endpoints(conf, server.unit_get) self.endpoint = _endpoints[""] def test_parse_urlencoded(self): diff --git a/tests/test_server_17_client_authn.py b/tests/test_server_17_client_authn.py index e8276d84..dbd6fba4 100644 --- a/tests/test_server_17_client_authn.py +++ b/tests/test_server_17_client_authn.py @@ -130,8 +130,8 @@ def setup(self): server = Server(conf=CONF, keyjar=KEYJAR) server.endpoint_context.cdb[client_id] = {"client_secret": client_secret} self.endpoint_context = server.endpoint_context - server.endpoint = do_endpoints(CONF, server.server_get) - self.method = ClientSecretBasic(server.server_get) + server.endpoint = do_endpoints(CONF, server.unit_get) + self.method = ClientSecretBasic(server.unit_get) def test_client_secret_basic(self): _token = "{}:{}".format(client_id, client_secret) @@ -165,7 +165,7 @@ def create_method(self): server = Server(conf=CONF, keyjar=KEYJAR) server.endpoint_context.cdb[client_id] = {"client_secret": client_secret} self.endpoint_context = server.endpoint_context - self.method = ClientSecretPost(server.server_get) + self.method = ClientSecretPost(server.unit_get) def test_client_secret_post(self): request = {"client_id": client_id, "client_secret": client_secret} @@ -188,7 +188,7 @@ def create_method(self): server = Server(conf=CONF, keyjar=KEYJAR) server.endpoint_context.cdb[client_id] = {"client_secret": client_secret} self.endpoint_context = server.endpoint_context - self.method = ClientSecretJWT(server.server_get) + self.method = ClientSecretJWT(server.unit_get) def test_client_secret_jwt(self): client_keyjar = KeyJar() @@ -214,10 +214,10 @@ class TestPrivateKeyJWT: def create_method(self): server = Server(conf=CONF, keyjar=KEYJAR) server.endpoint_context.cdb[client_id] = {"client_secret": client_secret} - server.endpoint = do_endpoints(CONF, server.server_get) + server.endpoint = do_endpoints(CONF, server.unit_get) self.server = server self.endpoint_context = server.endpoint_context - self.method = PrivateKeyJWT(server.server_get) + self.method = PrivateKeyJWT(server.unit_get) def test_private_key_jwt(self): # Own dynamic keys @@ -307,10 +307,10 @@ class TestBearerHeader: def create_method(self): server = Server(conf=CONF, keyjar=KEYJAR) server.endpoint_context.cdb[client_id] = {"client_secret": client_secret} - server.endpoint = do_endpoints(CONF, server.server_get) + server.endpoint = do_endpoints(CONF, server.unit_get) self.server = server self.endpoint_context = server.endpoint_context - self.method = BearerHeader(server.server_get) + self.method = BearerHeader(server.unit_get) def test_bearerheader(self): authorization_info = "Bearer 1234567890" @@ -330,10 +330,10 @@ class TestBearerBody: def create_method(self): server = Server(conf=CONF, keyjar=KEYJAR) server.endpoint_context.cdb[client_id] = {"client_secret": client_secret} - server.endpoint = do_endpoints(CONF, server.server_get) + server.endpoint = do_endpoints(CONF, server.unit_get) self.server = server self.endpoint_context = server.endpoint_context - self.method = BearerBody(server.server_get) + self.method = BearerBody(server.unit_get) def test_bearer_body(self): request = {"access_token": "1234567890"} @@ -350,10 +350,10 @@ class TestJWSAuthnMethod: def create_method(self): server = Server(conf=CONF, keyjar=KEYJAR) server.endpoint_context.cdb[client_id] = {"client_secret": client_secret} - server.endpoint = do_endpoints(CONF, server.server_get) + server.endpoint = do_endpoints(CONF, server.unit_get) self.server = server self.endpoint_context = server.endpoint_context - self.method = JWSAuthnMethod(server.server_get) + self.method = JWSAuthnMethod(server.unit_get) def test_jws_authn_method_wrong_key(self): client_keyjar = KeyJar() @@ -474,7 +474,7 @@ class TestVerify: def create_method(self): self.server = Server(conf=CONF, keyjar=KEYJAR) self.server.endpoint_context.cdb[client_id] = {"client_secret": client_secret} - self.server.endpoint = do_endpoints(CONF, self.server.server_get) + self.server.endpoint = do_endpoints(CONF, self.server.unit_get) self.endpoint_context = self.server.get_context() def test_verify_per_client(self): @@ -612,7 +612,7 @@ class TestVerify2: def create_method(self): self.server = Server(conf=CONF, keyjar=KEYJAR) self.server.endpoint_context.cdb[client_id] = {"client_secret": client_secret} - self.server.endpoint = do_endpoints(CONF, self.server.server_get) + self.server.endpoint = do_endpoints(CONF, self.server.unit_get) self.endpoint_context = self.server.get_context() def test_verify_client_jws_authn_method(self): @@ -742,7 +742,7 @@ class Mock: conf["endpoint"]["registration"]["kwargs"]["client_authn_method"] = ["custom"] server = Server(conf=conf, keyjar=KEYJAR) server.endpoint_context.cdb[client_id] = {"client_secret": client_secret} - server.endpoint = do_endpoints(CONF, server.server_get) + server.endpoint = do_endpoints(CONF, server.unit_get) request = {"redirect_uris": ["https://example.com/cb"]} res = verify_client( diff --git a/tests/test_server_20d_client_authn.py b/tests/test_server_20d_client_authn.py index f2552f80..2bb761d8 100755 --- a/tests/test_server_20d_client_authn.py +++ b/tests/test_server_20d_client_authn.py @@ -93,7 +93,7 @@ def setup(self): server = Server(conf=CONF, keyjar=KEYJAR) server.endpoint_context.cdb[client_id] = {"client_secret": client_secret} self.endpoint_context = server.endpoint_context - self.method = ClientSecretBasic(server.server_get) + self.method = ClientSecretBasic(server.unit_get) def test_client_secret_basic(self): _token = "{}:{}".format(client_id, client_secret) @@ -127,7 +127,7 @@ def create_method(self): server = Server(conf=CONF, keyjar=KEYJAR) server.endpoint_context.cdb[client_id] = {"client_secret": client_secret} self.endpoint_context = server.endpoint_context - self.method = ClientSecretPost(server.server_get) + self.method = ClientSecretPost(server.unit_get) def test_client_secret_post(self): request = {"client_id": client_id, "client_secret": client_secret} @@ -150,7 +150,7 @@ def create_method(self): server = Server(conf=CONF, keyjar=KEYJAR) server.endpoint_context.cdb[client_id] = {"client_secret": client_secret} self.endpoint_context = server.endpoint_context - self.method = ClientSecretJWT(server.server_get) + self.method = ClientSecretJWT(server.unit_get) def test_client_secret_jwt(self): client_keyjar = KeyJar() @@ -178,7 +178,7 @@ def create_method(self): server.endpoint_context.cdb[client_id] = {"client_secret": client_secret} self.server = server self.endpoint_context = server.endpoint_context - self.method = PrivateKeyJWT(server.server_get) + self.method = PrivateKeyJWT(server.unit_get) def test_private_key_jwt(self): # Own dynamic keys @@ -266,7 +266,7 @@ def create_method(self): server.endpoint_context.cdb[client_id] = {"client_secret": client_secret} self.server = server self.endpoint_context = server.endpoint_context - self.method = BearerHeader(server.server_get) + self.method = BearerHeader(server.unit_get) def test_bearerheader(self): authorization_info = "Bearer 1234567890" @@ -288,7 +288,7 @@ def create_method(self): server.endpoint_context.cdb[client_id] = {"client_secret": client_secret} self.server = server self.endpoint_context = server.endpoint_context - self.method = BearerBody(server.server_get) + self.method = BearerBody(server.unit_get) def test_bearer_body(self): request = {"access_token": "1234567890"} @@ -307,7 +307,7 @@ def create_method(self): server.endpoint_context.cdb[client_id] = {"client_secret": client_secret} self.server = server self.endpoint_context = server.endpoint_context - self.method = JWSAuthnMethod(server.server_get) + self.method = JWSAuthnMethod(server.unit_get) def test_jws_authn_method_wrong_key(self): client_keyjar = KeyJar() diff --git a/tests/test_server_20f_userinfo.py b/tests/test_server_20f_userinfo.py index 57007a95..161be87c 100644 --- a/tests/test_server_20f_userinfo.py +++ b/tests/test_server_20f_userinfo.py @@ -202,7 +202,7 @@ def create_endpoint_context(self): "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"] } self.session_manager = self.endpoint_context.session_manager - self.claims_interface = ClaimsInterface(server.server_get) + self.claims_interface = ClaimsInterface(server.unit_get) self.user_id = "diana" self.server = server @@ -427,7 +427,7 @@ def create_endpoint_context(self, conf): "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access", "research_and_scholarship"] } self.session_manager = self.endpoint_context.session_manager - self.claims_interface = ClaimsInterface(self.server.server_get) + self.claims_interface = ClaimsInterface(self.server.unit_get) self.user_id = "diana" def _create_session(self, auth_req, sub_type="public", sector_identifier=""): From dea10b288caecdcfcf1813aed89a7fe996611b1d Mon Sep 17 00:00:00 2001 From: roland Date: Tue, 1 Nov 2022 19:44:06 +0100 Subject: [PATCH 48/76] Deal with the fact that some responses may be signed and some clear text unsigned. --- src/idpyoidc/node.py | 68 +++++++++++++++++---------------- src/idpyoidc/server/endpoint.py | 10 ++++- 2 files changed, 43 insertions(+), 35 deletions(-) diff --git a/src/idpyoidc/node.py b/src/idpyoidc/node.py index e70fc1f1..9dcafdb7 100644 --- a/src/idpyoidc/node.py +++ b/src/idpyoidc/node.py @@ -4,11 +4,45 @@ from cryptojwt import KeyJar from cryptojwt.key_jar import init_key_jar + from idpyoidc.configure import Configuration from idpyoidc.impexp import ImpExp from idpyoidc.util import instantiate +def create_keyjar( + keyjar: Optional[KeyJar] = None, + conf: Optional[Union[dict, Configuration]] = None, + key_conf: Optional[dict] = None, + id: Optional[str] = "", +): + if keyjar is None: + if key_conf: + keys_args = {k: v for k, v in key_conf.items() if k != "uri_path"} + _keyjar = init_key_jar(**keys_args) + elif conf: + if "keys" in conf: + keys_args = {k: v for k, v in conf["keys"].items() if k != "uri_path"} + _keyjar = init_key_jar(**keys_args) + elif "key_conf" in conf: + keys_args = {k: v for k, v in conf["key_conf"].items() if k != "uri_path"} + _keyjar = init_key_jar(**keys_args) + else: + _keyjar = KeyJar() + if "jwks" in conf: + _keyjar.import_jwks(conf["jwks"], "") + else: + _keyjar = None + + if _keyjar and "" in _keyjar and id: + # make sure I have the keys under my own name too (if I know it) + _keyjar.import_jwks_as_json(_keyjar.export_jwks_as_json(True, ""), id) + + return _keyjar + else: + return keyjar + + class Unit(ImpExp): name = '' @@ -32,7 +66,7 @@ def __init__(self, if keyjar or key_conf or config.get('key_conf') or config.get('jwks') or config.get('keys'): # Should be either one id = issuer_id or client_id - self.keyjar = self._keyjar(keyjar, conf=config, key_conf=key_conf, id=id) + self.keyjar = create_keyjar(keyjar, conf=config, key_conf=key_conf, id=id) if client_id: self.keyjar.add_symmetric('', client_id) else: @@ -74,38 +108,6 @@ def set_attribute(self, attr, val): def get_unit(self, *args): return self - def _keyjar(self, - keyjar: Optional[KeyJar] = None, - conf: Optional[Union[dict, Configuration]] = None, - key_conf: Optional[dict] = None, - id: Optional[str] = "", - ): - if keyjar is None: - if key_conf: - keys_args = {k: v for k, v in key_conf.items() if k != "uri_path"} - _keyjar = init_key_jar(**keys_args) - elif conf: - if "keys" in conf: - keys_args = {k: v for k, v in conf["keys"].items() if k != "uri_path"} - _keyjar = init_key_jar(**keys_args) - elif "key_conf" in conf: - keys_args = {k: v for k, v in conf["key_conf"].items() if k != "uri_path"} - _keyjar = init_key_jar(**keys_args) - else: - _keyjar = KeyJar() - if "jwks" in conf: - _keyjar.import_jwks(conf["jwks"], "") - else: - _keyjar = None - - if _keyjar and "" in _keyjar and id: - # make sure I have the keys under my own name too (if I know it) - _keyjar.import_jwks_as_json(_keyjar.export_jwks_as_json(True, ""), id) - - return _keyjar - else: - return keyjar - def topmost_unit(unit): if hasattr(unit, 'upstream_get'): diff --git a/src/idpyoidc/server/endpoint.py b/src/idpyoidc/server/endpoint.py index 5de31940..4a38a243 100755 --- a/src/idpyoidc/server/endpoint.py +++ b/src/idpyoidc/server/endpoint.py @@ -89,6 +89,7 @@ class Endpoint(object): request_placement = "query" response_format = "json" response_placement = "body" + response_content_type = "" client_authn_method = "" default_capabilities = None auth_method_attribute = "" @@ -382,7 +383,9 @@ def do_response( _response = "" content_type = kwargs.get("content_type") if content_type is None: - if self.response_format == "json": + if self.response_content_type: + content_type = self.response_content_type + elif self.response_format == "json": content_type = "application/json" elif self.response_format in ["jws", "jwe", "jose"]: content_type = "application/jose" @@ -402,7 +405,10 @@ def do_response( else: resp = json.dumps(_response) elif self.response_format in ["jws", "jwe", "jose"]: - content_type = "application/jose; charset=utf-8" + if self.response_content_type: + content_type = self.response_content_type + else: + content_type = "application/jose; charset=utf-8" resp = _response else: content_type = "application/x-www-form-urlencoded" From 32532bc074982f6ec49d484958a14e6d80c9b93c Mon Sep 17 00:00:00 2001 From: Roland Hedberg Date: Thu, 8 Dec 2022 15:24:05 +0100 Subject: [PATCH 49/76] Rebased onto improved - partly done --- src/idpyoidc/client/client_auth.py | 55 +++- src/idpyoidc/client/entity.py | 38 ++- src/idpyoidc/client/oauth2/__init__.py | 2 +- src/idpyoidc/client/oauth2/access_token.py | 2 +- src/idpyoidc/client/oauth2/server_metadata.py | 3 +- src/idpyoidc/client/oauth2/utils.py | 2 - src/idpyoidc/client/oidc/__init__.py | 1 - src/idpyoidc/client/oidc/access_token.py | 7 +- src/idpyoidc/client/oidc/authorization.py | 18 +- .../client/oidc/provider_info_discovery.py | 3 +- src/idpyoidc/client/oidc/registration.py | 4 +- src/idpyoidc/client/oidc/webfinger.py | 2 +- src/idpyoidc/client/provider/github.py | 2 +- src/idpyoidc/client/provider/linkedin.py | 2 +- src/idpyoidc/client/rp_handler.py | 33 +-- src/idpyoidc/client/service.py | 51 ++-- src/idpyoidc/client/service_context.py | 23 +- .../client/work_environment/__init__.py | 23 +- .../client/work_environment/oauth2.py | 2 + src/idpyoidc/client/work_environment/oidc.py | 38 +-- src/idpyoidc/context.py | 15 +- src/idpyoidc/node.py | 7 +- src/idpyoidc/server/__init__.py | 2 +- src/idpyoidc/server/util.py | 30 --- .../server/work_environment/oauth2.py | 2 + src/idpyoidc/server/work_environment/oidc.py | 2 + tests/request123456.jwt | 2 +- tests/static/jwks.json | 2 +- tests/test_12_context.py | 2 +- tests/test_client_02_entity.py | 126 ++++++++- tests/test_client_02b_entity_metadata.py | 51 ++-- tests/test_client_04_service.py | 6 +- tests/test_client_06_client_authn.py | 24 +- tests/test_client_13_service_context.py | 254 ------------------ tests/test_client_20_oauth2.py | 15 +- tests/test_client_21_oidc_service.py | 43 +-- tests/test_client_22_oidc.py | 1 + tests/test_client_25_cc_oauth2_service.py | 5 +- tests/test_client_27_conversation.py | 8 +- tests/test_client_28_rp_handler_oidc.py | 3 +- tests/x_test_ciba_01_backchannel_auth.py | 4 +- 41 files changed, 429 insertions(+), 486 deletions(-) delete mode 100644 tests/test_client_13_service_context.py diff --git a/src/idpyoidc/client/client_auth.py b/src/idpyoidc/client/client_auth.py index 1ab0b100..3d38dde4 100755 --- a/src/idpyoidc/client/client_auth.py +++ b/src/idpyoidc/client/client_auth.py @@ -1,6 +1,7 @@ """Implementation of a number of client authentication methods.""" import base64 import logging +from typing import Optional from cryptojwt.exception import MissingKey from cryptojwt.exception import UnsupportedAlgorithm @@ -17,6 +18,7 @@ from idpyoidc.util import rndstr from .util import sanitize from ..message import VREQUIRED +from ..util import instantiate # from idpyoidc.oidc.backchannel_authentication import ClientNotificationAuthn @@ -270,7 +272,7 @@ def find_token(request, token_type, service, **kwargs): try: return kwargs["access_token"] except KeyError: - # Get the latest acquired access token. + # Get the latest acquired token. _state = kwargs.get("state", kwargs.get("key")) _arg = service.upstream_get("context").cstate.get_set(_state, claim=[token_type]) return _arg.get("access_token") @@ -473,7 +475,7 @@ def _get_audience_and_algorithm(self, context, keyjar, **kwargs): algorithm = alg break - audience = context.provider_info["token_endpoint"] + audience = context.provider_info.get("token_endpoint") else: audience = context.provider_info["issuer"] @@ -613,15 +615,50 @@ def valid_service_context(service_context, when=0): return True -def client_auth_setup(auth_set=None): +def get_client_authn_class(name): + try: + return CLIENT_AUTHN_METHOD[name] + except KeyError: + return None + + +def get_client_authn_methods(): + return list(CLIENT_AUTHN_METHOD.keys()) + + +def method_to_item(methods): + if isinstance(methods, list): + return {k: get_client_authn_class(k) for k in methods if get_client_authn_class(k)} + elif isinstance(methods, dict): + return methods + elif not methods: + return {} + + +def single_authn_setup(name, spec): + if isinstance(spec, dict): # class and kwargs + if spec: + return instantiate(spec["class"], **spec["kwargs"]) + else: + cls = get_client_authn_class(name) + return cls() + else: + if spec is None: + cls = get_client_authn_class(name) + elif isinstance(spec, str): + cls = importer(spec) + else: + cls = spec + return cls() + + +def client_auth_setup(auth_set: Optional[dict] = None): if auth_set is None: auth_set = CLIENT_AUTHN_METHOD - else: - auth_set.update(CLIENT_AUTHN_METHOD) + res = {} - for name, cls in auth_set.items(): - if isinstance(cls, str): - cls = importer(cls) - res[name] = cls() + for name, spec in auth_set.items(): + res[name] = single_authn_setup(name, spec) + return res diff --git a/src/idpyoidc/client/entity.py b/src/idpyoidc/client/entity.py index 7455ddf6..f80c5b3e 100644 --- a/src/idpyoidc/client/entity.py +++ b/src/idpyoidc/client/entity.py @@ -2,19 +2,21 @@ from typing import Optional from typing import Union +from cryptojwt import KeyBundle from cryptojwt import KeyJar +from cryptojwt.jwk.rsa import RSAKey +from cryptojwt.jwk.rsa import import_private_rsa_key_from_file from cryptojwt.key_jar import init_key_jar from idpyoidc.client.client_auth import client_auth_setup -from idpyoidc.client.client_auth import CLIENT_AUTHN_METHOD +from idpyoidc.client.client_auth import method_to_item from idpyoidc.client.configure import Configuration -from idpyoidc.client.configure import get_configuration from idpyoidc.client.defaults import DEFAULT_OAUTH2_SERVICES -from idpyoidc.client.defaults import DEFAULT_OIDC_SERVICES from idpyoidc.client.service import init_services from idpyoidc.client.service_context import ServiceContext from idpyoidc.context import OidcContext from idpyoidc.node import Unit +from idpyoidc.server.client_authn import client_auth_class logger = logging.getLogger(__name__) @@ -100,19 +102,13 @@ def __init__( if config is None: config = {} - self.entity_id = entity_id or config.get('entity_id') - self.client_id = config.get('client_id', entity_id) + _id = config.get('client_id') + self.client_id = self.entity_id = entity_id or config.get('entity_id', _id) Unit.__init__(self, upstream_get=upstream_get, keyjar=keyjar, httpc=httpc, httpc_params=httpc_params, config=config, key_conf=key_conf, client_id=self.client_id) - if context: - self.context = context - else: - self.context = ServiceContext(config=config, jwks_uri=jwks_uri, - upstream_get=self.unit_get) - if services: _srvs = services elif config: @@ -125,7 +121,11 @@ def __init__( self._service = init_services(service_definitions=_srvs, upstream_get=self.unit_get) - self.keyjar = self._service_context.get_preference('keyjar') + if context: + self.context = context + else: + self.context = ServiceContext(config=config, jwks_uri=jwks_uri, keyjar=self.keyjar, + upstream_get=self.unit_get, client_type=client_type) self.setup_client_authn_methods(config) self.upstream_get = upstream_get @@ -164,20 +164,14 @@ def get_client_id(self): def setup_client_authn_methods(self, config): if config and "client_authn_methods" in config: - self.context.client_authn_method = client_auth_setup( - config.get("client_authn_methods") - ) + _methods = config.get("client_authn_methods") + self.context.client_authn_methods = client_auth_setup(method_to_item(_methods)) else: - _default_methods = set( - [s.default_authn_method for s in self._service.db.values() if - s.default_authn_method]) - _methods = {m: CLIENT_AUTHN_METHOD[m] for m in _default_methods if - m in CLIENT_AUTHN_METHOD} - self.context.client_authn_method = client_auth_setup(_methods) + self.context.client_authn_methods = {} def import_keys(self, keyspec): """ - The client needs it's own set of keys. It can either dynamically + The client needs its own set of keys. It can either dynamically create them or load them from local storage. This method can also fetch other entities keys provided the URL points to a JWKS. diff --git a/src/idpyoidc/client/oauth2/__init__.py b/src/idpyoidc/client/oauth2/__init__.py index dc17c031..d3e95f5d 100755 --- a/src/idpyoidc/client/oauth2/__init__.py +++ b/src/idpyoidc/client/oauth2/__init__.py @@ -71,7 +71,7 @@ def __init__( if not client_type: client_type = "oauth2" - if verify_ssl in False: + if verify_ssl is False: # just ignore verify_ssl until it goes away if httpc_params: httpc_params['verify'] = False diff --git a/src/idpyoidc/client/oauth2/access_token.py b/src/idpyoidc/client/oauth2/access_token.py index 51b87ce4..fe259e86 100644 --- a/src/idpyoidc/client/oauth2/access_token.py +++ b/src/idpyoidc/client/oauth2/access_token.py @@ -2,9 +2,9 @@ import logging from typing import Optional +from idpyoidc.client.client_auth import get_client_authn_methods from idpyoidc.client.oauth2.utils import get_state_parameter from idpyoidc.client.service import Service -from idpyoidc.client.work_environment import get_client_authn_methods from idpyoidc.message import oauth2 from idpyoidc.message.oauth2 import ResponseMessage from idpyoidc.time_util import time_sans_frac diff --git a/src/idpyoidc/client/oauth2/server_metadata.py b/src/idpyoidc/client/oauth2/server_metadata.py index bb4ba306..ed91531e 100644 --- a/src/idpyoidc/client/oauth2/server_metadata.py +++ b/src/idpyoidc/client/oauth2/server_metadata.py @@ -1,5 +1,6 @@ """The service that talks to the OAuth2 provider info discovery endpoint.""" import logging +from typing import Optional from cryptojwt.key_jar import KeyJar @@ -126,5 +127,5 @@ def _update_service_context(self, resp): elif "jwks" in resp: _keyjar.load_keys(_pcr_issuer, jwks=resp["jwks"]) - def update_service_context(self, resp, **kwargs): + def update_service_context(self, resp, key: Optional[str] = "", **kwargs): return self._update_service_context(resp) diff --git a/src/idpyoidc/client/oauth2/utils.py b/src/idpyoidc/client/oauth2/utils.py index e16ce052..15d2c04c 100644 --- a/src/idpyoidc/client/oauth2/utils.py +++ b/src/idpyoidc/client/oauth2/utils.py @@ -25,7 +25,6 @@ def get_state_parameter(request_args, kwargs): def pick_redirect_uri( context, - entity, request_args: Optional[Union[Message, dict]] = None, response_type: Optional[str] = "", ): @@ -81,7 +80,6 @@ def pre_construct_pick_redirect_uri( **kwargs ): request_args["redirect_uri"] = pick_redirect_uri(service.upstream_get("context"), - entity=service.upstream_get("entity"), request_args=request_args) return request_args, {} diff --git a/src/idpyoidc/client/oidc/__init__.py b/src/idpyoidc/client/oidc/__init__.py index e7336f6a..ad3bf117 100755 --- a/src/idpyoidc/client/oidc/__init__.py +++ b/src/idpyoidc/client/oidc/__init__.py @@ -85,7 +85,6 @@ def __init__( services: Optional[dict] = None, httpc: Optional[Callable] = None, httpc_params: Optional[dict] = None, - context: Optional[OidcContext] = None, upstream_get: Optional[Callable] = None, key_conf: Optional[dict] = None, entity_id: Optional[str] = '', diff --git a/src/idpyoidc/client/oidc/access_token.py b/src/idpyoidc/client/oidc/access_token.py index 87324567..b4081fbd 100644 --- a/src/idpyoidc/client/oidc/access_token.py +++ b/src/idpyoidc/client/oidc/access_token.py @@ -2,10 +2,10 @@ from typing import Optional from typing import Union +from idpyoidc.client.client_auth import get_client_authn_methods from idpyoidc.client.exception import ParameterError from idpyoidc.client.oauth2 import access_token from idpyoidc.client.oidc import IDT2REG -from idpyoidc.client.work_environment import get_client_authn_methods from idpyoidc.work_environment import get_signing_algs from idpyoidc.message import Message from idpyoidc.message import oidc @@ -21,6 +21,7 @@ class AccessToken(access_token.AccessToken): msg_type = oidc.AccessTokenRequest response_cls = oidc.AccessTokenResponse error_msg = oidc.ResponseMessage + default_authn_method = "client_secret_basic" _supports = { "token_endpoint_auth_method": get_client_authn_methods, @@ -89,7 +90,3 @@ def update_service_context(self, resp, key: Optional[str] ="", **kwargs): resp["__expires_at"] = time_sans_frac() + int(resp["expires_in"]) _cstate.update(key, resp) - - def get_authn_method(self): - return self.upstream_get("context").get_preference("token_endpoint_auth_method", - self.default_authn_method) diff --git a/src/idpyoidc/client/oidc/authorization.py b/src/idpyoidc/client/oidc/authorization.py index c3865544..59d043af 100644 --- a/src/idpyoidc/client/oidc/authorization.py +++ b/src/idpyoidc/client/oidc/authorization.py @@ -57,7 +57,7 @@ class Authorization(authorization.Authorization): def __init__(self, upstream_get, conf=None, request_args: Optional[dict] = None): authorization.Authorization.__init__(self, upstream_get, conf=conf) - self.default_request_args = {"scope": ["openid"]} + self.default_request_args.update({"scope": ["openid"]}) if request_args: self.default_request_args.update(request_args) self.pre_construct = [ @@ -101,7 +101,7 @@ def post_parse_response(self, response, **kwargs): if _idt: # If there is a verified ID Token then we have to do nonce # verification. - _req_nonce = self.superior_get("context").cstate.get_set( + _req_nonce = self.upstream_get("context").cstate.get_set( response["state"], claim=['nonce']).get('nonce') if _req_nonce: _id_token_nonce = _idt.get("nonce") @@ -255,17 +255,13 @@ def construct_request_parameter( if k in kwargs } - _req = make_openid_request(req, **_mor_args) + _req_jwt = make_openid_request(req, **_mor_args) # Should the request be encrypted - _req = request_object_encryption(_req, _context, - self.upstream_get('attribute', 'keyjar'), - **kwargs) - - if request_param == "request": - req["request"] = _req - else: # MUST be request_uri - req["request_uri"] = self.store_request_on_file(_req, **kwargs) + _req_jwte = request_object_encryption(_req_jwt, _context, + self.upstream_get('attribute', 'keyjar'), + **kwargs) + return _req_jwte def oidc_post_construct(self, req, **kwargs): """ diff --git a/src/idpyoidc/client/oidc/provider_info_discovery.py b/src/idpyoidc/client/oidc/provider_info_discovery.py index 0acae2c7..bc9c1b7f 100644 --- a/src/idpyoidc/client/oidc/provider_info_discovery.py +++ b/src/idpyoidc/client/oidc/provider_info_discovery.py @@ -1,4 +1,5 @@ import logging +from typing import Optional from idpyoidc.client.exception import ConfigurationError from idpyoidc.client.oauth2 import server_metadata @@ -52,7 +53,7 @@ class ProviderInfoDiscovery(server_metadata.ServerMetadata): def __init__(self, upstream_get, conf=None): server_metadata.ServerMetadata.__init__(self, upstream_get, conf=conf) - def update_service_context(self, resp, **kwargs): + def update_service_context(self, resp, key: Optional[str] = '', **kwargs): _context = self.upstream_get("context") self._update_service_context(resp) _context.map_supported_to_preferred(resp) diff --git a/src/idpyoidc/client/oidc/registration.py b/src/idpyoidc/client/oidc/registration.py index 3c0b019f..13e32644 100644 --- a/src/idpyoidc/client/oidc/registration.py +++ b/src/idpyoidc/client/oidc/registration.py @@ -1,5 +1,7 @@ import logging +from cryptojwt import KeyJar + from idpyoidc.client.entity import response_types_to_grant_types from idpyoidc.client.service import Service from idpyoidc.client.work_environment.transform import create_registration_request @@ -101,7 +103,7 @@ def gather_request_args(self, **kwargs): @param kwargs: @return: """ - _context = self.client_get("service_context") + _context = self.upstream_get("context") req_args = create_registration_request(_context.work_environment.prefer, _context.supports()) if "request_args" in self.conf: req_args.update(self.conf["request_args"]) diff --git a/src/idpyoidc/client/oidc/webfinger.py b/src/idpyoidc/client/oidc/webfinger.py index 84235c30..c97e8284 100644 --- a/src/idpyoidc/client/oidc/webfinger.py +++ b/src/idpyoidc/client/oidc/webfinger.py @@ -49,7 +49,7 @@ def update_service_context(self, resp, key="", **kwargs): for link in links: if link["rel"] == self.rel: _href = link["href"] - _context = self.client_get('service_context') + _context = self.upstream_get('service_context') _http_allowed = 'http_links' in _context.get("allow", default={}) if _href.startswith("http://") and not _http_allowed: diff --git a/src/idpyoidc/client/provider/github.py b/src/idpyoidc/client/provider/github.py index dc50c906..ef1e9fee 100644 --- a/src/idpyoidc/client/provider/github.py +++ b/src/idpyoidc/client/provider/github.py @@ -1,6 +1,6 @@ +from idpyoidc.client.client_auth import get_client_authn_methods from idpyoidc.client.oauth2 import access_token from idpyoidc.client.oidc import userinfo -from idpyoidc.client.work_environment import get_client_authn_methods from idpyoidc.message import Message from idpyoidc.message import SINGLE_OPTIONAL_STRING from idpyoidc.message import SINGLE_REQUIRED_STRING diff --git a/src/idpyoidc/client/provider/linkedin.py b/src/idpyoidc/client/provider/linkedin.py index 916d6f35..419ad189 100644 --- a/src/idpyoidc/client/provider/linkedin.py +++ b/src/idpyoidc/client/provider/linkedin.py @@ -1,6 +1,6 @@ from idpyoidc.client.oauth2 import access_token from idpyoidc.client.oidc import userinfo -from idpyoidc.client.work_environment import get_client_authn_methods +from idpyoidc.client.client_auth import get_client_authn_methods from idpyoidc.message import Message from idpyoidc.message import SINGLE_OPTIONAL_JSON from idpyoidc.message import SINGLE_OPTIONAL_STRING diff --git a/src/idpyoidc/client/rp_handler.py b/src/idpyoidc/client/rp_handler.py index ab847249..e00599b2 100644 --- a/src/idpyoidc/client/rp_handler.py +++ b/src/idpyoidc/client/rp_handler.py @@ -32,6 +32,7 @@ from .oauth2 import Client from .oauth2 import dynamic_provider_info_discovery from .oauth2.utils import pick_redirect_uri +from ..message.oauth2 import ResponseMessage logger = logging.getLogger(__name__) @@ -127,7 +128,7 @@ def state2issuer(self, state): :return: An Issuer ID """ for _rp in self.issuer2rp.values(): - _iss = _rp.upstream_get("context").cstate.get_set( + _iss = _rp.get_context().cstate.get_set( state, claim=['iss']).get('iss') if _iss: return _iss @@ -155,7 +156,7 @@ def get_session_information(self, key, client=None): if not client: client = self.get_client_from_session_key(key) - return client.upstream_get("context").cstate.get(key) + return client.get_context().cstate.get(key) def init_client(self, issuer): """ @@ -198,7 +199,7 @@ def init_client(self, issuer): logger.error(message) raise - _context = client.upstream_get("context") + _context = client.get_context() if _context.iss_hash: self.hash2issuer[_context.iss_hash] = issuer # If non persistent @@ -206,7 +207,7 @@ def init_client(self, issuer): if not _keyjar: _keyjar = client.keyjar = KeyJar() _keyjar.load(self.keyjar.dump()) - # If persistent nothings has to be copied + # If persistent nothing has to be copied _context.base_url = self.base_url _context.jwks_uri = self.jwks_uri @@ -425,13 +426,13 @@ def init_authorization( else: raise ValueError("Missing state/session key") - _context = client.upstream_get("context") - _entity = client.upstream_get("entity") + _context = client.get_context() + #_entity = client.upstream_get("entity") _nonce = rndstr(24) _response_type = self._get_response_type(_context, req_args) request_args = { "redirect_uri": pick_redirect_uri( - _context, _entity, request_args=req_args, response_type=_response_type + _context, request_args=req_args, response_type=_response_type ), "scope": _context.work_environment.get_usage("scope"), "response_type": _response_type, @@ -522,7 +523,7 @@ def get_client_authn_method(client, endpoint): :return: The client authentication method """ if endpoint == "token_endpoint": - am = client.upstream_get("context").get_usage("token_endpoint_auth_method") + am = client.get_context().get_usage("token_endpoint_auth_method") if not am: return "" else: @@ -546,7 +547,7 @@ def get_tokens(self, state, client: Optional[Client] = None): if client is None: client = self.get_client_from_session_key(state) - _context = client.upstream_get("context") + _context = client.get_context() _claims = _context.cstate.get_set(state, claim=['code', 'redirect_uri']) req_args = { @@ -632,7 +633,7 @@ def get_user_info(self, state, client=None, access_token="", **kwargs): client = self.get_client_from_session_key(state) if not access_token: - _arg = client.upstream_get("context").cstate.get_set(state, claim=["access_token"]) + _arg = client.get_context().cstate.get_set(state, claim=["access_token"]) access_token = _arg["access_token"] request_args = {"access_token": access_token} @@ -784,9 +785,9 @@ def finalize(self, issuer, response, behaviour_args: Optional[dict] = None): know about. Once the consumer has redirected the user back to the callback URL there might be a number of services that the client should - use. Which one those are are defined by the client configuration. + use. Which one those are defined by the client configuration. - :param behaviour_args: For fine tuning + :param behaviour_args: For finetuning :param issuer: Who sent the response :param response: The Authorization response as a dictionary :returns: A dictionary with two claims: @@ -875,7 +876,7 @@ def has_active_authentication(self, state): client = self.get_client_from_session_key(state) # Look for an IdToken - _arg = client.upstream_get("context").cstate.get_set(state, + _arg = client.get_context().cstate.get_set(state, claim=["__verified_id_token"]) if _arg: @@ -899,7 +900,7 @@ def get_valid_access_token(self, state): now = utc_time_sans_frac() client = self.get_client_from_session_key(state) - _context = client.upstream_get("context") + _context = client.get_context() _args = _context.cstate.get_set(state, claim=["access_token", "__expires_at"]) if "access_token" in _args: access_token = _args["access_token"] @@ -973,7 +974,7 @@ def close( def clear_session(self, state): client = self.get_client_from_session_key(state) - client.upstream_get("context").cstate.remove_state(state) + client.get_context().cstate.remove_state(state) def backchannel_logout(client, request="", request_args=None): @@ -1033,7 +1034,7 @@ def load_registration_response(client, request_args=None): :param client: A :py:class:`idpyoidc.client.oidc.Client` instance """ - if not client.upstream_get("context").get_client_id(): + if not client.get_context().get_client_id(): try: response = client.do_request("registration", request_args=request_args) except KeyError: diff --git a/src/idpyoidc/client/service.py b/src/idpyoidc/client/service.py index 8e3053b2..025be666 100644 --- a/src/idpyoidc/client/service.py +++ b/src/idpyoidc/client/service.py @@ -16,6 +16,9 @@ from idpyoidc.message.oauth2 import ResponseMessage from idpyoidc.message.oauth2 import is_error_message from idpyoidc.util import importer +from .client_auth import client_auth_setup +from .client_auth import method_to_item +from .client_auth import single_authn_setup from .configure import Configuration from .exception import ResponseError from .util import get_http_body @@ -78,6 +81,7 @@ def __init__( self.upstream_get = upstream_get self.default_request_args = {} + self.client_authn_methods = {} if conf: self.conf = conf @@ -85,10 +89,10 @@ def __init__( "msg_type", "response_cls", "error_msg", - "default_authn_method", "http_method", "request_body_type", "response_body_type", + "default_authn_method" ]: if param in conf: setattr(self, param, conf[param]) @@ -98,8 +102,20 @@ def __init__( self.default_request_args = _default_request_args del conf["request_args"] + _client_authn_methods = conf.get("client_authn_methods", None) + if _client_authn_methods: + self.client_authn_methods = client_auth_setup(method_to_item(_client_authn_methods)) + + if self.default_authn_method: + if self.default_authn_method not in self.client_authn_methods: + self.client_authn_methods[self.default_authn_method] = single_authn_setup( + self.default_authn_method, None) + else: self.conf = {} + if self.default_authn_method: + self.client_authn_methods[self.default_authn_method] = single_authn_setup( + self.default_authn_method, None) # pull in all the modifiers self.pre_construct = [] @@ -138,8 +154,8 @@ def gather_request_args(self, **kwargs): val = _use.get(prop) if not val: - #val = request_claim(_context, prop) - #if not val: + # val = request_claim(_context, prop) + # if not val: val = self.default_request_args.get(prop) if val: @@ -269,12 +285,15 @@ def init_authentication_method(self, request, authn_method, http_args=None, **kw if authn_method: LOGGER.debug("Client authn method: %s", authn_method) - _context = self.upstream_get("context") - try: - _func = _context.client_authn_method[authn_method] - except KeyError: # not one of the common - LOGGER.error(f"Unknown client authentication method: {authn_method}") - raise Unsupported(f"Unknown client authentication method: {authn_method}") + if self.client_authn_methods and authn_method in self.client_authn_methods: + _func = self.client_authn_methods[authn_method] + else: + _context = self.upstream_get("context") + try: + _func = _context.client_authn_methods[authn_method] + except KeyError: # not one of the common + LOGGER.error(f"Unknown client authentication method: {authn_method}") + raise Unsupported(f"Unknown client authentication method: {authn_method}") return _func.construct(request, self, http_args=http_args, **kwargs) @@ -505,7 +524,8 @@ def _do_jwt(self, info): enc_algs = _context.get_enc_alg_enc(self.service_name) args["allowed_enc_algs"] = enc_algs["alg"] args["allowed_enc_encs"] = enc_algs["enc"] - _jwt = JWT(key_jar=self.upstream_get('attribute','keyjar'), **args) + + _jwt = JWT(key_jar=_context.get_keyjar(), **args) _jwt.iss = _context.get_client_id() return _jwt.unpack(info) @@ -556,7 +576,7 @@ def parse_response( :param sformat: Which serialization that was used :param state: The state :param kwargs: Extra key word arguments - :return: The parsed and to some extend verified response + :return: The parsed and to some extent verified response """ if not sformat: @@ -570,11 +590,12 @@ def parse_response( self._do_jwt(info) sformat = "dict" except Exception: - _context = self.client_get("service_context") - resp = self.response_cls().from_jwe(info, keys=_context.keyjar) + _keyjar = self.upstream_get("attribute", 'keyjar') + resp = self.response_cls().from_jwe(info, keys=_keyjar) elif sformat == "jwe": - _context = self.client_get("service_context") - resp = self.response_cls().from_jwe(info, keys=_context.keyjar) + _keyjar = self.upstream_get("attribute", 'keyjar') + _client_id = self.upstream_get("attribute", 'client_id') + resp = self.response_cls().from_jwe(info, keys=_keyjar.get_issuer_keys(_client_id)) # If format is urlencoded 'info' may be a URL # in which case I have to get at the query/fragment part elif sformat == "urlencoded": diff --git a/src/idpyoidc/client/service_context.py b/src/idpyoidc/client/service_context.py index dc721d19..813caff0 100644 --- a/src/idpyoidc/client/service_context.py +++ b/src/idpyoidc/client/service_context.py @@ -27,6 +27,7 @@ from .work_environment.transform import preferred_to_registered from .work_environment.transform import supported_to_preferred from ..impexp import ImpExp +from ..node import Unit logger = logging.getLogger(__name__) @@ -77,7 +78,7 @@ } -class ServiceContext(ImpExp): +class ServiceContext(Unit): """ This class keeps information that a client needs to be able to talk to a server. Some of this information comes from configuration and some @@ -118,6 +119,7 @@ def __init__(self, cstate: Optional[Current] = None, upstream_get: Optional[Callable] = None, client_type: Optional[str] = 'oauth2', + keyjar: Optional[KeyJar] = None, **kwargs): ImpExp.__init__(self) config = get_configuration(config) @@ -272,8 +274,8 @@ def collect_usage(self): def supports(self): res = {} - if self.client_get: - services = self.client_get('services') + if self.upstream_get: + services = self.upstream_get('services') for service in services.values(): res.update(service.supports()) res.update(self.work_environment.supports()) @@ -294,9 +296,16 @@ def get_usage(self, claim, default: Optional[str] = None): def set_usage(self, claim, value): return self.work_environment.set_usage(claim, value) + def get_keyjar(self): + val = getattr(self, 'keyjar', None) + if not val: + return self.upstream_get('attribute', 'keyjar') + else: + return val + def _callback_per_service(self): _cb = {} - for service in self.client_get('services').values(): + for service in self.upstream_get('services').values(): _cbs = service._callback_path.keys() if _cbs: _cb[service.service_name] = _cbs @@ -313,8 +322,8 @@ def construct_uris(self, response_types: Optional[list] = None): _base_url = self.get("base_url") _callback_uris = self.get_preference('callback_uris', {}) - if self.client_get: - services = self.client_get('services') + if self.upstream_get: + services = self.upstream_get('services') for service in services.values(): _callback_uris.update(service.construct_uris(base_url=_base_url, hex=_hex, context=self, @@ -331,7 +340,7 @@ def prefer_or_support(self, claim): if claim in self.work_environment.prefer: return 'prefer' else: - for service in self.client_get('services').values(): + for service in self.upstream_get('services').values(): _res = service.prefer_or_support(claim) if _res: return _res diff --git a/src/idpyoidc/client/work_environment/__init__.py b/src/idpyoidc/client/work_environment/__init__.py index 7eba9d42..082effd5 100644 --- a/src/idpyoidc/client/work_environment/__init__.py +++ b/src/idpyoidc/client/work_environment/__init__.py @@ -2,11 +2,6 @@ from cryptojwt.jwk.hmac import SYMKey from idpyoidc import work_environment -from idpyoidc.client.client_auth import CLIENT_AUTHN_METHOD - - -def get_client_authn_methods(): - return list(CLIENT_AUTHN_METHOD.keys()) class WorkEnvironment(work_environment.WorkEnvironment): @@ -24,8 +19,22 @@ def get_id(self, configuration: dict): def add_extra_keys(self, keyjar, id): _secret = self.get_preference('client_secret') if _secret: - keyjar.add_symmetric(issuer_id=id, key=_secret) - keyjar.add_symmetric(issuer_id='', key=_secret) + _new = SYMKey(key=_secret) + try: + _id_keys = keyjar.get_issuer_keys(id) + except IssuerNotFound: + keyjar.add_symmetric(issuer_id=id, key=_secret) + else: + if _new not in _id_keys: + keyjar.add_symmetric(issuer_id=id, key=_secret) + + try: + _own_keys = keyjar.get_issuer_keys('') + except IssuerNotFound: + keyjar.add_symmetric(issuer_id='', key=_secret) + else: + if _new not in _own_keys: + keyjar.add_symmetric(issuer_id='', key=_secret) def get_jwks(self, keyjar): _jwks = None diff --git a/src/idpyoidc/client/work_environment/oauth2.py b/src/idpyoidc/client/work_environment/oauth2.py index 71fedde9..8212151b 100644 --- a/src/idpyoidc/client/work_environment/oauth2.py +++ b/src/idpyoidc/client/work_environment/oauth2.py @@ -1,10 +1,12 @@ from typing import Optional from idpyoidc.client import work_environment +# from idpyoidc.client.client_auth import get_client_authn_methods class WorkEnvironment(work_environment.WorkEnvironment): _supports = { + # "client_authn_methods": get_client_authn_methods, "redirect_uris": None, "grant_types": ["authorization_code", "implicit", "refresh_token"], "response_types": ["code"], diff --git a/src/idpyoidc/client/work_environment/oidc.py b/src/idpyoidc/client/work_environment/oidc.py index f7996148..dae2b4b8 100644 --- a/src/idpyoidc/client/work_environment/oidc.py +++ b/src/idpyoidc/client/work_environment/oidc.py @@ -3,6 +3,8 @@ from idpyoidc import work_environment from idpyoidc.client import work_environment as client_work_environment +# from idpyoidc.client.client_auth import get_client_authn_methods + class WorkEnvironment(client_work_environment.WorkEnvironment): parameter = work_environment.WorkEnvironment.parameter.copy() @@ -11,32 +13,32 @@ class WorkEnvironment(client_work_environment.WorkEnvironment): }) _supports = { - "grant_types_supported": ["authorization_code", "implicit", "refresh_token"], - "id_token_signing_alg_values_supported": work_environment.get_signing_algs, - "id_token_encryption_alg_values_supported": work_environment.get_encryption_algs, - "id_token_encryption_enc_values_supported": work_environment.get_encryption_encs, "acr_values_supported": None, - "subject_types_supported": ["public", "pairwise", "ephemeral"], "application_type": "web", - "contacts": None, + "callback_uris": None, + # "client_authn_methods": get_client_authn_methods, + "client_id": None, "client_name": None, - "logo_uri": None, + "client_secret": None, "client_uri": None, - "policy_uri": None, - "tos_uri": None, + "contacts": None, + "default_max_age": 86400, + "encrypt_id_token_supported": None, + "grant_types_supported": ["authorization_code", "implicit", "refresh_token"], + "logo_uri": None, + "id_token_signing_alg_values_supported": work_environment.get_signing_algs, + "id_token_encryption_alg_values_supported": work_environment.get_encryption_algs, + "id_token_encryption_enc_values_supported": work_environment.get_encryption_encs, + "initiate_login_uri": None, "jwks": None, "jwks_uri": None, - "sector_identifier_uri": None, - "default_max_age": 86400, + "policy_uri": None, + "requests_dir": None, "require_auth_time": None, - "initiate_login_uri": None, - "client_id": None, - "client_secret": None, + "sector_identifier_uri": None, "scopes_supported": ["openid"], - # "verify_args": None, - "requests_dir": None, - "encrypt_id_token_supported": None, - "callback_uris": None + "subject_types_supported": ["public", "pairwise", "ephemeral"], + "tos_uri": None, } def __init__(self, diff --git a/src/idpyoidc/context.py b/src/idpyoidc/context.py index de558b2b..9b17f2b5 100644 --- a/src/idpyoidc/context.py +++ b/src/idpyoidc/context.py @@ -23,8 +23,19 @@ def add_issuer(conf, issuer): class OidcContext(ImpExp): - parameter = {"issuer": None} + parameter = {"entity_id": None} def __init__(self, config=None, entity_id=""): ImpExp.__init__(self) - self.entity_id = entity_id or config.get('client_id') + if entity_id: + self.entity_id = entity_id + else: + if config: + val = '' + for alt in ['client_id', 'issuer', 'entity_id']: + val = config.get(alt) + if val: + break + self.entity_id = val + else: + self.entity_id = '' diff --git a/src/idpyoidc/node.py b/src/idpyoidc/node.py index 9dcafdb7..df1051d6 100644 --- a/src/idpyoidc/node.py +++ b/src/idpyoidc/node.py @@ -71,8 +71,11 @@ def __init__(self, self.keyjar.add_symmetric('', client_id) else: if client_id: - self.keyjar = KeyJar() - self.keyjar.add_symmetric('', client_id) + _key = config.get("client_secret") + if _key: + self.keyjar = KeyJar() + self.keyjar.add_symmetric(client_id, _key) + self.keyjar.add_symmetric('', _key) else: self.keyjar = None diff --git a/src/idpyoidc/server/__init__.py b/src/idpyoidc/server/__init__.py index 4877e07e..45b36c3d 100644 --- a/src/idpyoidc/server/__init__.py +++ b/src/idpyoidc/server/__init__.py @@ -166,6 +166,6 @@ def setup_login_hint_lookup(self): self.endpoint_context.login_hint_lookup.userinfo = _userinfo def setup_client_authn_methods(self): - self.endpoint_context.client_authn_method = client_auth_setup( + self.endpoint_context.client_authn_methods = client_auth_setup( self.unit_get, self.conf.get("client_authn_methods") ) diff --git a/src/idpyoidc/server/util.py b/src/idpyoidc/server/util.py index eea2579d..ef481578 100755 --- a/src/idpyoidc/server/util.py +++ b/src/idpyoidc/server/util.py @@ -170,33 +170,3 @@ def execute(spec, **kwargs): else: return kwargs -# def sector_id_from_redirect_uris(uris): -# if not uris: -# return "" -# -# _parts = urlparse(uris[0]) -# hostname = _parts.netloc -# scheme = _parts.scheme -# for uri in uris[1:]: -# parsed = urlparse(uri) -# if scheme != parsed.scheme or hostname != parsed.netloc: -# raise ValueError( -# "All redirect_uris must have the same hostname in order to generate sector_id." -# ) -# -# return urlunsplit((scheme, hostname, "", "", "")) - - -# def get_logout_id(context, user_id, client_id): -# _item = NodeInfo() -# _item.user_id = user_id -# _item.client_id = client_id -# -# # Note that this session ID is not the session ID the session manager is using. -# # It must be possible to map from one to the other. -# logout_session_id = uuid.uuid4().hex -# # Store the map -# _mngr = context.session_manager -# _mngr.set([logout_session_id], _item) -# -# return logout_session_id diff --git a/src/idpyoidc/server/work_environment/oauth2.py b/src/idpyoidc/server/work_environment/oauth2.py index 42ab7579..783251a8 100644 --- a/src/idpyoidc/server/work_environment/oauth2.py +++ b/src/idpyoidc/server/work_environment/oauth2.py @@ -1,6 +1,7 @@ from typing import Optional from idpyoidc.server import work_environment +# from idpyoidc.server.client_authn import get_client_authn_methods class WorkEnvironment(work_environment.WorkEnvironment): @@ -11,6 +12,7 @@ class WorkEnvironment(work_environment.WorkEnvironment): # 'ui_locales_supported', 'op_policy_uri', 'op_tos_uri', 'revocation_endpoint', # 'introspection_endpoint' _supports = { + # "client_authn_methods": get_client_authn_methods, "grant_types_supported": ["authorization_code", "implicit", "refresh_token"], "response_types_supported": ["code"], "response_modes_supported": ["code"], diff --git a/src/idpyoidc/server/work_environment/oidc.py b/src/idpyoidc/server/work_environment/oidc.py index c776da38..071b4f95 100644 --- a/src/idpyoidc/server/work_environment/oidc.py +++ b/src/idpyoidc/server/work_environment/oidc.py @@ -2,6 +2,7 @@ from idpyoidc import work_environment as WE from idpyoidc.server import work_environment +# from idpyoidc.server.client_authn import get_client_authn_methods class WorkEnvironment(work_environment.WorkEnvironment): @@ -12,6 +13,7 @@ class WorkEnvironment(work_environment.WorkEnvironment): "claim_types_supported": None, "claims_locales_supported": None, "claims_supported": None, + # "client_authn_methods": get_client_authn_methods, "contacts": None, "default_max_age": 86400, "display_values_supported": None, diff --git a/tests/request123456.jwt b/tests/request123456.jwt index 7a9db987..5f373e53 100644 --- a/tests/request123456.jwt +++ b/tests/request123456.jwt @@ -1 +1 @@ -eyJhbGciOiJSUzI1NiIsImtpZCI6IlNIRXlZV2N3TlZrMExUZFJPVFp6WjJGVVduZElWWGRhY2sweFdVTTVTRXB3Y1MwM2RWVXhXVTR6UlEifQ.eyJyZXNwb25zZV90eXBlIjogImNvZGUiLCAic3RhdGUiOiAic3RhdGUiLCAicmVkaXJlY3RfdXJpIjogImh0dHBzOi8vZXhhbXBsZS5jb20vY2xpL2F1dGh6X2NiIiwgInNjb3BlIjogIm9wZW5pZCIsICJub25jZSI6ICJmMFNXNzRtbzFKSG1NbFAzZUVIWWhKcXZQTm1fZmxwYjBOcTJ0SXYzUXM0IiwgImNsaWVudF9pZCI6ICJjbGllbnRfaWQiLCAiaXNzIjogImNsaWVudF9pZCIsICJpYXQiOiAxNjY1ODIzODQ0LCAiYXVkIjogWyJodHRwczovL2V4YW1wbGUuY29tIl19.iJggLp-EJvEP4ARgGwCFIhLlwikTLV8EEd7D2PX-yW6H9rm261_l-NkKTKmfV_Y2-QLT1X3K0eepI_A1qVAzLzohFSw0OcPPJDRs9IugLxeZ0Ktr9pb29XcCHOU83DD3onIXTfzgihqX_aqUfPt32teD5NTTMmMGuaA700rtJiXzrPXWQmJXDVlStgtjFh4fZI59G3yPUNQqUTm0w_HHsF8IuzIPHFq5FTlixTaX3iu90dm9icXTJtLYxw5uHL7Je2_GxWTmCE9WEOzSI3AaQz-jIsG1RVVBx5WBRngHkcFPITuXCXklOKq_iFbFCcRL-Gt7SsDHqV_zrAm72LaIvg \ No newline at end of file +eyJhbGciOiJSUzI1NiIsImtpZCI6IlNIRXlZV2N3TlZrMExUZFJPVFp6WjJGVVduZElWWGRhY2sweFdVTTVTRXB3Y1MwM2RWVXhXVTR6UlEifQ.eyJyZXNwb25zZV90eXBlIjogImNvZGUiLCAic3RhdGUiOiAic3RhdGUiLCAicmVkaXJlY3RfdXJpIjogImh0dHBzOi8vZXhhbXBsZS5jb20vY2xpL2F1dGh6X2NiIiwgInNjb3BlIjogIm9wZW5pZCIsICJub25jZSI6ICJBVTdNa3Z0cnNJRUxqRTE1dEVUeGx6ck9GdVZPUVRSM2h0ZldLMlcyakN3IiwgImNsaWVudF9pZCI6ICJjbGllbnRfaWQiLCAiaXNzIjogImNsaWVudF9pZCIsICJpYXQiOiAxNjcwNDk2MzA0LCAiYXVkIjogWyJodHRwczovL2V4YW1wbGUuY29tIl19.AJjkd2WenuqnvpKI1rODXmXTK_CvWR7zJ8EVB3y7y_nTK8xajubBQQbXJql1r6r2yzxGC7wXOXQnp-4CNFV45pHyjawxGbA-p-Ko4sdTzebiJDOf-JGPdh0hzWff0oepU0zsL3vqg9L8V534Z4v6ugZDYw1EUZaht5xvRFAUEyxwG6rEf05DRQif01288Zbnc8i5oCLpevCreTlKlo7_jEcJVSKlnmuyTyGpDGENgjt2U3hNb7pFMKOw8J848vq4ukQvDVlD_7qBzt_-VDN_NWIFkeSp2-1e_AbZtsQdXC-gLo9xaTOoS5hG5Eh1-fdzLGdmdb0m4Tz6stlFF_AWbw \ No newline at end of file diff --git a/tests/static/jwks.json b/tests/static/jwks.json index 161a407b..8322d976 100644 --- a/tests/static/jwks.json +++ b/tests/static/jwks.json @@ -1 +1 @@ -{"keys": [{"kty": "RSA", "use": "sig", "kid": "YnNESFhyQjloMnYzV2VqRGR2a3VCblFLX2h4VGl3TDVlY3FUNkViUE90bw", "e": "AQAB", "n": "2iMaDALTQolz4UaT--GhjriLMyNbrDGlIXxSmgRh17Cm3cuHiyPOIQv1pjZVg4ATU1aafxmFyTfrmtf56tPuJ8yqcNNZC8XadYPAw7PTW9g8GJgLtC8GURJ9GQZD6FYIE6YCou8fYo6yd4b99y2y_vsl06cm9xQnstfp6eyMkcgQyrmdmlbyeuXwvcxsxtGX61MTJtCp4VELmDctJiYP_bD7HNRPV7uqXDMNmWSY0TYL-tg0As4y8-w3wSwmtcfWhnQEraFT0-m4hBpEWHlouuFNXRQIrXbamKxeh6kJNO0wJN8fZ4Ovygf8sE4kEwBPfWO59wxDF7camTpDUqg29Q"}, {"kty": "EC", "use": "sig", "kid": "aWhtalRSTDZmNmRTd1ZDNWZmY3ZGMTNqM1dnLVA2RjQyMi1CNGdOSUNKVQ", "crv": "P-256", "x": "Ww5XVT3CxYN88BpJDZGodRiar0qr8UvPFaRoqzyD1Io", "y": "w23EDFAvwe03NjL5NKtUXwxuVMFmEn3ecJOPbljiDkg"}]} \ No newline at end of file +{"keys": [{"kty": "RSA", "use": "sig", "kid": "YnNESFhyQjloMnYzV2VqRGR2a3VCblFLX2h4VGl3TDVlY3FUNkViUE90bw", "n": "2iMaDALTQolz4UaT--GhjriLMyNbrDGlIXxSmgRh17Cm3cuHiyPOIQv1pjZVg4ATU1aafxmFyTfrmtf56tPuJ8yqcNNZC8XadYPAw7PTW9g8GJgLtC8GURJ9GQZD6FYIE6YCou8fYo6yd4b99y2y_vsl06cm9xQnstfp6eyMkcgQyrmdmlbyeuXwvcxsxtGX61MTJtCp4VELmDctJiYP_bD7HNRPV7uqXDMNmWSY0TYL-tg0As4y8-w3wSwmtcfWhnQEraFT0-m4hBpEWHlouuFNXRQIrXbamKxeh6kJNO0wJN8fZ4Ovygf8sE4kEwBPfWO59wxDF7camTpDUqg29Q", "e": "AQAB"}, {"kty": "EC", "use": "sig", "kid": "aWhtalRSTDZmNmRTd1ZDNWZmY3ZGMTNqM1dnLVA2RjQyMi1CNGdOSUNKVQ", "crv": "P-256", "x": "Ww5XVT3CxYN88BpJDZGodRiar0qr8UvPFaRoqzyD1Io", "y": "w23EDFAvwe03NjL5NKtUXwxuVMFmEn3ecJOPbljiDkg"}]} \ No newline at end of file diff --git a/tests/test_12_context.py b/tests/test_12_context.py index 6ee98240..2448a86a 100644 --- a/tests/test_12_context.py +++ b/tests/test_12_context.py @@ -9,7 +9,7 @@ def test_context_with_entity_id(self): c = OidcContext({}, entity_id=ENTITY_ID) mem = c.dump() c2 = OidcContext().load(mem) - assert c2.issuer == ENTITY_ID + assert c2.entity_id == ENTITY_ID def test_context_with_entity_id_and_keys(self): c = OidcContext({"entity_id": ENTITY_ID}) diff --git a/tests/test_client_02_entity.py b/tests/test_client_02_entity.py index 7ea1199c..c1fe9163 100644 --- a/tests/test_client_02_entity.py +++ b/tests/test_client_02_entity.py @@ -1,5 +1,6 @@ import pytest +from idpyoidc.client.client_auth import ClientAuthnMethod from idpyoidc.client.entity import Entity KEYDEFS = [ @@ -37,7 +38,7 @@ def test_get_service_unsupported(self): def test_get_client_id(self): assert self.entity.get_service_context().get_preference("client_id") == "Number5" - assert self.entity.client_get("client_id") == "Number5" + assert self.entity.get_attribute("client_id") == "Number5" def test_get_service_by_endpoint_name(self): _srv = self.entity.get_service("") @@ -48,3 +49,126 @@ def test_get_service_by_endpoint_name(self): def test_get_service_context(self): _context = self.entity.get_service_context() assert _context + + +RP_BASEURL = "https://example.com/rp" +KEYSPEC = [ + {"type": "RSA", "use": ["sig"]}, + {"type": "EC", "crv": "P-256", "use": ["sig"]}, +] + + +def test_client_authn_default(): + config = { + "application_type": "web", + "contacts": ["ops@example.org"], + "redirect_uris": [f"{RP_BASEURL}/authz_cb"], + "keys": {"key_defs": KEYSPEC, "read_only": True}, + } + + entity = Entity(config=config, client_type='oidc') + + assert entity.get_context().client_authn_methods == {} + + +def test_client_authn_by_names(): + config = { + "application_type": "web", + "contacts": ["ops@example.org"], + "redirect_uris": [f"{RP_BASEURL}/authz_cb"], + "keys": {"key_defs": KEYSPEC, "read_only": True}, + "client_authn_methods": ['client_secret_basic', 'client_secret_post'] + } + + entity = Entity(config=config, client_type='oidc') + + assert set(entity.get_context().client_authn_methods.keys()) == {'client_secret_basic', + 'client_secret_post'} + + +class FooBar(ClientAuthnMethod): + def __init__(self, **kwargs): + self.kwargs = kwargs + + def modify_request(self, request, service, **kwargs): + request.update(self.kwargs) + + +def test_client_authn_full(): + config = { + "application_type": "web", + "contacts": ["ops@example.org"], + "redirect_uris": [f"{RP_BASEURL}/authz_cb"], + "keys": {"key_defs": KEYSPEC, "read_only": True}, + "client_authn_methods": { + 'client_secret_basic': {}, + 'client_secret_post': None, + 'home_brew': { + 'class': FooBar, + 'kwargs': {'one': 'bar'} + } + } + } + + entity = Entity(config=config, client_type='oidc') + + assert set(entity.get_context().client_authn_methods.keys()) == {'client_secret_basic', + 'client_secret_post', + 'home_brew'} + + +def test_service_specific(): + config = { + "application_type": "web", + "contacts": ["ops@example.org"], + "redirect_uris": [f"{RP_BASEURL}/authz_cb"], + "keys": {"key_defs": KEYSPEC, "read_only": True}, + "client_authn_methods": ['client_secret_basic', 'client_secret_post'] + } + + entity = Entity(config=config, client_type='oidc', + services={ + "xyz": { + "class": "idpyoidc.client.service.Service", + "kwargs": { + "client_authn_methods": ['private_key_jwt'] + } + } + }) + + # A specific does not change the general + assert set(entity.get_context().client_authn_methods.keys()) == {'client_secret_basic', + 'client_secret_post'} + + assert set(entity.get_service('').client_authn_methods.keys()) == {'private_key_jwt'} + + +def test_service_specific2(): + config = { + "application_type": "web", + "contacts": ["ops@example.org"], + "redirect_uris": [f"{RP_BASEURL}/authz_cb"], + "keys": {"key_defs": KEYSPEC, "read_only": True}, + "client_authn_methods": ['client_secret_basic', 'client_secret_post'] + } + + entity = Entity(config=config, client_type='oidc', + services={ + "xyz": { + "class": "idpyoidc.client.service.Service", + "kwargs": { + "client_authn_methods": { + 'home_brew': { + 'class': FooBar, + 'kwargs': {'one': 'bar'} + } + } + } + } + }) + + # A specific does not change the general + assert set(entity.get_context().client_authn_methods.keys()) == {'client_secret_basic', + 'client_secret_post'} + + assert set(entity.get_service('').client_authn_methods.keys()) == {'home_brew'} diff --git a/tests/test_client_02b_entity_metadata.py b/tests/test_client_02b_entity_metadata.py index 8c9bb5e7..2bd8412b 100644 --- a/tests/test_client_02b_entity_metadata.py +++ b/tests/test_client_02b_entity_metadata.py @@ -23,7 +23,8 @@ "userinfo_signing_alg_values_supported": ["ES256"], "post_logout_redirect_uris": ["https://rp.example.com/post"], "backchannel_logout_uri": "https://rp.example.com/back", - "backchannel_logout_session_required": True + "backchannel_logout_session_required": True, + "client_authn_methods": ['bearer_header'] }, "services": { @@ -64,8 +65,9 @@ def test_create_client(): client = Entity(config=CLIENT_CONFIG, client_type='oidc') - client.get_service_context().map_supported_to_preferred() - _pref = client.prefers() + _context = client.get_context() + _context.map_supported_to_preferred() + _pref = _context.prefers() assert set(_pref.keys()) == {'application_type', 'backchannel_logout_session_required', 'backchannel_logout_uri', @@ -110,27 +112,28 @@ def test_create_client(): rr = set(RegistrationRequest.c_param.keys()) # The ones that are not defined d = rr.difference(set(_conf_args)) - assert d == {'client_name', - 'client_uri', - 'default_acr_values', - 'frontchannel_logout_session_required', - 'frontchannel_logout_uri', - 'id_token_encrypted_response_alg', - 'id_token_encrypted_response_enc', - 'initiate_login_uri', - 'jwks', - 'jwks_uri', - 'logo_uri', - 'policy_uri', - 'post_logout_redirect_uri', - 'request_object_encryption_alg', - 'request_object_encryption_enc', - 'request_uris', - 'require_auth_time', - 'sector_identifier_uri', - 'tos_uri', - 'userinfo_encrypted_response_alg', - 'userinfo_encrypted_response_enc'} + assert d == { + 'client_name', + 'client_uri', + 'default_acr_values', + 'frontchannel_logout_session_required', + 'frontchannel_logout_uri', + 'id_token_encrypted_response_alg', + 'id_token_encrypted_response_enc', + 'initiate_login_uri', + 'jwks', + 'jwks_uri', + 'logo_uri', + 'policy_uri', + 'post_logout_redirect_uri', + 'request_object_encryption_alg', + 'request_object_encryption_enc', + 'request_uris', + 'require_auth_time', + 'sector_identifier_uri', + 'tos_uri', + 'userinfo_encrypted_response_alg', + 'userinfo_encrypted_response_enc'} def test_create_client_key_conf(): diff --git a/tests/test_client_04_service.py b/tests/test_client_04_service.py index d2778a7a..121e77b8 100644 --- a/tests/test_client_04_service.py +++ b/tests/test_client_04_service.py @@ -115,7 +115,7 @@ def test_parse_response_json(self): self.service_context.issuer = "https://op.example.com/" self.service_context.client_id = "client" - _sign_key = self.service.upstream_get('attribute','keyjar').get_signing_key() + _sign_key = self.service.upstream_get('context').keyjar.get_signing_key() resp1 = AuthorizationResponse(code="auth_grant", state="state").to_json() arg = self.service.parse_response(resp1) assert isinstance(arg, AuthorizationResponse) @@ -127,7 +127,7 @@ def test_parse_response_jwt(self): self.service_context.issuer = "https://op.example.com/" self.service_context.client_id = "client" - _sign_key = self.service.upstream_get('attribute','keyjar').get_signing_key() + _sign_key = self.service.upstream_get('context').keyjar.get_signing_key() resp1 = AuthorizationResponse(code="auth_grant", state="state").to_jwt( key=_sign_key, algorithm="RS256" ) @@ -141,7 +141,7 @@ def test_parse_response_err(self): self.service_context.issuer = "https://op.example.com/" self.service_context.client_id = "client" - _sign_key = self.service.upstream_get('attribute','keyjar').get_signing_key() + _sign_key = self.service.upstream_get('context').keyjar.get_signing_key() resp1 = AuthorizationResponse(code="auth_grant", state="state").to_jwt( key=_sign_key, algorithm="RS256" ) diff --git a/tests/test_client_06_client_authn.py b/tests/test_client_06_client_authn.py index 321d0203..c1dadfe1 100644 --- a/tests/test_client_06_client_authn.py +++ b/tests/test_client_06_client_authn.py @@ -211,8 +211,8 @@ def test_construct(self, entity): assert http_args is None def test_construct_with_state(self, entity): - _auth_service = entity.upstream_get("service", "") - _cntx = _auth_service.upstream_get("service_context") + _auth_service = entity.get_service("accesstoken") + _cntx = _auth_service.upstream_get("context") _key = _cntx.cstate.create_key() _cntx.cstate.set(_key, {'iss': "Issuer"}) @@ -260,7 +260,7 @@ def test_construct_with_request(self, entity): class TestClientSecretPost(object): def test_construct(self, entity): - _token_service = entity.upstream_get("service", "") + _token_service = entity.get_service("") request = _token_service.construct(request_args={'redirect_uri': "http://example.com", 'state': "ABCDE"}) csp = ClientSecretPost() @@ -277,7 +277,7 @@ def test_construct(self, entity): assert http_args is None def test_modify_1(self, entity): - token_service = entity.upstream_get("service", "") + token_service = entity.get_service("") request = token_service.construct(request_args={'redirect_uri': "http://example.com", 'state': "ABCDE"}) csp = ClientSecretPost() @@ -285,7 +285,7 @@ def test_modify_1(self, entity): assert "client_secret" in request def test_modify_2(self, entity): - _service = entity.upstream_get("service", "") + _service = entity.get_service("") request = _service.construct(request_args={'redirect_uri': "http://example.com", 'state': "ABCDE"}) csp = ClientSecretPost() @@ -308,7 +308,7 @@ def test_construct(self, entity): key.add_kid() _context = token_service.upstream_get('context') - token_service.upstream_get('attribute', 'keyjar').add_kb("", kb_rsa) + _context.get_keyjar().add_kb("", kb_rsa) _context.provider_info = { "issuer": "https://example.com/", "token_endpoint": "https://example.com/token", @@ -404,11 +404,15 @@ def test_get_key_by_kid(self, entity): csj = ClientSecretJWT() request = AccessTokenRequest() - # get a kid - _keys = entity.get_attribute('keyjar').get_issuer_keys("") - kid = _keys[0].kid + # get a kid for a symmetric key + kid = '' + for _key in entity.get_attribute('keyjar').get_issuer_keys(""): + if _key.kty == 'oct': + kid = _key.kid + break + # token_service = entity.get_service("") - token_service = entity.upstream_get("service", "accesstoken") + token_service = entity.get_service("accesstoken") csj.construct(request, service=token_service, authn_endpoint="token_endpoint", kid=kid) assert "client_assertion" in request diff --git a/tests/test_client_13_service_context.py b/tests/test_client_13_service_context.py deleted file mode 100644 index 5b3ceef8..00000000 --- a/tests/test_client_13_service_context.py +++ /dev/null @@ -1,254 +0,0 @@ -import os -from urllib.parse import urlsplit - -import pytest -import responses -from cryptojwt.key_jar import build_keyjar - -from idpyoidc.client.entity import Entity -from idpyoidc.client.service_context import ServiceContext - -BASE_URL = "https://entity.example.org" - - -def test_client_info_init(): - config = { - "client_id": "client_id", - "issuer": "issuer", - "client_secret": "client_secret_wordplay", - "base_url": "https://example.com", - "requests_dir": "requests", - } - entity = Entity(entity_id=BASE_URL, config=config) - entity_copy = Entity().load(entity.dump()) - - srvcnx = entity_copy.get_context() - - for attr in config.keys(): - try: - val = getattr(srvcnx, attr) - except AttributeError: - val = srvcnx.get(attr) - - assert val == config[attr] - - -def test_set_and_get_client_secret(): - service_context = ServiceContext() - service_context.client_secret = "longenoughsupersecret" - assert service_context.client_secret == "longenoughsupersecret" - - -def test_set_and_get_client_id(): - ci = ServiceContext() - ci.client_id = "myself" - assert ci.client_id == "myself" - - -def test_client_filename(): - config = { - "client_id": "client_id", - "issuer": "issuer", - "client_secret": "longenoughsupersecret", - "base_url": "https://example.com", - "requests_dir": "requests", - } - entity = Entity(config=config) - fname = entity.get_context().filename_from_webname("https://example.com/rq12345") - assert fname == "rq12345" - - -def verify_alg_support(service_context, alg, usage, typ): - """ - Verifies that the algorithm to be used are supported by the other side. - This will look at provider information either statically configured or - obtained through dynamic provider info discovery. - - :param alg: The algorithm specification - :param usage: In which context the 'alg' will be used. - The following contexts are supported: - - userinfo - - id_token - - request_object - - token_endpoint_auth - :param typ: Type of algorithm - - signing_alg - - encryption_alg - - encryption_enc - :return: True or False - """ - - supported = service_context.provider_info["{}_{}_values_supported".format(usage, typ)] - - if alg in supported: - return True - else: - return False - - -class TestClientInfo(object): - @pytest.fixture(autouse=True) - def create_client_info_instance(self): - config = { - "client_id": "client_id", - "issuer": "issuer", - "client_secret": "longenoughsupersecret", - "base_url": "https://example.com", - "requests_dir": "requests", - } - self.entity = Entity(config=config) - self.service_context = self.entity.get_context() - - def test_registration_userinfo_sign_enc_algs(self): - self.service_context.behaviour = { - "application_type": "web", - "redirect_uris": [ - "https://client.example.org/callback", - "https://client.example.org/callback2", - ], - "token_endpoint_auth_method": "client_secret_basic", - "jwks_uri": "https://client.example.org/my_public_keys.jwks", - "userinfo_encrypted_response_alg": "RSA1_5", - "userinfo_encrypted_response_enc": "A128CBC-HS256", - } - - assert self.service_context.get_sign_alg("userinfo") is None - assert self.service_context.get_enc_alg_enc("userinfo") == { - "alg": "RSA1_5", - "enc": "A128CBC-HS256", - } - - def test_registration_request_object_sign_enc_algs(self): - self.service_context.behaviour = { - "application_type": "web", - "redirect_uris": [ - "https://client.example.org/callback", - "https://client.example.org/callback2", - ], - "token_endpoint_auth_method": "client_secret_basic", - "jwks_uri": "https://client.example.org/my_public_keys.jwks", - "userinfo_encrypted_response_alg": "RSA1_5", - "userinfo_encrypted_response_enc": "A128CBC-HS256", - "request_object_signing_alg": "RS384", - } - - res = self.service_context.get_enc_alg_enc("userinfo") - # 'sign':'RS256' is an added default - assert res == {"alg": "RSA1_5", "enc": "A128CBC-HS256"} - res = self.service_context.get_sign_alg("request_object") - assert res == "RS384" - - def test_registration_id_token_sign_enc_algs(self): - self.service_context.behaviour = { - "application_type": "web", - "redirect_uris": [ - "https://client.example.org/callback", - "https://client.example.org/callback2", - ], - "token_endpoint_auth_method": "client_secret_basic", - "jwks_uri": "https://client.example.org/my_public_keys.jwks", - "userinfo_encrypted_response_alg": "RSA1_5", - "userinfo_encrypted_response_enc": "A128CBC-HS256", - "request_object_signing_alg": "RS384", - "id_token_encrypted_response_alg": "ECDH-ES", - "id_token_encrypted_response_enc": "A128GCM", - "id_token_signed_response_alg": "ES384", - } - - res = self.service_context.get_enc_alg_enc("userinfo") - # 'sign':'RS256' is an added default - assert res == {"alg": "RSA1_5", "enc": "A128CBC-HS256"} - res = self.service_context.get_sign_alg("request_object") - assert res == "RS384" - res = self.service_context.get_enc_alg_enc("id_token") - assert res == {"alg": "ECDH-ES", "enc": "A128GCM"} - - def test_verify_alg_support(self): - self.service_context.provider_info = { - "version": "3.0", - "issuer": "https://server.example.com", - "authorization_endpoint": "https://server.example.com/connect/authorize", - "token_endpoint": "https://server.example.com/connect/token", - "token_endpoint_auth_methods_supported": ["client_secret_basic", "private_key_jwt"], - "token_endpoint_auth_signing_alg_values_supported": ["RS256", "ES256"], - "userinfo_endpoint": "https://server.example.com/connect/userinfo", - "check_session_iframe": "https://server.example.com/connect/check_session", - "end_session_endpoint": "https://server.example.com/connect/end_session", - "jwks_uri": "https://server.example.com/jwks.json", - "registration_endpoint": "https://server.example.com/connect/register", - "scopes_supported": [ - "openid", - "profile", - "email", - "address", - "phone", - "offline_access", - ], - "response_types_supported": ["code", "code id_token", "id_token", "token id_token"], - "acr_values_supported": [ - "urn:mace:incommon:iap:silver", - "urn:mace:incommon:iap:bronze", - ], - "subject_types_supported": ["public", "pairwise"], - "userinfo_signing_alg_values_supported": ["RS256", "ES256", "HS256"], - "userinfo_encryption_alg_values_supported": ["RSA1_5", "A128KW"], - "userinfo_encryption_enc_values_supported": ["A128CBC+HS256", "A128GCM"], - "id_token_signing_alg_values_supported": ["RS256", "ES256", "HS256"], - "id_token_encryption_alg_values_supported": ["RSA1_5", "A128KW"], - "id_token_encryption_enc_values_supported": ["A128CBC+HS256", "A128GCM"], - "request_object_signing_alg_values_supported": ["none", "RS256", "ES256"], - "display_values_supported": ["page", "popup"], - "claim_types_supported": ["normal", "distributed"], - "claims_supported": [ - "sub", - "iss", - "auth_time", - "acr", - "name", - "given_name", - "family_name", - "nickname", - "profile", - "picture", - "website", - "email", - "email_verified", - "locale", - "zoneinfo", - "http://example.info/claims/groups", - ], - "claims_parameter_supported": True, - "service_documentation": "http://server.example.com/connect/service_documentation.html", - "ui_locales_supported": ["en-US", "en-GB", "en-CA", "fr-FR", "fr-CA"], - } - - assert verify_alg_support(self.service_context, "RS256", "id_token", "signing_alg") - assert verify_alg_support(self.service_context, "RS512", "id_token", "signing_alg") is False - - assert verify_alg_support(self.service_context, "RSA1_5", "userinfo", "encryption_alg") - - # token_endpoint_auth_signing_alg_values_supported - assert verify_alg_support( - self.service_context, "ES256", "token_endpoint_auth", "signing_alg" - ) - - def test_verify_requests_uri(self): - self.service_context.provider_info = {"issuer": "https://example.com/"} - url_list = self.service_context.generate_redirect_uris("/leading") - sp = urlsplit(url_list[0]) - p = sp.path.split("/") - assert p[0] == "" - assert p[1] == "leading" - assert len(p) == 3 - - # different for different OPs - self.service_context.provider_info = {"issuer": "https://op.example.org/"} - url_list = self.service_context.generate_redirect_uris("/leading") - sp = urlsplit(url_list[0]) - np = sp.path.split("/") - assert np[0] == "" - assert np[1] == "leading" - assert len(np) == 3 - - assert np[2] != p[2] - diff --git a/tests/test_client_20_oauth2.py b/tests/test_client_20_oauth2.py index 2aedc61a..b9d67fb0 100644 --- a/tests/test_client_20_oauth2.py +++ b/tests/test_client_20_oauth2.py @@ -65,7 +65,7 @@ def test_construct_authorization_request(self): "response_type": ["code"], } - self.client.get_context.cstate.set("ABCDE", {"iss": 'issuer'}) + self.client.get_context().cstate.set("ABCDE", {"iss": 'issuer'}) msg = self.client.get_service("authorization").construct(request_args=req_args) assert isinstance(msg, AuthorizationRequest) assert msg["client_id"] == "client_1" @@ -183,7 +183,8 @@ def create_client(self): "read_only": False, }, "clients": { - "client_1": { + "service_1": { + "client_id": "client_1", "client_secret": "abcdefghijklmnop", "redirect_uris": ["https://example.com/cli/authz_cb"], } @@ -191,17 +192,11 @@ def create_client(self): } rp_conf = RPHConfiguration(conf) rp_handler = RPHandler(base_url=BASE_URL, config=rp_conf) - self.client = rp_handler.init_client(issuer="client_1") + self.client = rp_handler.init_client(issuer="service_1") assert self.client def test_keyjar(self): - req_args = { - "state": "ABCDE", - "redirect_uri": "https://example.com/auth_cb", - "response_type": ["code"], - } - _keyjar = self.client.get_attribute('keyjar') - assert len(_keyjar) == 1 # one issuer + assert len(_keyjar) == 2 # one issuer assert len(_keyjar[""]) == 2 assert len(_keyjar.get("sig")) == 2 diff --git a/tests/test_client_21_oidc_service.py b/tests/test_client_21_oidc_service.py index e9ae9c1e..5e36f797 100644 --- a/tests/test_client_21_oidc_service.py +++ b/tests/test_client_21_oidc_service.py @@ -1,4 +1,3 @@ -import json import os from cryptojwt.exception import UnsupportedAlgorithm @@ -399,9 +398,12 @@ def create_request(self): "client_id": "client_id", "client_secret": "a longesh password", "redirect_uris": ["https://example.com/cli/authz_cb"], + 'client_authn_methods': ['client_secret_basic'] } entity = Entity(keyjar=make_keyjar(), config=client_config, services=DEFAULT_OIDC_SERVICES) - entity.get_context().issuer = "https://example.com" + _context = entity.get_context() + _context.issuer = "https://example.com" + _context.provider_info = {'token_endpoint': f'{_context.issuer}/token'} self.service = entity.get_service("accesstoken") # add some history @@ -455,16 +457,10 @@ def test_request_init(self): assert set(_info.keys()) == {"body", "url", "headers", "method", "request"} assert _info["url"] == "https://example.com/authorize" msg = AccessTokenRequest().from_urlencoded(self.service.get_urlinfo(_info["body"])) - assert msg.to_dict() == { - "client_id": "client_id", - "code": "access_code", - "grant_type": "authorization_code", - "state": "state", - "redirect_uri": "https://example.com/cli/authz_cb", - } + assert set(msg.keys()) == {'redirect_uri', 'grant_type', 'state', 'code', 'client_id'} def test_id_token_nonce_match(self): - _cstate = self.service.get_context().cstate + _cstate = self.service.upstream_get("context").cstate _cstate.bind_key("nonce", "state") resp = AccessTokenResponse() resp[verified_claim_name("id_token")] = {"nonce": "nonce"} @@ -729,7 +725,7 @@ def test_post_parse(self): "registration_endpoint": "{}/registration".format(OP_BASEURL), "end_session_endpoint": "{}/end_session".format(OP_BASEURL), } - _context = self.service.get_context() + _context = self.service.upstream_get("context") assert _context.work_environment.use == {} resp = self.service.post_parse_response(provider_info_response) @@ -917,10 +913,10 @@ def test_config_with_required_request_uri(): client_type='oidc') entity.get_context().issuer = "https://example.com" - pi_service = entity.client_get("service", "provider_info") + pi_service = entity.get_service("provider_info") pi_service.match_preferences({"require_request_uri_registration": True}) - reg_service = entity.client_get("service", "registration") + reg_service = entity.get_service("registration") _req = reg_service.construct() assert isinstance(_req, RegistrationRequest) assert set(_req.keys()) == {"application_type", "response_types", "jwks", @@ -956,11 +952,11 @@ def test_config_logout_uri(): _context = entity.get_context() _context.issuer = "https://example.com" - pi_service = entity.client_get("service", "provider_info") + pi_service = entity.get_service("provider_info") _pi = {"require_request_uri_registration": True, "backchannel_logout_supported": True} pi_service.match_preferences(_pi) - reg_service = entity.client_get("service", "registration") + reg_service = entity.get_service("registration") _req = reg_service.construct() assert isinstance(_req, RegistrationRequest) assert set(_req.keys()) == {'application_type', @@ -1002,7 +998,7 @@ def create_request(self): "userinfo_encrypted_response_enc": "A256GCM", } - _cstate = self.service.get_context().cstate + _cstate = self.service.upstream_get("context").cstate # Add history auth_response = AuthorizationResponse(code="access_code") _cstate.update("abcde", auth_response) @@ -1104,7 +1100,7 @@ def test_unpack_encrypted_response(self): # Add encryption key _kj = build_keyjar([{"type": "RSA", "use": ["enc"]}], issuer_id="") # Own key jar gets the private key - self.service.upstream_get("attribute",'keyjar').import_jwks( + self.service.upstream_get("attribute", 'keyjar').import_jwks( _kj.export_jwks(private=True), issuer_id="client_id" ) # opponent gets the public key @@ -1119,7 +1115,7 @@ def test_unpack_encrypted_response(self): ) enc_resp = resp.to_jwe(enckey, **algspec) - _resp = self.service.parse_response(enc_resp, state="abcde", sformat="jwt") + _resp = self.service.parse_response(enc_resp, state="abcde", sformat="jwe") assert _resp @@ -1215,7 +1211,7 @@ def test_authz_service_conf(): "client_id": "client_id", "client_secret": "a longesh password", "redirect_uris": ["https://example.com/cli/authz_cb"], - "preference": {"response_types": ["code"]}, + "response_types": ["code"], } services = { @@ -1242,7 +1238,14 @@ def test_authz_service_conf(): service = entity.get_service("authorization") req = service.construct() - assert "claims" in req + assert set(req.keys()) == {'claims', + 'client_id', + 'nonce', + 'redirect_uri', + 'response_type', + 'scope', + 'state'} + assert set(req["claims"].keys()) == {"id_token"} diff --git a/tests/test_client_22_oidc.py b/tests/test_client_22_oidc.py index eca2f7e6..9bbdd65e 100755 --- a/tests/test_client_22_oidc.py +++ b/tests/test_client_22_oidc.py @@ -50,6 +50,7 @@ def create_client(self): "redirect_uris": ["https://example.com/cli/authz_cb"], "client_id": "client_1", "client_secret": "abcdefghijklmnop", + 'client_authn_methods': ['bearer_header'] } self.client = RP(config=conf) diff --git a/tests/test_client_25_cc_oauth2_service.py b/tests/test_client_25_cc_oauth2_service.py index c7c6ae62..eb130b13 100644 --- a/tests/test_client_25_cc_oauth2_service.py +++ b/tests/test_client_25_cc_oauth2_service.py @@ -15,7 +15,8 @@ def create_service(self): client_config = { "client_id": "client_id", "client_secret": "another password", - "base_url": BASE_URL + "base_url": BASE_URL, + "client_authn_methods": ['client_secret_basic', 'bearer_header'] } services = { "token": { @@ -78,7 +79,7 @@ def test_refresh_token_get_request(self): } ) _srv = self.entity.get_service("refresh_token") - _info = _srv.get_request_parameters(state='') + _info = _srv.get_request_parameters(state='cc') assert _info["method"] == "POST" assert _info["url"] == "https://example.com/token" assert _info["body"] == "grant_type=refresh_token" diff --git a/tests/test_client_27_conversation.py b/tests/test_client_27_conversation.py index 4e960ba5..06c25769 100644 --- a/tests/test_client_27_conversation.py +++ b/tests/test_client_27_conversation.py @@ -150,6 +150,7 @@ def test_conversation(): "backchannel_logout_uri": "https://rp.example.com/back", "backchannel_logout_session_required": True, 'allow': {'missing_kid': True}, + "client_authn_methods": ['bearer_header'], "services": SERVICES } @@ -523,7 +524,12 @@ def test_conversation(): assert info["url"] == "https://example.org/op/token" _qp = parse_qs(info["body"]) - assert set(_qp.keys()) == {'state', 'code', 'client_id', 'grant_type', 'redirect_uri'} + # since the default is private_key_jwt !!! + assert set(_qp.keys()) == {'client_id', + 'code', + 'grant_type', + 'redirect_uri', + 'state'} assert info["headers"]["Content-Type"] == "application/x-www-form-urlencoded" # create the IdToken diff --git a/tests/test_client_28_rp_handler_oidc.py b/tests/test_client_28_rp_handler_oidc.py index bec7c541..87a6605d 100644 --- a/tests/test_client_28_rp_handler_oidc.py +++ b/tests/test_client_28_rp_handler_oidc.py @@ -142,6 +142,7 @@ "scopes_supported": ["user", "public_repo"], "token_endpoint_auth_methods_supported": [], "verify_args": {"allow_sign_alg_none": True}, + 'encrypt_request_object': False }, "provider_info": { "authorization_endpoint": "https://github.com/login/oauth/authorize", @@ -155,7 +156,7 @@ "access_token": {"class": "idpyoidc.client.oidc.access_token.AccessToken"}, "userinfo": { "class": "idpyoidc.client.oidc.userinfo.UserInfo", - "kwargs": {"conf": {"default_authn_method": ""}}, + "kwargs": {"default_authn_method": ""}, }, "refresh_access_token": { "class": "idpyoidc.client.oidc.refresh_access_token.RefreshAccessToken" diff --git a/tests/x_test_ciba_01_backchannel_auth.py b/tests/x_test_ciba_01_backchannel_auth.py index d55d77e2..d96fc859 100644 --- a/tests/x_test_ciba_01_backchannel_auth.py +++ b/tests/x_test_ciba_01_backchannel_auth.py @@ -540,7 +540,9 @@ def _create_ciba_client(self): }, }, "client_authn_methods": { - "client_notification_authn": "idpyoidc.client.oidc.backchannel_authentication.ClientNotificationAuthn" + "client_notification_authn": { + 'class': "idpyoidc.client.oidc.backchannel_authentication.ClientNotificationAuthn" + } }, } From 64f482a6ac08c8a7aac9733f1b2aaf542704c322 Mon Sep 17 00:00:00 2001 From: Roland Hedberg Date: Sun, 11 Dec 2022 08:44:50 +0100 Subject: [PATCH 50/76] Rebased onto improved - tests working --- src/idpyoidc/client/oauth2/token_exchange.py | 8 +- src/idpyoidc/client/oidc/__init__.py | 1 + src/idpyoidc/client/service_context.py | 2 +- src/idpyoidc/server/__init__.py | 12 +- src/idpyoidc/server/endpoint.py | 5 +- src/idpyoidc/server/endpoint_context.py | 34 ++-- src/idpyoidc/server/oidc/provider_config.py | 4 +- src/idpyoidc/server/oidc/registration.py | 64 +++++-- src/idpyoidc/server/session/grant.py | 6 +- .../server/work_environment/oauth2.py | 9 + src/idpyoidc/server/work_environment/oidc.py | 10 +- tests/test_client_28_rp_handler_oidc.py | 173 +++++++----------- tests/test_client_41_rp_handler_persistent.py | 4 +- tests/test_client_55_token_exchange.py | 7 +- tests/test_server_16_endpoint_context.py | 5 +- tests/test_server_17_client_authn.py | 36 ---- ...st_server_23_oidc_registration_endpoint.py | 4 +- tests/test_server_50_persistence.py | 23 ++- 18 files changed, 190 insertions(+), 217 deletions(-) diff --git a/src/idpyoidc/client/oauth2/token_exchange.py b/src/idpyoidc/client/oauth2/token_exchange.py index 0a31d743..9bed32cd 100644 --- a/src/idpyoidc/client/oauth2/token_exchange.py +++ b/src/idpyoidc/client/oauth2/token_exchange.py @@ -27,14 +27,14 @@ class TokenExchange(Service): request_body_type = "urlencoded" response_body_type = "json" - def __init__(self, client_get, conf=None): - Service.__init__(self, client_get, conf=conf) + def __init__(self, upstream_get, conf=None): + Service.__init__(self, upstream_get, conf=conf) self.pre_construct.append(self.oauth_pre_construct) def update_service_context(self, resp, key: Optional[str] = "", **kwargs): if "expires_in" in resp: resp["__expires_at"] = time_sans_frac() + int(resp["expires_in"]) - self.client_get("service_context").cstate.update(key, resp) + self.upstream_get("service_context").cstate.update(key, resp) def oauth_pre_construct(self, request_args=None, post_args=None, **kwargs): """ @@ -54,7 +54,7 @@ def oauth_pre_construct(self, request_args=None, post_args=None, **kwargs): parameters = {'access_token', 'scope'} - _current = self.client_get("service_context").cstate + _current = self.upstream_get("service_context").cstate _args = _current.get_set(_key, claim=parameters) diff --git a/src/idpyoidc/client/oidc/__init__.py b/src/idpyoidc/client/oidc/__init__.py index ad3bf117..05bcc894 100755 --- a/src/idpyoidc/client/oidc/__init__.py +++ b/src/idpyoidc/client/oidc/__init__.py @@ -107,6 +107,7 @@ def __init__( entity_id=entity_id, verify_ssl=verify_ssl, jwks_uri=jwks_uri, + client_type='oidc', **kwargs ) diff --git a/src/idpyoidc/client/service_context.py b/src/idpyoidc/client/service_context.py index 813caff0..2290fb63 100644 --- a/src/idpyoidc/client/service_context.py +++ b/src/idpyoidc/client/service_context.py @@ -118,7 +118,7 @@ def __init__(self, config: Optional[Union[dict, Configuration]] = None, cstate: Optional[Current] = None, upstream_get: Optional[Callable] = None, - client_type: Optional[str] = 'oauth2', + client_type: Optional[str] = 'oidc', keyjar: Optional[KeyJar] = None, **kwargs): ImpExp.__init__(self) diff --git a/src/idpyoidc/server/__init__.py b/src/idpyoidc/server/__init__.py index 45b36c3d..6218ec5a 100644 --- a/src/idpyoidc/server/__init__.py +++ b/src/idpyoidc/server/__init__.py @@ -59,21 +59,23 @@ def __init__( self.upstream_get = upstream_get self.conf = conf + + self.endpoint = do_endpoints(conf, self.unit_get) + self.endpoint_context = EndpointContext( conf=conf, upstream_get=self.unit_get, # points to me cwd=cwd, - cookie_handler=cookie_handler + cookie_handler=cookie_handler, + keyjar=keyjar ) self.endpoint_context.authz = self.setup_authz() self.setup_authentication(self.endpoint_context) - self.endpoint = do_endpoints(conf, self.unit_get) - _cap = get_provider_capabilities(conf, self.endpoint) - - self.endpoint_context.provider_info = self.endpoint_context.create_providerinfo(_cap) + # _cap = get_provider_capabilities(conf, self.endpoint) + # self.endpoint_context.provider_info = self.endpoint_context.create_providerinfo(_cap) self.endpoint_context.do_add_on(endpoints=self.endpoint) self.endpoint_context.session_manager = create_session_manager( diff --git a/src/idpyoidc/server/endpoint.py b/src/idpyoidc/server/endpoint.py index 4a38a243..a3d470ed 100755 --- a/src/idpyoidc/server/endpoint.py +++ b/src/idpyoidc/server/endpoint.py @@ -14,7 +14,6 @@ from idpyoidc.message.oauth2 import ResponseMessage from idpyoidc.message.oidc import RegistrationRequest from idpyoidc.server.client_authn import verify_client -from idpyoidc.server.construct import construct_provider_info from idpyoidc.server.exception import UnAuthorizedClient from idpyoidc.server.util import OAUTH2_NOCACHE_HEADERS from idpyoidc.util import sanitize @@ -166,7 +165,7 @@ def verify_request(self, request, keyjar, client_id, verify_args, lap=0): except IssuerNotFound as err: if lap: return self.error_cls(error=err) - client_id =self.find_client_keys(err.args[0]) + client_id = self.find_client_keys(err.args[0]) if not client_id: return self.error_cls(error=err) else: @@ -260,8 +259,6 @@ def client_authentication(self, request: Message, http_info: Optional[dict] = No kwargs["get_client_id_from_token"] = getattr(self, "get_client_id_from_token", None) authn_info = verify_client( - context=self.upstream_get("context"), - keyjar=self.upstream_get('attribute', 'keyjar'), request=request, http_info=http_info, **kwargs diff --git a/src/idpyoidc/server/endpoint_context.py b/src/idpyoidc/server/endpoint_context.py index 67f21c30..9bdba461 100755 --- a/src/idpyoidc/server/endpoint_context.py +++ b/src/idpyoidc/server/endpoint_context.py @@ -8,7 +8,6 @@ from cryptojwt import KeyJar from jinja2 import Environment from jinja2 import FileSystemLoader - from requests import request from idpyoidc.context import OidcContext @@ -116,14 +115,15 @@ class EndpointContext(OidcContext): } def __init__( - self, - conf: Union[dict, OPConfiguration], - upstream_get: Callable, - cwd: Optional[str] = "", - cookie_handler: Optional[Any] = None, - httpc: Optional[Any] = None, - server_type: Optional[str] = '', - entity_id: Optional[str] = "" + self, + conf: Union[dict, OPConfiguration], + upstream_get: Callable, + cwd: Optional[str] = "", + cookie_handler: Optional[Any] = None, + httpc: Optional[Any] = None, + server_type: Optional[str] = '', + entity_id: Optional[str] = "", + keyjar: Optional[KeyJar] = None ): _id = entity_id or conf.get("issuer", "") OidcContext.__init__(self, conf, entity_id=_id) @@ -252,11 +252,11 @@ def __init__( self.claims_interface = init_service(_interface, self.upstream_get) if isinstance(conf, OPConfiguration): - self.keyjar = self.work_environment.load_conf(conf.conf, supports=self.supports(), - keyjar=keyjar) - else: # OidcConfig - self.keyjar = self.work_environment.load_conf(conf, supports=self.supports(), - keyjar=keyjar) + conf = conf.conf + _supports = self.supports() + self.keyjar = self.work_environment.load_conf(conf, supports=_supports, keyjar=keyjar) + self.provider_info = self.work_environment.provider_info(_supports) + self.provider_info['issuer'] = self.issuer def new_cookie(self, name: str, max_age: Optional[int] = 0, **kwargs): cookie_cont = self.cookie_handler.make_cookie_content( @@ -360,8 +360,8 @@ def do_login_hint_lookup(self): def supports(self): res = {} - if self.server_get: - for endpoint in self.server_get('endpoints').values(): + if self.upstream_get: + for endpoint in self.upstream_get('endpoints').values(): res.update(endpoint.supports()) res.update(self.work_environment.supports()) return res @@ -371,7 +371,7 @@ def set_provider_info(self): supported = self.supports() _info = {'issuer': self.issuer, 'version': "3.0"} - for endp in self.server_get('endpoints').values(): + for endp in self.upstream_get('endpoints').values(): if endp.endpoint_name: _info[endp.endpoint_name] = endp.full_path diff --git a/src/idpyoidc/server/oidc/provider_config.py b/src/idpyoidc/server/oidc/provider_config.py index 361d5195..2ac5e53e 100755 --- a/src/idpyoidc/server/oidc/provider_config.py +++ b/src/idpyoidc/server/oidc/provider_config.py @@ -28,9 +28,9 @@ def add_endpoints(self, request, client_id, context, **kwargs): ]: endp_instance = self.upstream_get("endpoint", endpoint) if endp_instance: - info[endp_instance.endpoint_name] = endp_instance.full_path + request[endp_instance.endpoint_name] = endp_instance.full_path - return info + return request def process_request(self, request=None, **kwargs): return {"response_args": self.upstream_get("context").provider_info} diff --git a/src/idpyoidc/server/oidc/registration.py b/src/idpyoidc/server/oidc/registration.py index 5c55fb01..a916ada5 100755 --- a/src/idpyoidc/server/oidc/registration.py +++ b/src/idpyoidc/server/oidc/registration.py @@ -13,14 +13,12 @@ from idpyoidc.client.oidc import PREFERENCE2PROVIDER # from idpyoidc.defaults import PREFERENCE2SUPPORTED from idpyoidc.client.work_environment.transform import REGISTER2PREFERRED - from idpyoidc.exception import MessageException from idpyoidc.message.oauth2 import ResponseMessage from idpyoidc.message.oidc import ClientRegistrationErrorResponse from idpyoidc.message.oidc import RegistrationRequest from idpyoidc.message.oidc import RegistrationResponse from idpyoidc.server.endpoint import Endpoint -from idpyoidc.server.exception import CapabilitiesMisMatch from idpyoidc.server.exception import InvalidRedirectURIError from idpyoidc.server.exception import InvalidSectorIdentifier from idpyoidc.time_util import utc_time_sans_frac @@ -140,22 +138,52 @@ def __init__(self, *args, **kwargs): _seed = kwargs.get("seed") or rndstr(32) self.seed = as_bytes(_seed) - def match_client_request(self, request: dict) -> list: - err = [] + def match_claim(self, claim, val): + _context = self.upstream_get("context") + + # Use my defaults + _my_key = REGISTER2PREFERRED.get(claim, claim) + try: + _val = _context.provider_info[_my_key] + except KeyError: + return val + + try: + _claim_spec = RegistrationResponse.c_param[claim] + except KeyError: # something I don't know anything about + return None + + if _val: + if isinstance(_claim_spec[0], list): + if isinstance(val, str): + if val in _val: + return val + else: + return None + else: + return list(set(_val).intersection(set(val))) + else: + if val == _val: + return val + else: + return None + else: + return None + + def filter_client_request(self, request: dict) -> dict: + _args = {} _provider_info = self.upstream_get("context").provider_info for key, val in request.items(): if key not in REGISTER2PREFERRED: + _args[key] = val continue - _pi_key = REGISTER2PREFERRED.get(key, key) - if isinstance(val, str): - if val not in _provider_info[_pi_key]: - logger.error(f"CapabilitiesMisMatch: {key}") - err.append(key) + + _val = self.match_claim(key, val) + if _val: + _args[key] = _val else: - if not set(val).issubset(set(_provider_info[_pi_key])): - logger.error(f"CapabilitiesMisMatch: {key}") - err.append(key) - return err + logger.error(f"Capabilities mismatch: {key}={val} not supported") + return _args def do_client_registration(self, request, client_id, ignore=None): if ignore is None: @@ -382,14 +410,10 @@ def client_registration_setup(self, request, new_id=True, set_secret=True): return ResponseMessage(error=_error, error_description="%s" % err) request.rm_blanks() - faulty_claims = self.match_client_request(request) - if faulty_claims: - return ResponseMessage( - error="invalid_request", - error_description=f"Don't support proposed {faulty_claims}" - ) - _context = self.upstream_get("context") + + request = self.filter_client_request(request) + if new_id: if self.kwargs.get("client_id_generator"): cid_generator = importer(self.kwargs["client_id_generator"]["class"]) diff --git a/src/idpyoidc/server/session/grant.py b/src/idpyoidc/server/session/grant.py index a191be45..3d9060ef 100644 --- a/src/idpyoidc/server/session/grant.py +++ b/src/idpyoidc/server/session/grant.py @@ -181,7 +181,7 @@ def add_acr_value(self, claims_release_point): def payload_arguments( self, session_id: str, - context, + context: 'EndpointContext', item: SessionToken, claims_release_point: str, scope: Optional[dict] = None, @@ -241,7 +241,7 @@ def payload_arguments( if context.session_manager.node_type[0] == "user": user_id, _, _ = context.session_manager.decrypt_branch_id(session_id) user_info = context.claims_interface.get_user_claims(user_id, - _claims_restriction) + _claims_restriction) payload.update(user_info) # Should I add the acr value @@ -255,7 +255,7 @@ def payload_arguments( def mint_token( self, session_id: str, - context: object, + context: object, token_class: str, token_handler: TokenHandler = None, based_on: Optional[SessionToken] = None, diff --git a/src/idpyoidc/server/work_environment/oauth2.py b/src/idpyoidc/server/work_environment/oauth2.py index 783251a8..7473329f 100644 --- a/src/idpyoidc/server/work_environment/oauth2.py +++ b/src/idpyoidc/server/work_environment/oauth2.py @@ -1,5 +1,6 @@ from typing import Optional +from idpyoidc.message.oauth2 import ASConfigurationResponse from idpyoidc.server import work_environment # from idpyoidc.server.client_authn import get_client_authn_methods @@ -34,3 +35,11 @@ def __init__(self, prefer: Optional[dict] = None, callback_path: Optional[dict] = None): work_environment.WorkEnvironment.__init__(self, prefer=prefer, callback_path=callback_path) + + def provider_info(self, supports): + _info = {} + for key in ASConfigurationResponse.c_param.keys(): + _val = self.get_preference(key, supports.get(key, None)) + if _val: + _info[key] = _val + return _info diff --git a/src/idpyoidc/server/work_environment/oidc.py b/src/idpyoidc/server/work_environment/oidc.py index 071b4f95..a5503481 100644 --- a/src/idpyoidc/server/work_environment/oidc.py +++ b/src/idpyoidc/server/work_environment/oidc.py @@ -1,8 +1,8 @@ from typing import Optional from idpyoidc import work_environment as WE +from idpyoidc.message.oidc import ProviderConfigurationResponse from idpyoidc.server import work_environment -# from idpyoidc.server.client_authn import get_client_authn_methods class WorkEnvironment(work_environment.WorkEnvironment): @@ -59,3 +59,11 @@ def verify_rules(self): if not self.get_preference('encrypt_id_token_supported'): self.set_preference('id_token_encryption_alg_values_supported', []) self.set_preference('id_token_encryption_enc_values_supported', []) + + def provider_info(self, supports): + _info = {} + for key in ProviderConfigurationResponse.c_param.keys(): + _val = self.get_preference(key, supports.get(key, None)) + if _val is not None: + _info[key] = _val + return _info diff --git a/tests/test_client_28_rp_handler_oidc.py b/tests/test_client_28_rp_handler_oidc.py index 87a6605d..d15ce18a 100644 --- a/tests/test_client_28_rp_handler_oidc.py +++ b/tests/test_client_28_rp_handler_oidc.py @@ -17,6 +17,8 @@ from idpyoidc.message.oidc import Link from idpyoidc.message.oidc import OpenIDSchema from idpyoidc.message.oidc import ProviderConfigurationResponse +from idpyoidc.message.oidc import RegistrationResponse +from idpyoidc.util import rndstr BASE_URL = "https://example.com/rp" @@ -880,20 +882,22 @@ class TestRPHandlerWithMockOP(object): @pytest.fixture(autouse=True) def rphandler_setup(self): self.issuer = "https://github.com/login/oauth/authorize" - self.mock_op = MockOP(issuer=self.issuer) - self.rph = RPHandler( - BASE_URL, client_configs=CLIENT_CONFIG, httpc=self.mock_op, keyjar=CLI_KEY - ) + # self.mock_op = MockOP(issuer=self.issuer) + self.rph = RPHandler(BASE_URL, client_configs=CLIENT_CONFIG, keyjar=CLI_KEY) def test_finalize(self): auth_query = self.rph.begin(issuer_id="github") # The authorization query is sent and after successful authentication client = self.rph.get_client_from_session_key(state=auth_query["state"]) # register a response - p = urlparse(CLIENT_CONFIG["github"]["provider_info"]["authorization_endpoint"]) - self.mock_op.register_get_response(p.path, "Redirect", 302) - - _ = client.httpc("GET", auth_query["url"]) + _url = CLIENT_CONFIG["github"]["provider_info"]["authorization_endpoint"] + with responses.RequestsMock() as rsps: + rsps.add( + "GET", + _url, + status=302, + ) + _ = client.httpc("GET", auth_query["url"]) # the user is redirected back to the RP with a positive response auth_response = AuthorizationResponse(code="access_code", state=auth_query["state"]) @@ -910,27 +914,34 @@ def test_finalize(self): key_jar=GITHUB_KEY, ) - p = urlparse(CLIENT_CONFIG["github"]["provider_info"]["token_endpoint"]) - self.mock_op.register_post_response( - p.path, resp.to_json(), 200, {"content-type": "application/json"} - ) - - _info = OpenIDSchema( + _token_url = CLIENT_CONFIG["github"]["provider_info"]["token_endpoint"] + _user_url = CLIENT_CONFIG["github"]["provider_info"]["userinfo_endpoint"] + _user_info = OpenIDSchema( sub="EndUserSubject", given_name="Diana", family_name="Krall", occupation="Jazz pianist" ) - p = urlparse(CLIENT_CONFIG["github"]["provider_info"]["userinfo_endpoint"]) - self.mock_op.register_get_response( - p.path, _info.to_json(), 200, {"content-type": "application/json"} - ) - _github_id = iss_id("github") client.get_context().keyjar.import_jwks( GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id ) + with responses.RequestsMock() as rsps: + rsps.add( + "POST", + _token_url, + body=resp.to_json(), + adding_headers={"Content-Type": "application/json"}, + status=200, + ) + rsps.add( + "GET", + _user_url, + body=_user_info.to_json(), + adding_headers={"Content-Type": "application/json"}, + status=200, + ) - # do the rest (= get access token and user info) - # assume code flow - resp = self.rph.finalize(_session['iss'], auth_response.to_dict()) + # do the rest (= get access token and user info) + # assume code flow + resp = self.rph.finalize(_session['iss'], auth_response.to_dict()) assert set(resp.keys()) == {"userinfo", "state", "token", "id_token", "session_state"} @@ -940,13 +951,6 @@ def test_dynamic_setup(self): rel="http://openid.net/specs/connect/1.0/issuer", href="https://server.example.com" ) webfinger_response = JRD(subject=user_id, links=[_link]) - self.mock_op.register_get_response( - "/.well-known/webfinger", - webfinger_response.to_json(), - 200, - {"content-type": "application/json"}, - ) - resp = { "authorization_endpoint": "https://server.example.com/connect/authorize", "issuer": "https://server.example.com", @@ -973,83 +977,40 @@ def test_dynamic_setup(self): ], "request_object_algs_supported": ["HS256", "RS256", "A128CBC", "A128KW", "RSA1_5"], } - pcr = ProviderConfigurationResponse(**resp) - self.mock_op.register_get_response( - "/.well-known/openid-configuration", - pcr.to_json(), - 200, - {"content-type": "application/json"}, - ) - - self.mock_op.register_post_response( - "/connect/register", registration_callback, 200, {"content-type": "application/json"} - ) + _crr = {"application_type": "web", "response_types": ["code", "code id_token"], + "redirect_uris": [ + "https://example.com/rp/authz_cb" + "/7b7308fecf10c90b29303b6ae35ad1ef0f1914e49187f163335ae0b26a769e4f"], + "grant_types": ["authorization_code", "implicit"], "contacts": ["ops@example.com"], + "subject_type": "public", "id_token_signed_response_alg": "RS256", + "userinfo_signed_response_alg": "RS256", "request_object_signing_alg": "RS256", + "token_endpoint_auth_signing_alg": "RS256", "default_max_age": 86400, + "token_endpoint_auth_method": "client_secret_basic"} + _crr.update({'client_id':'abcdefghijkl', 'client_secret':rndstr(32)}) + cli_reg_resp = RegistrationResponse(**_crr) + with responses.RequestsMock() as rsps: + rsps.add( + "GET", + "https://example.com/.well-known/webfinger", + body=webfinger_response.to_json(), + adding_headers={"Content-Type": "application/json"}, + status=200, + ) + rsps.add( + "GET", + "https://server.example.com/.well-known/openid-configuration", + body=pcr.to_json(), + status=200, + adding_headers={"Content-Type": "application/json"}, + ) + rsps.add( + "POST", + "https://server.example.com/connect/register", + body=cli_reg_resp.to_json(), + status=200, + adding_headers={"Content-Type": "application/json"}, + ) - auth_query = self.rph.begin(user_id=user_id) + auth_query = self.rph.begin(user_id=user_id) assert auth_query - - def test_dynamic_setup_redirect_uri(self): - user_id = "acct:foobar@example.com" - _link = Link( - rel="http://openid.net/specs/connect/1.0/issuer", href="https://server.example.com" - ) - webfinger_response = JRD(subject=user_id, links=[_link]) - self.mock_op.register_get_response( - "/.well-known/webfinger", - webfinger_response.to_json(), - 200, - {"content-type": "application/json"}, - ) - - resp = { - "authorization_endpoint": "https://server.example.com/connect/authorize", - "issuer": "https://server.example.com", - "subject_types_supported": ["public"], - "token_endpoint": "https://server.example.com/connect/token", - "token_endpoint_auth_methods_supported": ["client_secret_basic", "private_key_jwt"], - "userinfo_endpoint": "https://server.example.com/connect/user", - "check_id_endpoint": "https://server.example.com/connect/check_id", - "refresh_session_endpoint": "https://server.example.com/connect/refresh_session", - "end_session_endpoint": "https://server.example.com/connect/end_session", - "jwks_uri": "https://server.example.com/jwk.json", - "registration_endpoint": "https://server.example.com/connect/register", - "scopes_supported": ["openid", "profile", "email", "address", "phone"], - "response_types_supported": ["code", "code id_token", "token id_token"], - "acrs_supported": ["1", "2", "http://id.incommon.org/assurance/bronze"], - "user_id_types_supported": ["public", "pairwise"], - "userinfo_algs_supported": ["HS256", "RS256", "A128CBC", "A128KW", "RSA1_5"], - "id_token_signing_alg_values_supported": [ - "HS256", - "RS256", - "A128CBC", - "A128KW", - "RSA1_5", - ], - "request_object_algs_supported": ["HS256", "RS256", "A128CBC", "A128KW", "RSA1_5"], - "request_parameter_supported": True, - "request_uri_parameter_supported": True, - "require_request_uri_registration": True, - } - - pcr = ProviderConfigurationResponse(**resp) - self.mock_op.register_get_response( - "/.well-known/openid-configuration", - pcr.to_json(), - 200, - {"content-type": "application/json"}, - ) - - self.mock_op.register_post_response( - "/connect/register", registration_callback, 200, {"content-type": "application/json"} - ) - - res = self.rph.begin( - user_id=user_id, - behaviour_args={"request_param": "request", "request_object_signing_alg": "RS256"}, - ) - assert res - - _url = res["url"] - _qp = parse_qs(urlparse(_url).query) - assert "request" in _qp diff --git a/tests/test_client_41_rp_handler_persistent.py b/tests/test_client_41_rp_handler_persistent.py index fef02546..5e6b91ea 100644 --- a/tests/test_client_41_rp_handler_persistent.py +++ b/tests/test_client_41_rp_handler_persistent.py @@ -51,6 +51,7 @@ "client_id": "xxxxxxx", "client_secret": "yyyyyyyyyyyyyyyyyyyy", "redirect_uris": ["{}/authz_cb/linkedin".format(BASE_URL)], + 'client_type': 'oauth2', "preference": { "response_types": ["code"], "scope": ["r_basicprofile", "r_emailaddress"], @@ -325,8 +326,7 @@ def test_finalize_auth(self): resp = rph_1.finalize_auth(client, _session["iss"], auth_response.to_dict()) assert set(resp.keys()) == {"state", "code"} aresp = ( - client.get_service("authorization") - .upstream("service_context").cstate.get(res["state"]) + client.get_service("authorization").upstream_get("context").cstate.get(res["state"]) ) assert set(aresp.keys()) == { "state", "code", 'iss', 'client_id', diff --git a/tests/test_client_55_token_exchange.py b/tests/test_client_55_token_exchange.py index 197c8b7d..976d3b6a 100644 --- a/tests/test_client_55_token_exchange.py +++ b/tests/test_client_55_token_exchange.py @@ -67,10 +67,9 @@ def create_request(self): }, } ) - entity.client_get("service_context").issuer = "https://example.com" - self.service = entity.client_get("service", "token_exchange") - - _cstate = self.service.client_get("service_context").cstate + entity.get_context().issuer = "https://example.com" + self.service = entity.get_service("token_exchange") + _cstate = self.service.upstream_get("context").cstate # Add history auth_response = AuthorizationResponse(code="access_code") _cstate.update("abcde", auth_response) diff --git a/tests/test_server_16_endpoint_context.py b/tests/test_server_16_endpoint_context.py index 1406c3be..af5d04b3 100644 --- a/tests/test_server_16_endpoint_context.py +++ b/tests/test_server_16_endpoint_context.py @@ -28,9 +28,10 @@ class Endpoint_1(Endpoint): _supports = { "claim_types_supported": ["normal", "aggregated", "distributed"], "userinfo_signing_alg_values_supported": work_environment.get_signing_algs, - "userinfo_encryption_alg_values_supported": None, - "userinfo_encryption_enc_values_supported": None, + "userinfo_encryption_alg_values_supported": work_environment.get_encryption_algs, + "userinfo_encryption_enc_values_supported": work_environment.get_encryption_encs, "client_authn_method": ["bearer_header", "bearer_body"], + "encrypt_userinfo_supported": False, } diff --git a/tests/test_server_17_client_authn.py b/tests/test_server_17_client_authn.py index dbd6fba4..9b72e6ae 100644 --- a/tests/test_server_17_client_authn.py +++ b/tests/test_server_17_client_authn.py @@ -482,8 +482,6 @@ def test_verify_per_client(self): request = {"client_id": client_id} res = verify_client( - self.endpoint_context, - keyjar=self.server.get_attribute('keyjar'), request=request, endpoint=self.server.get_endpoint("endpoint_4"), ) @@ -499,8 +497,6 @@ def test_verify_per_client_per_endpoint(self): request = {"client_id": client_id} res = verify_client( - self.endpoint_context, - keyjar=self.server.get_attribute('keyjar'), request=request, endpoint=self.server.get_endpoint("endpoint_4"), ) @@ -508,8 +504,6 @@ def test_verify_per_client_per_endpoint(self): with pytest.raises(ClientAuthenticationError) as e: verify_client( - self.endpoint_context, - keyjar=self.server.get_attribute('keyjar'), request=request, endpoint=self.server.get_endpoint("endpoint_1"), ) @@ -517,8 +511,6 @@ def test_verify_per_client_per_endpoint(self): request = {"client_id": client_id, "client_secret": client_secret} res = verify_client( - self.endpoint_context, - keyjar=self.server.get_attribute('keyjar'), request=request, endpoint=self.server.get_endpoint("endpoint_1"), ) @@ -528,8 +520,6 @@ def test_verify_per_client_per_endpoint(self): def test_verify_client_client_secret_post(self): request = {"client_id": client_id, "client_secret": client_secret} res = verify_client( - self.endpoint_context, - keyjar=self.server.get_attribute('keyjar'), request=request, endpoint=self.server.get_endpoint("endpoint_1"), ) @@ -550,8 +540,6 @@ def test_verify_client_jws_authn_method(self): request = {"client_assertion": _assertion, "client_assertion_type": JWT_BEARER} http_info = {"headers": {}} res = verify_client( - self.endpoint_context, - keyjar=self.server.get_attribute('keyjar'), request=request, http_info=http_info, endpoint=self.server.get_endpoint("endpoint_1"), @@ -563,8 +551,6 @@ def test_verify_client_bearer_body(self): request = {"access_token": "1234567890", "client_id": client_id} self.endpoint_context.registration_access_token["1234567890"] = client_id res = verify_client( - self.endpoint_context, - keyjar=self.server.get_attribute('keyjar'), request=request, get_client_id_from_token=get_client_id_from_token, endpoint=self.server.get_endpoint("endpoint_3"), @@ -579,8 +565,6 @@ def test_verify_client_client_secret_basic(self): http_info = {"headers": {"authorization": authz_token}} res = verify_client( - self.endpoint_context, - keyjar=self.server.get_attribute('keyjar'), request={}, http_info=http_info, endpoint=self.server.get_endpoint("endpoint_1"), @@ -596,8 +580,6 @@ def test_verify_client_bearer_header(self): http_info = {"headers": {"authorization": token}} request = {"client_id": client_id} res = verify_client( - self.endpoint_context, - keyjar=self.server.get_attribute('keyjar'), request=request, http_info=http_info, get_client_id_from_token=get_client_id_from_token, @@ -629,8 +611,6 @@ def test_verify_client_jws_authn_method(self): request = {"client_assertion": _assertion, "client_assertion_type": JWT_BEARER} res = verify_client( - self.endpoint_context, - keyjar=self.server.get_attribute('keyjar'), request=request, endpoint=self.server.get_endpoint("endpoint_1"), ) @@ -641,8 +621,6 @@ def test_verify_client_bearer_body(self): request = {"access_token": "1234567890", "client_id": client_id} self.endpoint_context.registration_access_token["1234567890"] = client_id res = verify_client( - self.endpoint_context, - keyjar=self.server.get_attribute('keyjar'), request=request, get_client_id_from_token=get_client_id_from_token, endpoint=self.server.get_endpoint("endpoint_3"), @@ -653,8 +631,6 @@ def test_verify_client_bearer_body(self): def test_verify_client_client_secret_post(self): request = {"client_id": client_id, "client_secret": client_secret} res = verify_client( - self.endpoint_context, - keyjar=self.server.get_attribute('keyjar'), request=request, endpoint=self.server.get_endpoint("endpoint_1"), ) @@ -668,8 +644,6 @@ def test_verify_client_client_secret_basic(self): http_info = {"headers": {"authorization": authz_token}} res = verify_client( - self.endpoint_context, - keyjar=self.server.get_attribute('keyjar'), request={}, http_info=http_info, endpoint=self.server.get_endpoint("endpoint_1"), @@ -685,8 +659,6 @@ def test_verify_client_bearer_header(self): http_info = {"headers": {"authorization": token}} request = {"client_id": client_id} res = verify_client( - self.endpoint_context, - keyjar=self.server.get_attribute('keyjar'), request=request, http_info=http_info, get_client_id_from_token=get_client_id_from_token, @@ -699,8 +671,6 @@ def test_verify_client_authorization_none(self): # This is when it's explicitly said that no client auth method is allowed request = {"client_id": client_id} res = verify_client( - self.endpoint_context, - keyjar=self.server.get_attribute('keyjar'), request=request, endpoint=self.server.get_endpoint("endpoint_2"), ) @@ -711,8 +681,6 @@ def test_verify_client_registration_public(self): # This is when no special auth method is configured request = {"redirect_uris": ["https://example.com/cb"], "client_id": "client_id"} res = verify_client( - self.endpoint_context, - keyjar=self.server.get_attribute('keyjar'), request=request, endpoint=self.server.get_endpoint("endpoint_4"), ) @@ -722,8 +690,6 @@ def test_verify_client_registration_none(self): # This is when no special auth method is configured request = {"redirect_uris": ["https://example.com/cb"]} res = verify_client( - self.endpoint_context, - keyjar=self.server.get_attribute('keyjar'), request=request, endpoint=self.server.get_endpoint("endpoint_4"), ) @@ -746,8 +712,6 @@ class Mock: request = {"redirect_uris": ["https://example.com/cb"]} res = verify_client( - server.endpoint_context, - keyjar=server.get_attribute('keyjar'), request=request, endpoint=server.get_endpoint("endpoint_4") ) diff --git a/tests/test_server_23_oidc_registration_endpoint.py b/tests/test_server_23_oidc_registration_endpoint.py index bff468a0..35e9d1bf 100755 --- a/tests/test_server_23_oidc_registration_endpoint.py +++ b/tests/test_server_23_oidc_registration_endpoint.py @@ -211,14 +211,14 @@ def test_register_unsupported_algo(self): _msg["id_token_signed_response_alg"] = "XYZ256" _req = self.endpoint.parse_request(RegistrationRequest(**_msg).to_json()) _resp = self.endpoint.process_request(request=_req) - assert _resp["error"] == "invalid_request" + assert "id_token_signed_response_alg" not in _resp def test_register_unsupported_set(self): _msg = MSG.copy() _msg["grant_types"] = ["authorization_code", "external"] _req = self.endpoint.parse_request(RegistrationRequest(**_msg).to_json()) _resp = self.endpoint.process_request(request=_req) - assert _resp["error"] == "invalid_request" + assert _resp["response_args"]["grant_types"] == ["authorization_code"] def test_register_post_logout_redirect_uri_with_fragment(self): _msg = MSG.copy() diff --git a/tests/test_server_50_persistence.py b/tests/test_server_50_persistence.py index 12697787..725510db 100644 --- a/tests/test_server_50_persistence.py +++ b/tests/test_server_50_persistence.py @@ -2,8 +2,9 @@ import os import shutil -import pytest from cryptojwt.jwt import utc_time_sans_frac +from cryptojwt.key_jar import init_key_jar +import pytest from idpyoidc.message.oidc import AccessTokenRequest from idpyoidc.message.oidc import AuthorizationRequest @@ -80,7 +81,7 @@ def full_path(local_file): "issuer": "https://example.com/", "httpc_params": {"verify": False, "timeout": 1}, "capabilities": CAPABILITIES, - "keys": {"uri_path": "jwks.json", "key_defs": KEYDEFS}, + # "keys": {"uri_path": "jwks.json", "key_defs": KEYDEFS}, "token_handler_args": { "jwks_file": "private/token_jwks.json", "code": {"lifetime": 600, "kwargs": {"crypt_conf": CRYPT_CONFIG}}, @@ -202,12 +203,18 @@ def create_endpoint(self): except FileNotFoundError: pass + # Both have to use the same keyjar + _keyjar = init_key_jar(key_defs=KEYDEFS) + _keyjar.import_jwks_as_json(_keyjar.export_jwks_as_json(True, ""), + ENDPOINT_CONTEXT_CONFIG['issuer']) server1 = Server( - OPConfiguration(conf=ENDPOINT_CONTEXT_CONFIG, base_path=BASEDIR), cwd=BASEDIR + OPConfiguration(conf=ENDPOINT_CONTEXT_CONFIG, base_path=BASEDIR), cwd=BASEDIR, + keyjar=_keyjar ) server2 = Server( - OPConfiguration(conf=ENDPOINT_CONTEXT_CONFIG, base_path=BASEDIR), cwd=BASEDIR + OPConfiguration(conf=ENDPOINT_CONTEXT_CONFIG, base_path=BASEDIR), cwd=BASEDIR, + keyjar=_keyjar ) # The top most part (Server class instance) is not @@ -294,11 +301,11 @@ def _dump_restore(self, fro, to): def test_init(self): assert self.endpoint[1] assert set( - self.endpoint[1].server_get("endpoint_context").provider_info["scopes_supported"] + self.endpoint[1].upstream_get("context").provider_info["scopes_supported"] ) == {"openid"} - assert set( - self.endpoint[1].upstream_get("context").provider_info["claims_supported"] - ) == set(self.endpoint[2].upstream_get("context").provider_info["claims_supported"]) + assert self.endpoint[1].upstream_get("context").provider_info[ + "claims_parameter_supported"] == \ + self.endpoint[2].upstream_get("context").provider_info["claims_parameter_supported"] def test_parse(self): session_id = self._create_session(AUTH_REQ, index=1) From f2edbfe91426e002650eec3d36eb57a991e00ae7 Mon Sep 17 00:00:00 2001 From: Roland Hedberg Date: Mon, 12 Dec 2022 15:40:02 +0100 Subject: [PATCH 51/76] Rebased onto improved - tests working --- src/idpyoidc/client/current.py | 107 ++++++++++++++++++++ src/idpyoidc/client/oidc/authorization.py | 2 +- src/idpyoidc/client/oidc/userinfo.py | 2 - src/idpyoidc/client/service_context.py | 7 +- src/idpyoidc/server/endpoint.py | 1 - src/idpyoidc/server/endpoint_context.py | 10 -- src/idpyoidc/server/oidc/authorization.py | 1 + src/idpyoidc/server/oidc/provider_config.py | 1 - src/idpyoidc/server/oidc/registration.py | 3 - src/idpyoidc/server/oidc/token.py | 2 - 10 files changed, 113 insertions(+), 23 deletions(-) create mode 100644 src/idpyoidc/client/current.py diff --git a/src/idpyoidc/client/current.py b/src/idpyoidc/client/current.py new file mode 100644 index 00000000..d81455e5 --- /dev/null +++ b/src/idpyoidc/client/current.py @@ -0,0 +1,107 @@ +from typing import Optional +from typing import Union + +from idpyoidc.impexp import ImpExp +from idpyoidc.message import Message +from idpyoidc.util import rndstr + + +class Current(ImpExp): + """A more powerful interface to a state DB.""" + + parameter = {"_db": None, "_map": None} + + def __init__(self): + ImpExp.__init__(self) + self._db = {} + self._map = {} + + def get(self, key: str) -> dict: + """ + Get the currently used claims connected to a given key. + + :param key: Key into the state database + :return: A dictionary with the currently used claims + """ + _data = self._db.get(key) + if not _data: + raise KeyError(key) + + return _data + + def update(self, key: str, info: Union[Message, dict]) -> dict: + if isinstance(info, Message): + info = info.to_dict() + + _current = self._db.get(key) + if _current is None: + self._db[key] = info + return info + else: + _current.update(info) + self._db[key] = _current + return _current + + def set(self, key: str, info: Union[Message, dict]): + if isinstance(info, Message): + self._db[key] = info.to_dict() + else: + self._db[key] = info + + def get_set(self, + key: str, + message: Optional[type(Message)] = None, + claim: Optional[list] = None) -> dict: + """ + + @param key: The key to a seet of current claims + @param message: A message class + @param claim: A list of claims + @return: Dictionary + """ + + try: + _current = self.get(key) + except KeyError: + return {} + + if message: + _res = {k: _current[k] for k in message.c_param.keys() if k in _current} + else: + _res = {} + + if claim: + _res.update({k: _current[k] for k in claim if k in _current}) + + return _res + + def rm_claim(self, key, claim): + try: + del self._db[key][claim] + except KeyError: + pass + + def remove_state(self, key): + try: + del self._db[key] + except KeyError: + pass + else: + _mkeys = list(self._map.keys()) + for k in _mkeys: + if self._map[k] == key: + del self._map[k] + + def bind_key(self, fro, to): + self._map[fro] = to + + def get_base_key(self, key): + return self._map[key] + + def create_key(self): + return rndstr(32) + + def create_state(self, **kwargs): + _key = self.create_key() + self._db[_key] = kwargs + return _key diff --git a/src/idpyoidc/client/oidc/authorization.py b/src/idpyoidc/client/oidc/authorization.py index 59d043af..6faacecd 100644 --- a/src/idpyoidc/client/oidc/authorization.py +++ b/src/idpyoidc/client/oidc/authorization.py @@ -140,7 +140,7 @@ def oidc_pre_construct(self, request_args=None, post_args=None, **kwargs): request_args["scope"].append("openid") # 'code' and/or 'id_token' in response_type means an ID Roken - # will eventually be returnedm, hence the need for a nonce + # will eventually be returned, hence the need for a nonce if "code" in _response_types or "id_token" in _response_types: if "nonce" not in request_args: request_args["nonce"] = rndstr(32) diff --git a/src/idpyoidc/client/oidc/userinfo.py b/src/idpyoidc/client/oidc/userinfo.py index 2dd54b1c..cb99cde9 100644 --- a/src/idpyoidc/client/oidc/userinfo.py +++ b/src/idpyoidc/client/oidc/userinfo.py @@ -37,10 +37,8 @@ class UserInfo(Service): response_cls = oidc.OpenIDSchema error_msg = oidc.ResponseMessage endpoint_name = "userinfo_endpoint" - synchronous = True service_name = "userinfo" default_authn_method = "bearer_header" - http_method = "GET" _supports = { "userinfo_signing_alg_values_supported": get_signing_algs, diff --git a/src/idpyoidc/client/service_context.py b/src/idpyoidc/client/service_context.py index 2290fb63..e6209173 100644 --- a/src/idpyoidc/client/service_context.py +++ b/src/idpyoidc/client/service_context.py @@ -99,7 +99,7 @@ class ServiceContext(Unit): "httpc_params": None, "iss_hash": None, "issuer": None, - 'keyjar': KeyJar, + # 'keyjar': KeyJar, "work_environment": WorkEnvironment, "provider_info": None, "requests_dir": None, @@ -167,8 +167,9 @@ def __init__(self, for key, val in kwargs.items(): setattr(self, key, val) - self.keyjar = self.work_environment.load_conf(config.conf, supports=self.supports(), - keyjar=keyjar) + _keyjar = self.work_environment.load_conf(config.conf, supports=self.supports(), + keyjar=keyjar) + self.upstream_get('set_attribute', 'keyjar', _keyjar) _response_types = self.get_preference( 'response_types_supported', diff --git a/src/idpyoidc/server/endpoint.py b/src/idpyoidc/server/endpoint.py index a3d470ed..fc8bb5de 100755 --- a/src/idpyoidc/server/endpoint.py +++ b/src/idpyoidc/server/endpoint.py @@ -90,7 +90,6 @@ class Endpoint(object): response_placement = "body" response_content_type = "" client_authn_method = "" - default_capabilities = None auth_method_attribute = "" _supports = {} diff --git a/src/idpyoidc/server/endpoint_context.py b/src/idpyoidc/server/endpoint_context.py index 9bdba461..95abc2d1 100755 --- a/src/idpyoidc/server/endpoint_context.py +++ b/src/idpyoidc/server/endpoint_context.py @@ -211,16 +211,6 @@ def __init__( if _loader: self.template_handler = Jinja2TemplateHandler(_loader) - # # self.setup = {} - # _keys_conf = conf.get("key_conf") - # if _keys_conf: - # jwks_uri_path = _keys_conf["uri_path"] - # - # if self.issuer.endswith("/"): - # self.jwks_uri = "{}{}".format(self.issuer, jwks_uri_path) - # else: - # self.jwks_uri = "{}/{}".format(self.issuer, jwks_uri_path) - for item in [ "cookie_handler", "authentication", diff --git a/src/idpyoidc/server/oidc/authorization.py b/src/idpyoidc/server/oidc/authorization.py index ab6d074b..b6293806 100755 --- a/src/idpyoidc/server/oidc/authorization.py +++ b/src/idpyoidc/server/oidc/authorization.py @@ -75,6 +75,7 @@ class Authorization(authorization.Authorization): response_placement = "url" endpoint_name = "authorization_endpoint" name = "authorization" + _supports = { "claims_parameter_supported": True, "encrypt_request_object_supported": None, diff --git a/src/idpyoidc/server/oidc/provider_config.py b/src/idpyoidc/server/oidc/provider_config.py index 2ac5e53e..5f6478a6 100755 --- a/src/idpyoidc/server/oidc/provider_config.py +++ b/src/idpyoidc/server/oidc/provider_config.py @@ -12,7 +12,6 @@ class ProviderConfiguration(Endpoint): request_format = "" response_format = "json" name = "provider_config" - # _supports = {"require_request_uri_registration": None} def __init__(self, upstream_get, **kwargs): Endpoint.__init__(self, upstream_get=upstream_get, **kwargs) diff --git a/src/idpyoidc/server/oidc/registration.py b/src/idpyoidc/server/oidc/registration.py index a916ada5..1c159ada 100755 --- a/src/idpyoidc/server/oidc/registration.py +++ b/src/idpyoidc/server/oidc/registration.py @@ -127,9 +127,6 @@ class Registration(Endpoint): endpoint_name = "registration_endpoint" name = "registration" - # default - # response_placement = 'body' - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/src/idpyoidc/server/oidc/token.py b/src/idpyoidc/server/oidc/token.py index cab8f66c..5c4436f2 100755 --- a/src/idpyoidc/server/oidc/token.py +++ b/src/idpyoidc/server/oidc/token.py @@ -36,8 +36,6 @@ class Token(token.Token): "token_endpoint_auth_signing_alg_values_supported": work_environment.get_signing_algs, } - # auth_method_attribute = "token_endpoint_auth_methods_supported" - helper_by_grant_type = { "authorization_code": AccessTokenHelper, "refresh_token": RefreshTokenHelper, From 3d367628f81d04d790bf8811b445216b18667847 Mon Sep 17 00:00:00 2001 From: roland Date: Mon, 12 Dec 2022 17:26:21 +0100 Subject: [PATCH 52/76] Replaced work_environment/metadata with claims. Improved readability. --- doc/server/contents/conf.rst | 113 +++++++ example/flask_op/views.py | 34 +-- example/flask_rp/application.py | 11 +- private/xmetadata/__init__.py | 186 ++++++++++++ private/xmetadata/oauth2.py | 51 ++++ private/xmetadata/oidc.py | 126 ++++++++ script/client_config.py | 2 +- script/rp_handler_config.py | 2 +- src/idpyoidc/actor/client/oidc/__init__.py | 4 +- .../actor/client/oidc/registration.py | 4 +- .../{work_environment.py => claims.py} | 23 +- src/idpyoidc/client/claims/__init__.py | 61 ++++ .../{work_environment => claims}/oauth2.py | 12 +- src/idpyoidc/client/claims/oidc.py | 132 +++++++++ .../{work_environment => claims}/transform.py | 3 +- src/idpyoidc/client/client_auth.py | 14 +- src/idpyoidc/client/configure.py | 6 +- src/idpyoidc/client/current.py | 3 + src/idpyoidc/client/entity.py | 8 +- src/idpyoidc/client/oauth2/__init__.py | 2 +- src/idpyoidc/client/oauth2/access_token.py | 2 +- src/idpyoidc/client/oauth2/add_on/dpop.py | 4 +- src/idpyoidc/client/oauth2/authorization.py | 8 +- src/idpyoidc/client/oauth2/server_metadata.py | 12 +- src/idpyoidc/client/oidc/access_token.py | 4 +- src/idpyoidc/client/oidc/authorization.py | 14 +- src/idpyoidc/client/oidc/end_session.py | 20 +- .../client/oidc/provider_info_discovery.py | 4 +- .../client/oidc/refresh_access_token.py | 2 +- src/idpyoidc/client/oidc/registration.py | 3 +- src/idpyoidc/client/oidc/userinfo.py | 8 +- src/idpyoidc/client/provider/github.py | 2 +- src/idpyoidc/client/provider/linkedin.py | 2 +- src/idpyoidc/client/rp_handler.py | 32 +- src/idpyoidc/client/service.py | 10 +- src/idpyoidc/client/service_context.py | 120 ++++---- src/idpyoidc/client/util.py | 5 +- .../client/work_environment/__init__.py | 51 ---- src/idpyoidc/client/work_environment/oidc.py | 76 ----- src/idpyoidc/configure.py | 2 +- src/idpyoidc/context.py | 6 - src/idpyoidc/impexp.py | 16 +- src/idpyoidc/message/oauth2/__init__.py | 98 ++++++ src/idpyoidc/metadata.py | 279 ++++++++++++++++++ src/idpyoidc/node.py | 80 ++++- src/idpyoidc/server/__init__.py | 102 ++----- src/idpyoidc/server/claims/__init__.py | 27 ++ .../{work_environment => claims}/oauth2.py | 28 +- .../{work_environment => claims}/oidc.py | 48 ++- src/idpyoidc/server/endpoint.py | 5 +- src/idpyoidc/server/endpoint_context.py | 167 +++++++++-- src/idpyoidc/server/oauth2/add_on/dpop.py | 4 +- src/idpyoidc/server/oauth2/introspection.py | 2 +- src/idpyoidc/server/oauth2/server_metadata.py | 10 +- src/idpyoidc/server/oidc/authorization.py | 11 +- .../server/oidc/backchannel_authentication.py | 2 - src/idpyoidc/server/oidc/provider_config.py | 2 +- src/idpyoidc/server/oidc/registration.py | 19 +- src/idpyoidc/server/oidc/session.py | 3 +- src/idpyoidc/server/oidc/token.py | 4 +- src/idpyoidc/server/oidc/userinfo.py | 10 +- src/idpyoidc/server/session/grant_manager.py | 4 +- src/idpyoidc/server/session/token.py | 2 - src/idpyoidc/server/token/handler.py | 27 +- src/idpyoidc/server/token/id_token.py | 7 +- src/idpyoidc/server/token/jwt_token.py | 2 +- .../server/user_authn/authn_context.py | 2 +- .../server/work_environment/__init__.py | 14 - src/idpyoidc/storage/abfile.py | 8 +- tests/pub_client.jwks | 2 +- tests/pub_iss.jwks | 2 +- tests/request123456.jwt | 2 +- tests/test_08_transform.py | 144 ++++----- tests/test_09_work_condition.py | 54 ++-- tests/test_client_01_service_context.py | 17 +- tests/test_client_02_entity.py | 2 +- tests/test_client_02b_entity_metadata.py | 98 +++--- tests/test_client_04_service.py | 17 +- tests/test_client_06_client_authn.py | 13 +- tests/test_client_10_entity.py | 3 +- .../test_client_14_service_context_impexp.py | 34 ++- tests/test_client_20_oauth2.py | 4 +- tests/test_client_21_oidc_service.py | 73 ++--- tests/test_client_24_oic_utils.py | 6 +- tests/test_client_28_rp_handler_oidc.py | 40 +-- tests/test_client_30_rph_defaults.py | 41 +-- tests/test_client_40_dpop.py | 4 +- tests/test_client_41_rp_handler_persistent.py | 23 +- tests/test_client_51_identity_assurance.py | 2 +- tests/test_server_01_claims.py | 30 +- tests/test_server_03_authz_handling.py | 8 +- tests/test_server_05_token_handler.py | 4 +- tests/test_server_06_grant.py | 104 +++---- tests/test_server_08_id_token.py | 62 ++-- tests/test_server_09_authn_context.py | 4 +- tests/test_server_12_session_life.py | 30 +- tests/test_server_13_user_authn.py | 18 +- tests/test_server_15_login_hint.py | 2 +- tests/test_server_16_endpoint.py | 6 +- tests/test_server_16_endpoint_context.py | 46 +-- tests/test_server_17_client_authn.py | 56 ++-- tests/test_server_20a_server.py | 18 +- tests/test_server_20b_claims.py | 20 +- tests/test_server_20c_authz_handling.py | 6 +- tests/test_server_20d_client_authn.py | 56 ++-- tests/test_server_20e_jwt_token.py | 34 +-- ...server_22_oidc_provider_config_endpoint.py | 2 +- ...st_server_23_oidc_registration_endpoint.py | 2 +- ...server_24_oauth2_authorization_endpoint.py | 12 +- ...er_24_oauth2_authorization_endpoint_jar.py | 6 +- tests/test_server_24_oauth2_token_endpoint.py | 72 ++--- ...t_server_24_oidc_authorization_endpoint.py | 114 +++---- tests/test_server_30_oidc_end_session.py | 10 +- tests/test_server_31_oauth2_introspection.py | 26 +- .../test_server_32_oidc_read_registration.py | 2 +- tests/test_server_33_oauth2_pkce.py | 14 +- tests/test_server_34_oidc_sso.py | 28 +- tests/test_server_35_oidc_token_endpoint.py | 82 ++--- tests/test_server_36_oauth2_token_exchange.py | 40 +-- ...t_server_40_oauth2_pushed_authorization.py | 4 +- tests/test_server_50_persistence.py | 12 +- tests/test_server_60_dpop.py | 12 +- tests/test_server_61_add_on.py | 4 +- tests/test_tandem_10_oauth2_token_exchange.py | 36 +-- tests/x_test_ciba_01_backchannel_auth.py | 28 +- 125 files changed, 2462 insertions(+), 1296 deletions(-) create mode 100644 private/xmetadata/__init__.py create mode 100644 private/xmetadata/oauth2.py create mode 100644 private/xmetadata/oidc.py rename src/idpyoidc/{work_environment.py => claims.py} (93%) create mode 100644 src/idpyoidc/client/claims/__init__.py rename src/idpyoidc/client/{work_environment => claims}/oauth2.py (69%) create mode 100644 src/idpyoidc/client/claims/oidc.py rename src/idpyoidc/client/{work_environment => claims}/transform.py (98%) delete mode 100644 src/idpyoidc/client/work_environment/__init__.py delete mode 100644 src/idpyoidc/client/work_environment/oidc.py create mode 100644 src/idpyoidc/metadata.py create mode 100644 src/idpyoidc/server/claims/__init__.py rename src/idpyoidc/server/{work_environment => claims}/oauth2.py (57%) rename src/idpyoidc/server/{work_environment => claims}/oidc.py (52%) delete mode 100644 src/idpyoidc/server/work_environment/__init__.py diff --git a/doc/server/contents/conf.rst b/doc/server/contents/conf.rst index d34503a3..4138094a 100644 --- a/doc/server/contents/conf.rst +++ b/doc/server/contents/conf.rst @@ -411,6 +411,19 @@ An example:: ] } }, + "revocation": { + "path": "revoke", + "class": "idpyoidc.server.oauth2.revocation.Revocation", + "kwargs": { + "client_authn_method": [ + "client_secret_post", + "client_secret_basic", + "client_secret_jwt", + "private_key_jwt", + "bearer_header" + ] + } + }, "end_session": { "path": "session", "class": "idpyoidc.server.oidc.session.Session", @@ -875,6 +888,106 @@ For example:: return request +============== +Token revocation +============== + +In order to enable the token revocation endpoint a dictionary with key `token_revocation` should be placed +under the `endpoint` key of the configuration. + +If present, the token revocation configuration should contain a `policy` dictionary +that defines the behaviour for each token type. Each token type +is mapped to a dictionary with the keys `callable` (mandatory), which must be a +python callable or a string that represents the path to a python callable, and +`kwargs` (optional), which must be a dict of key-value arguments that will be +passed to the callable. + +The key `""` represents a fallback policy that will be used if the token +type can't be found. If a token type is defined in the `policy` but is +not in the `token_types_supported` list then it is ignored. + +"token_revocation": { + "path": "revoke", + "class": "idpyoidc.server.oauth2.token_revocation.TokenRevocation", + "kwargs": { + "token_types_supported": ["access_token"], + "client_authn_method": [ + "client_secret_post", + "client_secret_basic", + "client_secret_jwt", + "private_key_jwt", + "bearer_header" + ], + "policy": { + "urn:ietf:params:oauth:token-type:access_token": { + "callable": "/path/to/callable", + "kwargs": { + "audience": ["https://example.com"], + "scopes": ["openid"] + } + }, + "urn:ietf:params:oauth:token-type:refresh_token": { + "callable": "/path/to/callable", + "kwargs": { + "resource": ["https://example.com"], + "scopes": ["openid"] + } + }, + "": { + "callable": "/path/to/callable", + "kwargs": { + "scopes": ["openid"] + } + } + } + } +} + +For the per-client configuration a similar configuration scheme should be present in the client's +metadata under the `token_revocation` key. + +For example:: + + "token_revocation":{ + "token_types_supported": ["access_token"], + "policy": { + "urn:ietf:params:oauth:token-type:access_token": { + "callable": "/path/to/callable", + "kwargs": { + "audience": ["https://example.com"], + "scopes": ["openid"] + } + }, + "urn:ietf:params:oauth:token-type:refresh_token": { + "callable": "/path/to/callable", + "kwargs": { + "resource": ["https://example.com"], + "scopes": ["openid"] + } + }, + "": { + "callable": "/path/to/callable", + "kwargs": { + "scopes": ["openid"] + } + } + } + } + } + +The policy callable accepts a specific argument list and handles the revocation appropriately and returns +an :py:class:`idpyoidc.message.oauth2..TokenRevocationResponse` or raises an exception. + +For example:: + + def custom_token_revocation_policy(token, session_info, **kwargs): + if some_condition: + return TokenErrorResponse( + error="invalid_request", error_description="Some error occured" + ) + response_args = {"response_args": {}} + return oauth2.TokenRevocationResponse(**response_args) + ================================== idpyoidc\.server\.configure module ================================== diff --git a/example/flask_op/views.py b/example/flask_op/views.py index 5da08b0a..7846af50 100644 --- a/example/flask_op/views.py +++ b/example/flask_op/views.py @@ -119,7 +119,7 @@ def verify(authn_method): auth_args = authn_method.unpack_token(kwargs['token']) authz_request = AuthorizationRequest().from_urlencoded(auth_args['query']) - endpoint = current_app.server.upstream_get("endpoint", 'authorization') + endpoint = current_app.server.get_endpoint('authorization') _session_id = endpoint.create_session(authz_request, username, auth_args['authn_class_ref'], auth_args['iat'], authn_method) @@ -133,8 +133,7 @@ def verify(authn_method): @oidc_op_views.route('/verify/user', methods=['GET', 'POST']) def verify_user(): - authn_method = current_app.server.upstream_get( - "endpoint_context").authn_broker.get_method_by_id('user') + authn_method = current_app.server.get_context().authn_broker.get_method_by_id('user') try: return verify(authn_method) except FailedAuthentication as exc: @@ -143,8 +142,7 @@ def verify_user(): @oidc_op_views.route('/verify/user_pass_jinja', methods=['GET', 'POST']) def verify_user_pass_jinja(): - authn_method = current_app.server.upstream_get( - "endpoint_context").authn_broker.get_method_by_id('user') + authn_method = current_app.server.get_context().authn_broker.get_method_by_id('user') try: return verify(authn_method) except FailedAuthentication as exc: @@ -154,9 +152,9 @@ def verify_user_pass_jinja(): @oidc_op_views.route('/.well-known/') def well_known(service): if service == 'openid-configuration': - _endpoint = current_app.server.upstream_get("endpoint", 'provider_config') + _endpoint = current_app.server.get_endpoint('provider_config') elif service == 'webfinger': - _endpoint = current_app.server.upstream_get("endpoint", 'discovery') + _endpoint = current_app.server.get_endpoint('discovery') else: return make_response('Not supported', 400) @@ -166,45 +164,45 @@ def well_known(service): @oidc_op_views.route('/registration', methods=['GET', 'POST']) def registration(): return service_endpoint( - current_app.server.upstream_get("endpoint", 'registration')) + current_app.server.get_endpoint('registration')) @oidc_op_views.route('/registration_api', methods=['GET', 'DELETE']) def registration_api(): if request.method == "DELETE": return service_endpoint( - current_app.server.upstream_get("endpoint", 'registration_delete')) + current_app.server.get_endpoint('registration_delete')) else: return service_endpoint( - current_app.server.upstream_get("endpoint", 'registration_read')) + current_app.server.get_endpoint('registration_read')) @oidc_op_views.route('/authorization') def authorization(): return service_endpoint( - current_app.server.upstream_get("endpoint", 'authorization')) + current_app.server.get_endpoint('authorization')) @oidc_op_views.route('/token', methods=['GET', 'POST']) def token(): return service_endpoint( - current_app.server.upstream_get("endpoint", 'token')) + current_app.server.get_endpoint('token')) @oidc_op_views.route('/introspection', methods=['POST']) def introspection_endpoint(): return service_endpoint( - current_app.server.upstream_get("endpoint", 'introspection')) + current_app.server.get_endpoint('introspection')) @oidc_op_views.route('/userinfo', methods=['GET', 'POST']) def userinfo(): return service_endpoint( - current_app.server.upstream_get("endpoint", 'userinfo')) + current_app.server.get_endpoint('userinfo')) @oidc_op_views.route('/session', methods=['GET']) def session_endpoint(): return service_endpoint( - current_app.server.upstream_get("endpoint", 'session')) + current_app.server.get_endpoint('session')) IGNORE = ["cookie", "user-agent"] @@ -298,7 +296,7 @@ def check_session_iframe(): req_args = dict([(k, v) for k, v in request.form.items()]) if req_args: - _context = current_app.server.upstream_get("endpoint_context") + _context = current_app.server.get_context() # will contain client_id and origin if req_args['origin'] != _context.issuer: return 'error' @@ -314,7 +312,7 @@ def check_session_iframe(): @oidc_op_views.route('/verify_logout', methods=['GET', 'POST']) def verify_logout(): - part = urlparse(current_app.server.upstream_get("endpoint_context").issuer) + part = urlparse(current_app.server.get_context().issuer) page = render_template('logout.html', op=part.hostname, do_logout='rp_logout', sjwt=request.args['sjwt']) return page @@ -322,7 +320,7 @@ def verify_logout(): @oidc_op_views.route('/rp_logout', methods=['GET', 'POST']) def rp_logout(): - _endp = current_app.server.upstream_get("endpoint", 'session') + _endp = current_app.server.get_endpoint('session') _info = _endp.unpack_signed_jwt(request.form['sjwt']) try: request.form['logout'] diff --git a/example/flask_rp/application.py b/example/flask_rp/application.py index 0ed7c3c1..cf58d426 100644 --- a/example/flask_rp/application.py +++ b/example/flask_rp/application.py @@ -23,9 +23,14 @@ def init_oidc_rp_handler(app): _path = '' _kj.httpc_params = _rp_conf.httpc_params - rph = RPHandler(_rp_conf.base_url, _rp_conf.clients, services=_rp_conf.services, - hash_seed=_rp_conf.hash_seed, keyjar=_kj, jwks_path=_path, - httpc_params=_rp_conf.httpc_params) + rph = RPHandler(base_url=_rp_conf.base_url, + client_configs=_rp_conf.clients, + services=_rp_conf.services, + keyjar=_kj, + hash_seed=_rp_conf.hash_seed, + httpc_params=_rp_conf.httpc_params, + jwks_path=_path, + ) return rph diff --git a/private/xmetadata/__init__.py b/private/xmetadata/__init__.py new file mode 100644 index 00000000..51f5770c --- /dev/null +++ b/private/xmetadata/__init__.py @@ -0,0 +1,186 @@ +import logging +from typing import Optional + +from cryptojwt.exception import IssuerNotFound +from cryptojwt.jwk.hmac import SYMKey + +from idpyoidc import metadata +from idpyoidc.message import Message +from idpyoidc.metadata import array_or_singleton +from idpyoidc.metadata import is_subset + +logger = logging.getLogger(__name__) + + +class Metadata(metadata.Metadata): + register2preferred = {} + registration_response = Message + registration_request = Message + + def get_base_url(self, configuration: dict): + _base = configuration.get('base_url') + if not _base: + _base = configuration.get('client_id') + + return _base + + def get_id(self, configuration: dict): + return self.get_preference('client_id') + + def add_extra_keys(self, keyjar, id): + _secret = self.get_preference('client_secret') + if _secret: + _new = SYMKey(key=_secret) + try: + _id_keys = keyjar.get_issuer_keys(id) + except IssuerNotFound: + keyjar.add_symmetric(issuer_id=id, key=_secret) + else: + if _new not in _id_keys: + keyjar.add_symmetric(issuer_id=id, key=_secret) + + try: + _own_keys = keyjar.get_issuer_keys('') + except IssuerNotFound: + keyjar.add_symmetric(issuer_id='', key=_secret) + else: + if _new not in _own_keys: + keyjar.add_symmetric(issuer_id='', key=_secret) + + def get_jwks(self, keyjar): + _jwks = None + try: + _own_keys = keyjar.get_issuer_keys('') + except IssuerNotFound: + pass + else: + if len(_own_keys) == 1 and isinstance(_own_keys[0], SYMKey): + pass + else: + _jwks = keyjar.export_jwks() + + return _jwks + + def supported_to_preferred(self, + supported: dict, + base_url: str, + info: Optional[dict] = None): + if info: # The provider info + for key, val in supported.items(): + if key in self.prefer: + _pref_val = self.prefer.get(key) # defined in configuration + _info_val = info.get(key) + if _info_val: + # Only use provider setting if less or equal to what I support + if key.endswith('supported'): # list + self.prefer[key] = [x for x in _pref_val if x in _info_val] + else: + pass + elif val is None: # No default, means the RP does not have a self.prefer + # if key not in ['jwks_uri', 'jwks']: + pass + else: + # there is a default + _info_val = info.get(key) + if _info_val: # The OP has an opinion + if key.endswith('supported'): # list + self.prefer[key] = [x for x in val if x in _info_val] + else: + pass + else: + self.prefer[key] = val + + # special case -> must have a request_uris value + if 'require_request_uri_registration' in info: + # only makes sense if I want to use request_uri + if self.prefer.get('request_parameter') == 'request_uri': + if 'request_uri' not in self.prefer: + self.prefer['request_uris'] = [f'{base_url}/requests'] + else: # just ignore + logger.info('Asked for "request_uri" which it did not plan to use') + else: + # Add defaults + for key, val in supported.items(): + if val is None: + continue + if key not in self.prefer: + self.prefer[key] = val + + def preferred_to_registered(self, + supported: dict, + response: Optional[dict] = None): + """ + The claims with values that are returned from the OP is what goes unless (!!) + the values returned are not within the supported values. + + @param registration_response: + @return: + """ + registered = {} + + if response: + for key, val in response.items(): + if key in self.register2preferred: + if is_subset(val, supported.get(self.register2preferred[key])): + registered[key] = val + else: + logger.warning( + f'OP tells me to do something I do not support: {key} = {val}') + else: + registered[key] = val # Should I just accept with the OP says ?? + + for key, spec in self.registration_response.c_param.items(): + if key in registered: + continue + _pref_key = self.register2preferred.get(key, key) + + _preferred_values = self.prefer.get(_pref_key, self.prefer.get(key)) + if not _preferred_values: + continue + + registered[key] = array_or_singleton(spec, _preferred_values) + + # transfer those claims that are not part of the registration request + _rr_keys = list(self.registration_response.c_param.keys()) + for key, val in self.prefer.items(): + _reg_key = self.register2preferred.get(key, key) + if _reg_key not in _rr_keys: + # If they are not part of the registration request I do not know if it is + # supposed to be a singleton or an array. So just add it as is. + registered[_reg_key] = val + + # all those others + _filtered_registered = {k: v for k, v in registered.items() if k not in + self.register2preferred.keys() and k not in + self.register2preferred.values()} + + # Removed supported if value chosen + for key, val in self.register2preferred.items(): + if val in registered: + if key in registered: + _filtered_registered[key] = registered[key] + elif registered[val] != []: + _filtered_registered[val] = registered[val] + elif key in registered: + _filtered_registered[key] = registered[key] + + logger.debug(f"Entity registered: {_filtered_registered}") + self.use = _filtered_registered + return _filtered_registered + + def create_registration_request(self, supported): + _request = {} + for key, spec in self.registration_request.c_param.items(): + _pref_key = self.register2preferred.get(key, key) + if _pref_key in self.prefer: + value = self.prefer[_pref_key] + elif _pref_key in supported: + value = supported[_pref_key] + else: + continue + + if not value: + continue + + _request[key] = array_or_singleton(spec, value) + return _request diff --git a/private/xmetadata/oauth2.py b/private/xmetadata/oauth2.py new file mode 100644 index 00000000..b55909a7 --- /dev/null +++ b/private/xmetadata/oauth2.py @@ -0,0 +1,51 @@ +from typing import Optional + +from idpyoidc.client import metadata +from idpyoidc.message.oauth2 import OauthClientInformationResponse +from idpyoidc.message.oauth2 import OauthClientMetadata + +REGISTER2PREFERRED = { + # "require_signed_request_object": "request_object_algs_supported", + "token_endpoint_auth_method": "token_endpoint_auth_methods_supported", + "response_types": "response_types_supported", + "grant_types": "grant_types_supported", + # In OAuth2 but not in OIDC + "scope": "scopes_supported", + "token_endpoint_auth_signing_alg": "token_endpoint_auth_signing_alg_values_supported", +} + + +class Metadata(metadata.Metadata): + _supports = { + # "client_authn_methods": get_client_authn_methods, + "redirect_uris": None, + "grant_types": ["authorization_code", "implicit", "refresh_token"], + 'token_endpoint_auth_method' + "response_types": ["code"], + "client_id": None, + 'client_secret': None, + "client_name": None, + "client_uri": None, + "logo_uri": None, + "contacts": None, + "scopes_supported": [], + "tos_uri": None, + "policy_uri": None, + "jwks_uri": None, + "jwks": None, + "software_id": None, + "software_version": None + } + + callback_path = {} + + callback_uris = ["redirect_uris"] + + register2preferred = REGISTER2PREFERRED + registration_response = OauthClientInformationResponse + registration_request = OauthClientMetadata + + def __init__(self, + prefer: Optional[dict] = None, + callback_path: Optional[dict] = None): + metadata.Metadata.__init__(self, prefer=prefer, callback_path=callback_path) diff --git a/private/xmetadata/oidc.py b/private/xmetadata/oidc.py new file mode 100644 index 00000000..6126e045 --- /dev/null +++ b/private/xmetadata/oidc.py @@ -0,0 +1,126 @@ +import logging +import os +from typing import Optional + +from idpyoidc import metadata +from idpyoidc.client import metadata as client_metadata +from idpyoidc.message.oidc import RegistrationRequest +from idpyoidc.message.oidc import RegistrationResponse + +logger = logging.getLogger(__name__) + +REGISTER2PREFERRED = { + # "require_signed_request_object": "request_object_algs_supported", + "request_object_signing_alg": "request_object_signing_alg_values_supported", + "request_object_encryption_alg": "request_object_encryption_alg_values_supported", + "request_object_encryption_enc": "request_object_encryption_enc_values_supported", + "userinfo_signed_response_alg": "userinfo_signing_alg_values_supported", + "userinfo_encrypted_response_alg": "userinfo_encryption_alg_values_supported", + "userinfo_encrypted_response_enc": "userinfo_encryption_enc_values_supported", + "id_token_signed_response_alg": "id_token_signing_alg_values_supported", + "id_token_encrypted_response_alg": "id_token_encryption_alg_values_supported", + "id_token_encrypted_response_enc": "id_token_encryption_enc_values_supported", + "default_acr_values": "acr_values_supported", + "subject_type": "subject_types_supported", + "token_endpoint_auth_method": "token_endpoint_auth_methods_supported", + "response_types": "response_types_supported", + "grant_types": "grant_types_supported", + # In OAuth2 but not in OIDC + "scope": "scopes_supported", + "token_endpoint_auth_signing_alg": "token_endpoint_auth_signing_alg_values_supported", + # "display": "display_values_supported", + # "claims": "claims_supported", + # "request": "request_parameter_supported", + # "request_uri": "request_uri_parameter_supported", + # 'claims_locales': 'claims_locales_supported', + # 'ui_locales': 'ui_locales_supported', +} + +PREFERRED2REGISTER = dict([(v, k) for k, v in REGISTER2PREFERRED.items()]) + +REQUEST2REGISTER = { + 'client_id': "client_id", + "client_secret": "client_secret", + # 'acr_values': "default_acr_values" , + # 'max_age': "default_max_age", + 'redirect_uri': "redirect_uris", + 'response_type': "response_types", + 'request_uri': "request_uris", + 'grant_type': "grant_types", + "scope": 'scopes_supported', + 'post_logout_redirect_uri': "post_logout_redirect_uris" +} + + +class Metadata(client_metadata.Metadata): + parameter = client_metadata.Metadata.parameter.copy() + parameter.update({ + "requests_dir": None + }) + + register2preferred = REGISTER2PREFERRED + registration_response = RegistrationResponse + registration_request = RegistrationRequest + + _supports = { + "acr_values_supported": None, + "application_type": "web", + "callback_uris": None, + # "client_authn_methods": get_client_authn_methods, + "client_id": None, + "client_name": None, + "client_secret": None, + "client_uri": None, + "contacts": None, + "default_max_age": 86400, + "encrypt_id_token_supported": None, + "grant_types_supported": ["authorization_code", "implicit", "refresh_token"], + "logo_uri": None, + "id_token_signing_alg_values_supported": metadata.get_signing_algs, + "id_token_encryption_alg_values_supported": metadata.get_encryption_algs, + "id_token_encryption_enc_values_supported": metadata.get_encryption_encs, + "initiate_login_uri": None, + "jwks": None, + "jwks_uri": None, + "policy_uri": None, + "requests_dir": None, + "require_auth_time": None, + "sector_identifier_uri": None, + "scopes_supported": ["openid"], + "subject_types_supported": ["public", "pairwise", "ephemeral"], + "tos_uri": None, + } + + def __init__(self, + prefer: Optional[dict] = None, + callback_path: Optional[dict] = None + ): + client_metadata.Metadata.__init__(self, prefer=prefer, callback_path=callback_path) + + def verify_rules(self): + if self.get_preference("request_parameter_supported") and self.get_preference( + "request_uri_parameter_supported"): + raise ValueError( + "You have to chose one of 'request_parameter_supported' and " + "'request_uri_parameter_supported'. You can't have both.") + + if not self.get_preference('encrypt_userinfo_supported'): + self.set_preference('userinfo_encryption_alg_values_supported', []) + self.set_preference('userinfo_encryption_enc_values_supported', []) + + if not self.get_preference('encrypt_request_object_supported'): + self.set_preference('request_object_encryption_alg_values_supported', []) + self.set_preference('request_object_encryption_enc_values_supported', []) + + if not self.get_preference('encrypt_id_token_supported'): + self.set_preference('id_token_encryption_alg_values_supported', []) + self.set_preference('id_token_encryption_enc_values_supported', []) + + def locals(self, info): + requests_dir = info.get("requests_dir") + if requests_dir: + # make sure the path exists. If not, then create it. + if not os.path.isdir(requests_dir): + os.makedirs(requests_dir) + + self.set("requests_dir", requests_dir) diff --git a/script/client_config.py b/script/client_config.py index a5040b16..583b38d8 100644 --- a/script/client_config.py +++ b/script/client_config.py @@ -32,7 +32,7 @@ _services = rp.client_get("services") for srv, item in _services.db.items(): _data = {"class": qualified_name(item.__class__)} - for attr in ["metadata", "usage", "default_request_args", "callback_uri"]: + for attr in ["claims", "usage", "default_request_args", "callback_uri"]: _val = getattr(item, attr) if _val: _data[attr] = _val diff --git a/script/rp_handler_config.py b/script/rp_handler_config.py index 745e2e00..ebd828ab 100644 --- a/script/rp_handler_config.py +++ b/script/rp_handler_config.py @@ -32,7 +32,7 @@ def display(rp, id): _services = rp.client_get("services") for srv, item in _services.db.items(): _data = {"class": qualified_name(item.__class__)} - for attr in ["metadata", "usage", "default_request_args", "callback_uri"]: + for attr in ["claims", "usage", "default_request_args", "callback_uri"]: _val = getattr(item, attr) if _val: _data[attr] = _val diff --git a/src/idpyoidc/actor/client/oidc/__init__.py b/src/idpyoidc/actor/client/oidc/__init__.py index ce1ab9d8..904f6a64 100644 --- a/src/idpyoidc/actor/client/oidc/__init__.py +++ b/src/idpyoidc/actor/client/oidc/__init__.py @@ -45,7 +45,7 @@ def get_client_id_from_token(self, token): return _context["client_id"] def do_client_notification(self, msg, http_info): - _notification_endpoint = self.server.server_get("endpoint", "client_notification") + _notification_endpoint = self.server.upstream_get("endpoint", "client_notification") _nreq = _notification_endpoint.parse_request( msg, http_info, get_client_id_from_token=self.get_client_id_from_token ) @@ -54,6 +54,6 @@ def do_client_notification(self, msg, http_info): def construct_metadata(self): _reg_serv = self.client.client_get("service", "registration") _reg_serv.construct_request() - # _reg_endp = self.server.server_get("endpoint", "discovery") + # _reg_endp = self.server.upstream_get("endpoint", "discovery") return {} diff --git a/src/idpyoidc/actor/client/oidc/registration.py b/src/idpyoidc/actor/client/oidc/registration.py index 2c98c411..f65500c2 100644 --- a/src/idpyoidc/actor/client/oidc/registration.py +++ b/src/idpyoidc/actor/client/oidc/registration.py @@ -162,10 +162,10 @@ def add_client_preference(self, request_args=None, **kwargs): continue try: - request_args[prop] = _context.work_environment.get_usage(prop) + request_args[prop] = _context.metadata.get_usage(prop) except KeyError: try: - request_args[prop] = _context.work_environment.get_preference[prop] + request_args[prop] = _context.metadata.get_preference[prop] except KeyError: pass return request_args, {} diff --git a/src/idpyoidc/work_environment.py b/src/idpyoidc/claims.py similarity index 93% rename from src/idpyoidc/work_environment.py rename to src/idpyoidc/claims.py index 4e5eabad..22198795 100644 --- a/src/idpyoidc/work_environment.py +++ b/src/idpyoidc/claims.py @@ -3,32 +3,29 @@ from typing import Optional from cryptojwt import KeyJar -from cryptojwt.exception import IssuerNotFound from cryptojwt.jwe import SUPPORTED -from cryptojwt.jwk.hmac import SYMKey from cryptojwt.jws.jws import SIGNER_ALGS from cryptojwt.key_jar import init_key_jar from cryptojwt.utils import importer -from idpyoidc.client.client_auth import CLIENT_AUTHN_METHOD from idpyoidc.client.util import get_uri from idpyoidc.impexp import ImpExp from idpyoidc.util import add_path from idpyoidc.util import qualified_name -def work_environment_dump(info, exclude_attributes): +def claims_dump(info, exclude_attributes): return {qualified_name(info.__class__): info.dump(exclude_attributes=exclude_attributes)} -def work_environment_load(item: dict, **kwargs): +def claims_load(item: dict, **kwargs): _class_name = list(item.keys())[0] # there is only one _cls = importer(_class_name) _cls = _cls().load(item[_class_name]) return _cls -class WorkEnvironment(ImpExp): +class Claims(ImpExp): parameter = { "prefer": None, "use": None, @@ -129,10 +126,14 @@ def _keyjar(self, keyjar=None, conf=None, entity_id=""): _httpc_params = conf.get("httpc_params") if _httpc_params: _keyjar.httpc_params = _httpc_params - return _keyjar, _uri_path else: - return keyjar, _uri_path + if "keys" in conf: + _uri_path = conf['keys'].get('uri_path') + elif "key_conf" in conf and conf["key_conf"]: + _uri_path = conf['key_conf'].get('uri_path') + + return keyjar, _uri_path def get_base_url(self, configuration: dict): raise NotImplementedError() @@ -151,7 +152,9 @@ def handle_keys(self, configuration: dict, keyjar: Optional[KeyJar] = None): _id = self.get_id(configuration) keyjar, uri_path = self._keyjar(keyjar, configuration, entity_id=_id) - self.add_extra_keys(keyjar, _id) + _kj = self.add_extra_keys(keyjar, _id) + if keyjar is None and _kj: + keyjar = _kj # now that keys are in the Key Jar, now for how to publish it if 'jwks_uri' in configuration: # simple @@ -246,5 +249,3 @@ def get_encryption_algs(): def get_encryption_encs(): return SUPPORTED['enc'] - - diff --git a/src/idpyoidc/client/claims/__init__.py b/src/idpyoidc/client/claims/__init__.py new file mode 100644 index 00000000..66365344 --- /dev/null +++ b/src/idpyoidc/client/claims/__init__.py @@ -0,0 +1,61 @@ +from cryptojwt import KeyJar +from cryptojwt.exception import IssuerNotFound +from cryptojwt.jwk.hmac import SYMKey + +from idpyoidc import claims +from idpyoidc.client.client_auth import CLIENT_AUTHN_METHOD + + +def get_client_authn_methods(): + return list(CLIENT_AUTHN_METHOD.keys()) + + +class Claims(claims.Claims): + + def get_base_url(self, configuration: dict): + _base = configuration.get('base_url') + if not _base: + _base = configuration.get('client_id') + + return _base + + def get_id(self, configuration: dict): + return self.get_preference('client_id') + + def _add_key_if_missing(self, keyjar, id, key): + try: + old_keys = keyjar.get_issuer_keys(id) + except IssuerNotFound: + old_keys = [] + + _new_key = SYMKey(key=key) + if _new_key not in old_keys: + keyjar.add_symmetric(issuer_id=id, key=key) + + def add_extra_keys(self, keyjar, id): + _secret = self.get_preference('client_secret') + if _secret: + if keyjar is None: + keyjar = KeyJar() + self._add_key_if_missing(keyjar, id, _secret) + self._add_key_if_missing(keyjar, '', _secret) + + def get_jwks(self, keyjar): + if keyjar is None: + return None + + _jwks = None + try: + _own_keys = keyjar.get_issuer_keys('') + except IssuerNotFound: + pass + else: + # if only one key under the id == "", that key being a SYMKey I assume it's + # and I have a client_secret then don't publish a JWKS + if len(_own_keys) == 1 and isinstance(_own_keys[0], SYMKey) and self.prefer[ + 'client_secret']: + pass + else: + _jwks = keyjar.export_jwks() + + return _jwks diff --git a/src/idpyoidc/client/work_environment/oauth2.py b/src/idpyoidc/client/claims/oauth2.py similarity index 69% rename from src/idpyoidc/client/work_environment/oauth2.py rename to src/idpyoidc/client/claims/oauth2.py index 8212151b..59536885 100644 --- a/src/idpyoidc/client/work_environment/oauth2.py +++ b/src/idpyoidc/client/claims/oauth2.py @@ -1,12 +1,11 @@ from typing import Optional -from idpyoidc.client import work_environment -# from idpyoidc.client.client_auth import get_client_authn_methods +from idpyoidc.client import claims +from idpyoidc.client.claims.transform import create_registration_request -class WorkEnvironment(work_environment.WorkEnvironment): +class Claims(claims.Claims): _supports = { - # "client_authn_methods": get_client_authn_methods, "redirect_uris": None, "grant_types": ["authorization_code", "implicit", "refresh_token"], "response_types": ["code"], @@ -32,4 +31,7 @@ class WorkEnvironment(work_environment.WorkEnvironment): def __init__(self, prefer: Optional[dict] = None, callback_path: Optional[dict] = None): - work_environment.WorkEnvironment.__init__(self, prefer=prefer, callback_path=callback_path) + claims.Claims.__init__(self, prefer=prefer, callback_path=callback_path) + + def create_registration_request(self): + return create_registration_request(self.prefer, self.supports()) diff --git a/src/idpyoidc/client/claims/oidc.py b/src/idpyoidc/client/claims/oidc.py new file mode 100644 index 00000000..d68fbeec --- /dev/null +++ b/src/idpyoidc/client/claims/oidc.py @@ -0,0 +1,132 @@ +import logging +import os +from typing import Optional + +from idpyoidc import claims +from idpyoidc.client import claims as client_claims +from idpyoidc.client.claims.transform import create_registration_request +from idpyoidc.message.oidc import RegistrationRequest +from idpyoidc.message.oidc import RegistrationResponse + +logger = logging.getLogger(__name__) + +REGISTER2PREFERRED = { + # "require_signed_request_object": "request_object_algs_supported", + "request_object_signing_alg": "request_object_signing_alg_values_supported", + "request_object_encryption_alg": "request_object_encryption_alg_values_supported", + "request_object_encryption_enc": "request_object_encryption_enc_values_supported", + "userinfo_signed_response_alg": "userinfo_signing_alg_values_supported", + "userinfo_encrypted_response_alg": "userinfo_encryption_alg_values_supported", + "userinfo_encrypted_response_enc": "userinfo_encryption_enc_values_supported", + "id_token_signed_response_alg": "id_token_signing_alg_values_supported", + "id_token_encrypted_response_alg": "id_token_encryption_alg_values_supported", + "id_token_encrypted_response_enc": "id_token_encryption_enc_values_supported", + "default_acr_values": "acr_values_supported", + "subject_type": "subject_types_supported", + "token_endpoint_auth_method": "token_endpoint_auth_methods_supported", + "response_types": "response_types_supported", + "grant_types": "grant_types_supported", + # In OAuth2 but not in OIDC + "scope": "scopes_supported", + "token_endpoint_auth_signing_alg": "token_endpoint_auth_signing_alg_values_supported", + # "display": "display_values_supported", + # "claims": "claims_supported", + # "request": "request_parameter_supported", + # "request_uri": "request_uri_parameter_supported", + # 'claims_locales': 'claims_locales_supported', + # 'ui_locales': 'ui_locales_supported', +} + +PREFERRED2REGISTER = dict([(v, k) for k, v in REGISTER2PREFERRED.items()]) + +REQUEST2REGISTER = { + 'client_id': "client_id", + "client_secret": "client_secret", + # 'acr_values': "default_acr_values" , + # 'max_age': "default_max_age", + 'redirect_uri': "redirect_uris", + 'response_type': "response_types", + 'request_uri': "request_uris", + 'grant_type': "grant_types", + "scope": 'scopes_supported', + 'post_logout_redirect_uri': "post_logout_redirect_uris" +} + + +class Claims(client_claims.Claims): + parameter = client_claims.Claims.parameter.copy() + parameter.update({ + "requests_dir": None + }) + + register2preferred = REGISTER2PREFERRED + registration_response = RegistrationResponse + registration_request = RegistrationRequest + + _supports = { + "acr_values_supported": None, + "application_type": "web", + "callback_uris": None, + # "client_authn_methods": get_client_authn_methods, + "client_id": None, + "client_name": None, + "client_secret": None, + "client_uri": None, + "contacts": None, + "default_max_age": 86400, + "encrypt_id_token_supported": None, + "grant_types_supported": ["authorization_code", "implicit", "refresh_token"], + "logo_uri": None, + "id_token_signing_alg_values_supported": claims.get_signing_algs, + "id_token_encryption_alg_values_supported": claims.get_encryption_algs, + "id_token_encryption_enc_values_supported": claims.get_encryption_encs, + "initiate_login_uri": None, + "jwks": None, + "jwks_uri": None, + "policy_uri": None, + "requests_dir": None, + "require_auth_time": None, + "sector_identifier_uri": None, + "scopes_supported": ["openid"], + "subject_types_supported": ["public", "pairwise", "ephemeral"], + "tos_uri": None, + } + + def __init__(self, + prefer: Optional[dict] = None, + callback_path: Optional[dict] = None + ): + client_claims.Claims.__init__(self, + prefer=prefer, + callback_path=callback_path) + + def verify_rules(self): + if self.get_preference("request_parameter_supported") and self.get_preference( + "request_uri_parameter_supported"): + raise ValueError( + "You have to chose one of 'request_parameter_supported' and " + "'request_uri_parameter_supported'. You can't have both.") + + if not self.get_preference('encrypt_userinfo_supported'): + self.set_preference('userinfo_encryption_alg_values_supported', []) + self.set_preference('userinfo_encryption_enc_values_supported', []) + + if not self.get_preference('encrypt_request_object_supported'): + self.set_preference('request_object_encryption_alg_values_supported', []) + self.set_preference('request_object_encryption_enc_values_supported', []) + + if not self.get_preference('encrypt_id_token_supported'): + self.set_preference('id_token_encryption_alg_values_supported', []) + self.set_preference('id_token_encryption_enc_values_supported', []) + + def locals(self, info): + requests_dir = info.get("requests_dir") + if requests_dir: + # make sure the path exists. If not, then create it. + if not os.path.isdir(requests_dir): + os.makedirs(requests_dir) + + self.set("requests_dir", requests_dir) + + def create_registration_request(self): + return create_registration_request(self.prefer, self.supports()) diff --git a/src/idpyoidc/client/work_environment/transform.py b/src/idpyoidc/client/claims/transform.py similarity index 98% rename from src/idpyoidc/client/work_environment/transform.py rename to src/idpyoidc/client/claims/transform.py index c2fcfe69..ec67b790 100644 --- a/src/idpyoidc/client/work_environment/transform.py +++ b/src/idpyoidc/client/claims/transform.py @@ -1,6 +1,7 @@ import logging from typing import Optional +from idpyoidc.message import Message from idpyoidc.message.oidc import RegistrationRequest from idpyoidc.message.oidc import RegistrationResponse @@ -167,7 +168,7 @@ def preferred_to_registered(prefers: dict, supported: dict, return registered -def create_registration_request(prefers, supported): +def create_registration_request(prefers: dict, supported: dict) -> dict: _request = {} for key, spec in RegistrationRequest.c_param.items(): _pref_key = REGISTER2PREFERRED.get(key, key) diff --git a/src/idpyoidc/client/client_auth.py b/src/idpyoidc/client/client_auth.py index 3d38dde4..6bcff13d 100755 --- a/src/idpyoidc/client/client_auth.py +++ b/src/idpyoidc/client/client_auth.py @@ -2,6 +2,7 @@ import base64 import logging from typing import Optional +from typing import Union from cryptojwt.exception import MissingKey from cryptojwt.exception import UnsupportedAlgorithm @@ -505,8 +506,7 @@ def _construct_client_assertion(self, service, **kwargs): # construct the signed JWT with the assertions and add # it as value to the 'client_assertion' claim of the request - return assertion_jwt(_context.get_usage('client_id'), signing_key, audience, algorithm, - **_args) + return assertion_jwt(_entity.client_id, signing_key, audience, algorithm, **_args) def modify_request(self, request, service, **kwargs): """ @@ -652,13 +652,17 @@ def single_authn_setup(name, spec): return cls() -def client_auth_setup(auth_set: Optional[dict] = None): +def client_auth_setup(auth_set: Optional[Union[list, dict]] = None): if auth_set is None: auth_set = CLIENT_AUTHN_METHOD res = {} - for name, spec in auth_set.items(): - res[name] = single_authn_setup(name, spec) + if isinstance(auth_set, list): # From the known set + for name in auth_set: + res[name] = single_authn_setup(name, None) + else: + for name, spec in auth_set.items(): + res[name] = single_authn_setup(name, spec) return res diff --git a/src/idpyoidc/client/configure.py b/src/idpyoidc/client/configure.py index 966e7f6c..3a7fa911 100755 --- a/src/idpyoidc/client/configure.py +++ b/src/idpyoidc/client/configure.py @@ -63,7 +63,7 @@ def __init__( self.default = lower_or_upper(conf, "default", {}) - for param in ["services", "metadata", "add_ons", "usage"]: + for param in ["services", "claims", "add_ons", "usage"]: _val = lower_or_upper(conf, param, {}) if _val and param not in self.default: self.default[param] = _val @@ -71,7 +71,7 @@ def __init__( self.clients = lower_or_upper(conf, "clients") if self.clients: for id, client in self.clients.items(): - for param in ["services", "usage", "add_ons", 'metadata']: + for param in ["services", "usage", "add_ons", 'claims']: if param not in client: if param in self.default: client[param] = self.default[param] @@ -112,7 +112,7 @@ def __init__( for attr, val in self.conf.items(): if attr in ["issuer", "key_conf"]: setattr(self, attr, val) - _del_key.append(attr) + # _del_key.append(attr) for _key in _del_key: del self.conf[_key] diff --git a/src/idpyoidc/client/current.py b/src/idpyoidc/client/current.py index d81455e5..196ec19b 100644 --- a/src/idpyoidc/client/current.py +++ b/src/idpyoidc/client/current.py @@ -48,6 +48,9 @@ def set(self, key: str, info: Union[Message, dict]): else: self._db[key] = info + def get_claim(self, key: str, claim: str) -> Union[str, None]: + return self.get(key).get(claim) + def get_set(self, key: str, message: Optional[type(Message)] = None, diff --git a/src/idpyoidc/client/entity.py b/src/idpyoidc/client/entity.py index f80c5b3e..2b876581 100644 --- a/src/idpyoidc/client/entity.py +++ b/src/idpyoidc/client/entity.py @@ -128,6 +128,7 @@ def __init__( upstream_get=self.unit_get, client_type=client_type) self.setup_client_authn_methods(config) + self.upstream_get = upstream_get def get_services(self, *arg): @@ -156,11 +157,11 @@ def get_entity(self): return self def get_client_id(self): - _val = self.context.work_environment.get_usage('client_id') + _val = self.context.claims.get_usage('client_id') if _val: return _val else: - return self.context.work_environment.get_preference('client_id') + return self.context.claims.get_preference('client_id') def setup_client_authn_methods(self, config): if config and "client_authn_methods" in config: @@ -197,3 +198,6 @@ def import_keys(self, keyspec): _bundle = KeyBundle(source=url) _keyjar.add_kb(iss, _bundle) return _keyjar + + def get_callback_uris(self): + return self.context.claims.callback_uri \ No newline at end of file diff --git a/src/idpyoidc/client/oauth2/__init__.py b/src/idpyoidc/client/oauth2/__init__.py index d3e95f5d..4800a7b9 100755 --- a/src/idpyoidc/client/oauth2/__init__.py +++ b/src/idpyoidc/client/oauth2/__init__.py @@ -152,7 +152,7 @@ def get_response( :return: """ try: - resp = self.httpc(method, url, data=body, headers=headers) + resp = self.httpc(method, url, data=body, headers=headers, **self.httpc_params) except Exception as err: logger.error("Exception on request: {}".format(err)) raise diff --git a/src/idpyoidc/client/oauth2/access_token.py b/src/idpyoidc/client/oauth2/access_token.py index fe259e86..f2fe2a56 100644 --- a/src/idpyoidc/client/oauth2/access_token.py +++ b/src/idpyoidc/client/oauth2/access_token.py @@ -8,7 +8,7 @@ from idpyoidc.message import oauth2 from idpyoidc.message.oauth2 import ResponseMessage from idpyoidc.time_util import time_sans_frac -from idpyoidc.work_environment import get_signing_algs +from idpyoidc.claims import get_signing_algs LOGGER = logging.getLogger(__name__) diff --git a/src/idpyoidc/client/oauth2/add_on/dpop.py b/src/idpyoidc/client/oauth2/add_on/dpop.py index a83fdaa9..3122a55a 100644 --- a/src/idpyoidc/client/oauth2/add_on/dpop.py +++ b/src/idpyoidc/client/oauth2/add_on/dpop.py @@ -144,7 +144,7 @@ def dpop_header( return headers -def add_support(services, signing_algorithms): +def add_support(services, dpop_signing_alg_values_supported): """ Add the necessary pieces to make pushed authorization happen. @@ -157,7 +157,7 @@ def add_support(services, signing_algorithms): _context = _service.upstream_get("context") _context.add_on["dpop"] = { # "key": key_by_alg(signing_algorithm), - "sign_algs": signing_algorithms + "sign_algs": dpop_signing_alg_values_supported } _service.construct_extra_headers.append(dpop_header) diff --git a/src/idpyoidc/client/oauth2/authorization.py b/src/idpyoidc/client/oauth2/authorization.py index e3f1b0ac..ff308d3e 100644 --- a/src/idpyoidc/client/oauth2/authorization.py +++ b/src/idpyoidc/client/oauth2/authorization.py @@ -3,7 +3,7 @@ from typing import List from typing import Optional -from idpyoidc import work_environment +from idpyoidc import claims from idpyoidc.client.oauth2.utils import get_state_parameter from idpyoidc.client.oauth2.utils import pre_construct_pick_redirect_uri from idpyoidc.client.oauth2.utils import set_state_parameter @@ -32,9 +32,9 @@ class Authorization(Service): _supports = { "response_types_supported": ["code", 'token'], "response_modes_supported": ['query', 'fragment'], - "request_object_signing_alg_values_supported": work_environment.get_signing_algs, - "request_object_encryption_alg_values_supported": work_environment.get_encryption_algs, - "request_object_encryption_enc_values_supported": work_environment.get_encryption_encs, + "request_object_signing_alg_values_supported": claims.get_signing_algs, + "request_object_encryption_alg_values_supported": claims.get_encryption_algs, + "request_object_encryption_enc_values_supported": claims.get_encryption_encs, } _callback_path = { diff --git a/src/idpyoidc/client/oauth2/server_metadata.py b/src/idpyoidc/client/oauth2/server_metadata.py index ed91531e..185da6e2 100644 --- a/src/idpyoidc/client/oauth2/server_metadata.py +++ b/src/idpyoidc/client/oauth2/server_metadata.py @@ -7,6 +7,7 @@ from idpyoidc.client.defaults import OIDCONF_PATTERN from idpyoidc.client.exception import OidcServiceError from idpyoidc.client.service import Service +from idpyoidc.message import Message from idpyoidc.message import oauth2 from idpyoidc.message.oauth2 import ResponseMessage @@ -14,7 +15,7 @@ class ServerMetadata(Service): - """The service that talks to the OAuth2 server metadata endpoint.""" + """The service that talks to the OAuth2 server claims endpoint.""" msg_type = oauth2.Message response_cls = oauth2.ASConfigurationResponse @@ -114,7 +115,7 @@ def _update_service_context(self, resp): self._set_endpoints(resp) # If I already have a Key Jar then I'll add then provider keys to - # that. Otherwise a new Key Jar is minted + # that. Otherwise, a new Key Jar is minted try: _keyjar = self.upstream_get('attribute', 'keyjar') except KeyError: @@ -127,5 +128,12 @@ def _update_service_context(self, resp): elif "jwks" in resp: _keyjar.load_keys(_pcr_issuer, jwks=resp["jwks"]) + # Combine what I prefer/supports with what the Provider supports + if isinstance(resp, Message): + _info = resp.to_dict() + else: + _info = resp + _context.map_supported_to_preferred(_info) + def update_service_context(self, resp, key: Optional[str] = "", **kwargs): return self._update_service_context(resp) diff --git a/src/idpyoidc/client/oidc/access_token.py b/src/idpyoidc/client/oidc/access_token.py index b4081fbd..4fc8fb7d 100644 --- a/src/idpyoidc/client/oidc/access_token.py +++ b/src/idpyoidc/client/oidc/access_token.py @@ -6,7 +6,7 @@ from idpyoidc.client.exception import ParameterError from idpyoidc.client.oauth2 import access_token from idpyoidc.client.oidc import IDT2REG -from idpyoidc.work_environment import get_signing_algs +from idpyoidc.claims import get_signing_algs from idpyoidc.message import Message from idpyoidc.message import oidc from idpyoidc.message.oidc import verified_claim_name @@ -64,7 +64,7 @@ def gather_verify_arguments( except KeyError: pass - _verify_args = _context.work_environment.get_usage("verify_args") + _verify_args = _context.claims.get_usage("verify_args") if _verify_args: if _verify_args: kwargs.update(_verify_args) diff --git a/src/idpyoidc/client/oidc/authorization.py b/src/idpyoidc/client/oidc/authorization.py index 6faacecd..204a8cd9 100644 --- a/src/idpyoidc/client/oidc/authorization.py +++ b/src/idpyoidc/client/oidc/authorization.py @@ -3,7 +3,7 @@ from typing import Optional from typing import Union -from idpyoidc import work_environment +from idpyoidc import claims from idpyoidc.client.oauth2 import authorization from idpyoidc.client.oauth2.utils import pre_construct_pick_redirect_uri from idpyoidc.client.oidc import IDT2REG @@ -32,22 +32,22 @@ class Authorization(authorization.Authorization): error_msg = oidc.ResponseMessage _supports = { - "request_object_signing_alg_values_supported": work_environment.get_signing_algs, - "request_object_encryption_alg_values_supported": work_environment.get_encryption_algs, - "request_object_encryption_enc_values_supported": work_environment.get_encryption_encs, + "request_object_signing_alg_values_supported": claims.get_signing_algs, + "request_object_encryption_alg_values_supported": claims.get_encryption_algs, + "request_object_encryption_enc_values_supported": claims.get_encryption_encs, "response_types_supported": ["code", "token", "code token", 'id_token', 'id_token token', 'code id_token', 'code idtoken token'], 'request_parameter_supported': None, 'request_uri_parameter_supported': None, "request_uris": None, "request_parameter": None, - "encrypt_request_object_supported": None, + "encrypt_request_object_supported": False, "redirect_uris": None, "response_modes_supported": ['query', 'fragment', 'form_post'] } _callback_path = { - "request_uris": "req", + "request_uris": ["req"], "redirect_uris": { # based on response_types "code": "authz_cb", "token": "authz_tok_cb", @@ -183,7 +183,7 @@ def get_request_object_signing_alg(self, **kwargs): if not alg: _context = self.upstream_get("context") try: - alg = _context.work_environment.get_usage("request_object_signing_alg") + alg = _context.claims.get_usage("request_object_signing_alg") except KeyError: # Use default alg = "RS256" return alg diff --git a/src/idpyoidc/client/oidc/end_session.py b/src/idpyoidc/client/oidc/end_session.py index 5820f89b..315e672d 100644 --- a/src/idpyoidc/client/oidc/end_session.py +++ b/src/idpyoidc/client/oidc/end_session.py @@ -33,7 +33,7 @@ class EndSession(Service): _callback_path = { "frontchannel_logout_uri": "fc_logout", "backchannel_logout_uri": "bc_logout", - "post_logout_redirect_uris": "session_logout" + "post_logout_redirect_uris": ["session_logout"] } def __init__(self, upstream_get, conf=None): @@ -53,22 +53,20 @@ def get_id_token_hint(self, request_args=None, **kwargs): :return: """ - _args = self.upstream_get("context").cstate.get_set(kwargs["state"], - claim=['id_token']) - try: - request_args["id_token_hint"] = _args["id_token"] - except KeyError: - pass + _id_token = self.upstream_get("context").cstate.get_claim(kwargs["state"], claim='id_token') + if _id_token: + request_args["id_token_hint"] = _id_token return request_args, {} def add_post_logout_redirect_uri(self, request_args=None, **kwargs): if "post_logout_redirect_uri" not in request_args: _uri = self.upstream_get("context").get_usage("post_logout_redirect_uris") - if isinstance(_uri, str): - request_args["post_logout_redirect_uri"] = _uri - else: # assume list - request_args["post_logout_redirect_uri"] = _uri[0] + if _uri: + if isinstance(_uri, str): + request_args["post_logout_redirect_uri"] = _uri + else: # assume list + request_args["post_logout_redirect_uri"] = _uri[0] return request_args, {} diff --git a/src/idpyoidc/client/oidc/provider_info_discovery.py b/src/idpyoidc/client/oidc/provider_info_discovery.py index bc9c1b7f..12a87b0b 100644 --- a/src/idpyoidc/client/oidc/provider_info_discovery.py +++ b/src/idpyoidc/client/oidc/provider_info_discovery.py @@ -1,9 +1,7 @@ import logging from typing import Optional -from idpyoidc.client.exception import ConfigurationError from idpyoidc.client.oauth2 import server_metadata -from idpyoidc.client.work_environment.transform import supported_to_preferred from idpyoidc.message import oidc from idpyoidc.message.oauth2 import ResponseMessage @@ -26,7 +24,7 @@ def add_redirect_uris(request_args, service=None, **kwargs): :param kwargs: Possible extra keyword arguments :return: A possibly augmented set of request arguments. """ - _work_environment = service.upstream_get("context").work_environment + _work_environment = service.upstream_get("context").claims if "redirect_uris" not in request_args: # Callbacks is a dictionary with callback type 'code', 'implicit', # 'form_post' as keys. diff --git a/src/idpyoidc/client/oidc/refresh_access_token.py b/src/idpyoidc/client/oidc/refresh_access_token.py index 8ee78d98..88d072b7 100644 --- a/src/idpyoidc/client/oidc/refresh_access_token.py +++ b/src/idpyoidc/client/oidc/refresh_access_token.py @@ -8,7 +8,7 @@ class RefreshAccessToken(refresh_access_token.RefreshAccessToken): error_msg = oidc.ResponseMessage def get_authn_method(self): - _work_environment = self.upstream_get("context").work_environment + _work_environment = self.upstream_get("context").claims try: return _work_environment.get_usage("token_endpoint_auth_method") except KeyError: diff --git a/src/idpyoidc/client/oidc/registration.py b/src/idpyoidc/client/oidc/registration.py index 13e32644..4f202a62 100644 --- a/src/idpyoidc/client/oidc/registration.py +++ b/src/idpyoidc/client/oidc/registration.py @@ -4,7 +4,6 @@ from idpyoidc.client.entity import response_types_to_grant_types from idpyoidc.client.service import Service -from idpyoidc.client.work_environment.transform import create_registration_request from idpyoidc.message import oidc from idpyoidc.message.oauth2 import ResponseMessage @@ -104,7 +103,7 @@ def gather_request_args(self, **kwargs): @return: """ _context = self.upstream_get("context") - req_args = create_registration_request(_context.work_environment.prefer, _context.supports()) + req_args = _context.claims.create_registration_request() if "request_args" in self.conf: req_args.update(self.conf["request_args"]) diff --git a/src/idpyoidc/client/oidc/userinfo.py b/src/idpyoidc/client/oidc/userinfo.py index cb99cde9..0a4cf22b 100644 --- a/src/idpyoidc/client/oidc/userinfo.py +++ b/src/idpyoidc/client/oidc/userinfo.py @@ -5,9 +5,9 @@ from idpyoidc import verified_claim_name from idpyoidc.client.oauth2.utils import get_state_parameter from idpyoidc.client.service import Service -from idpyoidc.work_environment import get_encryption_algs -from idpyoidc.work_environment import get_encryption_encs -from idpyoidc.work_environment import get_signing_algs +from idpyoidc.claims import get_encryption_algs +from idpyoidc.claims import get_encryption_encs +from idpyoidc.claims import get_signing_algs from idpyoidc.exception import MissingSigningKey from idpyoidc.message import Message from idpyoidc.message import oidc @@ -44,7 +44,7 @@ class UserInfo(Service): "userinfo_signing_alg_values_supported": get_signing_algs, "userinfo_encryption_alg_values_supported": get_encryption_algs, "userinfo_encryption_enc_values_supported": get_encryption_encs, - "encrypt_userinfo_supported": None + "encrypt_userinfo_supported": False } def __init__(self, upstream_get, conf=None): diff --git a/src/idpyoidc/client/provider/github.py b/src/idpyoidc/client/provider/github.py index ef1e9fee..3c32b687 100644 --- a/src/idpyoidc/client/provider/github.py +++ b/src/idpyoidc/client/provider/github.py @@ -6,7 +6,7 @@ from idpyoidc.message import SINGLE_REQUIRED_STRING from idpyoidc.message import oauth2 from idpyoidc.message.oauth2 import ResponseMessage -from idpyoidc.work_environment import get_signing_algs +from idpyoidc.claims import get_signing_algs class AccessTokenResponse(Message): diff --git a/src/idpyoidc/client/provider/linkedin.py b/src/idpyoidc/client/provider/linkedin.py index 419ad189..0d5db7ab 100644 --- a/src/idpyoidc/client/provider/linkedin.py +++ b/src/idpyoidc/client/provider/linkedin.py @@ -7,7 +7,7 @@ from idpyoidc.message import SINGLE_REQUIRED_INT from idpyoidc.message import SINGLE_REQUIRED_STRING from idpyoidc.message import oauth2 -from idpyoidc.work_environment import get_signing_algs +from idpyoidc.claims import get_signing_algs class AccessTokenResponse(Message): diff --git a/src/idpyoidc/client/rp_handler.py b/src/idpyoidc/client/rp_handler.py index e00599b2..59589506 100644 --- a/src/idpyoidc/client/rp_handler.py +++ b/src/idpyoidc/client/rp_handler.py @@ -49,7 +49,7 @@ def __init__( verify_ssl=True, client_cls=None, state_db=None, - httpc=None, + httpc=None, httpc_params=None, config=None, **kwargs, @@ -205,8 +205,13 @@ def init_client(self, issuer): # If non persistent _keyjar = client.keyjar if not _keyjar: - _keyjar = client.keyjar = KeyJar() - _keyjar.load(self.keyjar.dump()) + _keyjar = KeyJar() + _keyjar.httpc_params.update(self.httpc_params) + + for iss in self.keyjar.owners(): + _keyjar.import_jwks(self.keyjar.export_jwks(issuer_id=iss, private=True), iss) + + client.keyjar = _keyjar # If persistent nothing has to be copied _context.base_url = self.base_url @@ -395,9 +400,9 @@ def client_setup( def _get_response_type(self, context, req_args: Optional[dict] = None): if req_args: return req_args.get("response_type", - context.work_environment.get_usage("response_types")[0]) + context.claims.get_usage("response_types")[0]) else: - return context.work_environment.get_usage("response_types")[0] + return context.claims.get_usage("response_types")[0] def init_authorization( self, @@ -427,18 +432,21 @@ def init_authorization( raise ValueError("Missing state/session key") _context = client.get_context() - #_entity = client.upstream_get("entity") + # _entity = client.upstream_get("entity") _nonce = rndstr(24) _response_type = self._get_response_type(_context, req_args) request_args = { "redirect_uri": pick_redirect_uri( _context, request_args=req_args, response_type=_response_type ), - "scope": _context.work_environment.get_usage("scope"), "response_type": _response_type, "nonce": _nonce, } + _scope = _context.claims.get_usage("scope") + if _scope: + request_args['scope'] = _scope + _req_args = _context.config.get("request_args") if _req_args: if "claims" in _req_args: @@ -510,7 +518,7 @@ def get_response_type(client): :param client: A Client instance :return: The response_type """ - return client.service_context.work_environment.get_usage("response_types")[0] + return client.service_context.claims.get_usage("response_types")[0] @staticmethod def get_client_authn_method(client, endpoint): @@ -877,7 +885,7 @@ def has_active_authentication(self, state): # Look for an IdToken _arg = client.get_context().cstate.get_set(state, - claim=["__verified_id_token"]) + claim=["__verified_id_token"]) if _arg: _now = utc_time_sans_frac() @@ -926,14 +934,14 @@ def logout( post_logout_redirect_uri: Optional[str] = "", ) -> dict: """ - Does a RP initiated logout from an OP. After logout the user will be - redirect by the OP to a URL of choice (post_logout_redirect_uri). + Does an RP initiated logout from an OP. After logout the user will be + redirected by the OP to a URL of choice (post_logout_redirect_uri). :param state: Key to an active session :param client: Which client to use :param post_logout_redirect_uri: If a special post_logout_redirect_uri should be used - :return: A US + :return: Request arguments """ logger.debug(20 * "*" + " logout " + 20 * "*") diff --git a/src/idpyoidc/client/service.py b/src/idpyoidc/client/service.py index 025be666..762cd403 100644 --- a/src/idpyoidc/client/service.py +++ b/src/idpyoidc/client/service.py @@ -525,7 +525,8 @@ def _do_jwt(self, info): args["allowed_enc_algs"] = enc_algs["alg"] args["allowed_enc_encs"] = enc_algs["enc"] - _jwt = JWT(key_jar=_context.get_keyjar(), **args) + + _jwt = JWT(key_jar=self.upstream_get('attribute','keyjar'), **args) _jwt.iss = _context.get_client_id() return _jwt.unpack(info) @@ -672,7 +673,12 @@ def construct_uris(self, if uri in _callback_uris: pass else: - _callback_uris[uri] = self.get_uri(base_url, self._callback_path.get(uri), hex) + _path = self._callback_path.get(uri) + if isinstance(_path, str): + _callback_uris[uri] = self.get_uri(base_url, self._callback_path.get(_path), hex) + else: + _callback_uris[uri] = [self.get_uri(base_url, self._callback_path.get(_var), + hex) for _var in _path] return _callback_uris diff --git a/src/idpyoidc/client/service_context.py b/src/idpyoidc/client/service_context.py index e6209173..1da9765e 100644 --- a/src/idpyoidc/client/service_context.py +++ b/src/idpyoidc/client/service_context.py @@ -9,25 +9,24 @@ from typing import Optional from typing import Union -from cryptojwt.jwk.rsa import RSAKey from cryptojwt.jwk.rsa import import_private_rsa_key_from_file +from cryptojwt.jwk.rsa import RSAKey from cryptojwt.key_bundle import KeyBundle from cryptojwt.key_jar import KeyJar from cryptojwt.utils import as_bytes +from idpyoidc.claims import Claims +from idpyoidc.claims import claims_dump +from idpyoidc.claims import claims_load +from idpyoidc.client.claims.oauth2 import Claims as OAUTH2_Specs +from idpyoidc.client.claims.oidc import Claims as OIDC_Specs from idpyoidc.client.configure import Configuration -from idpyoidc.client.work_environment.oauth2 import WorkEnvironment as OAUTH2_Specs -from idpyoidc.client.work_environment.oidc import WorkEnvironment as OIDC_Specs from idpyoidc.util import rndstr -from idpyoidc.work_environment import WorkEnvironment -from idpyoidc.work_environment import work_environment_dump -from idpyoidc.work_environment import work_environment_load +from .claims.transform import preferred_to_registered +from .claims.transform import supported_to_preferred from .configure import get_configuration from .current import Current -from .work_environment.transform import preferred_to_registered -from .work_environment.transform import supported_to_preferred from ..impexp import ImpExp -from ..node import Unit logger = logging.getLogger(__name__) @@ -78,7 +77,7 @@ } -class ServiceContext(Unit): +class ServiceContext(ImpExp): """ This class keeps information that a client needs to be able to talk to a server. Some of this information comes from configuration and some @@ -99,8 +98,8 @@ class ServiceContext(Unit): "httpc_params": None, "iss_hash": None, "issuer": None, - # 'keyjar': KeyJar, - "work_environment": WorkEnvironment, + 'keyjar': KeyJar, + "claims": Claims, "provider_info": None, "requests_dir": None, "registration_response": None, @@ -110,16 +109,18 @@ class ServiceContext(Unit): } special_load_dump = { - "specs": {"load": work_environment_load, "dump": work_environment_dump}, + "specs": {"load": claims_load, "dump": claims_dump}, } + init_args = ['upstream_get'] + def __init__(self, + upstream_get: Optional[Callable] = None, base_url: Optional[str] = "", + keyjar: Optional[KeyJar] = None, config: Optional[Union[dict, Configuration]] = None, cstate: Optional[Current] = None, - upstream_get: Optional[Callable] = None, - client_type: Optional[str] = 'oidc', - keyjar: Optional[KeyJar] = None, + client_type: Optional[str] = 'oauth2', **kwargs): ImpExp.__init__(self) config = get_configuration(config) @@ -127,9 +128,9 @@ def __init__(self, self.upstream_get = upstream_get if not client_type or client_type == "oidc": - self.work_environment = OIDC_Specs() + self.claims = OIDC_Specs() elif client_type == "oauth2": - self.work_environment = OAUTH2_Specs() + self.claims = OAUTH2_Specs() else: raise ValueError(f"Unknown client type: {client_type}") @@ -167,9 +168,8 @@ def __init__(self, for key, val in kwargs.items(): setattr(self, key, val) - _keyjar = self.work_environment.load_conf(config.conf, supports=self.supports(), - keyjar=keyjar) - self.upstream_get('set_attribute', 'keyjar', _keyjar) + self.keyjar = self.claims.load_conf(config.conf, supports=self.supports(), + keyjar=keyjar) _response_types = self.get_preference( 'response_types_supported', @@ -211,6 +211,13 @@ def import_keys(self, keyspec): :param keyspec: """ + _keyjar = self.upstream_get('attribute', 'keyjar') + if _keyjar is None: + _keyjar = KeyJar() + new = True + else: + new = False + for where, spec in keyspec.items(): if where == "file": for typ, files in spec.items(): @@ -219,19 +226,23 @@ def import_keys(self, keyspec): _key = RSAKey(priv_key=import_private_rsa_key_from_file(fil), use="sig") _bundle = KeyBundle() _bundle.append(_key) - self.keyjar.add_kb("", _bundle) + _keyjar.add_kb("", _bundle) elif where == "url": for iss, url in spec.items(): _bundle = KeyBundle(source=url) - self.keyjar.add_kb(iss, _bundle) + _keyjar.add_kb(iss, _bundle) + + if new: + _unit = self.upstream_get('unit') + _unit.setattribute('keyjar', _keyjar) def _get_crypt(self, typ, attr): _item_typ = CLI_REG_MAP.get(typ) _alg = '' if _item_typ: - _alg = self.work_environment.get_usage(_item_typ[attr]) + _alg = self.claims.get_usage(_item_typ[attr]) if not _alg: - _alg = self.work_environment.get_preference(_item_typ[attr]) + _alg = self.claims.get_preference(_item_typ[attr]) if not _alg: _item_typ = PROVIDER_INFO_MAP.get(typ) @@ -268,41 +279,37 @@ def set(self, key, value): setattr(self, key, value) def get_client_id(self): - return self.work_environment.get_usage("client_id") + return self.claims.get_usage("client_id") def collect_usage(self): - return self.work_environment.use + return self.claims.use def supports(self): res = {} if self.upstream_get: services = self.upstream_get('services') - for service in services.values(): - res.update(service.supports()) - res.update(self.work_environment.supports()) + if not services: + pass + else: + for service in services.values(): + res.update(service.supports()) + res.update(self.claims.supports()) return res def prefers(self): - return self.work_environment.prefer + return self.claims.prefer def get_preference(self, claim, default=None): - return self.work_environment.get_preference(claim, default=default) + return self.claims.get_preference(claim, default=default) def set_preference(self, key, value): - self.work_environment.set_preference(key, value) + self.claims.set_preference(key, value) def get_usage(self, claim, default: Optional[str] = None): - return self.work_environment.get_usage(claim, default) + return self.claims.get_usage(claim, default) def set_usage(self, claim, value): - return self.work_environment.set_usage(claim, value) - - def get_keyjar(self): - val = getattr(self, 'keyjar', None) - if not val: - return self.upstream_get('attribute', 'keyjar') - else: - return val + return self.claims.set_usage(claim, value) def _callback_per_service(self): _cb = {} @@ -325,10 +332,11 @@ def construct_uris(self, response_types: Optional[list] = None): _callback_uris = self.get_preference('callback_uris', {}) if self.upstream_get: services = self.upstream_get('services') - for service in services.values(): - _callback_uris.update(service.construct_uris(base_url=_base_url, hex=_hex, - context=self, - response_types=response_types)) + if services: + for service in services.values(): + _callback_uris.update(service.construct_uris(base_url=_base_url, hex=_hex, + context=self, + response_types=response_types)) self.set_preference('callback_uris', _callback_uris) if 'redirect_uris' in _callback_uris: @@ -338,7 +346,7 @@ def construct_uris(self, response_types: Optional[list] = None): self.set_preference('redirect_uris', list(_redirect_uris)) def prefer_or_support(self, claim): - if claim in self.work_environment.prefer: + if claim in self.claims.prefer: return 'prefer' else: for service in self.upstream_get('services').values(): @@ -346,20 +354,20 @@ def prefer_or_support(self, claim): if _res: return _res - if claim in self.work_environment.supported(claim): + if claim in self.claims.supported(claim): return 'support' return None def map_supported_to_preferred(self, info: Optional[dict] = None): - self.work_environment.prefer = supported_to_preferred(self.supports(), - self.work_environment.prefer, - base_url=self.base_url, - info=info) - return self.work_environment.prefer + self.claims.prefer = supported_to_preferred(self.supports(), + self.claims.prefer, + base_url=self.base_url, + info=info) + return self.claims.prefer def map_preferred_to_registered(self, registration_response: Optional[dict] = None): - self.work_environment.use = preferred_to_registered( - self.work_environment.prefer, + self.claims.use = preferred_to_registered( + self.claims.prefer, supported=self.supports(), registration_response=registration_response) - return self.work_environment.use + return self.claims.use diff --git a/src/idpyoidc/client/util.py b/src/idpyoidc/client/util.py index a5d78c73..e2418cd2 100755 --- a/src/idpyoidc/client/util.py +++ b/src/idpyoidc/client/util.py @@ -1,8 +1,8 @@ """Utilities""" -from http.cookiejar import Cookie -from http.cookiejar import http2time import logging import secrets +from http.cookiejar import Cookie +from http.cookiejar import http2time from urllib.parse import parse_qs from urllib.parse import urlsplit from urllib.parse import urlunsplit @@ -322,5 +322,6 @@ def implicit_response_types(a): res.append(typ) return res + def get_uri(base_url, path, hex): return f"{base_url}/{path}/{hex}" diff --git a/src/idpyoidc/client/work_environment/__init__.py b/src/idpyoidc/client/work_environment/__init__.py deleted file mode 100644 index 082effd5..00000000 --- a/src/idpyoidc/client/work_environment/__init__.py +++ /dev/null @@ -1,51 +0,0 @@ -from cryptojwt.exception import IssuerNotFound -from cryptojwt.jwk.hmac import SYMKey - -from idpyoidc import work_environment - - -class WorkEnvironment(work_environment.WorkEnvironment): - - def get_base_url(self, configuration: dict): - _base = configuration.get('base_url') - if not _base: - _base = configuration.get('client_id') - - return _base - - def get_id(self, configuration: dict): - return self.get_preference('client_id') - - def add_extra_keys(self, keyjar, id): - _secret = self.get_preference('client_secret') - if _secret: - _new = SYMKey(key=_secret) - try: - _id_keys = keyjar.get_issuer_keys(id) - except IssuerNotFound: - keyjar.add_symmetric(issuer_id=id, key=_secret) - else: - if _new not in _id_keys: - keyjar.add_symmetric(issuer_id=id, key=_secret) - - try: - _own_keys = keyjar.get_issuer_keys('') - except IssuerNotFound: - keyjar.add_symmetric(issuer_id='', key=_secret) - else: - if _new not in _own_keys: - keyjar.add_symmetric(issuer_id='', key=_secret) - - def get_jwks(self, keyjar): - _jwks = None - try: - _own_keys = keyjar.get_issuer_keys('') - except IssuerNotFound: - pass - else: - if len(_own_keys) == 1 and isinstance(_own_keys[0], SYMKey): - pass - else: - _jwks = keyjar.export_jwks() - - return _jwks diff --git a/src/idpyoidc/client/work_environment/oidc.py b/src/idpyoidc/client/work_environment/oidc.py deleted file mode 100644 index dae2b4b8..00000000 --- a/src/idpyoidc/client/work_environment/oidc.py +++ /dev/null @@ -1,76 +0,0 @@ -import os -from typing import Optional - -from idpyoidc import work_environment -from idpyoidc.client import work_environment as client_work_environment -# from idpyoidc.client.client_auth import get_client_authn_methods - - -class WorkEnvironment(client_work_environment.WorkEnvironment): - parameter = work_environment.WorkEnvironment.parameter.copy() - parameter.update({ - "requests_dir": None - }) - - _supports = { - "acr_values_supported": None, - "application_type": "web", - "callback_uris": None, - # "client_authn_methods": get_client_authn_methods, - "client_id": None, - "client_name": None, - "client_secret": None, - "client_uri": None, - "contacts": None, - "default_max_age": 86400, - "encrypt_id_token_supported": None, - "grant_types_supported": ["authorization_code", "implicit", "refresh_token"], - "logo_uri": None, - "id_token_signing_alg_values_supported": work_environment.get_signing_algs, - "id_token_encryption_alg_values_supported": work_environment.get_encryption_algs, - "id_token_encryption_enc_values_supported": work_environment.get_encryption_encs, - "initiate_login_uri": None, - "jwks": None, - "jwks_uri": None, - "policy_uri": None, - "requests_dir": None, - "require_auth_time": None, - "sector_identifier_uri": None, - "scopes_supported": ["openid"], - "subject_types_supported": ["public", "pairwise", "ephemeral"], - "tos_uri": None, - } - - def __init__(self, - prefer: Optional[dict] = None, - callback_path: Optional[dict] = None - ): - work_environment.WorkEnvironment.__init__(self, prefer=prefer, callback_path=callback_path) - - def verify_rules(self): - if self.get_preference("request_parameter_supported") and self.get_preference( - "request_uri_parameter_supported"): - raise ValueError( - "You have to chose one of 'request_parameter_supported' and " - "'request_uri_parameter_supported'. You can't have both.") - - if not self.get_preference('encrypt_userinfo_supported'): - self.set_preference('userinfo_encryption_alg_values_supported', []) - self.set_preference('userinfo_encryption_enc_values_supported', []) - - if not self.get_preference('encrypt_request_object_supported'): - self.set_preference('request_object_encryption_alg_values_supported', []) - self.set_preference('request_object_encryption_enc_values_supported', []) - - if not self.get_preference('encrypt_id_token_supported'): - self.set_preference('id_token_encryption_alg_values_supported', []) - self.set_preference('id_token_encryption_enc_values_supported', []) - - def locals(self, info): - requests_dir = info.get("requests_dir") - if requests_dir: - # make sure the path exists. If not, then create it. - if not os.path.isdir(requests_dir): - os.makedirs(requests_dir) - - self.set("requests_dir", requests_dir) diff --git a/src/idpyoidc/configure.py b/src/idpyoidc/configure.py index 45d7afb4..9b74135d 100644 --- a/src/idpyoidc/configure.py +++ b/src/idpyoidc/configure.py @@ -298,7 +298,7 @@ def create_from_config_file( domain: Optional[str] = "", port: Optional[int] = 0, dir_attributes: Optional[List[str]] = None, -): +) -> Base: return cls( load_config_file(filename), entity_conf=entity_conf, diff --git a/src/idpyoidc/context.py b/src/idpyoidc/context.py index 9b17f2b5..55ec0e6a 100644 --- a/src/idpyoidc/context.py +++ b/src/idpyoidc/context.py @@ -1,12 +1,6 @@ import copy -from typing import Optional -from typing import Union from urllib.parse import quote_plus -from cryptojwt import KeyJar -from cryptojwt.key_jar import init_key_jar -from idpyoidc.configure import Configuration - from idpyoidc.impexp import ImpExp diff --git a/src/idpyoidc/impexp.py b/src/idpyoidc/impexp.py index 297282fd..efa2ac62 100644 --- a/src/idpyoidc/impexp.py +++ b/src/idpyoidc/impexp.py @@ -78,11 +78,11 @@ def local_load_adjustments(self, **kwargs): pass def load_attr( - self, - cls: Any, - item: dict, - init_args: Optional[dict] = None, - load_args: Optional[dict] = None, + self, + cls: Any, + item: dict, + init_args: Optional[dict] = None, + load_args: Optional[dict] = None, ) -> Any: if load_args: _kwargs = {"load_args": load_args} @@ -141,7 +141,11 @@ def load(self, item: dict, init_args: Optional[dict] = None, load_args: Optional _load_args = {} if init_args: - _kwargs["init_args"] = init_args + for attr, val in init_args.items(): + if attr in self.init_args: + setattr(self, attr, val) + + _kwargs["init_args"] = init_args for attr, cls in self.parameter.items(): if attr not in item or attr in self.special_load_dump: diff --git a/src/idpyoidc/message/oauth2/__init__.py b/src/idpyoidc/message/oauth2/__init__.py index b6666f0e..2660aec0 100644 --- a/src/idpyoidc/message/oauth2/__init__.py +++ b/src/idpyoidc/message/oauth2/__init__.py @@ -1,12 +1,16 @@ import inspect +import json import logging import string import sys from idpyoidc import verified_claim_name +from idpyoidc.exception import FormatError from idpyoidc.exception import MissingAttribute +from idpyoidc.exception import MissingRequiredAttribute from idpyoidc.exception import VerificationError from idpyoidc.message import Message +from idpyoidc.message import msg_ser from idpyoidc.message import OPTIONAL_LIST_OF_SP_SEP_STRINGS from idpyoidc.message import OPTIONAL_LIST_OF_STRINGS from idpyoidc.message import REQUIRED_LIST_OF_SP_SEP_STRINGS @@ -346,6 +350,80 @@ class ASConfigurationResponse(Message): c_default = {"version": "3.0"} +def deserialize_from_one_of(val, msgtype, sformat): + if sformat in ["dict", "json"]: + flist = ["json", "urlencoded"] + if not isinstance(val, str): + val = json.dumps(val) + else: + flist = ["urlencoded", "json"] + + for _format in flist: + try: + return msgtype().deserialize(val, _format) + except FormatError: + pass + raise FormatError("Unexpected format") + + +class OauthClientMetadata(Message): + """Metadata for an OAuth2 Client.""" + c_param = { + "redirect_uris": OPTIONAL_LIST_OF_STRINGS, + "token_endpoint_auth_method": SINGLE_OPTIONAL_STRING, + "grant_type": OPTIONAL_LIST_OF_STRINGS, + "response_types": OPTIONAL_LIST_OF_STRINGS, + "client_name": SINGLE_OPTIONAL_STRING, + "client_uri": SINGLE_OPTIONAL_STRING, + "logo_uri": SINGLE_OPTIONAL_STRING, + "scope": OPTIONAL_LIST_OF_SP_SEP_STRINGS, + "contacts": OPTIONAL_LIST_OF_STRINGS, + "tos_uri": SINGLE_OPTIONAL_STRING, + "policy_uri": SINGLE_OPTIONAL_STRING, + "jwks_uri": SINGLE_OPTIONAL_STRING, + "jwks": SINGLE_OPTIONAL_JSON, + "software_id": SINGLE_OPTIONAL_STRING, + "software_version": SINGLE_OPTIONAL_STRING + } + + +def oauth_client_metadata_deser(val, sformat="json"): + """Deserializes a JSON object (most likely) into a OauthClientMetadata.""" + return deserialize_from_one_of(val, OauthClientMetadata, sformat) + + +OPTIONAL_OAUTH_CLIENT_METADATA = (Message, False, msg_ser, + oauth_client_metadata_deser, False) + + +class OauthClientInformationResponse(OauthClientMetadata): + """The information returned by a OAuth2 Server about an OAuth2 client.""" + c_param = OauthClientMetadata.c_param.copy() + c_param.update({ + "client_id": SINGLE_REQUIRED_STRING, + "client_secret": SINGLE_OPTIONAL_STRING, + "client_id_issued_at": SINGLE_OPTIONAL_INT, + "client_secret_expires_at": SINGLE_OPTIONAL_INT + }) + + def verify(self, **kwargs): + super(OauthClientInformationResponse, self).verify(**kwargs) + + if "client_secret" in self: + if "client_secret_expires_at" not in self: + raise MissingRequiredAttribute( + "client_secret_expires_at is a MUST if client_secret is present") + + +def oauth_client_registration_response_deser(val, sformat="json"): + """Deserializes a JSON object (most likely) into a OauthClientInformationResponse.""" + return deserialize_from_one_of(val, OauthClientInformationResponse, sformat) + + +OPTIONAL_OAUTH_CLIENT_REGISTRATION_RESPONSE = ( + Message, False, msg_ser, oauth_client_registration_response_deser, False) + + # RFC 7662 class TokenIntrospectionRequest(Message): c_param = { @@ -484,6 +562,26 @@ class JWTAccessToken(Message): +class JSONWebToken(Message): + # implements RFC 9068 + c_param = { + 'iss': SINGLE_REQUIRED_STRING, + 'exp': SINGLE_REQUIRED_STRING, + 'aud': SINGLE_REQUIRED_STRING, + 'sub': SINGLE_REQUIRED_STRING, + "client_id": SINGLE_REQUIRED_STRING, + 'iat': SINGLE_REQUIRED_STRING, + 'jti': SINGLE_REQUIRED_STRING, + 'auth_time': SINGLE_OPTIONAL_INT, + 'acr': SINGLE_OPTIONAL_STRING, + 'amr': OPTIONAL_LIST_OF_STRINGS, + 'scope': OPTIONAL_LIST_OF_SP_SEP_STRINGS, + 'groups': OPTIONAL_LIST_OF_STRINGS, + 'roles': OPTIONAL_LIST_OF_STRINGS, + 'entitlements': OPTIONAL_LIST_OF_STRINGS + } + + def factory(msgtype, **kwargs): """ Factory method that can be used to easily instansiate a class instance diff --git a/src/idpyoidc/metadata.py b/src/idpyoidc/metadata.py new file mode 100644 index 00000000..7104d82a --- /dev/null +++ b/src/idpyoidc/metadata.py @@ -0,0 +1,279 @@ +import logging +from functools import cmp_to_key +from typing import Callable +from typing import Optional + +from cryptojwt import KeyJar +from cryptojwt.jwe import SUPPORTED +from cryptojwt.jws.jws import SIGNER_ALGS +from cryptojwt.key_jar import init_key_jar +from cryptojwt.utils import importer + +from idpyoidc.client.util import get_uri +from idpyoidc.impexp import ImpExp +from idpyoidc.util import add_path +from idpyoidc.util import qualified_name + +logger = logging.getLogger(__name__) + + +def metadata_dump(info, exclude_attributes): + return {qualified_name(info.__class__): info.dump(exclude_attributes=exclude_attributes)} + + +def metadata_load(item: dict, **kwargs): + _class_name = list(item.keys())[0] # there is only one + _cls = importer(_class_name) + _cls = _cls().load(item[_class_name]) + return _cls + + +class Metadata(ImpExp): + parameter = { + "prefer": None, + "use": None, + "callback_path": None, + "_local": None + } + + _supports = {} + + def __init__(self, + prefer: Optional[dict] = None, + callback_path: Optional[dict] = None): + + ImpExp.__init__(self) + if isinstance(prefer, dict): + self.prefer = {k: v for k, v in prefer.items() if k in self._supports} + else: + self.prefer = {} + + self.callback_path = callback_path or {} + self.use = {} + self._local = {} + + def get_use(self): + return self.use + + def set_usage(self, key, value): + self.use[key] = value + + def get_usage(self, key, default=None): + return self.use.get(key, default) + + def get_preference(self, key, default=None): + return self.prefer.get(key, default) + + def set_preference(self, key, value): + self.prefer[key] = value + + def remove_preference(self, key): + if key in self.prefer: + del self.prefer[key] + + def _callback_uris(self, base_url, hex): + _uri = [] + for type in self.get_usage("response_types", self._supports['response_types']): + if "code" in type: + _uri.append('code') + elif type in ["id_token", "id_token token"]: + _uri.append('implicit') + + if "form_post" in self.supports: + _uri.append("form_post") + + callback_uri = {} + for key in _uri: + callback_uri[key] = get_uri(base_url, self.callback_path[key], hex) + return callback_uri + + def construct_redirect_uris(self, + base_url: str, + hex: str, + callbacks: Optional[dict] = None): + if not callbacks: + callbacks = self._callback_uris(base_url, hex) + + if callbacks: + self.set_preference('callbacks', callbacks) + self.set_preference("redirect_uris", [v for k, v in callbacks.items()]) + + self.callback = callbacks + + def verify_rules(self): + return True + + def locals(self, info): + pass + + def _keyjar(self, keyjar=None, conf=None, entity_id=""): + _uri_path = '' + if keyjar is None: + if "keys" in conf: + keys_args = {k: v for k, v in conf["keys"].items() if k != "uri_path"} + _keyjar = init_key_jar(**keys_args) + _uri_path = conf['keys'].get('uri_path') + elif "key_conf" in conf and conf["key_conf"]: + keys_args = {k: v for k, v in conf["key_conf"].items() if k != "uri_path"} + _keyjar = init_key_jar(**keys_args) + _uri_path = conf['key_conf'].get('uri_path') + else: + _keyjar = KeyJar() + if "jwks" in conf: + _keyjar.import_jwks(conf["jwks"], "") + + if "" in _keyjar and entity_id: + # make sure I have the keys under my own name too (if I know it) + _keyjar.import_jwks_as_json(_keyjar.export_jwks_as_json(True, ""), entity_id) + + _httpc_params = conf.get("httpc_params") + if _httpc_params: + _keyjar.httpc_params = _httpc_params + + return _keyjar, _uri_path + else: + if "keys" in conf: + _uri_path = conf['keys'].get('uri_path') + elif "key_conf" in conf and conf["key_conf"]: + _uri_path = conf['key_conf'].get('uri_path') + return keyjar, _uri_path + + def get_base_url(self, configuration: dict): + raise NotImplementedError() + + def get_id(self, configuration: dict): + raise NotImplementedError() + + def add_extra_keys(self, keyjar, id): + return None + + def get_jwks(self, keyjar): + return None + + def handle_keys(self, configuration: dict, keyjar: Optional[KeyJar] = None, + base_url: Optional[str] = ''): + _jwks = _jwks_uri = None + _id = self.get_id(configuration) + keyjar, uri_path = self._keyjar(keyjar, configuration, entity_id=_id) + + self.add_extra_keys(keyjar, _id) + + # now that keys are in the Key Jar, now for how to publish it + if 'jwks_uri' in configuration: # simple + _jwks_uri = configuration.get('jwks_uri') + elif uri_path: + if not base_url: + base_url = self.get_base_url(configuration) + _jwks_uri = add_path(base_url, uri_path) + else: # jwks or nothing + _jwks = self.get_jwks(keyjar) + + return {'keyjar': keyjar, 'jwks': _jwks, 'jwks_uri': _jwks_uri} + + def load_conf(self, configuration, supports, keyjar: Optional[KeyJar] = None, + base_url: Optional[str] = ''): + for attr, val in configuration.items(): + if attr == "preference": + for k, v in val.items(): + if k in supports: + self.set_preference(k, v) + elif attr in supports: + self.set_preference(attr, val) + + self.locals(configuration) + + for key, val in self.handle_keys(configuration, keyjar=keyjar, base_url=base_url).items(): + if key == 'keyjar': + keyjar = val + elif val: + self.set_preference(key, val) + + self.verify_rules() + return keyjar + + def get(self, key, default=None): + if key in self._local: + return self._local[key] + else: + return default + + def set(self, key, val): + self._local[key] = val + + def construct_uris(self, *args): + pass + + def supports(self): + res = {} + for key, val in self._supports.items(): + if isinstance(val, Callable): + res[key] = val() + else: + res[key] = val + return res + + def supported(self, claim): + return claim in self._supports + + def prefers(self): + return self.prefer + + +SIGNING_ALGORITHM_SORT_ORDER = ['RS', 'ES', 'PS', 'HS'] + + +def cmp(a, b): + return (a > b) - (a < b) + + +def alg_cmp(a, b): + if a == 'none': + return 1 + elif b == 'none': + return -1 + + _pos1 = SIGNING_ALGORITHM_SORT_ORDER.index(a[0:2]) + _pos2 = SIGNING_ALGORITHM_SORT_ORDER.index(b[0:2]) + if _pos1 == _pos2: + return (a > b) - (a < b) + elif _pos1 > _pos2: + return 1 + else: + return -1 + + +def get_signing_algs(): + # Assumes Cryptojwt + _algs = [l for l in list(SIGNER_ALGS.keys()) if l != 'none'] + return sorted(_algs, key=cmp_to_key(alg_cmp)) + + +def get_encryption_algs(): + return SUPPORTED['alg'] + + +def get_encryption_encs(): + return SUPPORTED['enc'] + + +def array_or_singleton(claim_spec, values): + if isinstance(claim_spec[0], list): + if isinstance(values, list): + return values + else: + return [values] + else: + if isinstance(values, list): + return values[0] + else: # singleton + return values + + +def is_subset(a, b): + if isinstance(a, list): + if isinstance(b, list): + return set(b).issubset(set(a)) + elif isinstance(b, list): + return a in b + else: + return a == b diff --git a/src/idpyoidc/node.py b/src/idpyoidc/node.py index df1051d6..498e83e1 100644 --- a/src/idpyoidc/node.py +++ b/src/idpyoidc/node.py @@ -43,9 +43,43 @@ def create_keyjar( return keyjar +class Node: + + def __init__(self, upstream_get: Callable = None): + self.upstream_get = upstream_get + + def unit_get(self, what, *arg): + _func = getattr(self, f"get_{what}", None) + if _func: + return _func(*arg) + return None + + def get_attribute(self, attr, *args): + try: + val = getattr(self, attr) + except AttributeError: + if self.upstream_get: + return self.upstream_get("attribute", attr) + else: + return None + else: + if val is None and self.upstream_get: + return self.upstream_get("attribute", attr) + else: + return val + + def set_attribute(self, attr, val): + setattr(self, attr, val) + + def get_unit(self, *args): + return self + + class Unit(ImpExp): name = '' + init_args = ['upstream_get'] + def __init__(self, upstream_get: Callable = None, keyjar: Optional[KeyJar] = None, @@ -63,7 +97,16 @@ def __init__(self, if config is None: config = {} - if keyjar or key_conf or config.get('key_conf') or config.get('jwks') or config.get('keys'): + keyjar = keyjar or config.get('keyjar') + key_conf = key_conf or config.get('key_conf', config.get('keys')) + + if not keyjar and not key_conf: + _jwks = config.get('jwks') + if _jwks: + keyjar = KeyJar() + keyjar.import_jwks_as_json(_jwks, client_id) + + if keyjar or key_conf: # Should be either one id = issuer_id or client_id self.keyjar = create_keyjar(keyjar, conf=config, key_conf=key_conf, id=id) @@ -148,6 +191,13 @@ def __init__(self, self.context = context or None + def get_context_attribute(self, attr, *args): + _val = getattr(self.context, attr) + if not _val and self.upstream_get: + return self.upstream_get('context_attribute', attr) + else: + return _val + # Neither client nor Server class Collection(Unit): @@ -161,7 +211,7 @@ def __init__(self, entity_id: Optional[str] = "", key_conf: Optional[dict] = None, functions: Optional[dict] = None, - metadata: Optional[dict] = None + claims: Optional[dict] = None ): if config is None: config = {} @@ -175,10 +225,34 @@ def __init__(self, 'upstream_get': self.unit_get } - self.metadata = metadata or {} + self.claims = claims or {} + self.upstream_get = upstream_get + # self.context = {} if functions: for key, val in functions.items(): _kwargs = val["kwargs"].copy() _kwargs.update(_args) setattr(self, key, instantiate(val["class"], **_kwargs)) + + def get_context_attribute(self, attr, *args): + _cntx = getattr(self, 'context', None) + if _cntx: + _val = getattr(_cntx, attr, None) + if _val: + return _val + + if self.upstream_get: + return self.upstream_get('context_attribute', attr) + else: + return None + + def get_attribute(self, attr, *args): + val = getattr(self, attr, None) + if val: + return val + + if self.upstream_get: + return self.upstream_get('attribute', attr) + else: + return None diff --git a/src/idpyoidc/server/__init__.py b/src/idpyoidc/server/__init__.py index 6218ec5a..59b93e14 100644 --- a/src/idpyoidc/server/__init__.py +++ b/src/idpyoidc/server/__init__.py @@ -7,8 +7,6 @@ from cryptojwt import KeyJar -from idpyoidc.impexp import ImpExp -from idpyoidc.message.oidc import RegistrationRequest from idpyoidc.node import Unit from idpyoidc.server import authz from idpyoidc.server.client_authn import client_auth_setup @@ -16,9 +14,6 @@ from idpyoidc.server.configure import OPConfiguration from idpyoidc.server.endpoint import Endpoint from idpyoidc.server.endpoint_context import EndpointContext -from idpyoidc.server.endpoint_context import get_provider_capabilities -from idpyoidc.server.endpoint_context import init_service -from idpyoidc.server.endpoint_context import init_user_info from idpyoidc.server.session.manager import create_session_manager from idpyoidc.server.user_authn.authn_context import populate_authn_broker from idpyoidc.server.util import allow_refresh_token @@ -58,54 +53,30 @@ def __init__( issuer_id=self.issuer) self.upstream_get = upstream_get - self.conf = conf + if isinstance(conf, OPConfiguration) or isinstance(conf, ASConfiguration): + self.conf = conf + else: + self.conf = OPConfiguration(conf) - self.endpoint = do_endpoints(conf, self.unit_get) + self.endpoint = do_endpoints(self.conf, self.unit_get) - self.endpoint_context = EndpointContext( - conf=conf, + self.context = EndpointContext( + conf=self.conf, upstream_get=self.unit_get, # points to me cwd=cwd, cookie_handler=cookie_handler, - keyjar=keyjar + keyjar=self.keyjar ) - self.endpoint_context.authz = self.setup_authz() - - self.setup_authentication(self.endpoint_context) - - # _cap = get_provider_capabilities(conf, self.endpoint) - # self.endpoint_context.provider_info = self.endpoint_context.create_providerinfo(_cap) - self.endpoint_context.do_add_on(endpoints=self.endpoint) - - self.endpoint_context.session_manager = create_session_manager( - self.unit_get, - self.endpoint_context.th_args, - sub_func=self.endpoint_context._sub_func, - conf=self.conf, - ) - self.endpoint_context.do_userinfo() - # Must be done after userinfo - self.setup_login_hint_lookup() - self.endpoint_context.set_remember_token() + # Need to have context in place before doing this + self.context.do_add_on(endpoints=self.endpoint) - self.setup_client_authn_methods() for endpoint_name, _ in self.endpoint.items(): self.endpoint[endpoint_name].upstream_get = self.unit_get _token_endp = self.endpoint.get("token") if _token_endp: - _token_endp.allow_refresh = allow_refresh_token(self.endpoint_context) - - self.endpoint_context.claims_interface = init_service( - conf["claims_interface"], self.unit_get - ) - - _id_token_handler = self.endpoint_context.session_manager.token_handler.handler.get( - "id_token" - ) - if _id_token_handler: - self.endpoint_context.provider_info.update(_id_token_handler.provider_info) + _token_endp.allow_refresh = allow_refresh_token(self.context) def get_endpoints(self, *arg): return self.endpoint @@ -117,10 +88,10 @@ def get_endpoint(self, endpoint_name, *arg): return None def get_context(self, *arg): - return self.endpoint_context + return self.context def get_endpoint_context(self, *arg): - return self.endpoint_context + return self.context def get_server(self, *args): return self @@ -128,46 +99,7 @@ def get_server(self, *args): def get_entity(self, *args): return self - def setup_authz(self): - authz_spec = self.conf.get("authz") - if authz_spec: - return init_service(authz_spec, self.unit_get) - else: - return authz.Implicit(self.unit_get) - - def setup_authentication(self, target): - _conf = self.conf.get("authentication") - if _conf: - target.authn_broker = populate_authn_broker( - _conf, self.unit_get, target.template_handler - ) - else: - target.authn_broker = {} - - target.endpoint_to_authn_method = {} - for method in target.authn_broker: - try: - target.endpoint_to_authn_method[method.action] = method - except AttributeError: - pass - - def setup_login_hint_lookup(self): - _conf = self.conf.get("login_hint_lookup") - if _conf: - _userinfo = None - _kwargs = _conf.get("kwargs") - if _kwargs: - _userinfo_conf = _kwargs.get("userinfo") - if _userinfo_conf: - _userinfo = init_user_info(_userinfo_conf, self.endpoint_context.cwd) - - if _userinfo is None: - _userinfo = self.endpoint_context.userinfo - - self.endpoint_context.login_hint_lookup = init_service(_conf) - self.endpoint_context.login_hint_lookup.userinfo = _userinfo - - def setup_client_authn_methods(self): - self.endpoint_context.client_authn_methods = client_auth_setup( - self.unit_get, self.conf.get("client_authn_methods") - ) + def get_context_attribute(self, attr, *args): + _val = getattr(self.context, attr) + if not _val and self.upstream_get: + return self.upstream_get('context_attribute', attr) diff --git a/src/idpyoidc/server/claims/__init__.py b/src/idpyoidc/server/claims/__init__.py new file mode 100644 index 00000000..4c37b47f --- /dev/null +++ b/src/idpyoidc/server/claims/__init__.py @@ -0,0 +1,27 @@ +from typing import Optional + +from idpyoidc import claims + + +class Claims(claims.Claims): + + def get_base_url(self, configuration: dict): + _base = configuration.get('base_url') + if not _base: + _base = configuration.get('issuer') + + return _base + + def get_id(self, configuration: dict): + return configuration.get('issuer') + + def supported_to_preferred(self, + supported: dict, + base_url: Optional[str] = '', + info: Optional[dict] = None): + # Add defaults + for key, val in supported.items(): + if val is None: + continue + if key not in self.prefer: + self.prefer[key] = val diff --git a/src/idpyoidc/server/work_environment/oauth2.py b/src/idpyoidc/server/claims/oauth2.py similarity index 57% rename from src/idpyoidc/server/work_environment/oauth2.py rename to src/idpyoidc/server/claims/oauth2.py index 7473329f..8b259487 100644 --- a/src/idpyoidc/server/work_environment/oauth2.py +++ b/src/idpyoidc/server/claims/oauth2.py @@ -1,17 +1,22 @@ from typing import Optional from idpyoidc.message.oauth2 import ASConfigurationResponse -from idpyoidc.server import work_environment -# from idpyoidc.server.client_authn import get_client_authn_methods +from idpyoidc.server import claims +REGISTER2PREFERRED = { + # "require_signed_request_object": "request_object_algs_supported", + "token_endpoint_auth_method": "token_endpoint_auth_methods_supported", + "response_types": "response_types_supported", + "grant_types": "grant_types_supported", + # In OAuth2 but not in OIDC + "scope": "scopes_supported", + "token_endpoint_auth_signing_alg": "token_endpoint_auth_signing_alg_values_supported", +} + + +class Claims(claims.Claims): + register2preferred = REGISTER2PREFERRED -class WorkEnvironment(work_environment.WorkEnvironment): - # 'issuer', 'authorization_endpoint', 'token_endpoint', 'jwks_uri', 'registration_endpoint', - # 'scopes_supported', 'response_types_supported', 'response_modes_supported', - # 'grant_types_supported', 'token_endpoint_auth_methods_supported', - # 'token_endpoint_auth_signing_alg_values_supported', 'service_documentation', - # 'ui_locales_supported', 'op_policy_uri', 'op_tos_uri', 'revocation_endpoint', - # 'introspection_endpoint' _supports = { # "client_authn_methods": get_client_authn_methods, "grant_types_supported": ["authorization_code", "implicit", "refresh_token"], @@ -26,7 +31,6 @@ class WorkEnvironment(work_environment.WorkEnvironment): "op_policy_uri": None, } - callback_path = {} callback_uris = ["redirect_uris"] @@ -34,12 +38,12 @@ class WorkEnvironment(work_environment.WorkEnvironment): def __init__(self, prefer: Optional[dict] = None, callback_path: Optional[dict] = None): - work_environment.WorkEnvironment.__init__(self, prefer=prefer, callback_path=callback_path) + claims.Claims.__init__(self, prefer=prefer, callback_path=callback_path) def provider_info(self, supports): _info = {} for key in ASConfigurationResponse.c_param.keys(): _val = self.get_preference(key, supports.get(key, None)) - if _val: + if _val and _val != []: _info[key] = _val return _info diff --git a/src/idpyoidc/server/work_environment/oidc.py b/src/idpyoidc/server/claims/oidc.py similarity index 52% rename from src/idpyoidc/server/work_environment/oidc.py rename to src/idpyoidc/server/claims/oidc.py index a5503481..26434eaa 100644 --- a/src/idpyoidc/server/work_environment/oidc.py +++ b/src/idpyoidc/server/claims/oidc.py @@ -1,12 +1,38 @@ from typing import Optional -from idpyoidc import work_environment as WE +from idpyoidc import claims from idpyoidc.message.oidc import ProviderConfigurationResponse -from idpyoidc.server import work_environment +from idpyoidc.message.oidc import RegistrationRequest +from idpyoidc.message.oidc import RegistrationResponse +from idpyoidc.server import claims as server_claims +REGISTER2PREFERRED = { + # "require_signed_request_object": "request_object_algs_supported", + "request_object_signing_alg": "request_object_signing_alg_values_supported", + "request_object_encryption_alg": "request_object_encryption_alg_values_supported", + "request_object_encryption_enc": "request_object_encryption_enc_values_supported", + "userinfo_signed_response_alg": "userinfo_signing_alg_values_supported", + "userinfo_encrypted_response_alg": "userinfo_encryption_alg_values_supported", + "userinfo_encrypted_response_enc": "userinfo_encryption_enc_values_supported", + "id_token_signed_response_alg": "id_token_signing_alg_values_supported", + "id_token_encrypted_response_alg": "id_token_encryption_alg_values_supported", + "id_token_encrypted_response_enc": "id_token_encryption_enc_values_supported", + "default_acr_values": "acr_values_supported", + "subject_type": "subject_types_supported", + "token_endpoint_auth_method": "token_endpoint_auth_methods_supported", + "response_types": "response_types_supported", + "grant_types": "grant_types_supported", + # In OAuth2 but not in OIDC + "scope": "scopes_supported", + "token_endpoint_auth_signing_alg": "token_endpoint_auth_signing_alg_values_supported", +} -class WorkEnvironment(work_environment.WorkEnvironment): - parameter = work_environment.WorkEnvironment.parameter.copy() + +class Claims(server_claims.Claims): + parameter = server_claims.Claims.parameter.copy() + + registration_response = RegistrationResponse + registration_request = RegistrationRequest _supports = { "acr_values_supported": None, @@ -19,9 +45,9 @@ class WorkEnvironment(work_environment.WorkEnvironment): "display_values_supported": None, "encrypt_id_token_supported": None, "grant_types_supported": ["authorization_code", "implicit", "refresh_token"], - "id_token_signing_alg_values_supported": WE.get_signing_algs, - "id_token_encryption_alg_values_supported": WE.get_encryption_algs, - "id_token_encryption_enc_values_supported": WE.get_encryption_encs, + "id_token_signing_alg_values_supported": claims.get_signing_algs, + "id_token_encryption_alg_values_supported": claims.get_encryption_algs, + "id_token_encryption_enc_values_supported": claims.get_encryption_encs, "initiate_login_uri": None, "jwks": None, "jwks_uri": None, @@ -29,17 +55,20 @@ class WorkEnvironment(work_environment.WorkEnvironment): "require_auth_time": None, "scopes_supported": ["openid"], "service_documentation": None, + 'subject_types_supported': ['public', 'pairwise', 'ephemeral'], "op_tos_uri": None, "ui_locales_supported": None, # "version": '3.0' # "verify_args": None, } + register2preferred = REGISTER2PREFERRED + def __init__(self, prefer: Optional[dict] = None, callback_path: Optional[dict] = None ): - work_environment.WorkEnvironment.__init__(self, prefer=prefer, callback_path=callback_path) + server_claims.Claims.__init__(self, prefer=prefer, callback_path=callback_path) def verify_rules(self): if self.get_preference("request_parameter_supported") and self.get_preference( @@ -64,6 +93,7 @@ def provider_info(self, supports): _info = {} for key in ProviderConfigurationResponse.c_param.keys(): _val = self.get_preference(key, supports.get(key, None)) - if _val is not None: + if _val not in [None, []]: _info[key] = _val + return _info diff --git a/src/idpyoidc/server/endpoint.py b/src/idpyoidc/server/endpoint.py index fc8bb5de..1514a344 100755 --- a/src/idpyoidc/server/endpoint.py +++ b/src/idpyoidc/server/endpoint.py @@ -13,6 +13,7 @@ from idpyoidc.message import Message from idpyoidc.message.oauth2 import ResponseMessage from idpyoidc.message.oidc import RegistrationRequest +from idpyoidc.node import Node from idpyoidc.server.client_authn import verify_client from idpyoidc.server.exception import UnAuthorizedClient from idpyoidc.server.util import OAUTH2_NOCACHE_HEADERS @@ -77,7 +78,7 @@ def fragment_encoding(return_type): return True -class Endpoint(object): +class Endpoint(Node): request_cls = Message response_cls = Message error_cls = ResponseMessage @@ -102,6 +103,8 @@ def __init__(self, upstream_get: Callable, **kwargs): self.kwargs = kwargs self.full_path = "" + Node.__init__(self, upstream_get=upstream_get) + for param in [ "request_cls", "response_cls", diff --git a/src/idpyoidc/server/endpoint_context.py b/src/idpyoidc/server/endpoint_context.py index 95abc2d1..fa881379 100755 --- a/src/idpyoidc/server/endpoint_context.py +++ b/src/idpyoidc/server/endpoint_context.py @@ -11,15 +11,19 @@ from requests import request from idpyoidc.context import OidcContext -from idpyoidc.message.oidc import ProviderConfigurationResponse +from idpyoidc.server import authz +from idpyoidc.server.claims import Claims +from idpyoidc.server.claims.oauth2 import Claims as OAUTH2_Claims +from idpyoidc.server.claims.oidc import Claims as OIDC_Claims +from idpyoidc.server.client_authn import client_auth_setup from idpyoidc.server.configure import OPConfiguration from idpyoidc.server.scopes import SCOPE2CLAIMS from idpyoidc.server.scopes import Scopes +from idpyoidc.server.session.manager import create_session_manager from idpyoidc.server.session.manager import SessionManager from idpyoidc.server.template_handler import Jinja2TemplateHandler +from idpyoidc.server.user_authn.authn_context import populate_authn_broker from idpyoidc.server.util import get_http_params -from idpyoidc.server.work_environment.oauth2 import WorkEnvironment as OAUTH2_Env -from idpyoidc.server.work_environment.oidc import WorkEnvironment as OIDC_Env from idpyoidc.util import importer from idpyoidc.util import rndstr @@ -114,6 +118,8 @@ class EndpointContext(OidcContext): "client_authn_method": {}, } + init_args = ['upstream_get', 'handler'] + def __init__( self, conf: Union[dict, OPConfiguration], @@ -123,19 +129,23 @@ def __init__( httpc: Optional[Any] = None, server_type: Optional[str] = '', entity_id: Optional[str] = "", - keyjar: Optional[KeyJar] = None + keyjar: Optional[KeyJar] = None, + claims_class: Optional[Claims] = None ): _id = entity_id or conf.get("issuer", "") OidcContext.__init__(self, conf, entity_id=_id) self.conf = conf self.upstream_get = upstream_get - if not server_type or server_type == "oidc": - self.work_environment = OIDC_Env() - elif server_type == "oauth2": - self.work_environment = OAUTH2_Env() + if claims_class: + self.claims = claims_class else: - raise ValueError(f"Unknown server type: {server_type}") + if not server_type or server_type == "oidc": + self.claims = OIDC_Claims() + elif server_type == "oauth2": + self.claims = OAUTH2_Claims() + else: + raise ValueError(f"Unknown server type: {server_type}") _client_db = conf.get("client_db") if _client_db: @@ -145,7 +155,7 @@ def __init__( logger.debug("No special client db, will use memory based dictionary") self.cdb = {} - # For my Dev environment + # For my Dev claims self.jti_db = {} self.registration_access_token = {} # self.session_db = {} @@ -244,9 +254,63 @@ def __init__( if isinstance(conf, OPConfiguration): conf = conf.conf _supports = self.supports() - self.keyjar = self.work_environment.load_conf(conf, supports=_supports, keyjar=keyjar) - self.provider_info = self.work_environment.provider_info(_supports) + self.keyjar = self.claims.load_conf(conf, supports=_supports, keyjar=keyjar) + self.provider_info = self.claims.provider_info(_supports) self.provider_info['issuer'] = self.issuer + self.provider_info.update(self._get_endpoint_info()) + + # INTERFACES + + self.authz = self.setup_authz() + + self.setup_authentication() + + self.session_manager = create_session_manager( + self.unit_get, + self.th_args, + sub_func=self._sub_func, + conf=self.conf, + ) + + self.do_userinfo() + + # Must be done after userinfo + self.setup_login_hint_lookup() + self.set_remember_token() + + self.setup_client_authn_methods() + + _id_token_handler = self.session_manager.token_handler.handler.get("id_token") + # if _id_token_handler: + # self.provider_info.update(_id_token_handler.provider_info) + + def setup_authz(self): + authz_spec = self.conf.get("authz") + if authz_spec: + return init_service(authz_spec, self.unit_get) + else: + return authz.Implicit(self.unit_get) + + def setup_client_authn_methods(self): + self.client_authn_methods = client_auth_setup( + self.upstream_get, self.conf.get("client_authn_methods") + ) + + def setup_login_hint_lookup(self): + _conf = self.conf.get("login_hint_lookup") + if _conf: + _userinfo = None + _kwargs = _conf.get("kwargs") + if _kwargs: + _userinfo_conf = _kwargs.get("userinfo") + if _userinfo_conf: + _userinfo = init_user_info(_userinfo_conf, self.cwd) + + if _userinfo is None: + _userinfo = self.userinfo + + self.login_hint_lookup = init_service(_conf) + self.login_hint_lookup.userinfo = _userinfo def new_cookie(self, name: str, max_age: Optional[int] = 0, **kwargs): cookie_cont = self.cookie_handler.make_cookie_content( @@ -353,26 +417,17 @@ def supports(self): if self.upstream_get: for endpoint in self.upstream_get('endpoints').values(): res.update(endpoint.supports()) - res.update(self.work_environment.supports()) + res.update(self.claims.supports()) return res def set_provider_info(self): - prefers = self.work_environment.prefer - supported = self.supports() - _info = {'issuer': self.issuer, 'version': "3.0"} + _info = self.claims.provider_info(self.supports()) + _info.update({'issuer': self.issuer, 'version': "3.0"}) for endp in self.upstream_get('endpoints').values(): if endp.endpoint_name: _info[endp.endpoint_name] = endp.full_path - for key, spec in ProviderConfigurationResponse.c_param.items(): - _val = prefers.get(key, None) - if not _val and _val != False: - _val = supported.get(key, None) - if not _val and _val != False: - continue - _info[key] = _val - # acr_values if 'acr_values_supported' not in _info: if self.authn_broker: @@ -383,13 +438,69 @@ def set_provider_info(self): self.provider_info = _info def get_preference(self, claim, default=None): - return self.work_environment.get_preference(claim, default=default) + return self.claims.get_preference(claim, default=default) def set_preference(self, key, value): - self.work_environment.set_preference(key, value) + self.claims.set_preference(key, value) def get_usage(self, claim, default: Optional[str] = None): - return self.work_environment.get_usage(claim, default) + return self.claims.get_usage(claim, default) def set_usage(self, claim, value): - return self.work_environment.set_usage(claim, value) + return self.claims.set_usage(claim, value) + + def setup_authentication(self): + _conf = self.conf.get("authentication") + if _conf: + self.authn_broker = populate_authn_broker( + _conf, self.upstream_get, self.template_handler + ) + else: + self.authn_broker = {} + + self.endpoint_to_authn_method = {} + for method in self.authn_broker: + try: + self.endpoint_to_authn_method[method.action] = method + except AttributeError: + pass + + def unit_get(self, what, *arg): + _func = getattr(self, f"get_{what}", None) + if _func: + return _func(*arg) + return None + + def get_attribute(self, attr, *args): + try: + val = getattr(self, attr) + except AttributeError: + if self.upstream_get: + return self.upstream_get("attribute", attr) + else: + return None + else: + if val is None and self.upstream_get: + return self.upstream_get("attribute", attr) + else: + return val + + def set_attribute(self, attr, val): + setattr(self, attr, val) + + def get_unit(self, *args): + return self + + def get_context(self, *args): + return self + + def map_supported_to_preferred(self): + self.claims.supported_to_preferred(self.supports()) + return self.claims.prefer + + def _get_endpoint_info(self): + _res = {} + for name, endp in self.upstream_get('endpoints').items(): + if endp.endpoint_name: + _res[endp.endpoint_name] = endp.full_path + return _res diff --git a/src/idpyoidc/server/oauth2/add_on/dpop.py b/src/idpyoidc/server/oauth2/add_on/dpop.py index 4bb431c8..e426acd3 100644 --- a/src/idpyoidc/server/oauth2/add_on/dpop.py +++ b/src/idpyoidc/server/oauth2/add_on/dpop.py @@ -131,13 +131,11 @@ def token_args(context, client_id, token_args: Optional[dict] = None): return token_args -def add_support(endpoint, **kwargs): +def add_support(endpoint: dict, **kwargs): # _token_endp = endpoint["token"] _token_endp.post_parse_request.append(post_parse_request) - # Endpoint Context stuff - # _endp.context.token_args_methods.append(token_args) _algs_supported = kwargs.get("dpop_signing_alg_values_supported") if not _algs_supported: _algs_supported = ["RS256"] diff --git a/src/idpyoidc/server/oauth2/introspection.py b/src/idpyoidc/server/oauth2/introspection.py index 2ecf3b00..a7854959 100644 --- a/src/idpyoidc/server/oauth2/introspection.py +++ b/src/idpyoidc/server/oauth2/introspection.py @@ -20,7 +20,7 @@ class Introspection(Endpoint): response_format = "json" endpoint_name = "introspection_endpoint" name = "introspection" - default_capabilities = { + _supports = { "client_authn_method": [ "client_secret_basic", "client_secret_post", diff --git a/src/idpyoidc/server/oauth2/server_metadata.py b/src/idpyoidc/server/oauth2/server_metadata.py index 3e0230d8..2f9cea10 100755 --- a/src/idpyoidc/server/oauth2/server_metadata.py +++ b/src/idpyoidc/server/oauth2/server_metadata.py @@ -13,11 +13,11 @@ class ServerMetadata(Endpoint): response_format = "json" name = "server_metadata" - def __init__(self, server_get, **kwargs): - Endpoint.__init__(self, server_get=server_get, **kwargs) + def __init__(self, upstream_get, **kwargs): + Endpoint.__init__(self, upstream_get=upstream_get, **kwargs) self.pre_construct.append(self.add_endpoints) - def add_endpoints(self, request, client_id, endpoint_context, **kwargs): + def add_endpoints(self, request, client_id, context, **kwargs): for endpoint in [ "authorization_endpoint", "registration_endpoint", @@ -25,11 +25,11 @@ def add_endpoints(self, request, client_id, endpoint_context, **kwargs): "userinfo_endpoint", "end_session_endpoint", ]: - endp_instance = self.server_get("endpoint", endpoint) + endp_instance = self.upstream_get("endpoint", endpoint) if endp_instance: request[endpoint] = endp_instance.endpoint_path return request def process_request(self, request=None, **kwargs): - return {"response_args": self.server_get("endpoint_context").provider_info} + return {"response_args": self.upstream_get("context").provider_info} diff --git a/src/idpyoidc/server/oidc/authorization.py b/src/idpyoidc/server/oidc/authorization.py index b6293806..3182cb22 100755 --- a/src/idpyoidc/server/oidc/authorization.py +++ b/src/idpyoidc/server/oidc/authorization.py @@ -2,7 +2,7 @@ from typing import Callable from urllib.parse import urlsplit -from idpyoidc import work_environment +from idpyoidc import claims from idpyoidc.message import oidc from idpyoidc.message.oidc import Claims from idpyoidc.message.oidc import verified_claim_name @@ -78,10 +78,10 @@ class Authorization(authorization.Authorization): _supports = { "claims_parameter_supported": True, - "encrypt_request_object_supported": None, - "request_object_signing_alg_values_supported": work_environment.get_signing_algs, - "request_object_encryption_alg_values_supported": work_environment.get_encryption_algs, - "request_object_encryption_enc_values_supported": work_environment.get_encryption_encs, + "encrypt_request_object_supported": True, + "request_object_signing_alg_values_supported": claims.get_signing_algs, + "request_object_encryption_alg_values_supported": claims.get_encryption_algs, + "request_object_encryption_enc_values_supported": claims.get_encryption_encs, "request_parameter_supported": True, "request_uri_parameter_supported": True, "require_request_uri_registration": False, @@ -93,7 +93,6 @@ class Authorization(authorization.Authorization): def __init__(self, upstream_get: Callable, **kwargs): authorization.Authorization.__init__(self, upstream_get, **kwargs) - # self.pre_construct.append(self._pre_construct) self.post_parse_request.append(self._do_request_uri) self.post_parse_request.append(self._post_parse_request) diff --git a/src/idpyoidc/server/oidc/backchannel_authentication.py b/src/idpyoidc/server/oidc/backchannel_authentication.py index 6607cebe..60134e90 100644 --- a/src/idpyoidc/server/oidc/backchannel_authentication.py +++ b/src/idpyoidc/server/oidc/backchannel_authentication.py @@ -323,8 +323,6 @@ def is_usable(self, request=None, authorization_token=None): def _verify( self, - context: EndpointContext, - request: Optional[Union[dict, Message]] = None, authorization_token: Optional[str] = None, endpoint=None, # Optional[Endpoint] get_client_id_from_token: Optional[Callable] = None, diff --git a/src/idpyoidc/server/oidc/provider_config.py b/src/idpyoidc/server/oidc/provider_config.py index 5f6478a6..51f9a9d4 100755 --- a/src/idpyoidc/server/oidc/provider_config.py +++ b/src/idpyoidc/server/oidc/provider_config.py @@ -20,7 +20,7 @@ def __init__(self, upstream_get, **kwargs): def add_endpoints(self, request, client_id, context, **kwargs): for endpoint in [ "authorization", - "provider_config", + # "provider_config", "token", "userinfo", "session", diff --git a/src/idpyoidc/server/oidc/registration.py b/src/idpyoidc/server/oidc/registration.py index 1c159ada..09709f71 100755 --- a/src/idpyoidc/server/oidc/registration.py +++ b/src/idpyoidc/server/oidc/registration.py @@ -10,9 +10,6 @@ from cryptojwt.jws.utils import alg2keytype from cryptojwt.utils import as_bytes -from idpyoidc.client.oidc import PREFERENCE2PROVIDER -# from idpyoidc.defaults import PREFERENCE2SUPPORTED -from idpyoidc.client.work_environment.transform import REGISTER2PREFERRED from idpyoidc.exception import MessageException from idpyoidc.message.oauth2 import ResponseMessage from idpyoidc.message.oidc import ClientRegistrationErrorResponse @@ -139,14 +136,14 @@ def match_claim(self, claim, val): _context = self.upstream_get("context") # Use my defaults - _my_key = REGISTER2PREFERRED.get(claim, claim) + _my_key = _context.claims.register2preferred.get(claim, claim) try: _val = _context.provider_info[_my_key] except KeyError: return val try: - _claim_spec = RegistrationResponse.c_param[claim] + _claim_spec = _context.claims.registration_response.c_param[claim] except KeyError: # something I don't know anything about return None @@ -169,9 +166,10 @@ def match_claim(self, claim, val): def filter_client_request(self, request: dict) -> dict: _args = {} - _provider_info = self.upstream_get("context").provider_info + _context = self.upstream_get("context") + _provider_info = _context.provider_info for key, val in request.items(): - if key not in REGISTER2PREFERRED: + if key not in _context.claims.register2preferred: _args[key] = val continue @@ -252,7 +250,8 @@ def do_client_registration(self, request, client_id, ignore=None): # Do I have the necessary keys for item in ["id_token_signed_response_alg", "userinfo_signed_response_alg"]: if item in request: - if request[item] in _context.provider_info[PREFERENCE2PROVIDER[item]]: + if request[item] in _context.provider_info[ + _context.claims.register2preferred[item]]: ktyp = alg2keytype(request[item]) # do I have this ktyp and for EC type keys the curve if ktyp not in ["none", "oct"]: @@ -402,7 +401,7 @@ def client_registration_setup(self, request, new_id=True, set_secret=True): _error = "invalid_configuration_request" if len(err.args) > 1: if err.args[1] == "initiate_login_uri": - _error = "invalid_client_metadata" + _error = "invalid_client_claims" return ResponseMessage(error=_error, error_description="%s" % err) @@ -496,6 +495,6 @@ def process_verify_error(self, exception): if isinstance(exception, ValueError): if len(exception.args) > 1: if exception.args[1] == "initiate_login_uri": - _error = "invalid_client_metadata" + _error = "invalid_client_claims" return self.error_cls(error=_error, error_description=f"{exception}") diff --git a/src/idpyoidc/server/oidc/session.py b/src/idpyoidc/server/oidc/session.py index e9d2509e..99e30b0c 100644 --- a/src/idpyoidc/server/oidc/session.py +++ b/src/idpyoidc/server/oidc/session.py @@ -92,7 +92,8 @@ class Session(Endpoint): def __init__(self, upstream_get, **kwargs): _csi = kwargs.get("check_session_iframe") if _csi and not _csi.startswith("http"): - kwargs["check_session_iframe"] = add_path(upstream_get("context").issuer, _csi) + # unit since context does not exist at this point in time + kwargs["check_session_iframe"] = add_path(upstream_get("unit").issuer, _csi) Endpoint.__init__(self, upstream_get, **kwargs) self.iv = as_bytes(rndstr(24)) diff --git a/src/idpyoidc/server/oidc/token.py b/src/idpyoidc/server/oidc/token.py index 5c4436f2..b4280901 100755 --- a/src/idpyoidc/server/oidc/token.py +++ b/src/idpyoidc/server/oidc/token.py @@ -1,6 +1,6 @@ import logging -from idpyoidc import work_environment +from idpyoidc import claims from idpyoidc.message import Message from idpyoidc.message import oidc @@ -33,7 +33,7 @@ class Token(token.Token): "client_secret_jwt", "private_key_jwt", ], - "token_endpoint_auth_signing_alg_values_supported": work_environment.get_signing_algs, + "token_endpoint_auth_signing_alg_values_supported": claims.get_signing_algs, } helper_by_grant_type = { diff --git a/src/idpyoidc/server/oidc/userinfo.py b/src/idpyoidc/server/oidc/userinfo.py index 78a593bb..31bbce7d 100755 --- a/src/idpyoidc/server/oidc/userinfo.py +++ b/src/idpyoidc/server/oidc/userinfo.py @@ -9,7 +9,7 @@ from cryptojwt.jwt import JWT from cryptojwt.jwt import utc_time_sans_frac -from idpyoidc import work_environment +from idpyoidc import claims from idpyoidc.message import Message from idpyoidc.message import oidc from idpyoidc.message.oauth2 import ResponseMessage @@ -30,10 +30,10 @@ class UserInfo(Endpoint): name = "userinfo" _supports = { "claim_types_supported": ["normal", "aggregated", "distributed"], - "encrypt_userinfo_supported": False, - "userinfo_signing_alg_values_supported": work_environment.get_signing_algs, - "userinfo_encryption_alg_values_supported": work_environment.get_encryption_algs, - "userinfo_encryption_enc_values_supported": work_environment.get_encryption_encs, + "encrypt_userinfo_supported": True, + "userinfo_signing_alg_values_supported": claims.get_signing_algs, + "userinfo_encryption_alg_values_supported": claims.get_encryption_algs, + "userinfo_encryption_enc_values_supported": claims.get_encryption_encs, } def __init__(self, upstream_get: Callable, add_claims_by_scope: Optional[bool] = True, **kwargs): diff --git a/src/idpyoidc/server/session/grant_manager.py b/src/idpyoidc/server/session/grant_manager.py index 15a3a2fb..ec99134c 100644 --- a/src/idpyoidc/server/session/grant_manager.py +++ b/src/idpyoidc/server/session/grant_manager.py @@ -287,6 +287,6 @@ def flush(self): # -def create_grant_manager(server_get, token_handler_args, conf=None, **kwargs): - _token_handler = handler.factory(server_get, **token_handler_args) +def create_grant_manager(upstream_get, token_handler_args, conf=None, **kwargs): + _token_handler = handler.factory(upstream_get, **token_handler_args) return GrantManager(_token_handler, conf=conf) diff --git a/src/idpyoidc/server/session/token.py b/src/idpyoidc/server/session/token.py index 0d6d6c80..a450f1bd 100644 --- a/src/idpyoidc/server/session/token.py +++ b/src/idpyoidc/server/session/token.py @@ -86,8 +86,6 @@ class SessionToken(Item): "resources": [], "scope": [], "token_class": "", - "usage_rules": {}, - "used": 0, "value": "", } ) diff --git a/src/idpyoidc/server/token/handler.py b/src/idpyoidc/server/token/handler.py index 4cebeacc..20b36fa5 100755 --- a/src/idpyoidc/server/token/handler.py +++ b/src/idpyoidc/server/token/handler.py @@ -10,7 +10,6 @@ from idpyoidc.impexp import ImpExp from idpyoidc.item import DLDict from idpyoidc.util import importer - from . import DefaultToken from . import Token from . import UnknownToken @@ -25,11 +24,11 @@ class TokenHandler(ImpExp): parameter = {"handler": DLDict, "handler_order": [""]} def __init__( - self, - access_token: Optional[Token] = None, - authorization_code: Optional[Token] = None, - refresh_token: Optional[Token] = None, - id_token: Optional[Token] = None, + self, + access_token: Optional[Token] = None, + authorization_code: Optional[Token] = None, + refresh_token: Optional[Token] = None, + id_token: Optional[Token] = None, ): ImpExp.__init__(self) self.handler = {"authorization_code": authorization_code, "access_token": access_token} @@ -142,13 +141,13 @@ def default_token(spec): def factory( - upstream_get, - code: Optional[dict] = None, - token: Optional[dict] = None, - refresh: Optional[dict] = None, - id_token: Optional[dict] = None, - jwks_file: Optional[str] = "", - **kwargs + upstream_get, + code: Optional[dict] = None, + token: Optional[dict] = None, + refresh: Optional[dict] = None, + id_token: Optional[dict] = None, + jwks_file: Optional[str] = "", + **kwargs ) -> TokenHandler: """ Create a token handler @@ -169,7 +168,7 @@ def factory( key_defs = [] read_only = False - cwd = upstream_get("context").cwd + cwd = upstream_get("attribute", "cwd") if kwargs.get("jwks_def"): defs = kwargs["jwks_def"] if not jwks_file: diff --git a/src/idpyoidc/server/token/id_token.py b/src/idpyoidc/server/token/id_token.py index f7e8f652..0840ef5f 100755 --- a/src/idpyoidc/server/token/id_token.py +++ b/src/idpyoidc/server/token/id_token.py @@ -110,7 +110,8 @@ def get_sign_and_encrypt_algorithms( class IDToken(Token): - default_capabilities = { + _supports = { + "encrypt_id_token_supported": None, "id_token_signing_alg_values_supported": None, "id_token_encryption_alg_values_supported": None, "id_token_encryption_enc_values_supported": None, @@ -128,7 +129,7 @@ def __init__( self.upstream_get = upstream_get self.kwargs = kwargs self.scope_to_claims = None - self.provider_info = construct_provider_info(self.default_capabilities, **kwargs) + self.provider_info = construct_provider_info(self._supports, **kwargs) def payload( self, @@ -334,7 +335,7 @@ def info(self, token): if is_expired(_payload["exp"]): raise ToOld("Token has expired") - # All the token metadata + # All the token claims return { "sid": _payload.get("sid", ""), # TODO: would sid be there? # "type": _payload["ttype"], diff --git a/src/idpyoidc/server/token/jwt_token.py b/src/idpyoidc/server/token/jwt_token.py index ce03bb0c..6cb12d7a 100644 --- a/src/idpyoidc/server/token/jwt_token.py +++ b/src/idpyoidc/server/token/jwt_token.py @@ -141,7 +141,7 @@ def info(self, token): if is_expired(_payload["exp"]): raise ToOld("Token has expired") - # All the token metadata + # All the token claims _res = { "sid": _payload["sid"], "token_class": _payload["token_class"], diff --git a/src/idpyoidc/server/user_authn/authn_context.py b/src/idpyoidc/server/user_authn/authn_context.py index 08ab8ffe..4e42775d 100755 --- a/src/idpyoidc/server/user_authn/authn_context.py +++ b/src/idpyoidc/server/user_authn/authn_context.py @@ -28,7 +28,7 @@ def __setitem__(self, key, info): """ Adds a new authentication method. - :param value: A dictionary with metadata and configuration information + :param value: A dictionary with claims and configuration information """ for attr in ["acr", "method"]: diff --git a/src/idpyoidc/server/work_environment/__init__.py b/src/idpyoidc/server/work_environment/__init__.py deleted file mode 100644 index 06bbda43..00000000 --- a/src/idpyoidc/server/work_environment/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -from idpyoidc import work_environment - - -class WorkEnvironment(work_environment.WorkEnvironment): - - def get_base_url(self, configuration: dict): - _base = configuration.get('base_url') - if not _base: - _base = configuration.get('issuer') - - return _base - - def get_id(self, configuration: dict): - return configuration.get('issuer') diff --git a/src/idpyoidc/storage/abfile.py b/src/idpyoidc/storage/abfile.py index a6577993..e6f980c0 100644 --- a/src/idpyoidc/storage/abfile.py +++ b/src/idpyoidc/storage/abfile.py @@ -15,7 +15,7 @@ class AbstractFileSystem(DictType): """ - FileSystem implements a simple file based database. + FileSystem implements a simple read-only file based database. It has a dictionary like interface. Each key maps one-to-one to a file on disc, where the content of the file is the value. @@ -24,7 +24,10 @@ class AbstractFileSystem(DictType): """ def __init__( - self, fdir: Optional[str] = "", key_conv: Optional[str] = "", value_conv: Optional[str] = "" + self, fdir: Optional[str] = "", + key_conv: Optional[str] = "", + value_conv: Optional[str] = "", + **kwargs ): """ items = FileSystem( @@ -87,6 +90,7 @@ def __getitem__(self, item): self.storage[item] = self._read_info(fname) logger.debug('Read from "%s"', item) + # storage values are already value converted return self.storage[item] def __setitem__(self, key, value): diff --git a/tests/pub_client.jwks b/tests/pub_client.jwks index 84a27042..d5ce25ed 100644 --- a/tests/pub_client.jwks +++ b/tests/pub_client.jwks @@ -1 +1 @@ -{"keys": [{"kty": "EC", "use": "sig", "kid": "azZQQ2FEQjh3QnVZWVdrbHJkMEZSaWR6aVJ0LTBjeUFfeWRlbTRrRFZ5VQ", "crv": "P-256", "x": "2ADe18caWWGp6hpRbfa9HqQHDFNpid9xUmR56Wzm_wc", "y": "HnD_8QBanz4Y-UF8mKQFZXfqkGkXUSm34mLsdDKtSyk"}, {"kty": "RSA", "use": "sig", "kid": "SHEyYWcwNVk0LTdROTZzZ2FUWndIVXdack0xWUM5SEpwcS03dVUxWU4zRQ", "e": "AQAB", "n": "rRz52ddyP9Y2ezSlRsnkt-sjXfV_Ii7vOFX-cStLE3IUlVeSJGEe_kAASLr2r3BE2unjntaxj67NP8D95h_rzG1SpCklTEn-aTe3FOwNyTzUH_oiDVeRoEcf04Y43ciRGYRB5PhI6ii-2lYuig6hyUr776Qxiu6-0zw-M_ay2MgGSy5CEj55dDSvcUyxStUObxGpPWnEvybO1vnE7iJEWGNe0L5uPe5nLidOiR-JwjxSWEx1xZYtIjxaf2Ulu-qu4hwgwBUQdx4bNZyBfljKj55skWuHqPMG3xMjnedQC6Ms5bR3rIkbBpvmgI3kJK-4CZikM6ruyLo94-Lk19aYQw"}]} \ No newline at end of file +{"keys": [{"kty": "EC", "use": "sig", "kid": "azZQQ2FEQjh3QnVZWVdrbHJkMEZSaWR6aVJ0LTBjeUFfeWRlbTRrRFZ5VQ", "crv": "P-256", "x": "2ADe18caWWGp6hpRbfa9HqQHDFNpid9xUmR56Wzm_wc", "y": "HnD_8QBanz4Y-UF8mKQFZXfqkGkXUSm34mLsdDKtSyk"}, {"kty": "RSA", "use": "sig", "kid": "SHEyYWcwNVk0LTdROTZzZ2FUWndIVXdack0xWUM5SEpwcS03dVUxWU4zRQ", "n": "rRz52ddyP9Y2ezSlRsnkt-sjXfV_Ii7vOFX-cStLE3IUlVeSJGEe_kAASLr2r3BE2unjntaxj67NP8D95h_rzG1SpCklTEn-aTe3FOwNyTzUH_oiDVeRoEcf04Y43ciRGYRB5PhI6ii-2lYuig6hyUr776Qxiu6-0zw-M_ay2MgGSy5CEj55dDSvcUyxStUObxGpPWnEvybO1vnE7iJEWGNe0L5uPe5nLidOiR-JwjxSWEx1xZYtIjxaf2Ulu-qu4hwgwBUQdx4bNZyBfljKj55skWuHqPMG3xMjnedQC6Ms5bR3rIkbBpvmgI3kJK-4CZikM6ruyLo94-Lk19aYQw", "e": "AQAB"}]} \ No newline at end of file diff --git a/tests/pub_iss.jwks b/tests/pub_iss.jwks index 9b062907..77081f40 100644 --- a/tests/pub_iss.jwks +++ b/tests/pub_iss.jwks @@ -1 +1 @@ -{"keys": [{"kty": "EC", "use": "sig", "kid": "SmdKMlVGcG1zMnprdDdXZGpGWEczdHhlZVpGbkx1THpPdUY4d0w4bnZkSQ", "crv": "P-256", "x": "tRHJYm0fsOi0icpGEb33qiDVgt68ltMoYSWdLGhDGz4", "y": "fRpX0i6p5Jigf5I0qwW34PyStosMShwWAWS8x_w5o7E"}, {"kty": "RSA", "use": "sig", "kid": "R0FsaFdqREFaUFp1c0MwbUpsbHVSZ200blBJZWJVMTUtNGsyVlBmdHk5UQ", "e": "AQAB", "n": "2ilgsKVqF92KfhwmosSVeZOaDgb3RF1mbg-pqkmLO6YpOO06LF4V4angF-GhP-ysAm2E75aSIU4tnHVThFlcxTgKFqjYKJQXyVzTVK2r-L2IbvFPaDtvoU6WteybpMlIUVk2po3cFDGObCWYKCm7CUOLlwH0uOpui66P9VSCqdKVKbJRAQBvTSbP10KWPxulfqjWGJtHO5fY7-JVWwOBkG-eHSJIT_uaoPjyvKCZjknq04bLUV9qP78KRQpRyYijBN60w2v8F79baN9CN10TIEjjWKGz0uX0M_YYQzTUoSY5l5ka9RkL3wT4o2iQ1t5nHphX6aA-gqwgCQmi-nvjaw"}]} \ No newline at end of file +{"keys": [{"kty": "EC", "use": "sig", "kid": "SmdKMlVGcG1zMnprdDdXZGpGWEczdHhlZVpGbkx1THpPdUY4d0w4bnZkSQ", "crv": "P-256", "x": "tRHJYm0fsOi0icpGEb33qiDVgt68ltMoYSWdLGhDGz4", "y": "fRpX0i6p5Jigf5I0qwW34PyStosMShwWAWS8x_w5o7E"}, {"kty": "RSA", "use": "sig", "kid": "R0FsaFdqREFaUFp1c0MwbUpsbHVSZ200blBJZWJVMTUtNGsyVlBmdHk5UQ", "n": "2ilgsKVqF92KfhwmosSVeZOaDgb3RF1mbg-pqkmLO6YpOO06LF4V4angF-GhP-ysAm2E75aSIU4tnHVThFlcxTgKFqjYKJQXyVzTVK2r-L2IbvFPaDtvoU6WteybpMlIUVk2po3cFDGObCWYKCm7CUOLlwH0uOpui66P9VSCqdKVKbJRAQBvTSbP10KWPxulfqjWGJtHO5fY7-JVWwOBkG-eHSJIT_uaoPjyvKCZjknq04bLUV9qP78KRQpRyYijBN60w2v8F79baN9CN10TIEjjWKGz0uX0M_YYQzTUoSY5l5ka9RkL3wT4o2iQ1t5nHphX6aA-gqwgCQmi-nvjaw", "e": "AQAB"}]} \ No newline at end of file diff --git a/tests/request123456.jwt b/tests/request123456.jwt index 5f373e53..28f863e1 100644 --- a/tests/request123456.jwt +++ b/tests/request123456.jwt @@ -1 +1 @@ -eyJhbGciOiJSUzI1NiIsImtpZCI6IlNIRXlZV2N3TlZrMExUZFJPVFp6WjJGVVduZElWWGRhY2sweFdVTTVTRXB3Y1MwM2RWVXhXVTR6UlEifQ.eyJyZXNwb25zZV90eXBlIjogImNvZGUiLCAic3RhdGUiOiAic3RhdGUiLCAicmVkaXJlY3RfdXJpIjogImh0dHBzOi8vZXhhbXBsZS5jb20vY2xpL2F1dGh6X2NiIiwgInNjb3BlIjogIm9wZW5pZCIsICJub25jZSI6ICJBVTdNa3Z0cnNJRUxqRTE1dEVUeGx6ck9GdVZPUVRSM2h0ZldLMlcyakN3IiwgImNsaWVudF9pZCI6ICJjbGllbnRfaWQiLCAiaXNzIjogImNsaWVudF9pZCIsICJpYXQiOiAxNjcwNDk2MzA0LCAiYXVkIjogWyJodHRwczovL2V4YW1wbGUuY29tIl19.AJjkd2WenuqnvpKI1rODXmXTK_CvWR7zJ8EVB3y7y_nTK8xajubBQQbXJql1r6r2yzxGC7wXOXQnp-4CNFV45pHyjawxGbA-p-Ko4sdTzebiJDOf-JGPdh0hzWff0oepU0zsL3vqg9L8V534Z4v6ugZDYw1EUZaht5xvRFAUEyxwG6rEf05DRQif01288Zbnc8i5oCLpevCreTlKlo7_jEcJVSKlnmuyTyGpDGENgjt2U3hNb7pFMKOw8J848vq4ukQvDVlD_7qBzt_-VDN_NWIFkeSp2-1e_AbZtsQdXC-gLo9xaTOoS5hG5Eh1-fdzLGdmdb0m4Tz6stlFF_AWbw \ No newline at end of file +eyJhbGciOiJSUzI1NiIsImtpZCI6IlNIRXlZV2N3TlZrMExUZFJPVFp6WjJGVVduZElWWGRhY2sweFdVTTVTRXB3Y1MwM2RWVXhXVTR6UlEifQ.eyJyZXNwb25zZV90eXBlIjogImNvZGUiLCAic3RhdGUiOiAic3RhdGUiLCAicmVkaXJlY3RfdXJpIjogImh0dHBzOi8vZXhhbXBsZS5jb20vY2xpL2F1dGh6X2NiIiwgInNjb3BlIjogIm9wZW5pZCIsICJub25jZSI6ICJJTVFraVkxckVhT2pncW5VZkpGSjN6dGV1MG9QMDJ2S1J5d0xyM0p1aHFjIiwgImNsaWVudF9pZCI6ICJjbGllbnRfaWQiLCAiaXNzIjogImNsaWVudF9pZCIsICJpYXQiOiAxNjc1NjczNjU1LCAiYXVkIjogWyJodHRwczovL2V4YW1wbGUuY29tIl19.Oj4q4UDeBTbkpI3oAGl6Bwt_DS1_rHJQxmpLwkKQTEgaTh08Fhr64iZoxUyyJYOZGkmMlgXz5nJZLt1uO5uotsA2wZaoAn6-EMXZ8lfm8vDxq5YdqoJX_8UfE3HSQDlmIsuHdtjOSYijYUP2FtSMutryzxkAW9Sp50GpaJ6QmpL_GE55lEfpHpR4A_2rf0SikwW2xHtMYU90XI0Jv_m-6rBf6sTlaqePge6ToNjCxpyYDOWKMa-qwrMeFe99JECzDdMbMYXQB2WPmRVuFkV7mJOoxFY7wviqjT_-eM3YI_jKDPB6M2-oVXQ7IjUv5t3WqYWasEoGEXgkMk-WcfBtTw \ No newline at end of file diff --git a/tests/test_08_transform.py b/tests/test_08_transform.py index 606e87b6..dab3adcf 100644 --- a/tests/test_08_transform.py +++ b/tests/test_08_transform.py @@ -1,21 +1,22 @@ from typing import Callable -from cryptojwt.utils import importer import pytest +from cryptojwt.utils import importer -from idpyoidc.client.work_environment.oidc import WorkEnvironment as WorkEnvironmentOIDC -from idpyoidc.client.work_environment.transform import REGISTER2PREFERRED -from idpyoidc.client.work_environment.transform import create_registration_request -from idpyoidc.client.work_environment.transform import preferred_to_registered -from idpyoidc.client.work_environment.transform import supported_to_preferred +from idpyoidc.claims import Claims +from idpyoidc.client.claims.oidc import Claims as OIDC_Claims +from idpyoidc.client.claims.transform import create_registration_request +from idpyoidc.client.claims.transform import preferred_to_registered +from idpyoidc.client.claims.transform import supported_to_preferred from idpyoidc.message.oidc import ProviderConfigurationResponse from idpyoidc.message.oidc import RegistrationRequest class TestTransform: + @pytest.fixture(autouse=True) def setup(self): - supported = WorkEnvironmentOIDC._supports.copy() + supported = OIDC_Claims._supports.copy() for service in [ 'idpyoidc.client.oidc.access_token.AccessToken', 'idpyoidc.client.oidc.authorization.Authorization', @@ -149,35 +150,39 @@ def test_oidc_setup(self): 'token_endpoint_auth_method', 'tos_uri'} - preference = {} - pref = supported_to_preferred(supported=self.supported, preference=preference, - base_url='https://example.com') + claims = OIDC_Claims() + # No input from the IDP so info is absent + claims.prefer = supported_to_preferred(supported=self.supported, + preference=claims.prefer, + base_url='https://example.com') # These are the claims that has default values. A default value may be an empty list. # This is the case for claims like id_token_encryption_enc_values_supported. - assert set(pref.keys()) == {'application_type', - 'default_max_age', - 'grant_types_supported', - 'id_token_encryption_alg_values_supported', - 'id_token_encryption_enc_values_supported', - 'id_token_signing_alg_values_supported', - 'request_object_encryption_alg_values_supported', - 'request_object_encryption_enc_values_supported', - 'request_object_signing_alg_values_supported', - 'response_modes_supported', - 'response_types_supported', - 'scopes_supported', - 'subject_types_supported', - 'token_endpoint_auth_method', - 'token_endpoint_auth_signing_alg_values_supported', - 'userinfo_encryption_alg_values_supported', - 'userinfo_encryption_enc_values_supported', - 'userinfo_signing_alg_values_supported'} + assert set(claims.prefer.keys()) == {'application_type', + 'default_max_age', + 'encrypt_request_object_supported', + 'encrypt_userinfo_supported', + 'grant_types_supported', + 'id_token_encryption_alg_values_supported', + 'id_token_encryption_enc_values_supported', + 'id_token_signing_alg_values_supported', + 'request_object_encryption_alg_values_supported', + 'request_object_encryption_enc_values_supported', + 'request_object_signing_alg_values_supported', + 'response_modes_supported', + 'response_types_supported', + 'scopes_supported', + 'subject_types_supported', + 'token_endpoint_auth_method', + 'token_endpoint_auth_signing_alg_values_supported', + 'userinfo_encryption_alg_values_supported', + 'userinfo_encryption_enc_values_supported', + 'userinfo_signing_alg_values_supported'} # To verify that I have all the necessary claims to do client registration reg_claim = [] - for key, spec in RegistrationRequest.c_param.items(): - _pref_key = REGISTER2PREFERRED.get(key, key) + for key, spec in OIDC_Claims.registration_request.c_param.items(): + _pref_key = OIDC_Claims.register2preferred.get(key, key) if _pref_key in self.supported: reg_claim.append(key) @@ -188,7 +193,7 @@ def test_oidc_setup(self): l_to_s = [] non_oidc = [] - for key, pref_key in REGISTER2PREFERRED.items(): + for key, pref_key in OIDC_Claims.register2preferred.items(): spec = RegistrationRequest.c_param.get(key) if spec is None: non_oidc.append(pref_key) @@ -223,45 +228,49 @@ def test_provider_info(self): "acr_values_supported": ['mfa'], } - preference = {} - pref = supported_to_preferred(supported=self.supported, preference=preference, - base_url='https://example.com', - info=provider_info_response) + claims = OIDC_Claims() + claims.prefer = supported_to_preferred(supported=self.supported, + preference=claims.prefer, + base_url='https://example.com', + info=provider_info_response) # These are the claims that has default values - assert set(pref.keys()) == {'application_type', - 'default_max_age', - 'grant_types_supported', - 'id_token_encryption_alg_values_supported', - 'id_token_encryption_enc_values_supported', - 'id_token_signing_alg_values_supported', - 'request_object_encryption_alg_values_supported', - 'request_object_encryption_enc_values_supported', - 'request_object_signing_alg_values_supported', - 'response_modes_supported', - 'response_types_supported', - 'scopes_supported', - 'subject_types_supported', - 'token_endpoint_auth_method', - 'token_endpoint_auth_signing_alg_values_supported', - 'userinfo_encryption_alg_values_supported', - 'userinfo_encryption_enc_values_supported', - 'userinfo_signing_alg_values_supported'} + assert set(claims.prefer.keys()) == {'application_type', + 'default_max_age', + 'encrypt_request_object_supported', + 'encrypt_userinfo_supported', + 'grant_types_supported', + 'id_token_encryption_alg_values_supported', + 'id_token_encryption_enc_values_supported', + 'id_token_signing_alg_values_supported', + 'request_object_encryption_alg_values_supported', + 'request_object_encryption_enc_values_supported', + 'request_object_signing_alg_values_supported', + 'response_modes_supported', + 'response_types_supported', + 'scopes_supported', + 'subject_types_supported', + 'token_endpoint_auth_method', + 'token_endpoint_auth_signing_alg_values_supported', + 'userinfo_encryption_alg_values_supported', + 'userinfo_encryption_enc_values_supported', + 'userinfo_signing_alg_values_supported'} # least common denominator # The RP supports less than the OP - assert pref['scopes_supported'] == ['openid'] - assert pref["response_modes_supported"] == ['query', 'form_post'] + assert claims.get_preference('scopes_supported') == ['openid'] + assert claims.get_preference("response_modes_supported") == ['query', 'form_post'] # The OP supports less than the RP - assert pref["response_types_supported"] == ['code', 'id_token', 'code id_token'] + assert claims.get_preference("response_types_supported") == ['code', 'id_token', + 'code id_token'] class TestTransform2: @pytest.fixture(autouse=True) def setup(self): - self.work_environment = WorkEnvironmentOIDC() - supported = self.work_environment._supports.copy() + self.claims = OIDC_Claims() + supported = self.claims._supports.copy() for service in [ 'idpyoidc.client.oidc.access_token.AccessToken', 'idpyoidc.client.oidc.authorization.Authorization', @@ -295,7 +304,7 @@ def setup(self): 'contacts': ["ve7jtb@example.org", "mary@example.org"] } - self.work_environment.load_conf(preference, self.supported) + self.claims.load_conf(preference, self.supported) def test_registration_response(self): OP_BASEURL = 'https://example.com' @@ -322,12 +331,13 @@ def test_registration_response(self): "acr_values_supported": ['mfa'], } - pref = supported_to_preferred(supported=self.supported, - preference=self.work_environment.prefer, - base_url='https://example.com', - info=provider_info_response) + self.claims.prefer = supported_to_preferred(supported=self.supported, + preference=self.claims.prefer, + base_url='https://example.com', + info=provider_info_response) - registration_request = create_registration_request(pref, self.supported) + registration_request = create_registration_request(prefers=self.claims.prefer, + supported=self.supported) assert set(registration_request.keys()) == {'application_type', 'client_name', @@ -364,14 +374,16 @@ def test_registration_response(self): "https://client.example.org/rf.txt#qpXaRLh_n93TTR9F252ValdatUQvQiJi5BDub2BeznA"] } - to_use = preferred_to_registered(prefers=pref, - supported=self.supported, + to_use = preferred_to_registered(supported=self.supported, + prefers=self.claims.prefer, registration_response=registration_response) assert set(to_use.keys()) == {'application_type', 'client_name', 'contacts', 'default_max_age', + 'encrypt_request_object_supported', + 'encrypt_userinfo_supported', 'grant_types', 'id_token_signed_response_alg', 'jwks_uri', diff --git a/tests/test_09_work_condition.py b/tests/test_09_work_condition.py index 8353d34a..9fbb8b34 100644 --- a/tests/test_09_work_condition.py +++ b/tests/test_09_work_condition.py @@ -1,12 +1,12 @@ from typing import Callable -from cryptojwt.utils import importer import pytest as pytest +from cryptojwt.utils import importer -from idpyoidc.client.work_environment.oidc import WorkEnvironment as WorkEnvironmentOIDC -from idpyoidc.client.work_environment.transform import create_registration_request -from idpyoidc.client.work_environment.transform import preferred_to_registered -from idpyoidc.client.work_environment.transform import supported_to_preferred +from idpyoidc.client.claims.oidc import Claims +from idpyoidc.client.claims.transform import create_registration_request +from idpyoidc.client.claims.transform import preferred_to_registered +from idpyoidc.client.claims.transform import supported_to_preferred KEYSPEC = [ {"type": "RSA", "use": ["sig"]}, @@ -18,8 +18,8 @@ class TestWorkEnvironment: @pytest.fixture(autouse=True) def setup(self): - self.work_environment = WorkEnvironmentOIDC() - supported = self.work_environment._supports.copy() + self.claims = Claims() + supported = self.claims._supports.copy() for service in [ 'idpyoidc.client.oidc.access_token.AccessToken', 'idpyoidc.client.oidc.authorization.Authorization', @@ -57,9 +57,9 @@ def test_load_conf(self): 'contacts': ["ve7jtb@example.org", "mary@example.org"] } - self.work_environment.load_conf(client_conf, self.supported) - assert self.work_environment.get_preference('jwks') is None - assert self.work_environment.get_preference('jwks_uri') is None + self.claims.load_conf(client_conf, self.supported) + assert self.claims.get_preference('jwks') is None + assert self.claims.get_preference('jwks_uri') is None def test_load_jwks(self): # Symmetric and asymmetric keys published as JWKS @@ -76,9 +76,9 @@ def test_load_jwks(self): 'contacts': ["ve7jtb@example.org", "mary@example.org"] } - self.work_environment.load_conf(client_conf, self.supported) - assert self.work_environment.get_preference('jwks') is not None - assert self.work_environment.get_preference('jwks_uri') is None + self.claims.load_conf(client_conf, self.supported) + assert self.claims.get_preference('jwks') is not None + assert self.claims.get_preference('jwks_uri') is None def test_load_jwks_uri1(self): # Symmetric and asymmetric keys published through a jwks_uri @@ -93,9 +93,9 @@ def test_load_jwks_uri1(self): 'contacts': ["ve7jtb@example.org", "mary@example.org"] } - self.work_environment.load_conf(client_conf, self.supported) - assert self.work_environment.get_preference('jwks') is None - assert self.work_environment.get_preference( + self.claims.load_conf(client_conf, self.supported) + assert self.claims.get_preference('jwks') is None + assert self.claims.get_preference( 'jwks_uri') == f"{client_conf['base_url']}{client_conf['keys']['uri_path']}" def test_load_jwks_uri2(self): @@ -112,9 +112,9 @@ def test_load_jwks_uri2(self): 'contacts': ["ve7jtb@example.org", "mary@example.org"] } - self.work_environment.load_conf(client_conf, self.supported) - assert self.work_environment.get_preference('jwks') is None - assert self.work_environment.get_preference('jwks_uri') == client_conf['jwks_uri'] + self.claims.load_conf(client_conf, self.supported) + assert self.claims.get_preference('jwks') is None + assert self.claims.get_preference('jwks_uri') == client_conf['jwks_uri'] def test_registration_response(self): client_conf = { @@ -130,7 +130,7 @@ def test_registration_response(self): 'contacts': ["ve7jtb@example.org", "mary@example.org"] } - self.work_environment.load_conf(client_conf, self.supported) + self.claims.load_conf(client_conf, self.supported) OP_BASEURL = 'https://example.com' provider_info_response = { @@ -156,12 +156,12 @@ def test_registration_response(self): "acr_values_supported": ['mfa'], } - pref = supported_to_preferred(supported=self.supported, - preference=self.work_environment.prefer, - base_url='https://example.com', - info=provider_info_response) + pref = self.claims.prefer = supported_to_preferred(supported=self.supported, + preference=self.claims.prefer, + base_url='https://example.com', + info=provider_info_response) - registration_request = create_registration_request(pref, self.supported) + registration_request = create_registration_request(self.claims.prefer, self.supported) assert set(registration_request.keys()) == {'application_type', 'client_name', @@ -199,7 +199,7 @@ def test_registration_response(self): "https://client.example.org/rf.txt#qpXaRLh_n93TTR9F252ValdatUQvQiJi5BDub2BeznA"] } - to_use = preferred_to_registered(prefers=pref, + to_use = preferred_to_registered(prefers=self.claims.prefer, supported=self.supported, registration_response=registration_response) @@ -209,6 +209,8 @@ def test_registration_response(self): 'client_secret', 'contacts', 'default_max_age', + 'encrypt_request_object_supported', + 'encrypt_userinfo_supported', 'grant_types', 'id_token_signed_response_alg', 'jwks', diff --git a/tests/test_client_01_service_context.py b/tests/test_client_01_service_context.py index a57e9c82..a4637a42 100644 --- a/tests/test_client_01_service_context.py +++ b/tests/test_client_01_service_context.py @@ -2,6 +2,7 @@ from cryptojwt.key_jar import build_keyjar from idpyoidc.client.service_context import ServiceContext +from idpyoidc.node import Unit KEYDEFS = [ {"type": "RSA", "key": "", "use": ["sig"]}, @@ -24,7 +25,8 @@ class TestServiceContext: @pytest.fixture(autouse=True) def setup(self): - self.service_context = ServiceContext(config=MINI_CONFIG) + self.unit = Unit() + self.service_context = ServiceContext(config=MINI_CONFIG, upstream_get=self.unit.unit_get) def test_init(self): assert self.service_context @@ -37,11 +39,11 @@ def test_get_sign_alg(self): _alg = self.service_context.get_sign_alg("id_token") assert _alg is None - self.service_context.work_environment.set_preference("id_token_signed_response_alg", "RS384") + self.service_context.claims.set_preference("id_token_signed_response_alg", "RS384") _alg = self.service_context.get_sign_alg("id_token") assert _alg == "RS384" - self.service_context.work_environment.prefer = {} + self.service_context.claims.prefer = {} self.service_context.provider_info["id_token_signing_alg_values_supported"] = [ "RS256", "ES256", @@ -53,15 +55,14 @@ def test_get_enc_alg_enc(self): _alg_enc = self.service_context.get_enc_alg_enc("userinfo") assert _alg_enc == {"alg": None, "enc": None} - self.service_context.work_environment.set_preference("userinfo_encrypted_response_alg", - "RSA1_5") - self.service_context.work_environment.set_preference("userinfo_encrypted_response_enc", - "A128CBC+HS256") + self.service_context.claims.set_preference("userinfo_encrypted_response_alg", "RSA1_5") + self.service_context.claims.set_preference("userinfo_encrypted_response_enc", + "A128CBC+HS256") _alg_enc = self.service_context.get_enc_alg_enc("userinfo") assert _alg_enc == {"alg": "RSA1_5", "enc": "A128CBC+HS256"} - self.service_context.work_environment.prefer = {} + self.service_context.claims.prefer = {} self.service_context.provider_info["userinfo_encryption_alg_values_supported"] = [ "RSA1_5", "A128KW", diff --git a/tests/test_client_02_entity.py b/tests/test_client_02_entity.py index c1fe9163..2492dc53 100644 --- a/tests/test_client_02_entity.py +++ b/tests/test_client_02_entity.py @@ -20,7 +20,7 @@ class TestEntity: @pytest.fixture(autouse=True) def setup(self): self.entity = Entity( - config=MINI_CONFIG, services={"xyz": {"class": "idpyoidc.client.service.Service"}} + config=MINI_CONFIG.copy(), services={"xyz": {"class": "idpyoidc.client.service.Service"}} ) def test_1(self): diff --git a/tests/test_client_02b_entity_metadata.py b/tests/test_client_02b_entity_metadata.py index 2bd8412b..491c054f 100644 --- a/tests/test_client_02b_entity_metadata.py +++ b/tests/test_client_02b_entity_metadata.py @@ -14,7 +14,7 @@ "preference": { "application_type": "web", "contacts": "support@example.com", - "response_types": ["code"], + "response_types_supported": ["code"], 'request_parameter': "request_uri", "request_object_signing_alg_values_supported": ["ES256"], "scope": ["openid", "profile", "email", "address", "phone"], @@ -68,33 +68,28 @@ def test_create_client(): _context = client.get_context() _context.map_supported_to_preferred() _pref = _context.prefers() - assert set(_pref.keys()) == {'application_type', - 'backchannel_logout_session_required', - 'backchannel_logout_uri', - 'callback_uris', - 'client_id', - 'client_secret', - 'contacts', - 'default_max_age', - 'grant_types_supported', - 'id_token_encryption_alg_values_supported', - 'id_token_encryption_enc_values_supported', - 'id_token_signing_alg_values_supported', - 'post_logout_redirect_uris', - 'redirect_uris', - 'request_object_encryption_alg_values_supported', - 'request_object_encryption_enc_values_supported', - 'request_object_signing_alg_values_supported', - 'request_parameter', - 'response_modes_supported', - 'response_types_supported', - 'scopes_supported', - 'subject_types_supported', - 'token_endpoint_auth_method', - 'token_endpoint_auth_signing_alg_values_supported', - 'userinfo_encryption_alg_values_supported', - 'userinfo_encryption_enc_values_supported', - 'userinfo_signing_alg_values_supported'} + _pref_with_values = [k for k, v in _pref.items() if v] + assert set(_pref_with_values) == {'application_type', + 'backchannel_logout_session_required', + 'backchannel_logout_uri', + 'callback_uris', + 'client_id', + 'client_secret', + 'contacts', + 'default_max_age', + 'grant_types_supported', + 'id_token_signing_alg_values_supported', + 'post_logout_redirect_uris', + 'redirect_uris', + 'request_object_signing_alg_values_supported', + 'request_parameter', + 'response_modes_supported', + 'response_types_supported', + 'scopes_supported', + 'subject_types_supported', + 'token_endpoint_auth_method', + 'token_endpoint_auth_signing_alg_values_supported', + 'userinfo_signing_alg_values_supported'} # What's in service configuration has higher priority then what's just supported. _context = client.get_service_context() @@ -108,32 +103,31 @@ def test_create_client(): _conf_args = list(_context.collect_usage().keys()) assert _conf_args - assert len(_conf_args) == 21 + assert len(_conf_args) == 23 rr = set(RegistrationRequest.c_param.keys()) - # The ones that are not defined + # The ones that are not defined and will therefore not appear in a registration request d = rr.difference(set(_conf_args)) - assert d == { - 'client_name', - 'client_uri', - 'default_acr_values', - 'frontchannel_logout_session_required', - 'frontchannel_logout_uri', - 'id_token_encrypted_response_alg', - 'id_token_encrypted_response_enc', - 'initiate_login_uri', - 'jwks', - 'jwks_uri', - 'logo_uri', - 'policy_uri', - 'post_logout_redirect_uri', - 'request_object_encryption_alg', - 'request_object_encryption_enc', - 'request_uris', - 'require_auth_time', - 'sector_identifier_uri', - 'tos_uri', - 'userinfo_encrypted_response_alg', - 'userinfo_encrypted_response_enc'} + assert d == {'client_name', + 'client_uri', + 'default_acr_values', + 'frontchannel_logout_session_required', + 'frontchannel_logout_uri', + 'id_token_encrypted_response_alg', + 'id_token_encrypted_response_enc', + 'initiate_login_uri', + 'logo_uri', + 'jwks', + 'jwks_uri', + 'policy_uri', + 'post_logout_redirect_uri', + 'request_object_encryption_alg', + 'request_object_encryption_enc', + 'request_uris', + 'require_auth_time', + 'sector_identifier_uri', + 'tos_uri', + 'userinfo_encrypted_response_alg', + 'userinfo_encrypted_response_enc'} def test_create_client_key_conf(): diff --git a/tests/test_client_04_service.py b/tests/test_client_04_service.py index 121e77b8..2873dd5c 100644 --- a/tests/test_client_04_service.py +++ b/tests/test_client_04_service.py @@ -1,5 +1,4 @@ import pytest -from cryptojwt.key_jar import init_key_jar from idpyoidc.client.entity import Entity from idpyoidc.message.oauth2 import AuthorizationResponse @@ -33,9 +32,10 @@ class TestService: @pytest.fixture(autouse=True) def create_service(self): self.entity = Entity( - config=CLIENT_CONF, + config=CLIENT_CONF.copy(), services={"authz": {"class": "idpyoidc.client.oidc.authorization.Authorization"}}, - client_type='oidc' + client_type='oidc', + jwks_uri='https://example.com/cli/jwks.json' ) self.service = self.entity.get_service("authorization") @@ -46,7 +46,7 @@ def upstream_get(self, *args): if args[0] == "context": return self.service_context elif args[0] == 'attribute' and args[1] == 'keyjar': - return self.upstream_get('attribute','keyjar') + return self.upstream_get('attribute', 'keyjar') def test_1(self): assert self.service @@ -58,6 +58,7 @@ def test_use(self): 'callback_uris', 'client_id', 'default_max_age', + 'encrypt_request_object_supported', 'grant_types', 'id_token_signed_response_alg', 'jwks', @@ -115,7 +116,7 @@ def test_parse_response_json(self): self.service_context.issuer = "https://op.example.com/" self.service_context.client_id = "client" - _sign_key = self.service.upstream_get('context').keyjar.get_signing_key() + _sign_key = self.service.upstream_get('attribute', 'keyjar').get_signing_key() resp1 = AuthorizationResponse(code="auth_grant", state="state").to_json() arg = self.service.parse_response(resp1) assert isinstance(arg, AuthorizationResponse) @@ -127,7 +128,7 @@ def test_parse_response_jwt(self): self.service_context.issuer = "https://op.example.com/" self.service_context.client_id = "client" - _sign_key = self.service.upstream_get('context').keyjar.get_signing_key() + _sign_key = self.service.upstream_get('attribute', 'keyjar').get_signing_key() resp1 = AuthorizationResponse(code="auth_grant", state="state").to_jwt( key=_sign_key, algorithm="RS256" ) @@ -141,7 +142,7 @@ def test_parse_response_err(self): self.service_context.issuer = "https://op.example.com/" self.service_context.client_id = "client" - _sign_key = self.service.upstream_get('context').keyjar.get_signing_key() + _sign_key = self.service.upstream_get('attribute', 'keyjar').get_signing_key() resp1 = AuthorizationResponse(code="auth_grant", state="state").to_jwt( key=_sign_key, algorithm="RS256" ) @@ -154,7 +155,7 @@ class TestAuthorization(object): @pytest.fixture(autouse=True) def create_service(self): self.entity = Entity( - config=CLIENT_CONF, services={"base": {"class": "idpyoidc.client.service.Service"}} + config=CLIENT_CONF.copy(), services={"base": {"class": "idpyoidc.client.service.Service"}} ) self.service = self.entity.get_service("") diff --git a/tests/test_client_06_client_authn.py b/tests/test_client_06_client_authn.py index c1dadfe1..69eb264a 100644 --- a/tests/test_client_06_client_authn.py +++ b/tests/test_client_06_client_authn.py @@ -21,7 +21,7 @@ from idpyoidc.client.client_auth import bearer_auth from idpyoidc.client.client_auth import valid_service_context from idpyoidc.client.entity import Entity -from idpyoidc.work_environment import WorkEnvironment +from idpyoidc.claims import Claims from idpyoidc.defaults import JWT_BEARER from idpyoidc.message import Message from idpyoidc.message.oauth2 import AccessTokenRequest @@ -308,7 +308,8 @@ def test_construct(self, entity): key.add_kid() _context = token_service.upstream_get('context') - _context.get_keyjar().add_kb("", kb_rsa) + _keyjar = token_service.upstream_get('attribute', 'keyjar') + _keyjar.add_kb("", kb_rsa) _context.provider_info = { "issuer": "https://example.com/", "token_endpoint": "https://example.com/token", @@ -324,7 +325,7 @@ def test_construct(self, entity): # Receiver _kj = KeyJar() - _kj.import_jwks(_context.keyjar.export_jwks(), issuer_id=_context.get_client_id()) + _kj.import_jwks(_keyjar.export_jwks(), issuer_id=_context.get_client_id()) _kj.add_kb(_context.get_client_id(), kb_rsa) jso = JWT(key_jar=_kj).unpack(cas) assert _eq(jso.keys(), ["aud", "iss", "sub", "jti", "exp", "iat"]) @@ -462,7 +463,7 @@ def test_get_audience_and_algorithm_default_alg(self, entity): _rsa_key = entity.keyjar.get(key_use='sig', key_type='rsa', issuer_id='')[0] _jws = factory(request["client_assertion"]) assert _jws.jwt.headers["alg"] == "RS256" - _rsa_key = _service_context.keyjar.get_signing_key(key_type="RSA")[0] + _rsa_key = entity.keyjar.get_signing_key(key_type="RSA")[0] assert _jws.jwt.headers["kid"] == _rsa_key.kid # By client preferences @@ -476,7 +477,7 @@ def test_get_audience_and_algorithm_default_alg(self, entity): # Use provider information is everything else fails request = AccessTokenRequest() - _service_context.work_environment = WorkEnvironment() + _service_context.claims = Claims() _service_context.provider_info["token_endpoint_auth_signing_alg_values_supported"] = [ "ES256", "RS256", @@ -487,7 +488,7 @@ def test_get_audience_and_algorithm_default_alg(self, entity): _jws = factory(request["client_assertion"]) # Should be ES256 since I have a key for ES256 assert _jws.jwt.headers["alg"] == "ES256" - _ec_key = _service_context.keyjar.get_signing_key(key_type="EC")[0] + _ec_key = entity.keyjar.get_signing_key(key_type="EC")[0] assert _jws.jwt.headers["kid"] == _ec_key.kid diff --git a/tests/test_client_10_entity.py b/tests/test_client_10_entity.py index 095e2a5f..3a3f3a7f 100644 --- a/tests/test_client_10_entity.py +++ b/tests/test_client_10_entity.py @@ -22,8 +22,7 @@ def create_client_info_instance(self): self.entity = Entity(config=config) def test_import_keys_file(self): - # Should only be one and that a symmetric key (client_secret) usable - # for signing and encryption + # Should only be one, a symmetric key (client_secret) assert len(self.entity.keyjar.get_issuer_keys("")) == 1 file_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "salesforce.key")) diff --git a/tests/test_client_14_service_context_impexp.py b/tests/test_client_14_service_context_impexp.py index 976f475d..27e3f1ad 100644 --- a/tests/test_client_14_service_context_impexp.py +++ b/tests/test_client_14_service_context_impexp.py @@ -20,7 +20,7 @@ def test_client_info_init(): "requests_dir": "requests", } ci = ServiceContext(config=config, client_type='oidc') - ci.work_environment.load_conf(config, supports=ci.supports()) + ci.claims.load_conf(config, supports=ci.supports()) ci.map_supported_to_preferred() ci.map_preferred_to_registered() @@ -111,7 +111,7 @@ def create_client_info_instance(self): self.service_context = self.entity.get_context() def test_registration_userinfo_sign_enc_algs(self): - self.service_context.work_environment.use = { + self.service_context.claims.use = { "application_type": "web", "redirect_uris": [ "https://client.example.org/callback", @@ -130,7 +130,7 @@ def test_registration_userinfo_sign_enc_algs(self): assert srvcntx.get_enc_alg_enc("userinfo") == {"alg": "RSA1_5", "enc": "A128CBC-HS256"} def test_registration_request_object_sign_enc_algs(self): - self.service_context.work_environment.use = { + self.service_context.claims.use = { "application_type": "web", "redirect_uris": [ "https://client.example.org/callback", @@ -152,7 +152,7 @@ def test_registration_request_object_sign_enc_algs(self): assert srvcntx.get_sign_alg("request_object") == "RS384" def test_registration_id_token_sign_enc_algs(self): - self.service_context.work_environment.use = { + self.service_context.claims.use = { "application_type": "web", "redirect_uris": [ "https://client.example.org/callback", @@ -250,7 +250,8 @@ def test_verify_alg_support(self): def test_import_keys_file(self): # Should only be one and that a symmetric key (client_secret) usable # for signing and encryption - assert len(self.service_context.keyjar.get_issuer_keys("")) == 1 + _keyjar = self.entity.keyjar + assert len(_keyjar.get_issuer_keys("")) == 1 file_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "salesforce.key")) @@ -262,28 +263,31 @@ def test_import_keys_file(self): ) # Now there should be 2, the second a RSA key for signing - assert len(srvcntx.keyjar.get_issuer_keys("")) == 2 + assert len(_keyjar.get_issuer_keys("")) == 2 def test_import_keys_file_json(self): # Should only be one and that a symmetric key (client_secret) usable # for signing and encryption - assert len(self.service_context.keyjar.get_issuer_keys("")) == 1 + _keyjar = self.entity.keyjar + assert len(_keyjar.get_issuer_keys("")) == 1 file_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "salesforce.key")) keyspec = {"file": {"rsa": [file_path]}} self.service_context.import_keys(keyspec) - _sc_state = self.service_context.dump(exclude_attributes=["context"]) + _sc_state = self.service_context.dump(exclude_attributes=["context", 'upstream_get']) _jsc_state = json.dumps(_sc_state) _o_state = json.loads(_jsc_state) - srvcntx = ServiceContext(base_url=BASE_URL).load(_o_state) + srvcntx = ServiceContext(base_url=BASE_URL).load(_o_state, init_args={ + 'upstream_get': self.service_context.upstream_get}) # Now there should be 2, the second a RSA key for signing - assert len(srvcntx.keyjar.get_issuer_keys("")) == 2 + assert len(srvcntx.upstream_get('attribute', 'keyjar').get_issuer_keys("")) == 2 def test_import_keys_url(self): - assert len(self.service_context.keyjar.get_issuer_keys("")) == 1 + _keyjar = self.service_context.upstream_get('attribute', 'keyjar') + assert len(_keyjar.get_issuer_keys("")) == 1 # One EC key for signing key_def = [{"type": "EC", "crv": "P-256", "use": ["sig"]}] @@ -301,11 +305,13 @@ def test_import_keys_url(self): ) keyspec = {"url": {"https://foobar.com": _jwks_url}} self.service_context.import_keys(keyspec) - self.service_context.keyjar.update() + _keyjar.update() srvcntx = ServiceContext(base_url=BASE_URL).load( - self.service_context.dump(exclude_attributes=["context"]) + self.service_context.dump(exclude_attributes=["context"]), + init_args={'upstream_get': self.service_context.upstream_get} ) # Now there should be one belonging to https://example.com - assert len(srvcntx.keyjar.get_issuer_keys("https://foobar.com")) == 1 + assert len(srvcntx.upstream_get('attribute', 'keyjar').get_issuer_keys( + "https://foobar.com")) == 1 diff --git a/tests/test_client_20_oauth2.py b/tests/test_client_20_oauth2.py index b9d67fb0..81defa1b 100644 --- a/tests/test_client_20_oauth2.py +++ b/tests/test_client_20_oauth2.py @@ -198,5 +198,5 @@ def create_client(self): def test_keyjar(self): _keyjar = self.client.get_attribute('keyjar') assert len(_keyjar) == 2 # one issuer - assert len(_keyjar[""]) == 2 - assert len(_keyjar.get("sig")) == 2 + assert len(_keyjar[""]) == 3 + assert len(_keyjar.get("sig")) == 3 diff --git a/tests/test_client_21_oidc_service.py b/tests/test_client_21_oidc_service.py index 5e36f797..a2de0faa 100644 --- a/tests/test_client_21_oidc_service.py +++ b/tests/test_client_21_oidc_service.py @@ -1,13 +1,13 @@ import os +import pytest +import responses from cryptojwt.exception import UnsupportedAlgorithm from cryptojwt.jws import jws from cryptojwt.jws.utils import left_hash from cryptojwt.jwt import JWT from cryptojwt.key_jar import build_keyjar from cryptojwt.key_jar import init_key_jar -import pytest -import responses from idpyoidc.client.defaults import DEFAULT_OIDC_SERVICES from idpyoidc.client.entity import Entity @@ -297,7 +297,7 @@ def test_allow_unsigned_idtoken(self, allow_sign_alg_none): idt = JWT(ISS_KEY, iss=ISS, lifetime=3600, sign_alg="none") payload = {"sub": "123456789", "aud": ["client_id"], "nonce": req_args["nonce"]} _idt = idt.pack(payload) - self.service.upstream_get("context").work_environment.set_usage("verify_args", { + self.service.upstream_get("context").claims.set_usage("verify_args", { "allow_sign_alg_none": allow_sign_alg_none }) resp = AuthorizationResponse(state="state", code="code", id_token=_idt) @@ -726,7 +726,7 @@ def test_post_parse(self): "end_session_endpoint": "{}/end_session".format(OP_BASEURL), } _context = self.service.upstream_get("context") - assert _context.work_environment.use == {} + assert _context.claims.use == {} resp = self.service.post_parse_response(provider_info_response) iss_jwks = ISS_KEY.export_jwks_as_json(issuer_id=ISS) @@ -738,7 +738,7 @@ def test_post_parse(self): # static client registration _context.map_preferred_to_registered() - use_copy = self.service.upstream_get("context").work_environment.use.copy() + use_copy = self.service.upstream_get("context").claims.use.copy() # jwks content will change dynamically between runs assert 'jwks' in use_copy del use_copy['jwks'] @@ -752,6 +752,8 @@ def test_post_parse(self): 'contacts': ['ops@example.org'], 'default_max_age': 86400, 'encrypt_id_token_supported': False, + 'encrypt_request_object_supported': False, + 'encrypt_userinfo_supported': False, 'grant_types': ['authorization_code', 'refresh_token'], 'id_token_signed_response_alg': 'RS256', 'post_logout_redirect_uris': ['https://rp.example.com/post'], @@ -763,8 +765,7 @@ def test_post_parse(self): 'subject_type': 'public', 'token_endpoint_auth_method': 'private_key_jwt', 'token_endpoint_auth_signing_alg': 'ES256', - 'userinfo_signed_response_alg': 'ES256' - } + 'userinfo_signed_response_alg': 'ES256'} def test_post_parse_2(self): OP_BASEURL = ISS @@ -786,7 +787,7 @@ def test_post_parse_2(self): "end_session_endpoint": "{}/end_session".format(OP_BASEURL), } _context = self.service.upstream_get("context") - assert _context.work_environment.use == {} + assert _context.claims.use == {} resp = self.service.post_parse_response(provider_info_response) iss_jwks = ISS_KEY.export_jwks_as_json(issuer_id=ISS) @@ -798,33 +799,34 @@ def test_post_parse_2(self): # static client registration _context.map_preferred_to_registered() - use_copy = self.service.upstream_get("context").work_environment.use.copy() + use_copy = self.service.upstream_get("context").claims.use.copy() # jwks content will change dynamically between runs assert 'jwks' in use_copy del use_copy['jwks'] del use_copy['callback_uris'] - assert use_copy == { - 'application_type': 'web', - 'backchannel_logout_session_required': True, - 'backchannel_logout_uri': 'https://rp.example.com/back', - 'client_id': 'client_id', - 'client_secret': 'a longesh password', - 'contacts': ['ops@example.org'], - 'default_max_age': 86400, - 'encrypt_id_token_supported': False, - 'grant_types': ['authorization_code', 'implicit', 'refresh_token'], - 'id_token_signed_response_alg': 'RS256', - 'post_logout_redirect_uris': ['https://rp.example.com/post'], - 'redirect_uris': ['https://example.com/cli/authz_cb'], - 'request_object_signing_alg': 'ES256', - 'response_modes_supported': ['query', 'fragment', 'form_post'], - 'response_types': ['code'], - 'scope': ['openid'], - 'subject_type': 'public', - 'token_endpoint_auth_method': 'private_key_jwt', - 'token_endpoint_auth_signing_alg': 'ES256', - 'userinfo_signed_response_alg': 'ES256'} + assert use_copy == {'application_type': 'web', + 'backchannel_logout_session_required': True, + 'backchannel_logout_uri': 'https://rp.example.com/back', + 'client_id': 'client_id', + 'client_secret': 'a longesh password', + 'contacts': ['ops@example.org'], + 'default_max_age': 86400, + 'encrypt_id_token_supported': False, + 'encrypt_request_object_supported': False, + 'encrypt_userinfo_supported': False, + 'grant_types': ['authorization_code', 'implicit', 'refresh_token'], + 'id_token_signed_response_alg': 'RS256', + 'post_logout_redirect_uris': ['https://rp.example.com/post'], + 'redirect_uris': ['https://example.com/cli/authz_cb'], + 'request_object_signing_alg': 'ES256', + 'response_modes_supported': ['query', 'fragment', 'form_post'], + 'response_types': ['code'], + 'scope': ['openid'], + 'subject_type': 'public', + 'token_endpoint_auth_method': 'private_key_jwt', + 'token_endpoint_auth_signing_alg': 'ES256', + 'userinfo_signed_response_alg': 'ES256'} def test_response_types_to_grant_types(): @@ -857,9 +859,12 @@ def create_request(self): "requests_dir": "requests", "base_url": "https://example.com/cli/", } - entity = Entity(keyjar=make_keyjar(), config=client_config, services=DEFAULT_OIDC_SERVICES, + entity = Entity(keyjar=make_keyjar(), + config=client_config, + services=DEFAULT_OIDC_SERVICES, client_type='oidc') entity.get_context().issuer = "https://example.com" + entity.get_context().map_supported_to_preferred() self.service = entity.get_service("registration") def test_construct(self): @@ -874,11 +879,12 @@ def test_construct(self): 'request_object_signing_alg', 'response_types', 'subject_type', + 'token_endpoint_auth_method', 'token_endpoint_auth_signing_alg', 'userinfo_signed_response_alg'} def test_config_with_post_logout(self): - self.service.upstream_get("context").work_environment.set_preference( + self.service.upstream_get("context").claims.set_preference( "post_logout_redirect_uri", "https://example.com/post_logout") _req = self.service.construct() @@ -893,6 +899,7 @@ def test_config_with_post_logout(self): 'request_object_signing_alg', 'response_types', 'subject_type', + 'token_endpoint_auth_method', 'token_endpoint_auth_signing_alg', 'userinfo_signed_response_alg'} assert "post_logout_redirect_uri" in _req.keys() @@ -992,7 +999,7 @@ def create_request(self): entity.get_context().issuer = "https://example.com" self.service = entity.get_service("userinfo") - entity.get_context().work_environment.use = { + entity.get_context().claims.use = { "userinfo_signed_response_alg": "RS256", "userinfo_encrypted_response_alg": "RSA-OAEP", "userinfo_encrypted_response_enc": "A256GCM", diff --git a/tests/test_client_24_oic_utils.py b/tests/test_client_24_oic_utils.py index 1e1b42f9..4e799803 100644 --- a/tests/test_client_24_oic_utils.py +++ b/tests/test_client_24_oic_utils.py @@ -27,9 +27,9 @@ def test_request_object_encryption(): "client_secret": "abcdefghijklmnop", } service_context = ServiceContext(keyjar=KEYJAR, config=conf) - _condition = service_context.work_environment - _condition.set_usage("request_object_encryption_alg", "RSA1_5") - _condition.set_usage("request_object_encryption_enc", "A128CBC-HS256") + _claims = service_context.claims + _claims.set_usage("request_object_encryption_alg", "RSA1_5") + _claims.set_usage("request_object_encryption_enc", "A128CBC-HS256") _jwe = request_object_encryption(msg.to_json(), service_context, KEYJAR, target=RECEIVER) assert _jwe diff --git a/tests/test_client_28_rp_handler_oidc.py b/tests/test_client_28_rp_handler_oidc.py index d15ce18a..e480387c 100644 --- a/tests/test_client_28_rp_handler_oidc.py +++ b/tests/test_client_28_rp_handler_oidc.py @@ -47,7 +47,7 @@ "services": { "web_finger": {"class": "idpyoidc.client.oidc.webfinger.WebFinger"}, "discovery": { - "class": "idpyoidc.client.oidc.provider_info_discovery" ".ProviderInfoDiscovery" + "class": "idpyoidc.client.oidc.provider_info_discovery.ProviderInfoDiscovery" }, "registration": {"class": "idpyoidc.client.oidc.registration.Registration"}, "authorization": {"class": "idpyoidc.client.oidc.authorization.Authorization"}, @@ -258,15 +258,16 @@ def test_init_client(self): 'response_types_supported', 'callback_uris', 'scopes_supported'} _github_id = iss_id("github") - _context.keyjar.import_jwks(GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) + _keyjar = _context.upstream_get('attribute', 'keyjar') + _keyjar.import_jwks(GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) # The key jar should only contain a symmetric key that is the clients # secret. 2 because one is marked for encryption and the other signing # usage. - assert set(_context.keyjar.owners()) == {"", 'eeeeeeeee', _github_id} - keys = _context.keyjar.get_issuer_keys("") - assert len(keys) == 2 + assert set(_keyjar.owners()) == {"", 'eeeeeeeee', _github_id} + keys = _keyjar.get_issuer_keys("") + assert len(keys) == 3 assert _context.base_url == BASE_URL @@ -306,11 +307,12 @@ def test_do_client_setup(self): assert _context.get_preference("client_secret") == "aaaaaaaaaaaaaaaaaaaa" assert _context.issuer == _github_id - _context.keyjar.import_jwks(GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) + _keyjar = _context.upstream_get('attribute', 'keyjar') + _keyjar.import_jwks(GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) - assert set(_context.keyjar.owners()) == {"", "eeeeeeeee", _github_id} - keys = _context.keyjar.get_issuer_keys("") - assert len(keys) == 2 + assert set(_keyjar.owners()) == {"", "eeeeeeeee", _github_id} + keys = _keyjar.get_issuer_keys("") + assert len(keys) == 3 for service_type in ["authorization", "accesstoken", "userinfo"]: _srv = client.get_service(service_type) @@ -422,7 +424,8 @@ def test_get_tokens(self): _github_id = iss_id("github") _context = client.get_context() - _context.keyjar.import_jwks(GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) + _keyjar = _context.upstream_get('attribute', 'keyjar') + _keyjar.import_jwks(GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) _nonce = _session["nonce"] _iss = _session['iss'] @@ -493,7 +496,8 @@ def test_access_and_id_token(self): idval = {"nonce": _nonce, "sub": "EndUserSubject", 'iss': _iss, "aud": _aud} _github_id = iss_id("github") - _context.keyjar.import_jwks(GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) + _keyjar = _context.upstream_get('attribute', 'keyjar') + _keyjar.import_jwks(GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) idts = IdToken(**idval) _signed_jwt = idts.to_jwt( @@ -538,7 +542,8 @@ def test_access_and_id_token_by_reference(self): idval = {"nonce": _nonce, "sub": "EndUserSubject", 'iss': _iss, "aud": _aud} _github_id = iss_id("github") - _context.keyjar.import_jwks(GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) + _keyjar = _context.upstream_get('attribute', 'keyjar') + _keyjar.import_jwks(GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) idts = IdToken(**idval) _signed_jwt = idts.to_jwt( @@ -583,7 +588,8 @@ def test_get_user_info(self): idval = {"nonce": _nonce, "sub": "EndUserSubject", 'iss': _iss, "aud": _aud} _github_id = iss_id("github") - _context.keyjar.import_jwks(GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) + _keyjar = _context.upstream_get('attribute', 'keyjar') + _keyjar.import_jwks(GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) idts = IdToken(**idval) _signed_jwt = idts.to_jwt( @@ -674,7 +680,8 @@ def rphandler_setup(self): idval = {"nonce": _nonce, "sub": "EndUserSubject", 'iss': _iss, "aud": _aud} _github_id = iss_id("github") - _context.keyjar.import_jwks(GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) + _keyjar = _context.upstream_get('attribute', 'keyjar') + _keyjar.import_jwks(GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) idts = IdToken(**idval) _signed_jwt = idts.to_jwt( @@ -920,9 +927,8 @@ def test_finalize(self): sub="EndUserSubject", given_name="Diana", family_name="Krall", occupation="Jazz pianist" ) _github_id = iss_id("github") - client.get_context().keyjar.import_jwks( - GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id - ) + _keyjar = client.get_attribute('keyjar') + _keyjar.import_jwks(GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) with responses.RequestsMock() as rsps: rsps.add( "POST", diff --git a/tests/test_client_30_rph_defaults.py b/tests/test_client_30_rph_defaults.py index 19916f46..1355bf22 100644 --- a/tests/test_client_30_rph_defaults.py +++ b/tests/test_client_30_rph_defaults.py @@ -1,9 +1,9 @@ from urllib.parse import parse_qs from urllib.parse import urlparse -from cryptojwt.key_jar import build_keyjar import pytest import responses +from cryptojwt.key_jar import build_keyjar from idpyoidc.client.defaults import DEFAULT_KEY_DEFS from idpyoidc.client.rp_handler import RPHandler @@ -14,6 +14,7 @@ class TestRPHandler(object): + @pytest.fixture(autouse=True) def rphandler_setup(self): self.rph = RPHandler(BASE_URL) @@ -35,7 +36,7 @@ def test_init_client(self): _context = client.get_context() - assert set(_context.work_environment.prefer.keys()) == { + assert set(_context.claims.prefer.keys()) == { 'application_type', 'callback_uris', 'id_token_encryption_alg_values_supported', @@ -96,23 +97,25 @@ def test_begin(self): self.rph.issuer2rp[issuer] = client - assert set(_context.work_environment.use.keys()) == {'application_type', - 'callback_uris', - 'client_id', - 'client_secret', - 'default_max_age', - 'grant_types', - 'id_token_signed_response_alg', - 'jwks_uri', - 'redirect_uris', - 'request_object_signing_alg', - 'response_modes_supported', - 'response_types', - 'scope', - 'subject_type', - 'token_endpoint_auth_method', - 'token_endpoint_auth_signing_alg', - 'userinfo_signed_response_alg'} + assert set(_context.claims.use.keys()) == {'application_type', + 'callback_uris', + 'client_id', + 'client_secret', + 'default_max_age', + 'encrypt_request_object_supported', + 'encrypt_userinfo_supported', + 'grant_types', + 'id_token_signed_response_alg', + 'jwks_uri', + 'redirect_uris', + 'request_object_signing_alg', + 'response_modes_supported', + 'response_types', + 'scope', + 'subject_type', + 'token_endpoint_auth_method', + 'token_endpoint_auth_signing_alg', + 'userinfo_signed_response_alg'} assert _context.get_client_id() == "client uno" assert _context.get_usage("client_secret") == "VerySecretAndLongEnough" assert _context.get("issuer") == ISS_ID diff --git a/tests/test_client_40_dpop.py b/tests/test_client_40_dpop.py index 6d6849b4..f96661e5 100644 --- a/tests/test_client_40_dpop.py +++ b/tests/test_client_40_dpop.py @@ -33,7 +33,7 @@ def create_client(self): "add_ons": { "dpop": { "function": "idpyoidc.client.oauth2.add_on.dpop.add_support", - "kwargs": {"signing_algorithms": ["ES256", "ES512"]}, + "kwargs": {"dpop_signing_alg_values_supported": ["ES256", "ES512"]}, } }, } @@ -81,7 +81,7 @@ def create_client(self): "add_ons": { "dpop": { "function": "idpyoidc.client.oauth2.add_on.dpop.add_support", - "kwargs": {"signing_algorithms": ["ES256", "ES512"]}, + "kwargs": {"dpop_signing_alg_values_supported": ["ES256", "ES512"]}, } }, } diff --git a/tests/test_client_41_rp_handler_persistent.py b/tests/test_client_41_rp_handler_persistent.py index 5e6b91ea..965392d9 100644 --- a/tests/test_client_41_rp_handler_persistent.py +++ b/tests/test_client_41_rp_handler_persistent.py @@ -120,7 +120,7 @@ "kwargs": {"conf": {"default_authn_method": ""}}, }, "refresh_access_token": { - "class": "idpyoidc.client.oidc.refresh_access_token" ".RefreshAccessToken" + "class": "idpyoidc.client.oidc.refresh_access_token.RefreshAccessToken" }, }, }, @@ -241,11 +241,12 @@ def test_do_client_setup(self): assert _context.get_usage("client_secret") == "aaaaaaaaaaaaaaaaaaaa" assert _context.get("issuer") == _github_id - _context.keyjar.import_jwks(GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) + _keyjar = _context.upstream_get('attribute', 'keyjar') + _keyjar.import_jwks(GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) - assert set(_context.keyjar.owners()) == {"", 'eeeeeeeee', _github_id} - keys = _context.keyjar.get_issuer_keys("") - assert len(keys) == 2 + assert set(_keyjar.owners()) == {"", 'eeeeeeeee', _github_id} + keys = _keyjar.get_issuer_keys("") + assert len(keys) == 3 # one symmetric, one RSA and one EC for service_type in ["authorization", "accesstoken", "userinfo"]: _srv = client.get_service(service_type) @@ -360,7 +361,8 @@ def test_get_tokens(self): _github_id = iss_id("github") _context = client.get_context() - _context.keyjar.import_jwks(GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) + _keyjar = _context.upstream_get('attribute', 'keyjar') + _keyjar.import_jwks(GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) _nonce = _session["nonce"] _iss = _session["iss"] @@ -441,7 +443,8 @@ def test_access_and_id_token(self): idval = {"nonce": _nonce, "sub": "EndUserSubject", "iss": _iss, "aud": _aud} _github_id = iss_id("github") - _context.keyjar.import_jwks(GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) + _keyjar = _context.upstream_get('attribute', 'keyjar') + _keyjar.import_jwks(GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) idts = IdToken(**idval) _signed_jwt = idts.to_jwt( @@ -490,7 +493,8 @@ def test_access_and_id_token_by_reference(self): idval = {"nonce": _nonce, "sub": "EndUserSubject", "iss": _iss, "aud": _aud} _github_id = iss_id("github") - _context.keyjar.import_jwks(GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) + _keyjar = _context.upstream_get('attribute', 'keyjar') + _keyjar.import_jwks(GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) idts = IdToken(**idval) _signed_jwt = idts.to_jwt( @@ -539,7 +543,8 @@ def test_get_user_info(self): idval = {"nonce": _nonce, "sub": "EndUserSubject", "iss": _iss, "aud": _aud} _github_id = iss_id("github") - _context.keyjar.import_jwks(GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) + _keyjar = _context.upstream_get('attribute', 'keyjar') + _keyjar.import_jwks(GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) idts = IdToken(**idval) _signed_jwt = idts.to_jwt( diff --git a/tests/test_client_51_identity_assurance.py b/tests/test_client_51_identity_assurance.py index 6758e6b1..f7fbf39f 100644 --- a/tests/test_client_51_identity_assurance.py +++ b/tests/test_client_51_identity_assurance.py @@ -36,7 +36,7 @@ def create_request(self): entity.get_context().issuer = "https://server.otherop.com" self.service = entity.get_service("userinfo") - entity.get_context().work_environment.use = { + entity.get_context().claims.use = { "userinfo_signed_response_alg": "RS256", "userinfo_encrypted_response_alg": "RSA-OAEP", "userinfo_encrypted_response_enc": "A256GCM", diff --git a/tests/test_server_01_claims.py b/tests/test_server_01_claims.py index 41d8d0bf..9162e329 100644 --- a/tests/test_server_01_claims.py +++ b/tests/test_server_01_claims.py @@ -128,9 +128,9 @@ class TestEndpoint(object): @pytest.fixture(autouse=True) def create_idtoken(self): self.server = Server(conf) - # self.endpoint_context = EndpointContext(conf=conf, upstream_get=self.upstream_get) - self.endpoint_context = self.server.endpoint_context - self.endpoint_context.cdb["client_1"] = { + # self.context = EndpointContext(conf=conf, upstream_get=self.upstream_get) + self.context = self.server.context + self.context.cdb["client_1"] = { "client_secret": "hemligtochintekort", "redirect_uris": [("https://example.com/cb", None)], "client_salt": "salted", @@ -143,7 +143,7 @@ def create_idtoken(self): } self.server.get_attribute('keyjar').add_symmetric("client_1", "hemligtochintekort", ["sig", "enc"]) - self.claims_interface = self.endpoint_context.claims_interface + self.claims_interface = self.context.claims_interface self.user_id = USER_ID @@ -156,7 +156,7 @@ def _create_session(self, auth_req, sub_type="public", sector_identifier=""): client_id = authz_req["client_id"] ae = create_authn_event(self.user_id) - return self.endpoint_context.session_manager.create_session( + return self.context.session_manager.create_session( ae, authz_req, self.user_id, client_id=client_id, sub_type=sub_type ) @@ -183,7 +183,7 @@ def test_authorization_request_userinfo_claims_3(self): def test_get_claims_id_token_1(self): session_id = self._create_session(AREQ) - self.endpoint_context.session_manager.token_handler["id_token"].kwargs = { + self.context.session_manager.token_handler["id_token"].kwargs = { "base_claims": {"email": None, "email_verified": None} } claims = self.claims_interface.get_claims(session_id, [], "id_token") @@ -191,11 +191,11 @@ def test_get_claims_id_token_1(self): def test_get_claims_id_token_2(self): session_id = self._create_session(AREQ) - self.endpoint_context.session_manager.token_handler["id_token"].kwargs = { + self.context.session_manager.token_handler["id_token"].kwargs = { "base_claims": {"email": None, "email_verified": None}, "enable_claims_per_client": True, } - self.endpoint_context.cdb["client_1"]["add_claims"]["always"]["id_token"] = [ + self.context.cdb["client_1"]["add_claims"]["always"]["id_token"] = [ "name", "email", ] @@ -205,12 +205,12 @@ def test_get_claims_id_token_2(self): def test_get_claims_id_token_3(self): session_id = self._create_session(AREQ) - self.endpoint_context.session_manager.token_handler["id_token"].kwargs = { + self.context.session_manager.token_handler["id_token"].kwargs = { "base_claims": {"email": None, "email_verified": None}, "enable_claims_per_client": True, "add_claims_by_scope": True, } - self.endpoint_context.cdb["client_1"]["add_claims"]["always"]["id_token"] = [ + self.context.cdb["client_1"]["add_claims"]["always"]["id_token"] = [ "name", "email", ] @@ -226,16 +226,16 @@ def test_get_claims_id_token_3(self): def test_get_claims_id_token_and_userinfo(self): session_id = self._create_session(AREQ) - self.endpoint_context.session_manager.token_handler["id_token"].kwargs = { + self.context.session_manager.token_handler["id_token"].kwargs = { "base_claims": {"email": None, "email_verified": None}, "enable_claims_per_client": True, "add_claims_by_scope": True, } - self.endpoint_context.cdb["client_1"]["add_claims"]["always"]["id_token"] = [ + self.context.cdb["client_1"]["add_claims"]["always"]["id_token"] = [ "name", "email", ] - self.endpoint_context.cdb["client_1"]["add_claims"]["always"]["userinfo"] = [ + self.context.cdb["client_1"]["add_claims"]["always"]["userinfo"] = [ "phone", "phone_verified", ] @@ -254,13 +254,13 @@ def test_get_claims_id_token_and_userinfo(self): } def test_get_claims_access_token_3(self): - _module = self.endpoint_context.session_manager.token_handler["access_token"] + _module = self.context.session_manager.token_handler["access_token"] _module.kwargs = { "base_claims": {"email": None, "email_verified": None}, "enable_claims_per_client": True, "add_claims_by_scope": True, } - self.endpoint_context.cdb["client_1"]["add_claims"]["always"]["access_token"] = [ + self.context.cdb["client_1"]["add_claims"]["always"]["access_token"] = [ "name", "email", ] diff --git a/tests/test_server_03_authz_handling.py b/tests/test_server_03_authz_handling.py index af152e81..4edc5c92 100644 --- a/tests/test_server_03_authz_handling.py +++ b/tests/test_server_03_authz_handling.py @@ -126,7 +126,7 @@ class TestEndpoint(object): @pytest.fixture(autouse=True) def create_idtoken(self): server = Server(conf) - server.endpoint_context.cdb["client_1"] = { + server.context.cdb["client_1"] = { "client_secret": "hemligtochintekort", "redirect_uris": [("https://example.com/cb", None)], "client_salt": "salted", @@ -138,10 +138,10 @@ def create_idtoken(self): "client_1", "hemligtochintekort", ["sig", "enc"] ) server.endpoint = do_endpoints(conf, server.upstream_get) - self.session_manager = server.endpoint_context.session_manager + self.session_manager = server.context.session_manager self.user_id = USER_ID self.server = server - self.authz = server.endpoint_context.authz + self.authz = server.context.authz def _create_session(self, auth_req, sub_type="public", sector_identifier=""): if sector_identifier: @@ -176,7 +176,7 @@ def test_usage_rules(self): def test_usage_rules_client(self): _ = self._create_session(AREQ) - self.server.endpoint_context.cdb["client_1"]["token_usage_rules"] = { + self.server.context.cdb["client_1"]["token_usage_rules"] = { "authorization_code": {"supports_minting": ["access_token", "id_token"]}, "refresh_token": {}, } diff --git a/tests/test_server_05_token_handler.py b/tests/test_server_05_token_handler.py index 7bf41728..21d247df 100644 --- a/tests/test_server_05_token_handler.py +++ b/tests/test_server_05_token_handler.py @@ -190,7 +190,7 @@ def test_token_handler_from_config(): } server = Server(OPConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) - token_handler = server.endpoint_context.session_manager.token_handler + token_handler = server.context.session_manager.token_handler assert token_handler assert len(token_handler.handler) == 4 assert set(token_handler.handler.keys()) == { @@ -280,5 +280,5 @@ def test_file(jwks): } server = Server(OPConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) - token_handler = server.endpoint_context.session_manager.token_handler + token_handler = server.context.session_manager.token_handler assert token_handler diff --git a/tests/test_server_06_grant.py b/tests/test_server_06_grant.py index 287c8b42..f56a80e0 100644 --- a/tests/test_server_06_grant.py +++ b/tests/test_server_06_grant.py @@ -110,7 +110,7 @@ class TestGrant: @pytest.fixture(autouse=True) def create_session_manager(self): self.server = Server(conf=conf) - self.endpoint_context = self.server.get_context() + self.context = self.server.get_context() def _create_session(self, auth_req, sub_type="public", sector_identifier=""): if sector_identifier: @@ -121,27 +121,27 @@ def _create_session(self, auth_req, sub_type="public", sector_identifier=""): client_id = authz_req["client_id"] ae = create_authn_event(USER_ID) - return self.server.endpoint_context.session_manager.create_session( + return self.server.context.session_manager.create_session( ae, authz_req, USER_ID, client_id=client_id, sub_type=sub_type ) def test_mint_token(self): session_id = self._create_session(AREQ) - session_info = self.endpoint_context.session_manager.get_session_info( + session_info = self.context.session_manager.get_session_info( session_id=session_id, grant=True ) grant = session_info["grant"] code = grant.mint_token( session_id, - context=self.endpoint_context, + context=self.context, token_class="authorization_code", token_handler=TOKEN_HANDLER["authorization_code"], ) access_token = grant.mint_token( session_id, - context=self.endpoint_context, + context=self.context, token_class="access_token", token_handler=TOKEN_HANDLER["access_token"], based_on=code, @@ -152,20 +152,20 @@ def test_mint_token(self): def test_grant(self): session_id = self._create_session(AREQ) - session_info = self.endpoint_context.session_manager.get_session_info( + session_info = self.context.session_manager.get_session_info( session_id=session_id, grant=True ) grant = session_info["grant"] code = grant.mint_token( session_id, - context=self.endpoint_context, + context=self.context, token_class="authorization_code", token_handler=TOKEN_HANDLER["authorization_code"], ) access_token = grant.mint_token( session_id, - context=self.endpoint_context, + context=self.context, token_class="access_token", token_handler=TOKEN_HANDLER["access_token"], based_on=code, @@ -173,7 +173,7 @@ def test_grant(self): refresh_token = grant.mint_token( session_id, - context=self.endpoint_context, + context=self.context, token_class="refresh_token", token_handler=TOKEN_HANDLER["refresh_token"], based_on=code, @@ -184,20 +184,20 @@ def test_grant(self): def test_get_token(self): session_id = self._create_session(AREQ) - session_info = self.endpoint_context.session_manager.get_session_info( + session_info = self.context.session_manager.get_session_info( session_id=session_id, grant=True ) grant = session_info["grant"] code = grant.mint_token( session_id, - context=self.endpoint_context, + context=self.context, token_class="authorization_code", token_handler=TOKEN_HANDLER["authorization_code"], ) access_token = grant.mint_token( session_id, - context=self.endpoint_context, + context=self.context, token_class="access_token", token_handler=TOKEN_HANDLER["access_token"], based_on=code, @@ -213,20 +213,20 @@ def test_get_token(self): def test_grant_revoked_based_on(self): session_id = self._create_session(AREQ) - session_info = self.endpoint_context.session_manager.get_session_info( + session_info = self.context.session_manager.get_session_info( session_id=session_id, grant=True ) grant = session_info["grant"] code = grant.mint_token( session_id, - context=self.endpoint_context, + context=self.context, token_class="authorization_code", token_handler=TOKEN_HANDLER["authorization_code"], ) access_token = grant.mint_token( session_id, - context=self.endpoint_context, + context=self.context, token_class="access_token", token_handler=TOKEN_HANDLER["access_token"], based_on=code, @@ -234,7 +234,7 @@ def test_grant_revoked_based_on(self): refresh_token = grant.mint_token( session_id, - context=self.endpoint_context, + context=self.context, token_class="refresh_token", token_handler=TOKEN_HANDLER["refresh_token"], based_on=code, @@ -250,20 +250,20 @@ def test_grant_revoked_based_on(self): def test_revoke(self): session_id = self._create_session(AREQ) - session_info = self.endpoint_context.session_manager.get_session_info( + session_info = self.context.session_manager.get_session_info( session_id=session_id, grant=True ) grant = session_info["grant"] code = grant.mint_token( session_id, - context=self.endpoint_context, + context=self.context, token_class="authorization_code", token_handler=TOKEN_HANDLER["authorization_code"], ) access_token = grant.mint_token( session_id, - context=self.endpoint_context, + context=self.context, token_class="access_token", token_handler=TOKEN_HANDLER["access_token"], based_on=code, @@ -276,7 +276,7 @@ def test_revoke(self): access_token_2 = grant.mint_token( session_id, - context=self.endpoint_context, + context=self.context, token_class="access_token", token_handler=TOKEN_HANDLER["access_token"], based_on=code, @@ -289,20 +289,20 @@ def test_revoke(self): def test_json_conversion(self): session_id = self._create_session(AREQ) - session_info = self.endpoint_context.session_manager.get_session_info( + session_info = self.context.session_manager.get_session_info( session_id=session_id, grant=True ) grant = session_info["grant"] code = grant.mint_token( session_id, - context=self.endpoint_context, + context=self.context, token_class="authorization_code", token_handler=TOKEN_HANDLER["authorization_code"], ) grant.mint_token( session_id, - context=self.endpoint_context, + context=self.context, token_class="access_token", token_handler=TOKEN_HANDLER["access_token"], based_on=code, @@ -325,7 +325,7 @@ def test_json_conversion(self): def test_json_no_token_map(self): session_id = self._create_session(AREQ) - session_info = self.endpoint_context.session_manager.get_session_info( + session_info = self.context.session_manager.get_session_info( session_id=session_id, grant=True ) grant = session_info["grant"] @@ -333,14 +333,14 @@ def test_json_no_token_map(self): with pytest.raises(ValueError): grant.mint_token( session_id, - context=self.endpoint_context, + context=self.context, token_class="authorization_code", token_handler=TOKEN_HANDLER["authorization_code"], ) def test_json_custom_token_map(self): session_id = self._create_session(AREQ) - session_info = self.endpoint_context.session_manager.get_session_info( + session_info = self.context.session_manager.get_session_info( session_id=session_id, grant=True ) grant = session_info["grant"] @@ -350,14 +350,14 @@ def test_json_custom_token_map(self): grant.token_map = token_map code = grant.mint_token( session_id, - context=self.endpoint_context, + context=self.context, token_class="authorization_code", token_handler=TOKEN_HANDLER["authorization_code"], ) grant.mint_token( session_id, - context=self.endpoint_context, + context=self.context, token_class="access_token", token_handler=TOKEN_HANDLER["access_token"], based_on=code, @@ -365,7 +365,7 @@ def test_json_custom_token_map(self): grant.mint_token( session_id, - context=self.endpoint_context, + context=self.context, token_class="my_token", token_handler=DefaultToken("my_token", typ="M"), ) @@ -393,7 +393,7 @@ def test_json_custom_token_map(self): def test_get_spec(self): session_id = self._create_session(AREQ) - session_info = self.endpoint_context.session_manager.get_session_info( + session_info = self.context.session_manager.get_session_info( session_id=session_id, grant=True ) grant = session_info["grant"] @@ -404,14 +404,14 @@ def test_get_spec(self): code = grant.mint_token( session_id, - context=self.endpoint_context, + context=self.context, token_class="authorization_code", token_handler=TOKEN_HANDLER["authorization_code"], ) access_token = grant.mint_token( session_id, - context=self.endpoint_context, + context=self.context, token_class="access_token", token_handler=TOKEN_HANDLER["access_token"], based_on=code, @@ -427,7 +427,7 @@ def test_get_spec(self): def test_get_usage_rules(self): session_id = self._create_session(AREQ) - session_info = self.endpoint_context.session_manager.get_session_info( + session_info = self.context.session_manager.get_session_info( session_id=session_id, grant=True ) grant = session_info["grant"] @@ -437,22 +437,22 @@ def test_get_usage_rules(self): grant.resources = ["https://api.example.com"] # Default usage rules - self.endpoint_context.cdb["client_id"] = {} - rules = get_usage_rules("access_token", self.endpoint_context, grant, "client_id") + self.context.cdb["client_id"] = {} + rules = get_usage_rules("access_token", self.context, grant, "client_id") assert rules == {"supports_minting": [], "expires_in": 3600} # client specific usage rules - self.endpoint_context.cdb["client_id"] = {"access_token": {"expires_in": 600}} + self.context.cdb["client_id"] = {"access_token": {"expires_in": 600}} def test_assigned_scope(self): session_id = self._create_session(AREQ) - session_info = self.endpoint_context.session_manager.get_session_info( + session_info = self.context.session_manager.get_session_info( session_id=session_id, grant=True ) grant = session_info["grant"] code = grant.mint_token( session_id, - context=self.endpoint_context, + context=self.context, token_class="authorization_code", token_handler=TOKEN_HANDLER["authorization_code"], ) @@ -461,7 +461,7 @@ def test_assigned_scope(self): access_token = grant.mint_token( session_id, - context=self.endpoint_context, + context=self.context, token_class="access_token", token_handler=TOKEN_HANDLER["access_token"], based_on=code, @@ -471,13 +471,13 @@ def test_assigned_scope(self): def test_assigned_scope_2nd(self): session_id = self._create_session(AREQ) - session_info = self.endpoint_context.session_manager.get_session_info( + session_info = self.context.session_manager.get_session_info( session_id=session_id, grant=True ) grant = session_info["grant"] code = grant.mint_token( session_id, - context=self.endpoint_context, + context=self.context, token_class="authorization_code", token_handler=TOKEN_HANDLER["authorization_code"], ) @@ -486,7 +486,7 @@ def test_assigned_scope_2nd(self): refresh_token = grant.mint_token( session_id, - context=self.endpoint_context, + context=self.context, token_class="refresh_token", token_handler=TOKEN_HANDLER["refresh_token"], based_on=code, @@ -494,7 +494,7 @@ def test_assigned_scope_2nd(self): access_token = grant.mint_token( session_id, - context=self.endpoint_context, + context=self.context, token_class="access_token", token_handler=TOKEN_HANDLER["access_token"], based_on=refresh_token, @@ -506,7 +506,7 @@ def test_assigned_scope_2nd(self): access_token = grant.mint_token( session_id, - context=self.endpoint_context, + context=self.context, token_class="access_token", token_handler=TOKEN_HANDLER["access_token"], based_on=refresh_token, @@ -516,20 +516,20 @@ def test_assigned_scope_2nd(self): def test_grant_remove_based_on_code(self): session_id = self._create_session(AREQ) - session_info = self.endpoint_context.session_manager.get_session_info( + session_info = self.context.session_manager.get_session_info( session_id=session_id, grant=True ) grant = session_info["grant"] code = grant.mint_token( session_id, - context=self.endpoint_context, + context=self.context, token_class="authorization_code", token_handler=TOKEN_HANDLER["authorization_code"], ) access_token = grant.mint_token( session_id, - context=self.endpoint_context, + context=self.context, token_class="access_token", token_handler=TOKEN_HANDLER["access_token"], based_on=code, @@ -537,7 +537,7 @@ def test_grant_remove_based_on_code(self): refresh_token = grant.mint_token( session_id, - context=self.endpoint_context, + context=self.context, token_class="refresh_token", token_handler=TOKEN_HANDLER["refresh_token"], based_on=code, @@ -548,20 +548,20 @@ def test_grant_remove_based_on_code(self): def test_grant_remove_one_by_one(self): session_id = self._create_session(AREQ) - session_info = self.endpoint_context.session_manager.get_session_info( + session_info = self.context.session_manager.get_session_info( session_id=session_id, grant=True ) grant = session_info["grant"] code = grant.mint_token( session_id, - context=self.endpoint_context, + context=self.context, token_class="authorization_code", token_handler=TOKEN_HANDLER["authorization_code"], ) access_token = grant.mint_token( session_id, - context=self.endpoint_context, + context=self.context, token_class="access_token", token_handler=TOKEN_HANDLER["access_token"], based_on=code, @@ -569,7 +569,7 @@ def test_grant_remove_one_by_one(self): refresh_token = grant.mint_token( session_id, - context=self.endpoint_context, + context=self.context, token_class="refresh_token", token_handler=TOKEN_HANDLER["refresh_token"], based_on=code, diff --git a/tests/test_server_08_id_token.py b/tests/test_server_08_id_token.py index 85c99b1a..fddf289f 100644 --- a/tests/test_server_08_id_token.py +++ b/tests/test_server_08_id_token.py @@ -162,8 +162,8 @@ class TestEndpoint(object): @pytest.fixture(autouse=True) def create_session_manager(self): self.server = Server(conf) - self.endpoint_context = self.server.endpoint_context - self.endpoint_context.cdb["client_1"] = { + self.context = self.server.context + self.context.cdb["client_1"] = { "client_secret": "hemligtochintekort", "redirect_uris": [("https://example.com/cb", None)], "client_salt": "salted", @@ -176,7 +176,7 @@ def create_session_manager(self): "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"] } self.server.keyjar.add_symmetric("client_1", "hemligtochintekort", ["sig", "enc"]) - self.session_manager = self.endpoint_context.session_manager + self.session_manager = self.context.session_manager self.user_id = USER_ID def _create_session(self, auth_req, sub_type="public", sector_identifier="", authn_info=""): @@ -196,7 +196,7 @@ def _mint_code(self, grant, session_id): # Constructing an authorization code is now done return grant.mint_token( session_id=session_id, - context=self.endpoint_context, + context=self.context, token_class="authorization_code", token_handler=self.session_manager.token_handler["authorization_code"], expires_at=utc_time_sans_frac() + 300, # 5 minutes from now @@ -205,7 +205,7 @@ def _mint_code(self, grant, session_id): def _mint_access_token(self, grant, session_id, token_ref): access_token = grant.mint_token( session_id=session_id, - context=self.endpoint_context, + context=self.context, token_class="access_token", token_handler=self.session_manager.token_handler["access_token"], expires_at=utc_time_sans_frac() + 900, # 15 minutes from now @@ -216,7 +216,7 @@ def _mint_access_token(self, grant, session_id, token_ref): def _mint_id_token(self, grant, session_id, token_ref=None, code=None, access_token=None): return grant.mint_token( session_id=session_id, - context=self.endpoint_context, + context=self.context, token_class="id_token", token_handler=self.session_manager.token_handler["id_token"], expires_at=utc_time_sans_frac() + 900, # 15 minutes from now @@ -413,7 +413,7 @@ def test_sign_encrypt_id_token(self): client_keyjar = KeyJar() _jwks = self.server.keyjar.export_jwks() - client_keyjar.import_jwks(_jwks, self.endpoint_context.issuer) + client_keyjar.import_jwks(_jwks, self.context.issuer) _jwt = JWT(key_jar=client_keyjar, iss="client_1") res = _jwt.unpack(id_token.value) @@ -421,9 +421,9 @@ def test_sign_encrypt_id_token(self): assert res["aud"] == ["client_1"] def test_get_sign_algorithm(self): - client_info = self.endpoint_context.cdb[AREQ["client_id"]] + client_info = self.context.cdb[AREQ["client_id"]] algs = get_sign_and_encrypt_algorithms( - self.endpoint_context, + self.context, client_info, "id_token", sign=True, @@ -432,15 +432,13 @@ def test_get_sign_algorithm(self): assert algs == {"sign": True, "encrypt": False, "sign_alg": "RS256"} algs = get_sign_and_encrypt_algorithms( - self.endpoint_context, client_info, "id_token", sign=True, encrypt=True + self.context, client_info, "id_token", sign=True, encrypt=True ) # default signing alg assert algs == { "sign": True, "encrypt": True, - "sign_alg": "RS256", - "enc_alg": "RSA-OAEP", - "enc_enc": "A128CBC-HS256", + "sign_alg": "RS256" } def test_available_claims(self): @@ -453,7 +451,7 @@ def test_available_claims(self): client_keyjar = KeyJar() _jwks = self.server.keyjar.export_jwks() - client_keyjar.import_jwks(_jwks, self.endpoint_context.issuer) + client_keyjar.import_jwks(_jwks, self.context.issuer) _jwt = JWT(key_jar=client_keyjar, iss="client_1") res = _jwt.unpack(id_token.value) assert "nickname" in res @@ -466,7 +464,7 @@ def test_lifetime_default(self): client_keyjar = KeyJar() _jwks = self.server.keyjar.export_jwks() - client_keyjar.import_jwks(_jwks, self.endpoint_context.issuer) + client_keyjar.import_jwks(_jwks, self.context.issuer) _jwt = JWT(key_jar=client_keyjar, iss="client_1") res = _jwt.unpack(id_token.value) @@ -483,7 +481,7 @@ def test_lifetime(self): client_keyjar = KeyJar() _jwks = self.server.keyjar.export_jwks() - client_keyjar.import_jwks(_jwks, self.endpoint_context.issuer) + client_keyjar.import_jwks(_jwks, self.context.issuer) _jwt = JWT(key_jar=client_keyjar, iss="client_1") res = _jwt.unpack(id_token.value) @@ -498,7 +496,7 @@ def test_no_available_claims(self): client_keyjar = KeyJar() _jwks = self.server.keyjar.export_jwks() - client_keyjar.import_jwks(_jwks, self.endpoint_context.issuer) + client_keyjar.import_jwks(_jwks, self.context.issuer) _jwt = JWT(key_jar=client_keyjar, iss="client_1") res = _jwt.unpack(id_token.value) assert "foobar" not in res @@ -508,11 +506,11 @@ def test_client_claims(self): grant = self.session_manager[session_id] self.session_manager.token_handler["id_token"].kwargs["enable_claims_per_client"] = True - self.endpoint_context.cdb["client_1"]["add_claims"]["always"]["id_token"] = { + self.context.cdb["client_1"]["add_claims"]["always"]["id_token"] = { "address": None } - _claims = self.endpoint_context.claims_interface.get_claims( + _claims = self.context.claims_interface.get_claims( session_id=session_id, scopes=AREQ["scope"], claims_release_point="id_token" ) grant.claims = {"id_token": _claims} @@ -521,7 +519,7 @@ def test_client_claims(self): client_keyjar = KeyJar() _jwks = self.server.keyjar.export_jwks() - client_keyjar.import_jwks(_jwks, self.endpoint_context.issuer) + client_keyjar.import_jwks(_jwks, self.context.issuer) _jwt = JWT(key_jar=client_keyjar, iss="client_1") res = _jwt.unpack(id_token.value) assert "address" in res @@ -531,7 +529,7 @@ def test_client_claims_with_default(self): session_id = self._create_session(AREQ) grant = self.session_manager[session_id] - _claims = self.endpoint_context.claims_interface.get_claims( + _claims = self.context.claims_interface.get_claims( session_id=session_id, scopes=AREQ["scope"], claims_release_point="id_token" ) grant.claims = {"id_token": _claims} @@ -540,7 +538,7 @@ def test_client_claims_with_default(self): client_keyjar = KeyJar() _jwks = self.server.keyjar.export_jwks() - client_keyjar.import_jwks(_jwks, self.endpoint_context.issuer) + client_keyjar.import_jwks(_jwks, self.context.issuer) _jwt = JWT(key_jar=client_keyjar, iss="client_1") res = _jwt.unpack(id_token.value) @@ -559,7 +557,7 @@ def test_client_claims_scopes(self): client_keyjar = KeyJar() _jwks = self.server.keyjar.export_jwks() - client_keyjar.import_jwks(_jwks, self.endpoint_context.issuer) + client_keyjar.import_jwks(_jwks, self.context.issuer) _jwt = JWT(key_jar=client_keyjar, iss="client_1") res = _jwt.unpack(id_token.value) assert "address" in res @@ -570,9 +568,9 @@ def test_client_claims_scopes_per_client(self): session_id = self._create_session(AREQS) grant = self.session_manager[session_id] self.session_manager.token_handler["id_token"].kwargs["add_claims_by_scope"] = True - self.endpoint_context.cdb[AREQS["client_id"]]["add_claims"]["by_scope"]["id_token"] = False + self.context.cdb[AREQS["client_id"]]["add_claims"]["by_scope"]["id_token"] = False - _claims = self.endpoint_context.claims_interface.get_claims( + _claims = self.context.claims_interface.get_claims( session_id=session_id, scopes=AREQS["scope"], claims_release_point="id_token" ) grant.claims = {"id_token": _claims} @@ -581,7 +579,7 @@ def test_client_claims_scopes_per_client(self): client_keyjar = KeyJar() _jwks = self.server.keyjar.export_jwks() - client_keyjar.import_jwks(_jwks, self.endpoint_context.issuer) + client_keyjar.import_jwks(_jwks, self.context.issuer) _jwt = JWT(key_jar=client_keyjar, iss="client_1") res = _jwt.unpack(id_token.value) assert "address" in res @@ -599,7 +597,7 @@ def test_client_claims_scopes_and_request_claims_no_match(self): client_keyjar = KeyJar() _jwks = self.server.keyjar.export_jwks() - client_keyjar.import_jwks(_jwks, self.endpoint_context.issuer) + client_keyjar.import_jwks(_jwks, self.context.issuer) _jwt = JWT(key_jar=client_keyjar, iss="client_1") res = _jwt.unpack(id_token.value) # User information, from scopes -> claims @@ -622,7 +620,7 @@ def test_client_claims_scopes_and_request_claims_one_match(self): client_keyjar = KeyJar() _jwks = self.server.keyjar.export_jwks() - client_keyjar.import_jwks(_jwks, self.endpoint_context.issuer) + client_keyjar.import_jwks(_jwks, self.context.issuer) _jwt = JWT(key_jar=client_keyjar, iss="client_1") res = _jwt.unpack(id_token.value) # Email didn't match @@ -640,8 +638,8 @@ def test_id_token_info(self): grant, session_id, token_ref=code, access_token=access_token.value ) - endpoint_context = self.endpoint_context - sman = endpoint_context.session_manager + context = self.context + sman = context.session_manager _info = self.session_manager.token_handler.info(id_token.value) assert "sid" in _info assert "exp" in _info @@ -654,9 +652,9 @@ def test_id_token_info(self): # TODO: we need an authentication event for this id_token for a better coverage _id_token.payload(session_id) - client_info = endpoint_context.cdb[client_id] + client_info = context.cdb[client_id] get_sign_and_encrypt_algorithms( - endpoint_context, client_info, payload_type="id_token", sign=True, encrypt=True + context, client_info, payload_type="id_token", sign=True, encrypt=True ) def test_id_token_acr_claim(self): diff --git a/tests/test_server_09_authn_context.py b/tests/test_server_09_authn_context.py index 07c5dcf9..0d77920b 100644 --- a/tests/test_server_09_authn_context.py +++ b/tests/test_server_09_authn_context.py @@ -150,8 +150,8 @@ def create_authn_broker(self): # cookie_handler = CookieHandler(**cookie_conf) # server = Server(conf, cookie_handler=cookie_handler) server = Server(conf) - endpoint_context = server.endpoint_context - endpoint_context.cdb["client_1"] = { + context = server.context + context.cdb["client_1"] = { "client_secret": "hemligt", "redirect_uris": [("https://example.com/cb", None)], "client_salt": "salted", diff --git a/tests/test_server_12_session_life.py b/tests/test_server_12_session_life.py index 9a86d74b..2bbd3856 100644 --- a/tests/test_server_12_session_life.py +++ b/tests/test_server_12_session_life.py @@ -48,8 +48,8 @@ def setup_token_handler(self): } server = Server(OPConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) - self.endpoint_context = server.endpoint_context - self.session_manager = self.endpoint_context.session_manager + self.context = server.context + self.session_manager = self.context.session_manager def auth(self): # Start with an authentication request @@ -97,7 +97,7 @@ def auth(self): code = grant.mint_token( session_id=session_id, - context=self.endpoint_context, + context=self.context, token_class="authorization_code", token_handler=self.session_manager.token_handler["authorization_code"], expires_at=utc_time_sans_frac() + 300, # 5 minutes from now @@ -145,7 +145,7 @@ def test_code_flow(self): grant.mint_token( session_id=session_id, - context=self.endpoint_context, + context=self.context, token_class="access_token", token_handler=self.session_manager.token_handler["access_token"], expires_at=utc_time_sans_frac() + 900, # 15 minutes from now @@ -156,7 +156,7 @@ def test_code_flow(self): refresh_token = grant.mint_token( session_id=session_id, - context=self.endpoint_context, + context=self.context, token_class="refresh_token", token_handler=self.session_manager.token_handler["refresh_token"], based_on=tok, @@ -182,7 +182,7 @@ def test_code_flow(self): access_token_2 = grant.mint_token( session_id=session_id, - context=self.endpoint_context, + context=self.context, token_class="access_token", token_handler=self.session_manager.token_handler["access_token"], expires_at=utc_time_sans_frac() + 900, # 15 minutes from now @@ -279,10 +279,10 @@ def setup_session_manager(self): }, } server = Server(OPConfiguration(conf=conf, base_path=BASEDIR), keyjar=KEYJAR, cwd=BASEDIR) - self.endpoint_context = server.endpoint_context - self.session_manager = self.endpoint_context.session_manager - # self.session_manager = SessionManager(handler=self.endpoint_context.sdb.handler) - # self.endpoint_context.session_manager = self.session_manager + self.context = server.context + self.session_manager = self.context.session_manager + # self.session_manager = SessionManager(handler=self.context.sdb.handler) + # self.context.session_manager = self.session_manager def auth(self): # Start with an authentication request @@ -330,7 +330,7 @@ def auth(self): # Constructing an authorization code is now done by code = grant.mint_token( session_id=session_id, - context=self.endpoint_context, + context=self.context, token_class="authorization_code", token_handler=self.session_manager.token_handler["authorization_code"], expires_at=utc_time_sans_frac() + 300, # 5 minutes from now @@ -378,7 +378,7 @@ def test_code_flow(self): grant.mint_token( session_id=session_id, - context=self.endpoint_context, + context=self.context, token_class="access_token", token_handler=self.session_manager.token_handler["access_token"], expires_at=utc_time_sans_frac() + 900, # 15 minutes from now @@ -390,7 +390,7 @@ def test_code_flow(self): refresh_token = grant.mint_token( session_id=session_id, - context=self.endpoint_context, + context=self.context, token_class="refresh_token", token_handler=self.session_manager.token_handler["refresh_token"], based_on=tok, @@ -418,13 +418,13 @@ def test_code_flow(self): # Can I use this token to mint another token ? assert grant.is_active() - user_claims = self.endpoint_context.userinfo( + user_claims = self.context.userinfo( user_id, client_id=TOKEN_REQ["client_id"], user_info_claims=grant.claims ) access_token_2 = grant.mint_token( session_id=session_id, - context=self.endpoint_context, + context=self.context, token_class="access_token", token_handler=self.session_manager.token_handler["access_token"], expires_at=utc_time_sans_frac() + 900, # 15 minutes from now diff --git a/tests/test_server_13_user_authn.py b/tests/test_server_13_user_authn.py index c274855f..555d258d 100644 --- a/tests/test_server_13_user_authn.py +++ b/tests/test_server_13_user_authn.py @@ -80,8 +80,8 @@ def create_endpoint_context(self): }, } self.server = Server(OPConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) - self.endpoint_context = self.server.endpoint_context - self.session_manager = self.endpoint_context.session_manager + self.context = self.server.context + self.session_manager = self.context.session_manager self.user_id = "diana" def _create_session(self, auth_req, sub_type="public", sector_identifier=""): @@ -97,20 +97,20 @@ def _create_session(self, auth_req, sub_type="public", sector_identifier=""): ) def test_authenticated_as_without_cookie(self): - authn_item = self.endpoint_context.authn_broker.pick(INTERNETPROTOCOLPASSWORD) + authn_item = self.context.authn_broker.pick(INTERNETPROTOCOLPASSWORD) method = authn_item[0]["method"] _info, _time_stamp = method.authenticated_as(None) assert _info is None def test_authenticated_as_with_cookie(self): - authn_item = self.endpoint_context.authn_broker.pick(INTERNETPROTOCOLPASSWORD) + authn_item = self.context.authn_broker.pick(INTERNETPROTOCOLPASSWORD) method = authn_item[0]["method"] authn_req = {"state": "state_identifier", "client_id": "client 12345"} _sid = self._create_session(authn_req) - _cookie = self.endpoint_context.new_cookie( - name=self.endpoint_context.cookie_handler.name["session"], + _cookie = self.context.new_cookie( + name=self.context.cookie_handler.name["session"], sub="diana", sid=_sid, state=authn_req["state"], @@ -118,8 +118,8 @@ def test_authenticated_as_with_cookie(self): ) # Parsed once before authenticated_as - kakor = self.endpoint_context.cookie_handler.parse_cookie( - cookies=[_cookie], name=self.endpoint_context.cookie_handler.name["session"] + kakor = self.context.cookie_handler.parse_cookie( + cookies=[_cookie], name=self.context.cookie_handler.name["session"] ) _info, _time_stamp = method.authenticated_as("client 12345", kakor) @@ -139,7 +139,7 @@ def test_userpassjinja2(self): "class": JSONDictDB, "kwargs": {"filename": full_path("passwd.json")}, } - template_handler = self.endpoint_context.template_handler + template_handler = self.context.template_handler res = UserPassJinja2(db, template_handler, upstream_get=self.server.unit_get) res() assert "page_header" in res.kwargs diff --git a/tests/test_server_15_login_hint.py b/tests/test_server_15_login_hint.py index 45fe267b..b6ca121b 100644 --- a/tests/test_server_15_login_hint.py +++ b/tests/test_server_15_login_hint.py @@ -46,4 +46,4 @@ def test_server_login_hint_lookup(): configuration = OPConfiguration(conf=_conf, base_path=BASEDIR, domain="127.0.0.1", port=443) server = Server(configuration) - assert server.endpoint_context.login_hint_lookup("tel:0907865000") == "diana" + assert server.context.login_hint_lookup("tel:0907865000") == "diana" diff --git a/tests/test_server_16_endpoint.py b/tests/test_server_16_endpoint.py index 25b863f6..9ebf8173 100755 --- a/tests/test_server_16_endpoint.py +++ b/tests/test_server_16_endpoint.py @@ -76,8 +76,8 @@ def create_endpoint(self): } server = Server(OPConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) - server.endpoint_context.cdb["client_id"] = {} - self.endpoint_context = server.endpoint_context + server.context.cdb["client_id"] = {} + self.context = server.context _endpoints = do_endpoints(conf, server.unit_get) self.endpoint = _endpoints[""] @@ -89,7 +89,7 @@ def test_parse_urlencoded(self): def test_parse_url(self): self.endpoint.request_format = "url" - request = "{}?{}".format(self.endpoint_context.issuer, REQ.to_urlencoded()) + request = "{}?{}".format(self.context.issuer, REQ.to_urlencoded()) req = self.endpoint.parse_request(request, http_info={}) assert req == REQ diff --git a/tests/test_server_16_endpoint_context.py b/tests/test_server_16_endpoint_context.py index af5d04b3..751018c3 100644 --- a/tests/test_server_16_endpoint_context.py +++ b/tests/test_server_16_endpoint_context.py @@ -4,7 +4,7 @@ import pytest from cryptojwt.key_jar import build_keyjar -from idpyoidc import work_environment +from idpyoidc import metadata from idpyoidc.server import OPConfiguration from idpyoidc.server import Server from idpyoidc.server.endpoint import Endpoint @@ -27,9 +27,9 @@ class Endpoint_1(Endpoint): name = "userinfo" _supports = { "claim_types_supported": ["normal", "aggregated", "distributed"], - "userinfo_signing_alg_values_supported": work_environment.get_signing_algs, - "userinfo_encryption_alg_values_supported": work_environment.get_encryption_algs, - "userinfo_encryption_enc_values_supported": work_environment.get_encryption_encs, + "userinfo_signing_alg_values_supported": metadata.get_signing_algs, + "userinfo_encryption_alg_values_supported": metadata.get_encryption_algs, + "userinfo_encryption_enc_values_supported": metadata.get_encryption_encs, "client_authn_method": ["bearer_header", "bearer_body"], "encrypt_userinfo_supported": False, } @@ -97,40 +97,42 @@ class TestEndpointContext: @pytest.fixture(autouse=True) def create_endpoint_context(self): server = Server(conf) - self.endpoint_context = server.endpoint_context + server.context.map_supported_to_preferred() + self.context = server.context def test(self): - assert set(self.endpoint_context.provider_info.keys()) == { + self.context.set_provider_info() + assert set(self.context.provider_info.keys()) == { 'grant_types_supported', - 'id_token_encryption_alg_values_supported', - 'id_token_encryption_enc_values_supported', 'id_token_signing_alg_values_supported', 'issuer', 'jwks_uri', 'scopes_supported', - 'userinfo_signing_alg_values_supported'} + 'subject_types_supported', + 'userinfo_signing_alg_values_supported', + 'version'} def test_allow_refresh_token(self): - assert allow_refresh_token(self.endpoint_context) + assert allow_refresh_token(self.context) # Have the software but is not expected to use it. - self.endpoint_context.set_preference("grant_types_supported", [ + self.context.set_preference("grant_types_supported", [ "authorization_code", "implicit", "urn:ietf:params:oauth:grant-type:jwt-bearer", ]) - assert allow_refresh_token(self.endpoint_context) is False + assert allow_refresh_token(self.context) is False # Don't have the software but are expected to use it. - self.endpoint_context.set_preference("grant_types_supported", [ + self.context.set_preference("grant_types_supported", [ "authorization_code", "implicit", "urn:ietf:params:oauth:grant-type:jwt-bearer", "refresh_token", ]) - del self.endpoint_context.session_manager.token_handler.handler["refresh_token"] + del self.context.session_manager.token_handler.handler["refresh_token"] with pytest.raises(OidcEndpointError): - assert allow_refresh_token(self.endpoint_context) is False + assert allow_refresh_token(self.context) is False class Tokenish(Endpoint): @@ -194,16 +196,18 @@ def test_provider_configuration(kwargs): } server = Server(OPConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) - server.endpoint_context.cdb["client_id"] = {} - pi = server.endpoint_context.provider_info - assert set(pi.keys()) == {'grant_types_supported', - 'id_token_encryption_alg_values_supported', - 'id_token_encryption_enc_values_supported', + server.context.cdb["client_id"] = {} + server.context.set_provider_info() + pi = server.context.provider_info + assert set(pi.keys()) == {'acr_values_supported', + 'grant_types_supported', 'id_token_signing_alg_values_supported', 'issuer', 'jwks_uri', 'scopes_supported', - 'token_endpoint_auth_methods_supported'} + 'subject_types_supported', + 'token_endpoint_auth_methods_supported', + 'version'} if kwargs: if 'token_endpoint_auth_methods_supported' in kwargs: diff --git a/tests/test_server_17_client_authn.py b/tests/test_server_17_client_authn.py index 9b72e6ae..135ffee1 100644 --- a/tests/test_server_17_client_authn.py +++ b/tests/test_server_17_client_authn.py @@ -117,9 +117,9 @@ class Endpoint_4(Endpoint): KEYJAR.add_symmetric(client_id, client_secret, ["sig"]) -def get_client_id_from_token(endpoint_context, token, request=None): +def get_client_id_from_token(context, token, request=None): if "client_id" in request: - if request["client_id"] == endpoint_context.registration_access_token[token]: + if request["client_id"] == context.registration_access_token[token]: return request["client_id"] return "" @@ -128,8 +128,8 @@ class TestClientSecretBasic: @pytest.fixture(autouse=True) def setup(self): server = Server(conf=CONF, keyjar=KEYJAR) - server.endpoint_context.cdb[client_id] = {"client_secret": client_secret} - self.endpoint_context = server.endpoint_context + server.context.cdb[client_id] = {"client_secret": client_secret} + self.context = server.context server.endpoint = do_endpoints(CONF, server.unit_get) self.method = ClientSecretBasic(server.unit_get) @@ -163,8 +163,8 @@ class TestClientSecretPost: @pytest.fixture(autouse=True) def create_method(self): server = Server(conf=CONF, keyjar=KEYJAR) - server.endpoint_context.cdb[client_id] = {"client_secret": client_secret} - self.endpoint_context = server.endpoint_context + server.context.cdb[client_id] = {"client_secret": client_secret} + self.context = server.context self.method = ClientSecretPost(server.unit_get) def test_client_secret_post(self): @@ -186,8 +186,8 @@ class TestClientSecretJWT: @pytest.fixture(autouse=True) def create_method(self): server = Server(conf=CONF, keyjar=KEYJAR) - server.endpoint_context.cdb[client_id] = {"client_secret": client_secret} - self.endpoint_context = server.endpoint_context + server.context.cdb[client_id] = {"client_secret": client_secret} + self.context = server.context self.method = ClientSecretJWT(server.unit_get) def test_client_secret_jwt(self): @@ -213,10 +213,10 @@ class TestPrivateKeyJWT: @pytest.fixture(autouse=True) def create_method(self): server = Server(conf=CONF, keyjar=KEYJAR) - server.endpoint_context.cdb[client_id] = {"client_secret": client_secret} + server.context.cdb[client_id] = {"client_secret": client_secret} server.endpoint = do_endpoints(CONF, server.unit_get) self.server = server - self.endpoint_context = server.endpoint_context + self.context = server.context self.method = PrivateKeyJWT(server.unit_get) def test_private_key_jwt(self): @@ -306,10 +306,10 @@ class TestBearerHeader: @pytest.fixture(autouse=True) def create_method(self): server = Server(conf=CONF, keyjar=KEYJAR) - server.endpoint_context.cdb[client_id] = {"client_secret": client_secret} + server.context.cdb[client_id] = {"client_secret": client_secret} server.endpoint = do_endpoints(CONF, server.unit_get) self.server = server - self.endpoint_context = server.endpoint_context + self.context = server.context self.method = BearerHeader(server.unit_get) def test_bearerheader(self): @@ -329,10 +329,10 @@ class TestBearerBody: @pytest.fixture(autouse=True) def create_method(self): server = Server(conf=CONF, keyjar=KEYJAR) - server.endpoint_context.cdb[client_id] = {"client_secret": client_secret} + server.context.cdb[client_id] = {"client_secret": client_secret} server.endpoint = do_endpoints(CONF, server.unit_get) self.server = server - self.endpoint_context = server.endpoint_context + self.context = server.context self.method = BearerBody(server.unit_get) def test_bearer_body(self): @@ -349,10 +349,10 @@ class TestJWSAuthnMethod: @pytest.fixture(autouse=True) def create_method(self): server = Server(conf=CONF, keyjar=KEYJAR) - server.endpoint_context.cdb[client_id] = {"client_secret": client_secret} + server.context.cdb[client_id] = {"client_secret": client_secret} server.endpoint = do_endpoints(CONF, server.unit_get) self.server = server - self.endpoint_context = server.endpoint_context + self.context = server.context self.method = JWSAuthnMethod(server.unit_get) def test_jws_authn_method_wrong_key(self): @@ -473,12 +473,12 @@ class TestVerify: @pytest.fixture(autouse=True) def create_method(self): self.server = Server(conf=CONF, keyjar=KEYJAR) - self.server.endpoint_context.cdb[client_id] = {"client_secret": client_secret} + self.server.context.cdb[client_id] = {"client_secret": client_secret} self.server.endpoint = do_endpoints(CONF, self.server.unit_get) - self.endpoint_context = self.server.get_context() + self.context = self.server.get_context() def test_verify_per_client(self): - self.server.endpoint_context.cdb[client_id]["client_authn_method"] = ["public"] + self.server.context.cdb[client_id]["client_authn_method"] = ["public"] request = {"client_id": client_id} res = verify_client( @@ -488,10 +488,10 @@ def test_verify_per_client(self): assert res == {"method": "public", "client_id": client_id} def test_verify_per_client_per_endpoint(self): - self.server.endpoint_context.cdb[client_id]["registration_endpoint_client_authn_method"] = [ + self.server.context.cdb[client_id]["registration_endpoint_client_authn_method"] = [ "public" ] - self.server.endpoint_context.cdb[client_id]["token_endpoint_client_authn_method"] = [ + self.server.context.cdb[client_id]["token_endpoint_client_authn_method"] = [ "client_secret_post" ] @@ -549,7 +549,7 @@ def test_verify_client_jws_authn_method(self): def test_verify_client_bearer_body(self): request = {"access_token": "1234567890", "client_id": client_id} - self.endpoint_context.registration_access_token["1234567890"] = client_id + self.context.registration_access_token["1234567890"] = client_id res = verify_client( request=request, get_client_id_from_token=get_client_id_from_token, @@ -574,7 +574,7 @@ def test_verify_client_client_secret_basic(self): def test_verify_client_bearer_header(self): # A prerequisite for the get_client_id_from_token function - self.endpoint_context.registration_access_token["1234567890"] = client_id + self.context.registration_access_token["1234567890"] = client_id token = "Bearer 1234567890" http_info = {"headers": {"authorization": token}} @@ -593,9 +593,9 @@ class TestVerify2: @pytest.fixture(autouse=True) def create_method(self): self.server = Server(conf=CONF, keyjar=KEYJAR) - self.server.endpoint_context.cdb[client_id] = {"client_secret": client_secret} + self.server.context.cdb[client_id] = {"client_secret": client_secret} self.server.endpoint = do_endpoints(CONF, self.server.unit_get) - self.endpoint_context = self.server.get_context() + self.context = self.server.get_context() def test_verify_client_jws_authn_method(self): client_keyjar = KeyJar() @@ -619,7 +619,7 @@ def test_verify_client_jws_authn_method(self): def test_verify_client_bearer_body(self): request = {"access_token": "1234567890", "client_id": client_id} - self.endpoint_context.registration_access_token["1234567890"] = client_id + self.context.registration_access_token["1234567890"] = client_id res = verify_client( request=request, get_client_id_from_token=get_client_id_from_token, @@ -653,7 +653,7 @@ def test_verify_client_client_secret_basic(self): def test_verify_client_bearer_header(self): # A prerequisite for the get_client_id_from_token function - self.endpoint_context.registration_access_token["1234567890"] = client_id + self.context.registration_access_token["1234567890"] = client_id token = "Bearer 1234567890" http_info = {"headers": {"authorization": token}} @@ -707,7 +707,7 @@ class Mock: conf["client_authn_methods"] = {"custom": MagicMock(return_value=mock)} conf["endpoint"]["registration"]["kwargs"]["client_authn_method"] = ["custom"] server = Server(conf=conf, keyjar=KEYJAR) - server.endpoint_context.cdb[client_id] = {"client_secret": client_secret} + server.context.cdb[client_id] = {"client_secret": client_secret} server.endpoint = do_endpoints(CONF, server.unit_get) request = {"redirect_uris": ["https://example.com/cb"]} diff --git a/tests/test_server_20a_server.py b/tests/test_server_20a_server.py index fbcf6008..3f41200f 100755 --- a/tests/test_server_20a_server.py +++ b/tests/test_server_20a_server.py @@ -117,7 +117,7 @@ def test_capabilities_default(): configuration = OPConfiguration(conf=_conf, base_path=BASEDIR, domain="127.0.0.1", port=443) server = Server(configuration) - assert set(server.endpoint_context.provider_info["response_types_supported"]) == { + assert set(server.context.provider_info["response_types_supported"]) == { "code", "token", "id_token", @@ -126,8 +126,8 @@ def test_capabilities_default(): "id_token token", "code id_token token", } - assert server.endpoint_context.provider_info["request_uri_parameter_supported"] is True - assert server.endpoint_context.get_preference('jwks_uri') == \ + assert server.context.provider_info["request_uri_parameter_supported"] is True + assert server.context.get_preference('jwks_uri') == \ "https://127.0.0.1:443/static/jwks.json" @@ -135,14 +135,14 @@ def test_capabilities_subset1(): _cnf = deepcopy(CONF) _cnf["response_types_supported"] = ["code"] server = Server(_cnf) - assert server.endpoint_context.provider_info["response_types_supported"] == ["code"] + assert server.context.provider_info["response_types_supported"] == ["code"] def test_capabilities_subset2(): _cnf = deepcopy(CONF) _cnf["response_types_supported"] = ["code", "id_token"] server = Server(_cnf) - assert set(server.endpoint_context.provider_info["response_types_supported"]) == { + assert set(server.context.provider_info["response_types_supported"]) == { "code", "id_token", } @@ -152,16 +152,16 @@ def test_capabilities_bool(): _cnf = deepcopy(CONF) _cnf["request_uri_parameter_supported"] = False server = Server(_cnf) - assert server.endpoint_context.provider_info["request_uri_parameter_supported"] is False + assert server.context.provider_info["request_uri_parameter_supported"] is False def test_cdb(): _cnf = deepcopy(CONF) server = Server(_cnf) _clients = yaml.safe_load(io.StringIO(client_yaml)) - server.endpoint_context.cdb = _clients["oidc_clients"] + server.context.cdb = _clients["oidc_clients"] - assert set(server.endpoint_context.cdb.keys()) == {"client1", "client2", "client3"} + assert set(server.context.cdb.keys()) == {"client1", "client2", "client3"} def test_cdb_afs(): @@ -171,4 +171,4 @@ def test_cdb_afs(): "kwargs": {"fdir": full_path("afs"), "value_conv": "idpyoidc.util.JSON"}, } server = Server(_cnf) - assert isinstance(server.endpoint_context.cdb, AbstractFileSystem) + assert isinstance(server.context.cdb, AbstractFileSystem) diff --git a/tests/test_server_20b_claims.py b/tests/test_server_20b_claims.py index 81d290b4..1d95fece 100644 --- a/tests/test_server_20b_claims.py +++ b/tests/test_server_20b_claims.py @@ -116,7 +116,7 @@ class TestEndpoint(object): @pytest.fixture(autouse=True) def create_idtoken(self): server = Server(conf) - server.endpoint_context.cdb["client_1"] = { + server.context.cdb["client_1"] = { "client_secret": "hemligtochintekort", "redirect_uris": [("https://example.com/cb", None)], "client_salt": "salted", @@ -130,9 +130,9 @@ def create_idtoken(self): server.keyjar.add_symmetric( "client_1", "hemligtochintekort", ["sig", "enc"] ) - self.claims_interface = server.endpoint_context.claims_interface - self.endpoint_context = server.endpoint_context - self.session_manager = self.endpoint_context.session_manager + self.claims_interface = server.context.claims_interface + self.context = server.context + self.session_manager = self.context.session_manager self.user_id = USER_ID self.server = server @@ -163,7 +163,7 @@ def test_get_claims_userinfo_3(self): "enable_claims_per_client": True, "add_claims_by_scope": True, } - self.endpoint_context.cdb["client_1"]["add_claims"]["always"]["userinfo"] = [ + self.context.cdb["client_1"]["add_claims"]["always"]["userinfo"] = [ "name", "email", ] @@ -184,7 +184,7 @@ def test_get_claims_introspection_3(self): "enable_claims_per_client": True, "add_claims_by_scope": True, } - self.endpoint_context.cdb["client_1"]["add_claims"]["always"]["introspection"] = [ + self.context.cdb["client_1"]["add_claims"]["always"]["introspection"] = [ "name", "email", ] @@ -229,14 +229,14 @@ def test_get_claims_all_usage_2(self): self.server.get_endpoint("userinfo").kwargs = { "enable_claims_per_client": True, } - self.endpoint_context.cdb["client_1"]["add_claims"]["always"]["userinfo"] = [ + self.context.cdb["client_1"]["add_claims"]["always"]["userinfo"] = [ "name", "email", ] self.server.get_endpoint("introspection").kwargs = {"add_claims_by_scope": True} - self.endpoint_context.session_manager.token_handler["access_token"].kwargs = {} + self.context.session_manager.token_handler["access_token"].kwargs = {} session_id = self._create_session(AREQ) claims = self.claims_interface.get_claims_all_usage(session_id, ["openid", "address"]) @@ -261,14 +261,14 @@ def test_get_user_claims(self): self.server.get_endpoint("userinfo").kwargs = { "enable_claims_per_client": True, } - self.endpoint_context.cdb["client_1"]["add_claims"]["always"]["userinfo"] = [ + self.context.cdb["client_1"]["add_claims"]["always"]["userinfo"] = [ "name", "email", ] self.server.get_endpoint("introspection").kwargs = {"add_claims_by_scope": True} - self.endpoint_context.session_manager.token_handler["access_token"].kwargs = {} + self.context.session_manager.token_handler["access_token"].kwargs = {} session_id = self._create_session(AREQ) claims_restriction = self.claims_interface.get_claims_all_usage( diff --git a/tests/test_server_20c_authz_handling.py b/tests/test_server_20c_authz_handling.py index e2ea920d..797e5450 100644 --- a/tests/test_server_20c_authz_handling.py +++ b/tests/test_server_20c_authz_handling.py @@ -102,7 +102,7 @@ class TestEndpoint(object): @pytest.fixture(autouse=True) def create_idtoken(self): server = Server(conf) - server.endpoint_context.cdb["client_1"] = { + server.context.cdb["client_1"] = { "client_secret": "hemligtochintekort", "redirect_uris": [("https://example.com/cb", None)], "client_salt": "salted", @@ -113,10 +113,10 @@ def create_idtoken(self): server.keyjar.add_symmetric( "client_1", "hemligtochintekort", ["sig", "enc"] ) - self.session_manager = server.endpoint_context.session_manager + self.session_manager = server.context.session_manager self.user_id = USER_ID self.server = server - self.authz = server.endpoint_context.authz + self.authz = server.context.authz def _create_session(self, auth_req, sub_type="public", sector_identifier=""): if sector_identifier: diff --git a/tests/test_server_20d_client_authn.py b/tests/test_server_20d_client_authn.py index 2bb761d8..9af23ca0 100755 --- a/tests/test_server_20d_client_authn.py +++ b/tests/test_server_20d_client_authn.py @@ -80,9 +80,9 @@ KEYJAR.add_symmetric(client_id, client_secret, ["sig"]) -def get_client_id_from_token(endpoint_context, token, request=None): +def get_client_id_from_token(context, token, request=None): if "client_id" in request: - if request["client_id"] == endpoint_context.registration_access_token[token]: + if request["client_id"] == context.registration_access_token[token]: return request["client_id"] return "" @@ -91,8 +91,8 @@ class TestClientSecretBasic: @pytest.fixture(autouse=True) def setup(self): server = Server(conf=CONF, keyjar=KEYJAR) - server.endpoint_context.cdb[client_id] = {"client_secret": client_secret} - self.endpoint_context = server.endpoint_context + server.context.cdb[client_id] = {"client_secret": client_secret} + self.context = server.context self.method = ClientSecretBasic(server.unit_get) def test_client_secret_basic(self): @@ -125,8 +125,8 @@ class TestClientSecretPost: @pytest.fixture(autouse=True) def create_method(self): server = Server(conf=CONF, keyjar=KEYJAR) - server.endpoint_context.cdb[client_id] = {"client_secret": client_secret} - self.endpoint_context = server.endpoint_context + server.context.cdb[client_id] = {"client_secret": client_secret} + self.context = server.context self.method = ClientSecretPost(server.unit_get) def test_client_secret_post(self): @@ -148,8 +148,8 @@ class TestClientSecretJWT: @pytest.fixture(autouse=True) def create_method(self): server = Server(conf=CONF, keyjar=KEYJAR) - server.endpoint_context.cdb[client_id] = {"client_secret": client_secret} - self.endpoint_context = server.endpoint_context + server.context.cdb[client_id] = {"client_secret": client_secret} + self.context = server.context self.method = ClientSecretJWT(server.unit_get) def test_client_secret_jwt(self): @@ -175,9 +175,9 @@ class TestPrivateKeyJWT: @pytest.fixture(autouse=True) def create_method(self): server = Server(conf=CONF, keyjar=KEYJAR) - server.endpoint_context.cdb[client_id] = {"client_secret": client_secret} + server.context.cdb[client_id] = {"client_secret": client_secret} self.server = server - self.endpoint_context = server.endpoint_context + self.context = server.context self.method = PrivateKeyJWT(server.unit_get) def test_private_key_jwt(self): @@ -263,9 +263,9 @@ class TestBearerHeader: @pytest.fixture(autouse=True) def create_method(self): server = Server(conf=CONF, keyjar=KEYJAR) - server.endpoint_context.cdb[client_id] = {"client_secret": client_secret} + server.context.cdb[client_id] = {"client_secret": client_secret} self.server = server - self.endpoint_context = server.endpoint_context + self.context = server.context self.method = BearerHeader(server.unit_get) def test_bearerheader(self): @@ -285,9 +285,9 @@ class TestBearerBody: @pytest.fixture(autouse=True) def create_method(self): server = Server(conf=CONF, keyjar=KEYJAR) - server.endpoint_context.cdb[client_id] = {"client_secret": client_secret} + server.context.cdb[client_id] = {"client_secret": client_secret} self.server = server - self.endpoint_context = server.endpoint_context + self.context = server.context self.method = BearerBody(server.unit_get) def test_bearer_body(self): @@ -304,9 +304,9 @@ class TestJWSAuthnMethod: @pytest.fixture(autouse=True) def create_method(self): server = Server(conf=CONF, keyjar=KEYJAR) - server.endpoint_context.cdb[client_id] = {"client_secret": client_secret} + server.context.cdb[client_id] = {"client_secret": client_secret} self.server = server - self.endpoint_context = server.endpoint_context + self.context = server.context self.method = JWSAuthnMethod(server.unit_get) def test_jws_authn_method_wrong_key(self): @@ -427,11 +427,11 @@ class TestVerify: @pytest.fixture(autouse=True) def create_method(self): self.server = Server(conf=CONF, keyjar=KEYJAR) - self.server.endpoint_context.cdb[client_id] = {"client_secret": client_secret} - self.endpoint_context = self.server.get_context() + self.server.context.cdb[client_id] = {"client_secret": client_secret} + self.context = self.server.get_context() def test_verify_per_client(self): - self.server.endpoint_context.cdb[client_id]["client_authn_method"] = ["public"] + self.server.context.cdb[client_id]["client_authn_method"] = ["public"] request = {"client_id": client_id} res = verify_client( @@ -443,10 +443,10 @@ def test_verify_per_client(self): assert res == {"method": "public", "client_id": client_id} def test_verify_per_client_per_endpoint(self): - self.server.endpoint_context.cdb[client_id]["registration_endpoint_client_authn_method"] = [ + self.server.context.cdb[client_id]["registration_endpoint_client_authn_method"] = [ "public" ] - self.server.endpoint_context.cdb[client_id]["token_endpoint_client_authn_method"] = [ + self.server.context.cdb[client_id]["token_endpoint_client_authn_method"] = [ "client_secret_post" ] @@ -512,7 +512,7 @@ def test_verify_client_jws_authn_method(self): def test_verify_client_bearer_body(self): request = {"access_token": "1234567890", "client_id": client_id} - self.endpoint_context.registration_access_token["1234567890"] = client_id + self.context.registration_access_token["1234567890"] = client_id res = verify_client( self.endpoint_context, keyjar=self.server.get_attribute('keyjar'), @@ -552,7 +552,7 @@ def test_verify_client_client_secret_basic(self): def test_verify_client_bearer_header(self): # A prerequisite for the get_client_id_from_token function - self.endpoint_context.registration_access_token["1234567890"] = client_id + self.context.registration_access_token["1234567890"] = client_id token = "Bearer 1234567890" http_info = {"headers": {"authorization": token}} @@ -573,8 +573,8 @@ class TestVerify2: @pytest.fixture(autouse=True) def create_method(self): self.server = Server(conf=CONF, keyjar=KEYJAR) - self.server.endpoint_context.cdb[client_id] = {"client_secret": client_secret} - self.endpoint_context = self.server.get_context() + self.server.context.cdb[client_id] = {"client_secret": client_secret} + self.context = self.server.get_context() def test_verify_client_jws_authn_method(self): client_keyjar = KeyJar() @@ -600,7 +600,7 @@ def test_verify_client_jws_authn_method(self): def test_verify_client_bearer_body(self): request = {"access_token": "1234567890", "client_id": client_id} - self.endpoint_context.registration_access_token["1234567890"] = client_id + self.context.registration_access_token["1234567890"] = client_id res = verify_client( self.endpoint_context, keyjar=self.server.get_attribute('keyjar'), @@ -640,7 +640,7 @@ def test_verify_client_client_secret_basic(self): def test_verify_client_bearer_header(self): # A prerequisite for the get_client_id_from_token function - self.endpoint_context.registration_access_token["1234567890"] = client_id + self.context.registration_access_token["1234567890"] = client_id token = "Bearer 1234567890" http_info = {"headers": {"authorization": token}} @@ -702,7 +702,7 @@ class Mock: conf["client_authn_methods"] = {"custom": MagicMock(return_value=mock)} conf["endpoint"]["registration"]["kwargs"]["client_authn_method"] = ["custom"] server = Server(conf=conf, keyjar=KEYJAR) - server.endpoint_context.cdb[client_id] = {"client_secret": client_secret} + server.context.cdb[client_id] = {"client_secret": client_secret} request = {"redirect_uris": ["https://example.com/cb"]} res = verify_client( diff --git a/tests/test_server_20e_jwt_token.py b/tests/test_server_20e_jwt_token.py index d7fd2687..87ea8209 100644 --- a/tests/test_server_20e_jwt_token.py +++ b/tests/test_server_20e_jwt_token.py @@ -196,8 +196,8 @@ def create_endpoint(self): "session_params": {"encrypter": SESSION_PARAMS}, } self.server = Server(conf, keyjar=KEYJAR) - self.endpoint_context = self.server.endpoint_context - self.endpoint_context.cdb["client_1"] = { + self.context = self.server.context + self.context.cdb["client_1"] = { "client_secret": "hemligt", "redirect_uris": [("https://example.com/cb", None)], "client_salt": "salted", @@ -209,7 +209,7 @@ def create_endpoint(self): }, "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"] } - self.session_manager = self.endpoint_context.session_manager + self.session_manager = self.context.session_manager self.user_id = "diana" self.endpoint = self.server.get_endpoint("session") @@ -229,7 +229,7 @@ def _mint_token(self, token_class, grant, session_id, based_on=None, **kwargs): # Constructing an authorization code is now done return grant.mint_token( session_id=session_id, - context=self.endpoint_context, + context=self.context, token_class=token_class, token_handler=self.session_manager.token_handler.handler[token_class], expires_at=utc_time_sans_frac() + 300, # 5 minutes from now @@ -240,7 +240,7 @@ def _mint_token(self, token_class, grant, session_id, based_on=None, **kwargs): def test_parse(self): session_id = self._create_session(AUTH_REQ) # apply consent - grant = self.endpoint_context.authz(session_id=session_id, request=AUTH_REQ) + grant = self.context.authz(session_id=session_id, request=AUTH_REQ) # grant = self.session_manager[session_id] code = self._mint_token("authorization_code", grant, session_id) access_token = self._mint_token( @@ -257,7 +257,7 @@ def test_parse(self): def test_info(self): session_id = self._create_session(AUTH_REQ) # apply consent - grant = self.endpoint_context.authz(session_id=session_id, request=AUTH_REQ) + grant = self.context.authz(session_id=session_id, request=AUTH_REQ) # code = self._mint_token("authorization_code", grant, session_id) access_token = self._mint_token("access_token", grant, session_id, code) @@ -269,16 +269,16 @@ def test_info(self): @pytest.mark.parametrize("enable_claims_per_client", [True, False]) def test_enable_claims_per_client(self, enable_claims_per_client): # Set up configuration - self.endpoint_context.cdb["client_1"]["add_claims"]["always"]["access_token"] = { + self.context.cdb["client_1"]["add_claims"]["always"]["access_token"] = { "address": None } - self.endpoint_context.session_manager.token_handler.handler["access_token"].kwargs[ + self.context.session_manager.token_handler.handler["access_token"].kwargs[ "enable_claims_per_client" ] = enable_claims_per_client session_id = self._create_session(AUTH_REQ) # apply consent - grant = self.endpoint_context.authz(session_id=session_id, request=AUTH_REQ) + grant = self.context.authz(session_id=session_id, request=AUTH_REQ) # code = self._mint_token("authorization_code", grant, session_id) access_token = self._mint_token("access_token", grant, session_id, code) @@ -400,8 +400,8 @@ def create_endpoint(self): "session_params": SESSION_PARAMS, } self.server = Server(conf, keyjar=KEYJAR) - self.endpoint_context = self.server.endpoint_context - self.endpoint_context.cdb["client_1"] = { + self.context = self.server.context + self.context.cdb["client_1"] = { "client_secret": "hemligt", "redirect_uris": [("https://example.com/cb", None)], "client_salt": "salted", @@ -413,7 +413,7 @@ def create_endpoint(self): }, "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access", "webid"] } - self.session_manager = self.endpoint_context.session_manager + self.session_manager = self.context.session_manager self.user_id = "diana" self.endpoint = self.server.get_endpoint("session") @@ -433,7 +433,7 @@ def _mint_token(self, token_class, grant, session_id, based_on=None, **kwargs): # Constructing an authorization code is now done return grant.mint_token( session_id=session_id, - context=self.endpoint_context, + context=self.context, token_class=token_class, token_handler=self.session_manager.token_handler.handler[token_class], expires_at=utc_time_sans_frac() + 300, # 5 minutes from now @@ -452,7 +452,7 @@ def test_parse(self): session_id = self._create_session(_auth_req) # apply consent - grant = self.endpoint_context.authz(session_id=session_id, request=_auth_req) + grant = self.context.authz(session_id=session_id, request=_auth_req) # grant = self.session_manager[session_id] code = self._mint_token("authorization_code", grant, session_id) access_token = self._mint_token( @@ -478,7 +478,7 @@ def test_mint_with_aud(self): session_id = self._create_session(_auth_req) # apply consent - grant = self.endpoint_context.authz(session_id=session_id, request=_auth_req) + grant = self.context.authz(session_id=session_id, request=_auth_req) # grant = self.session_manager[session_id] code = self._mint_token("authorization_code", grant, session_id) access_token = self._mint_token( @@ -509,7 +509,7 @@ def test_mint_with_scope(self): session_id = self._create_session(_auth_req) # apply consent - grant = self.endpoint_context.authz(session_id=session_id, request=_auth_req) + grant = self.context.authz(session_id=session_id, request=_auth_req) # grant = self.session_manager[session_id] code = self._mint_token("authorization_code", grant, session_id) access_token = self._mint_token( @@ -540,7 +540,7 @@ def test_mint_with_extra(self): session_id = self._create_session(_auth_req) # apply consent - grant = self.endpoint_context.authz(session_id=session_id, request=_auth_req) + grant = self.context.authz(session_id=session_id, request=_auth_req) # grant = self.session_manager[session_id] code = self._mint_token("authorization_code", grant, session_id) access_token = self._mint_token( diff --git a/tests/test_server_22_oidc_provider_config_endpoint.py b/tests/test_server_22_oidc_provider_config_endpoint.py index b532d4aa..bd5f20a4 100755 --- a/tests/test_server_22_oidc_provider_config_endpoint.py +++ b/tests/test_server_22_oidc_provider_config_endpoint.py @@ -80,7 +80,7 @@ def conf(self): def create_endpoint(self, conf): server = Server(OPConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) - self.endpoint_context = server.endpoint_context + self.context = server.context self.endpoint = server.get_endpoint("provider_config") def test_do_response(self): diff --git a/tests/test_server_23_oidc_registration_endpoint.py b/tests/test_server_23_oidc_registration_endpoint.py index 35e9d1bf..04a74858 100755 --- a/tests/test_server_23_oidc_registration_endpoint.py +++ b/tests/test_server_23_oidc_registration_endpoint.py @@ -163,7 +163,7 @@ def create_endpoint(self): "session_params": SESSION_PARAMS, } server = Server(OPConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) - server.endpoint_context.cdb["client_id"] = {} + server.context.cdb["client_id"] = {} self.endpoint = server.get_endpoint("registration") def test_parse(self): diff --git a/tests/test_server_24_oauth2_authorization_endpoint.py b/tests/test_server_24_oauth2_authorization_endpoint.py index 39e7af37..29685a01 100755 --- a/tests/test_server_24_oauth2_authorization_endpoint.py +++ b/tests/test_server_24_oauth2_authorization_endpoint.py @@ -259,13 +259,13 @@ def create_endpoint(self): } server = Server(ASConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) - endpoint_context = server.endpoint_context + context = server.context _clients = yaml.safe_load(io.StringIO(client_yaml)) - endpoint_context.cdb = _clients["clients"] + context.cdb = _clients["clients"] server.keyjar.import_jwks(server.keyjar.export_jwks(True, ""), conf["issuer"]) - self.endpoint_context = endpoint_context + self.context = context self.endpoint = server.get_endpoint("authorization") - self.session_manager = endpoint_context.session_manager + self.session_manager = context.session_manager self.user_id = "diana" self.rp_keyjar = KeyJar() @@ -520,8 +520,8 @@ def test_setup_auth(self): ) # Parsed once before setup_auth - kakor = self.endpoint_context.cookie_handler.parse_cookie( - cookies=[kaka], name=self.endpoint_context.cookie_handler.name["session"] + kakor = self.context.cookie_handler.parse_cookie( + cookies=[kaka], name=self.context.cookie_handler.name["session"] ) res = self.endpoint.setup_auth(request, redirect_uri, cinfo, kakor) diff --git a/tests/test_server_24_oauth2_authorization_endpoint_jar.py b/tests/test_server_24_oauth2_authorization_endpoint_jar.py index 922a65ed..f788c7e4 100755 --- a/tests/test_server_24_oauth2_authorization_endpoint_jar.py +++ b/tests/test_server_24_oauth2_authorization_endpoint_jar.py @@ -187,12 +187,12 @@ def create_endpoint(self): }, } server = Server(ASConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) - endpoint_context = server.endpoint_context + context = server.context _clients = yaml.safe_load(io.StringIO(client_yaml)) - endpoint_context.cdb = _clients["clients"] + context.cdb = _clients["clients"] server.keyjar.import_jwks(server.keyjar.export_jwks(True, ""), conf["issuer"]) self.endpoint = server.get_endpoint("authorization") - self.session_manager = endpoint_context.session_manager + self.session_manager = context.session_manager self.user_id = "diana" self.rp_keyjar = KeyJar() diff --git a/tests/test_server_24_oauth2_token_endpoint.py b/tests/test_server_24_oauth2_token_endpoint.py index a03e7034..78edd336 100644 --- a/tests/test_server_24_oauth2_token_endpoint.py +++ b/tests/test_server_24_oauth2_token_endpoint.py @@ -177,8 +177,8 @@ class TestEndpoint(object): @pytest.fixture(autouse=True) def create_endpoint(self, conf): server = Server(ASConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) - endpoint_context = server.endpoint_context - endpoint_context.cdb["client_1"] = { + context = server.context + context.cdb["client_1"] = { "client_secret": "hemligt", "redirect_uris": [("https://example.com/cb", None)], "client_salt": "salted", @@ -187,10 +187,10 @@ def create_endpoint(self, conf): "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"] } server.keyjar.import_jwks(CLIENT_KEYJAR.export_jwks(), "client_1") - self.session_manager = endpoint_context.session_manager + self.session_manager = context.session_manager self.token_endpoint = server.get_endpoint("token") self.user_id = "diana" - self.endpoint_context = endpoint_context + self.context = context def test_init(self): assert self.token_endpoint @@ -215,7 +215,7 @@ def _mint_code(self, grant, client_id): # Constructing an authorization code is now done _code = grant.mint_token( session_id=session_id, - context=self.endpoint_context, + context=self.context, token_class="authorization_code", token_handler=self.session_manager.token_handler["authorization_code"], usage_rules=usage_rules, @@ -235,7 +235,7 @@ def _mint_access_token(self, grant, session_id, token_ref=None): _token = grant.mint_token( _session_info, - context=self.endpoint_context, + context=self.context, token_class="access_token", token_handler=self.session_manager.token_handler["access_token"], based_on=token_ref, # Means the token (tok) was used to mint this token @@ -262,13 +262,13 @@ def test_parse(self): def test_auth_code_grant_disallowed_per_client(self): areq = AUTH_REQ.copy() areq["scope"] = ["email"] - self.endpoint_context.cdb["client_1"]["grant_types_supported"] = [] + self.context.cdb["client_1"]["grant_types_supported"] = [] session_id = self._create_session(areq) - grant = self.endpoint_context.authz(session_id, areq) + grant = self.context.authz(session_id, areq) code = self._mint_code(grant, areq["client_id"]) - _cntx = self.endpoint_context + _cntx = self.context _token_request = TOKEN_REQ_DICT.copy() _token_request["code"] = code.value @@ -287,7 +287,7 @@ def test_process_request(self): code = self._mint_code(grant, AUTH_REQ["client_id"]) _token_request = TOKEN_REQ_DICT.copy() - _context = self.endpoint_context + _context = self.context _token_request["code"] = code.value _req = self.token_endpoint.parse_request(_token_request) _resp = self.token_endpoint.process_request(request=_req) @@ -301,7 +301,7 @@ def test_process_request_using_code_twice(self): code = self._mint_code(grant, AUTH_REQ["client_id"]) _token_request = TOKEN_REQ_DICT.copy() - _context = self.endpoint_context + _context = self.context _token_request["code"] = code.value _req = self.token_endpoint.parse_request(_token_request) @@ -332,7 +332,7 @@ def test_process_request_using_private_key_jwt(self): _token_request = TOKEN_REQ_DICT.copy() del _token_request["client_id"] del _token_request["client_secret"] - _context = self.endpoint_context + _context = self.context _jwt = JWT(CLIENT_KEYJAR, iss=AUTH_REQ["client_id"], sign_alg="RS256") _jwt.with_jti = True @@ -349,13 +349,13 @@ def test_process_request_using_private_key_jwt(self): def test_do_refresh_access_token(self): areq = AUTH_REQ.copy() - areq["scope"] = ["email"] + areq["scope"] = ["email", "foobar"] session_id = self._create_session(areq) - grant = self.endpoint_context.authz(session_id, areq) + grant = self.context.authz(session_id, areq) code = self._mint_code(grant, areq["client_id"]) - _cntx = self.endpoint_context + _cntx = self.context _token_request = TOKEN_REQ_DICT.copy() _token_request["code"] = code.value @@ -388,13 +388,13 @@ def test_do_refresh_access_token(self): def test_refresh_grant_disallowed_per_client(self): areq = AUTH_REQ.copy() areq["scope"] = ["email"] - self.endpoint_context.cdb["client_1"]["grant_types_supported"] = ["authorization_code"] + self.context.cdb["client_1"]["grant_types_supported"] = ["authorization_code"] session_id = self._create_session(areq) - grant = self.endpoint_context.authz(session_id, areq) + grant = self.context.authz(session_id, areq) code = self._mint_code(grant, areq["client_id"]) - _cntx = self.endpoint_context + _cntx = self.context _token_request = TOKEN_REQ_DICT.copy() _token_request["code"] = code.value @@ -408,11 +408,11 @@ def test_do_2nd_refresh_access_token(self): areq["scope"] = ["email"] session_id = self._create_session(areq) - grant = self.endpoint_context.authz(session_id, areq) + grant = self.context.authz(session_id, areq) code = self._mint_code(grant, areq["client_id"]) self.token_endpoint.revoke_refresh_on_issue = False - _cntx = self.endpoint_context + _cntx = self.context _token_request = TOKEN_REQ_DICT.copy() _token_request["code"] = code.value @@ -452,7 +452,7 @@ def test_do_2nd_refresh_access_token(self): assert isinstance(msg, dict) def test_new_refresh_token(self, conf): - self.endpoint_context.cdb["client_1"] = { + self.context.cdb["client_1"] = { "client_secret": "hemligt", "redirect_uris": [("https://example.com/cb", None)], "client_salt": "salted", @@ -465,7 +465,7 @@ def test_new_refresh_token(self, conf): areq["scope"] = ["email"] session_id = self._create_session(areq) - grant = self.endpoint_context.authz(session_id, areq) + grant = self.context.authz(session_id, areq) code = self._mint_code(grant, areq["client_id"]) _token_request = TOKEN_REQ_DICT.copy() @@ -492,7 +492,7 @@ def test_new_refresh_token(self, conf): assert first_refresh_token != second_refresh_token def test_revoke_on_issue_refresh_token(self, conf): - self.endpoint_context.cdb["client_1"] = { + self.context.cdb["client_1"] = { "client_secret": "hemligt", "redirect_uris": [("https://example.com/cb", None)], "client_salt": "salted", @@ -506,7 +506,7 @@ def test_revoke_on_issue_refresh_token(self, conf): areq["scope"] = ["email"] session_id = self._create_session(areq) - grant = self.endpoint_context.authz(session_id, areq) + grant = self.context.authz(session_id, areq) code = self._mint_code(grant, areq["client_id"]) _token_request = TOKEN_REQ_DICT.copy() @@ -530,7 +530,7 @@ def test_revoke_on_issue_refresh_token(self, conf): assert second_refresh_token.revoked is False def test_revoke_on_issue_refresh_token_per_client(self, conf): - self.endpoint_context.cdb["client_1"] = { + self.context.cdb["client_1"] = { "client_secret": "hemligt", "redirect_uris": [("https://example.com/cb", None)], "client_salt": "salted", @@ -538,12 +538,12 @@ def test_revoke_on_issue_refresh_token_per_client(self, conf): "response_types": ["code", "token", "code id_token", "id_token"], "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"] } - self.endpoint_context.cdb[AUTH_REQ["client_id"]]["revoke_refresh_on_issue"] = True + self.context.cdb[AUTH_REQ["client_id"]]["revoke_refresh_on_issue"] = True areq = AUTH_REQ.copy() areq["scope"] = ["openid", "offline_access"] session_id = self._create_session(areq) - grant = self.endpoint_context.authz(session_id, areq) + grant = self.context.authz(session_id, areq) code = self._mint_code(grant, areq["client_id"]) _token_request = TOKEN_REQ_DICT.copy() @@ -574,7 +574,7 @@ def test_refresh_scopes(self): areq["scope"] = ["email", "profile"] session_id = self._create_session(areq) - grant = self.endpoint_context.authz(session_id, areq) + grant = self.context.authz(session_id, areq) code = self._mint_code(grant, areq["client_id"]) _token_request = TOKEN_REQ_DICT.copy() @@ -613,7 +613,7 @@ def test_refresh_more_scopes(self): areq["scope"] = ["email"] session_id = self._create_session(areq) - grant = self.endpoint_context.authz(session_id, areq) + grant = self.context.authz(session_id, areq) code = self._mint_code(grant, areq["client_id"]) _token_request = TOKEN_REQ_DICT.copy() @@ -639,7 +639,7 @@ def test_refresh_more_scopes_2(self): areq["scope"] = ["email", "profile"] session_id = self._create_session(areq) - grant = self.endpoint_context.authz(session_id, areq) + grant = self.context.authz(session_id, areq) code = self._mint_code(grant, areq["client_id"]) _token_request = TOKEN_REQ_DICT.copy() @@ -689,10 +689,10 @@ def test_do_refresh_access_token_not_allowed(self): areq["scope"] = ["email"] session_id = self._create_session(areq) - grant = self.endpoint_context.authz(session_id, areq) + grant = self.context.authz(session_id, areq) code = self._mint_code(grant, areq["client_id"]) - _cntx = self.token_endpoint.upstream_get("endpoint_context") + _cntx = self.token_endpoint.upstream_get("context") _token_request = TOKEN_REQ_DICT.copy() _token_request["code"] = code.value @@ -714,10 +714,10 @@ def test_do_refresh_access_token_revoked(self): areq["scope"] = ["email"] session_id = self._create_session(areq) - grant = self.endpoint_context.authz(session_id, areq) + grant = self.context.authz(session_id, areq) code = self._mint_code(grant, areq["client_id"]) - _cntx = self.token_endpoint.upstream_get("endpoint_context") + _cntx = self.token_endpoint.upstream_get("context") _token_request = TOKEN_REQ_DICT.copy() _token_request["code"] = code.value @@ -743,7 +743,7 @@ def test_configure_grant_types(self): assert "refresh_token" not in self.token_endpoint.helper def test_token_request_other_client(self): - _context = self.endpoint_context + _context = self.context _context.cdb["client_2"] = _context.cdb["client_1"] session_id = self._create_session(AUTH_REQ) grant = self.session_manager[session_id] @@ -760,7 +760,7 @@ def test_token_request_other_client(self): assert _resp.to_dict() == {"error": "invalid_grant", "error_description": "Wrong client"} def test_refresh_token_request_other_client(self): - _context = self.endpoint_context + _context = self.context _context.cdb["client_2"] = _context.cdb["client_1"] session_id = self._create_session(AUTH_REQ) grant = self.session_manager[session_id] diff --git a/tests/test_server_24_oidc_authorization_endpoint.py b/tests/test_server_24_oidc_authorization_endpoint.py index e98326ad..fc0bcca8 100755 --- a/tests/test_server_24_oidc_authorization_endpoint.py +++ b/tests/test_server_24_oidc_authorization_endpoint.py @@ -290,16 +290,16 @@ def create_endpoint(self): } server = Server(OPConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) - endpoint_context = server.endpoint_context + context = server.context _clients = yaml.safe_load(io.StringIO(client_yaml)) - endpoint_context.cdb = _clients["oidc_clients"] + context.cdb = _clients["oidc_clients"] server.keyjar.import_jwks( server.keyjar.export_jwks(True, ""), conf["issuer"] ) - self.endpoint_context = endpoint_context + self.context = context self.endpoint = server.get_endpoint("authorization") - self.session_manager = endpoint_context.session_manager + self.session_manager = context.session_manager self.user_id = "diana" self.rp_keyjar = KeyJar() @@ -664,14 +664,14 @@ def test_setup_auth(self): } session_id = self._create_session(request) - kaka = self.endpoint.upstream_get("endpoint_context").cookie_handler.make_cookie_content( + kaka = self.endpoint.upstream_get("context").cookie_handler.make_cookie_content( value=json.dumps({"sid": session_id, "state": request.get("state")}), - name=self.endpoint_context.cookie_handler.name["session"], + name=self.context.cookie_handler.name["session"], ) # Parsed once before setup_auth - kakor = self.endpoint_context.cookie_handler.parse_cookie( - cookies=[kaka], name=self.endpoint_context.cookie_handler.name["session"] + kakor = self.context.cookie_handler.parse_cookie( + cookies=[kaka], name=self.context.cookie_handler.name["session"] ) res = self.endpoint.setup_auth(request, redirect_uri, cinfo, kakor) @@ -693,7 +693,7 @@ def test_setup_auth_error(self): "id_token_signed_response_alg": "RS256", } - item = self.endpoint.upstream_get("endpoint_context").authn_broker.db["anon"] + item = self.endpoint.upstream_get("context").authn_broker.db["anon"] item["method"].fail = NoSuchAuthentication res = self.endpoint.setup_auth(request, redirect_uri, cinfo, None) @@ -716,7 +716,7 @@ def test_setup_auth_user_form_post(self): nonce="nonce", scope="openid", ) - _ec = self.endpoint.upstream_get("endpoint_context") + _ec = self.endpoint.upstream_get("context") session_id = self._create_session(request) @@ -744,7 +744,7 @@ def test_setup_auth_error_form_post(self): scope=["openid"], ) - item = self.endpoint.upstream_get("endpoint_context").authn_broker.db["anon"] + item = self.endpoint.upstream_get("context").authn_broker.db["anon"] item["method"].fail = NoSuchAuthentication res = self.endpoint.process_request(request) @@ -768,7 +768,7 @@ def test_setup_auth_session_revoked(self): "redirect_uris": [("https://rp.example.com/cb", {})], "id_token_signed_response_alg": "RS256", } - _ec = self.endpoint.upstream_get("endpoint_context") + _ec = self.endpoint.upstream_get("context") session_id = self._create_session(request) @@ -782,7 +782,7 @@ def test_setup_auth_session_revoked(self): assert set(res.keys()) == {"args", "function"} def test_check_session_iframe(self): - self.endpoint.upstream_get("endpoint_context").provider_info[ + self.endpoint.upstream_get("context").provider_info[ "check_session_iframe" ] = "https://example.com/csi" _pr_resp = self.endpoint.parse_request(AUTH_REQ_DICT) @@ -806,7 +806,7 @@ def test_setup_auth_login_hint(self): "id_token_signed_response_alg": "RS256", } - item = self.endpoint.upstream_get("endpoint_context").authn_broker.db["anon"] + item = self.endpoint.upstream_get("context").authn_broker.db["anon"] item["method"].fail = NoSuchAuthentication res = self.endpoint.setup_auth(request, redirect_uri, cinfo, None) @@ -830,13 +830,13 @@ def test_setup_auth_login_hint2acrs(self): "kwargs": {"user": "knoll"}, "class": NoAuthn, } - self.endpoint.upstream_get("endpoint_context").authn_broker["foo"] = init_method( + self.endpoint.upstream_get("context").authn_broker["foo"] = init_method( method_spec, None ) - item = self.endpoint.upstream_get("endpoint_context").authn_broker.db["anon"] + item = self.endpoint.upstream_get("context").authn_broker.db["anon"] item["method"].fail = NoSuchAuthentication - item = self.endpoint.upstream_get("endpoint_context").authn_broker.db["foo"] + item = self.endpoint.upstream_get("context").authn_broker.db["foo"] item["method"].fail = NoSuchAuthentication res = self.endpoint.pick_authn_method(request, redirect_uri) @@ -852,7 +852,7 @@ def test_parse_request(self): _jwt = JWT(key_jar=self.rp_keyjar, iss="client_1", sign_alg="HS256") _jws = _jwt.pack( AUTH_REQ_DICT, - aud=self.endpoint.upstream_get("endpoint_context").provider_info["issuer"], + aud=self.endpoint.upstream_get("context").provider_info["issuer"], ) # ----------------- _req = self.endpoint.parse_request( @@ -878,7 +878,7 @@ def test_parse_request_uri(self): _jwt = JWT(key_jar=self.rp_keyjar, iss="client_1", sign_alg="HS256") _jws = _jwt.pack( AUTH_REQ_DICT, - aud=self.endpoint.upstream_get("endpoint_context").provider_info["issuer"], + aud=self.endpoint.upstream_get("context").provider_info["issuer"], ) request_uri = "https://client.example.com/req" @@ -958,11 +958,11 @@ def test_do_request_uri(self): _jwt = JWT(key_jar=self.rp_keyjar, iss="client_1", sign_alg="HS256") _jws = _jwt.pack( orig_request.to_dict(), - aud=self.endpoint.upstream_get("endpoint_context").provider_info["issuer"], + aud=self.endpoint.upstream_get("context").provider_info["issuer"], ) - endpoint_context = self.endpoint.upstream_get("endpoint_context") - endpoint_context.cdb["client_1"]["request_uris"] = [("https://example.com/request", {})] + context = self.endpoint.upstream_get("context") + context.cdb["client_1"]["request_uris"] = [("https://example.com/request", {})] with responses.RequestsMock() as rsps: rsps.add( @@ -973,7 +973,7 @@ def test_do_request_uri(self): status=200, ) - self.endpoint._do_request_uri(request, "client_1", endpoint_context) + self.endpoint._do_request_uri(request, "client_1", context) request["request_uri"] = "https://example.com/request#1" @@ -986,19 +986,19 @@ def test_do_request_uri(self): status=200, ) - self.endpoint._do_request_uri(request, "client_1", endpoint_context) + self.endpoint._do_request_uri(request, "client_1", context) request["request_uri"] = "https://example.com/another" with pytest.raises(ValueError): - self.endpoint._do_request_uri(request, "client_1", endpoint_context) + self.endpoint._do_request_uri(request, "client_1", context) - endpoint_context.provider_info["request_uri_parameter_supported"] = False + context.provider_info["request_uri_parameter_supported"] = False with pytest.raises(ServiceError): - self.endpoint._do_request_uri(request, "client_1", endpoint_context) + self.endpoint._do_request_uri(request, "client_1", context) def test_post_parse_request(self): - endpoint_context = self.endpoint.upstream_get("endpoint_context") - msg = self.endpoint._post_parse_request({}, "client_1", endpoint_context) + context = self.endpoint.upstream_get("context") + msg = self.endpoint._post_parse_request({}, "client_1", context) assert "error" in msg request = AuthorizationRequest( @@ -1009,17 +1009,17 @@ def test_post_parse_request(self): scope="openid", ) - msg = self.endpoint._post_parse_request(request, "client_X", endpoint_context) + msg = self.endpoint._post_parse_request(request, "client_X", context) assert "error" in msg assert msg["error_description"] == "unknown client" request["client_id"] = "client_1" - endpoint_context.cdb["client_1"]["redirect_uris"] = [ + context.cdb["client_1"]["redirect_uris"] = [ ("https://example.com/cb", ""), ("https://example.com/2nd_cb", ""), ] - msg = self.endpoint._post_parse_request(request, "client_1", endpoint_context) + msg = self.endpoint._post_parse_request(request, "client_1", context) assert "error" in msg assert msg["error"] == "invalid_request" @@ -1081,7 +1081,7 @@ def test_do_request_user(self): request["login_hint"] = "mail:diana@example.org" assert self.endpoint.do_request_user(request) == {} - endpoint_context = self.endpoint.upstream_get("endpoint_context") + context = self.endpoint.upstream_get("context") # userinfo _userinfo = init_user_info( { @@ -1091,10 +1091,10 @@ def test_do_request_user(self): "", ) # login_hint - endpoint_context.login_hint_lookup = init_service( + context.login_hint_lookup = init_service( {"class": "idpyoidc.server.login_hint.LoginHintLookup"}, None ) - endpoint_context.login_hint_lookup.userinfo = _userinfo + context.login_hint_lookup.userinfo = _userinfo # With login_hint and login_hint_lookup assert self.endpoint.do_request_user(request) == {"req_user": "diana"} @@ -1240,15 +1240,15 @@ def create_endpoint(self): } server = Server(OPConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) - endpoint_context = server.endpoint_context + context = server.context _clients = yaml.safe_load(io.StringIO(client_yaml)) - endpoint_context.cdb = _clients["oidc_clients"] + context.cdb = _clients["oidc_clients"] server.keyjar.import_jwks( server.keyjar.export_jwks(True, ""), conf["issuer"] ) self.endpoint = server.get_endpoint("authorization") - self.session_manager = endpoint_context.session_manager + self.session_manager = context.session_manager self.user_id = "diana" self.rp_keyjar = KeyJar() @@ -1267,7 +1267,7 @@ def test_setup_acr_claim(self): ) redirect_uri = request["redirect_uri"] - _context = self.endpoint.upstream_get("endpoint_context") + _context = self.endpoint.upstream_get("context") cinfo = _context.cdb["client_1"] res = self.endpoint.setup_auth(request, redirect_uri, cinfo, None) @@ -1386,8 +1386,8 @@ def create_endpoint_context(self): "template_dir": "template", } server = Server(OPConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) - self.endpoint_context = server.endpoint_context - self.session_manager = self.endpoint_context.session_manager + self.context = server.context + self.session_manager = self.context.session_manager self.user_id = "diana" def _create_session(self, auth_req, sub_type="public", sector_identifier=""): @@ -1403,21 +1403,21 @@ def _create_session(self, auth_req, sub_type="public", sector_identifier=""): ) def test_authenticated_as_without_cookie(self): - authn_item = self.endpoint_context.authn_broker.pick(INTERNETPROTOCOLPASSWORD) + authn_item = self.context.authn_broker.pick(INTERNETPROTOCOLPASSWORD) method = authn_item[0]["method"] _info, _time_stamp = method.authenticated_as(None) assert _info is None def test_authenticated_as_with_cookie(self): - authn_item = self.endpoint_context.authn_broker.pick(INTERNETPROTOCOLPASSWORD) + authn_item = self.context.authn_broker.pick(INTERNETPROTOCOLPASSWORD) method = authn_item[0]["method"] authn_req = {"state": "state_identifier", "client_id": "client 12345"} session_id = self._create_session(authn_req) - _cookie = self.endpoint_context.new_cookie( - name=self.endpoint_context.cookie_handler.name["session"], + _cookie = self.context.new_cookie( + name=self.context.cookie_handler.name["session"], sub="diana", sid=session_id, state=authn_req["state"], @@ -1425,23 +1425,23 @@ def test_authenticated_as_with_cookie(self): ) # Parsed once before setup_auth - kakor = self.endpoint_context.cookie_handler.parse_cookie( - cookies=[_cookie], name=self.endpoint_context.cookie_handler.name["session"] + kakor = self.context.cookie_handler.parse_cookie( + cookies=[_cookie], name=self.context.cookie_handler.name["session"] ) _info, _time_stamp = method.authenticated_as(client_id="client 12345", cookie=kakor) assert _info["sub"] == "diana" def test_authenticated_as_with_unknown_user(self): - authn_item = self.endpoint_context.authn_broker.pick(INTERNETPROTOCOLPASSWORD) + authn_item = self.context.authn_broker.pick(INTERNETPROTOCOLPASSWORD) method = authn_item[0]["method"] authn_req = {"state": "state_identifier", "client_id": "client 12345"} session_id = self._create_session(authn_req) - _cookie = self.endpoint_context.new_cookie( - name=self.endpoint_context.cookie_handler.name["session"], + _cookie = self.context.new_cookie( + name=self.context.cookie_handler.name["session"], sub="adam", - sid=self.endpoint_context.session_manager.encrypted_session_id( + sid=self.context.session_manager.encrypted_session_id( "adam", "client 12345", "0123456789" ), state=authn_req["state"], @@ -1449,23 +1449,23 @@ def test_authenticated_as_with_unknown_user(self): ) # Parsed once before setup_auth - kakor = self.endpoint_context.cookie_handler.parse_cookie( - cookies=[_cookie], name=self.endpoint_context.cookie_handler.name["session"] + kakor = self.context.cookie_handler.parse_cookie( + cookies=[_cookie], name=self.context.cookie_handler.name["session"] ) _info, _time_stamp = method.authenticated_as(client_id="client 12345", cookie=kakor) assert _info == {} def test_authenticated_as_with_goobledigook(self): - authn_item = self.endpoint_context.authn_broker.pick(INTERNETPROTOCOLPASSWORD) + authn_item = self.context.authn_broker.pick(INTERNETPROTOCOLPASSWORD) method = authn_item[0]["method"] authn_req = {"state": "state_identifier", "client_id": "client 12345"} _ = self._create_session(authn_req) - _cookie = self.endpoint_context.new_cookie( - name=self.endpoint_context.cookie_handler.name["session"], + _cookie = self.context.new_cookie( + name=self.context.cookie_handler.name["session"], sub="adam", - sid=self.endpoint_context.session_manager.encrypted_session_id( + sid=self.context.session_manager.encrypted_session_id( "adam", "client 12345", "0123456789" ), state=authn_req["state"], diff --git a/tests/test_server_30_oidc_end_session.py b/tests/test_server_30_oidc_end_session.py index 20ca42c8..b8fc9f7a 100644 --- a/tests/test_server_30_oidc_end_session.py +++ b/tests/test_server_30_oidc_end_session.py @@ -200,8 +200,8 @@ def create_endpoint(self): cookie_handler=self.cd, keyjar=KEYJAR, ) - endpoint_context = server.endpoint_context - endpoint_context.cdb = { + context = server.context + context.cdb = { "client_1": { "client_secret": "hemligt", "redirect_uris": [("{}cb".format(CLI1), None)], @@ -221,8 +221,8 @@ def create_endpoint(self): "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"] }, } - self.endpoint_context = endpoint_context - self.session_manager = endpoint_context.session_manager + self.context = context + self.session_manager = context.session_manager self.authn_endpoint = server.get_endpoint("authorization") self.session_endpoint = server.get_endpoint("session") self.token_endpoint = server.get_endpoint("token") @@ -287,7 +287,7 @@ def _auth_with_id_token(self, state): def _mint_token(self, token_class, grant, session_id, token_ref=None): return grant.mint_token( session_id=session_id, - context=self.endpoint_context, + context=self.context, token_class=token_class, token_handler=self.session_manager.token_handler[token_class], expires_at=utc_time_sans_frac() + 900, # 15 minutes from now diff --git a/tests/test_server_31_oauth2_introspection.py b/tests/test_server_31_oauth2_introspection.py index 47664844..e599c71f 100644 --- a/tests/test_server_31_oauth2_introspection.py +++ b/tests/test_server_31_oauth2_introspection.py @@ -191,8 +191,8 @@ def create_endpoint(self, jwt_token): "kwargs": {}, } server = Server(ASConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) - endpoint_context = server.endpoint_context - endpoint_context.cdb["client_1"] = { + context = server.context + context.cdb["client_1"] = { "client_secret": "hemligt", "redirect_uris": [("https://example.com/cb", None)], "client_salt": "salted", @@ -207,11 +207,11 @@ def create_endpoint(self, jwt_token): "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"] } server.keyjar.import_jwks_as_json( - server.keyjar.export_jwks_as_json(private=True),endpoint_context.issuer + server.keyjar.export_jwks_as_json(private=True),context.issuer ) self.introspection_endpoint = server.get_endpoint("introspection") self.token_endpoint = server.get_endpoint("token") - self.session_manager = endpoint_context.session_manager + self.session_manager = context.session_manager self.user_id = "diana" def _create_session(self, auth_req, sub_type="public", sector_identifier=""): @@ -398,7 +398,7 @@ def test_code(self): def test_introspection_claims(self): session_id = self._create_session(AUTH_REQ) # Apply consent - grant = self.token_endpoint.upstream_get("endpoint_context").authz(session_id, AUTH_REQ) + grant = self.token_endpoint.upstream_get("context").authz(session_id, AUTH_REQ) self.session_manager[session_id] = grant code = self._mint_token("authorization_code", grant, session_id) @@ -406,14 +406,14 @@ def test_introspection_claims(self): self.introspection_endpoint.kwargs["enable_claims_per_client"] = True - _c_interface = self.introspection_endpoint.upstream_get("endpoint_context").claims_interface + _c_interface = self.introspection_endpoint.upstream_get("context").claims_interface grant.claims = { "introspection": _c_interface.get_claims( session_id, scopes=AUTH_REQ["scope"], claims_release_point="introspection" ) } - _context = self.introspection_endpoint.upstream_get("endpoint_context") + _context = self.introspection_endpoint.upstream_get("context") _req = self.introspection_endpoint.parse_request( { "token": access_token.value, @@ -434,7 +434,7 @@ def test_jwt_unknown_key(self): _jwt = JWT( _keyjar, - iss=self.introspection_endpoint.upstream_get("endpoint_context").issuer, + iss=self.introspection_endpoint.upstream_get("context").issuer, lifetime=3600, ) @@ -442,7 +442,7 @@ def test_jwt_unknown_key(self): _payload = {"sub": "subject_id"} _token = _jwt.pack(_payload, aud="client_1") - _context = self.introspection_endpoint.upstream_get("endpoint_context") + _context = self.introspection_endpoint.upstream_get("context") _req = self.introspection_endpoint.parse_request( { @@ -465,7 +465,7 @@ def mock(): monkeypatch.setattr("idpyoidc.server.token.utc_time_sans_frac", mock) - _context = self.introspection_endpoint.upstream_get("endpoint_context") + _context = self.introspection_endpoint.upstream_get("context") _req = self.introspection_endpoint.parse_request( { @@ -481,7 +481,7 @@ def test_revoked_access_token(self): access_token = self._get_access_token(AUTH_REQ) access_token.revoked = True - _context = self.introspection_endpoint.upstream_get("endpoint_context") + _context = self.introspection_endpoint.upstream_get("context") _req = self.introspection_endpoint.parse_request( { @@ -495,12 +495,12 @@ def test_revoked_access_token(self): def test_introspect_id_token(self): session_id = self._create_session(AUTH_REQ) - grant = self.token_endpoint.upstream_get("endpoint_context").authz(session_id, AUTH_REQ) + grant = self.token_endpoint.upstream_get("context").authz(session_id, AUTH_REQ) self.session_manager[session_id] = grant code = self._mint_token("authorization_code", grant, session_id) id_token = self._mint_token("id_token", grant, session_id, code) - _context = self.introspection_endpoint.upstream_get("endpoint_context") + _context = self.introspection_endpoint.upstream_get("context") _req = self.introspection_endpoint.parse_request( { "token": id_token.value, diff --git a/tests/test_server_32_oidc_read_registration.py b/tests/test_server_32_oidc_read_registration.py index af0c7324..01783749 100644 --- a/tests/test_server_32_oidc_read_registration.py +++ b/tests/test_server_32_oidc_read_registration.py @@ -127,7 +127,7 @@ def create_endpoint(self): server = Server(OPConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) self.registration_endpoint = server.get_endpoint("registration") self.registration_api_endpoint = server.get_endpoint("registration_read") - server.endpoint_context.cdb["client_1"] = {} + server.context.cdb["client_1"] = {} def test_do_response(self): _req = self.registration_endpoint.parse_request(CLI_REQ.to_json()) diff --git a/tests/test_server_33_oauth2_pkce.py b/tests/test_server_33_oauth2_pkce.py index 8cd3bd15..137668c5 100644 --- a/tests/test_server_33_oauth2_pkce.py +++ b/tests/test_server_33_oauth2_pkce.py @@ -228,9 +228,9 @@ def _code_challenge(): def create_server(config): server = Server(ASConfiguration(conf=config, base_path=BASEDIR), cwd=BASEDIR) - endpoint_context = server.endpoint_context + context = server.context _clients = yaml.safe_load(io.StringIO(client_yaml)) - endpoint_context.cdb = _clients["oidc_clients"] + context.cdb = _clients["oidc_clients"] server.keyjar.import_jwks(server.keyjar.export_jwks(True, ""), config["issuer"]) return server @@ -239,7 +239,7 @@ class TestEndpoint(object): @pytest.fixture(autouse=True) def create_endpoint(self, conf): server = create_server(conf) - self.session_manager = server.endpoint_context.session_manager + self.session_manager = server.context.session_manager self.authn_endpoint = server.get_endpoint("authorization") self.token_endpoint = server.get_endpoint("token") @@ -325,8 +325,8 @@ def test_essential_per_client(self, conf): authn_endpoint = server.get_endpoint("authorization") token_endpoint = server.get_endpoint("token") _authn_req = AUTH_REQ.copy() - endpoint_context = server.get_context() - endpoint_context.cdb[AUTH_REQ["client_id"]]["pkce_essential"] = True + context = server.get_context() + context.cdb[AUTH_REQ["client_id"]]["pkce_essential"] = True _pr_resp = authn_endpoint.parse_request(_authn_req.to_dict()) @@ -340,8 +340,8 @@ def test_not_essential_per_client(self, conf): authn_endpoint = server.get_endpoint("authorization") token_endpoint = server.get_endpoint("token") _authn_req = AUTH_REQ.copy() - endpoint_context = server.get_context() - endpoint_context.cdb[AUTH_REQ["client_id"]]["pkce_essential"] = False + context = server.get_context() + context.cdb[AUTH_REQ["client_id"]]["pkce_essential"] = False _pr_resp = authn_endpoint.parse_request(_authn_req.to_dict()) resp = authn_endpoint.process_request(_pr_resp) diff --git a/tests/test_server_34_oidc_sso.py b/tests/test_server_34_oidc_sso.py index b66ddf76..6b4132f2 100755 --- a/tests/test_server_34_oidc_sso.py +++ b/tests/test_server_34_oidc_sso.py @@ -196,14 +196,14 @@ def create_endpoint_context(self): } server = Server(OPConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) - endpoint_context = server.endpoint_context + context = server.context _clients = yaml.safe_load(io.StringIO(client_yaml)) - endpoint_context.cdb = _clients["oidc_clients"] + context.cdb = _clients["oidc_clients"] server.keyjar.import_jwks( server.keyjar.export_jwks(True, ""), conf["issuer"] ) self.endpoint = server.get_endpoint("authorization") - self.endpoint_context = endpoint_context + self.context = context self.rp_keyjar = KeyJar() self.rp_keyjar.add_symmetric("client_1", "hemligtkodord1234567890") server.keyjar.add_symmetric("client_1", "hemligtkodord1234567890") @@ -211,7 +211,7 @@ def create_endpoint_context(self): def test_sso(self): request = self.endpoint.parse_request(AUTH_REQ_DICT) redirect_uri = request["redirect_uri"] - cinfo = self.endpoint.upstream_get("endpoint_context").cdb[request["client_id"]] + cinfo = self.endpoint.upstream_get("context").cdb[request["client_id"]] info = self.endpoint.setup_auth(request, redirect_uri, cinfo, cookie=None) # info = self.endpoint.process_request(request) @@ -224,7 +224,7 @@ def test_sso(self): # second login - from 2nd client request = self.endpoint.parse_request(AUTH_REQ_2.to_dict()) redirect_uri = request["redirect_uri"] - cinfo = self.endpoint.upstream_get("endpoint_context").cdb[request["client_id"]] + cinfo = self.endpoint.upstream_get("context").cdb[request["client_id"]] info = self.endpoint.setup_auth(request, redirect_uri, cinfo, cookie=None) sid2 = info["session_id"] @@ -237,7 +237,7 @@ def test_sso(self): # third login - from 3rd client request = self.endpoint.parse_request(AUTH_REQ_3.to_dict()) redirect_uri = request["redirect_uri"] - cinfo = self.endpoint.upstream_get("endpoint_context").cdb[request["client_id"]] + cinfo = self.endpoint.upstream_get("context").cdb[request["client_id"]] info = self.endpoint.setup_auth(request, redirect_uri, cinfo, cookie=None) assert set(info.keys()) == {"session_id", "identity", "user"} @@ -250,11 +250,11 @@ def test_sso(self): request = self.endpoint.parse_request(AUTH_REQ_4.to_dict()) redirect_uri = request["redirect_uri"] - cinfo = self.endpoint.upstream_get("endpoint_context").cdb[request["client_id"]] + cinfo = self.endpoint.upstream_get("context").cdb[request["client_id"]] # Parse cookies once before setup_auth - kakor = self.endpoint_context.cookie_handler.parse_cookie( - cookies=cookies_1, name=self.endpoint_context.cookie_handler.name["session"] + kakor = self.context.cookie_handler.parse_cookie( + cookies=cookies_1, name=self.context.cookie_handler.name["session"] ) info = self.endpoint.setup_auth(request, redirect_uri, cinfo, cookie=kakor) @@ -267,12 +267,12 @@ def test_sso(self): # Fifth login - from 2nd client - wrong cookie request = self.endpoint.parse_request(AUTH_REQ_2.to_dict()) redirect_uri = request["redirect_uri"] - cinfo = self.endpoint.upstream_get("endpoint_context").cdb[request["client_id"]] + cinfo = self.endpoint.upstream_get("context").cdb[request["client_id"]] info = self.endpoint.setup_auth(request, redirect_uri, cinfo, cookie=kakor) # No valid login cookie so new session assert info["session_id"] != sid2 - user_session_info = self.endpoint.upstream_get("endpoint_context").session_manager.get( + user_session_info = self.endpoint.upstream_get("context").session_manager.get( ["diana"] ) assert len(user_session_info.subordinate) == 3 @@ -285,13 +285,13 @@ def test_sso(self): # Should be one grant for each of client_2 and client_3 and # 2 grants for client_1 - csi1 = self.endpoint.upstream_get("endpoint_context").session_manager.get( + csi1 = self.endpoint.upstream_get("context").session_manager.get( ["diana", "client_1"] ) - csi2 = self.endpoint.upstream_get("endpoint_context").session_manager.get( + csi2 = self.endpoint.upstream_get("context").session_manager.get( ["diana", "client_2"] ) - csi3 = self.endpoint.upstream_get("endpoint_context").session_manager.get( + csi3 = self.endpoint.upstream_get("context").session_manager.get( ["diana", "client_3"] ) diff --git a/tests/test_server_35_oidc_token_endpoint.py b/tests/test_server_35_oidc_token_endpoint.py index 34f3ca24..9ae4f83e 100755 --- a/tests/test_server_35_oidc_token_endpoint.py +++ b/tests/test_server_35_oidc_token_endpoint.py @@ -24,12 +24,10 @@ from idpyoidc.server.oidc.provider_config import ProviderConfiguration from idpyoidc.server.oidc.registration import Registration from idpyoidc.server.oidc.token import Token -from idpyoidc.server.session import MintingNotAllowed from idpyoidc.server.user_authn.authn_context import INTERNETPROTOCOLPASSWORD from idpyoidc.server.user_info import UserInfo from idpyoidc.server.util import lv_pack from idpyoidc.time_util import utc_time_sans_frac - from . import CRYPT_CONFIG from . import SESSION_PARAMS from .test_server_24_oauth2_token_endpoint import TestEndpoint as _TestEndpoint @@ -200,12 +198,13 @@ def conf(): class TestEndpoint(_TestEndpoint): + @pytest.fixture(autouse=True) def create_endpoint(self, conf): self.server = Server(OPConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) - endpoint_context = self.server.endpoint_context - endpoint_context.cdb["client_1"] = { + context = self.server.context + context.cdb["client_1"] = { "client_secret": "hemligt", "redirect_uris": [("https://example.com/cb", None)], "client_salt": "salted", @@ -214,11 +213,11 @@ def create_endpoint(self, conf): "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"] } self.server.keyjar.import_jwks(CLIENT_KEYJAR.export_jwks(), "client_1") - endpoint_context.userinfo = USERINFO - self.session_manager = endpoint_context.session_manager + context.userinfo = USERINFO + self.session_manager = context.session_manager self.token_endpoint = self.server.get_endpoint("token") self.user_id = "diana" - self.endpoint_context = endpoint_context + self.context = context def test_init(self): assert self.token_endpoint @@ -243,7 +242,7 @@ def _mint_code(self, grant, client_id): # Constructing an authorization code is now done _code = grant.mint_token( session_id=session_id, - context=self.endpoint_context, + context=self.context, token_class="authorization_code", token_handler=self.session_manager.token_handler["authorization_code"], usage_rules=usage_rules, @@ -263,7 +262,7 @@ def _mint_access_token(self, grant, session_id, token_ref=None): _token = grant.mint_token( _session_info, - context=self.endpoint_context, + context=self.context, token_class="access_token", token_handler=self.session_manager.token_handler["access_token"], based_on=token_ref, # Means the token (tok) was used to mint this token @@ -293,7 +292,7 @@ def test_process_request(self): code = self._mint_code(grant, AUTH_REQ["client_id"]) _token_request = TOKEN_REQ_DICT.copy() - _context = self.endpoint_context + _context = self.context _token_request["code"] = code.value _req = self.token_endpoint.parse_request(_token_request) _resp = self.token_endpoint.process_request(request=_req) @@ -308,7 +307,7 @@ def test_process_request_using_code_twice(self): code = self._mint_code(grant, AUTH_REQ["client_id"]) _token_request = TOKEN_REQ_DICT.copy() - _context = self.endpoint_context + _context = self.context _token_request["code"] = code.value _req = self.token_endpoint.parse_request(_token_request) @@ -339,7 +338,7 @@ def test_process_request_using_private_key_jwt(self): _token_request = TOKEN_REQ_DICT.copy() del _token_request["client_id"] del _token_request["client_secret"] - _context = self.endpoint_context + _context = self.context _jwt = JWT(CLIENT_KEYJAR, iss=AUTH_REQ["client_id"], sign_alg="RS256") _jwt.with_jti = True @@ -359,10 +358,10 @@ def test_do_refresh_access_token(self): areq["scope"] = ["openid", "offline_access"] session_id = self._create_session(areq) - grant = self.endpoint_context.authz(session_id, areq) + grant = self.context.authz(session_id, areq) code = self._mint_code(grant, areq["client_id"]) - _cntx = self.endpoint_context + _cntx = self.context _token_request = TOKEN_REQ_DICT.copy() _token_request["code"] = code.value @@ -406,10 +405,10 @@ def test_do_2nd_refresh_access_token(self): areq["scope"] = ["openid", "offline_access"] session_id = self._create_session(areq) - grant = self.endpoint_context.authz(session_id, areq) + grant = self.context.authz(session_id, areq) code = self._mint_code(grant, areq["client_id"]) self.token_endpoint.revoke_refresh_on_issue = False - _cntx = self.endpoint_context + _cntx = self.context _token_request = TOKEN_REQ_DICT.copy() _token_request["code"] = code.value @@ -472,7 +471,7 @@ def test_refresh_scopes(self): areq["scope"] = ["openid", "offline_access", "profile"] session_id = self._create_session(areq) - grant = self.endpoint_context.authz(session_id, areq) + grant = self.context.authz(session_id, areq) code = self._mint_code(grant, areq["client_id"]) _token_request = TOKEN_REQ_DICT.copy() @@ -528,7 +527,7 @@ def test_refresh_more_scopes(self): areq["scope"] = ["openid", "offline_access"] session_id = self._create_session(areq) - grant = self.endpoint_context.authz(session_id, areq) + grant = self.context.authz(session_id, areq) code = self._mint_code(grant, areq["client_id"]) _token_request = TOKEN_REQ_DICT.copy() @@ -565,7 +564,7 @@ def test_refresh_more_scopes_2(self): areq["scope"] = ["openid", "offline_access", "profile"] session_id = self._create_session(areq) - grant = self.endpoint_context.authz(session_id, areq) + grant = self.context.authz(session_id, areq) code = self._mint_code(grant, areq["client_id"]) _token_request = TOKEN_REQ_DICT.copy() @@ -640,7 +639,7 @@ def test_refresh_less_scopes(self): self.session_manager.token_handler.handler["id_token"].kwargs["add_claims_by_scope"] = True session_id = self._create_session(areq) - grant = self.endpoint_context.authz(session_id, areq) + grant = self.context.authz(session_id, areq) code = self._mint_code(grant, areq["client_id"]) _token_request = TOKEN_REQ_DICT.copy() @@ -686,7 +685,7 @@ def test_refresh_no_openid_scope(self): areq["scope"] = ["openid", "offline_access"] session_id = self._create_session(areq) - grant = self.endpoint_context.authz(session_id, areq) + grant = self.context.authz(session_id, areq) code = self._mint_code(grant, areq["client_id"]) _token_request = TOKEN_REQ_DICT.copy() @@ -727,7 +726,7 @@ def test_refresh_no_offline_access_scope(self): areq["scope"] = ["openid", "offline_access"] session_id = self._create_session(areq) - grant = self.endpoint_context.authz(session_id, areq) + grant = self.context.authz(session_id, areq) code = self._mint_code(grant, areq["client_id"]) _token_request = TOKEN_REQ_DICT.copy() @@ -769,7 +768,7 @@ def test_refresh_no_offline_access_scope(self): assert _resp["response_args"]["scope"] == ["openid"] def test_new_refresh_token(self, conf): - self.endpoint_context.cdb["client_1"] = { + self.context.cdb["client_1"] = { "client_secret": "hemligt", "redirect_uris": [("https://example.com/cb", None)], "client_salt": "salted", @@ -782,7 +781,7 @@ def test_new_refresh_token(self, conf): areq["scope"] = ["openid", "offline_access"] session_id = self._create_session(areq) - grant = self.endpoint_context.authz(session_id, areq) + grant = self.context.authz(session_id, areq) code = self._mint_code(grant, areq["client_id"]) _token_request = TOKEN_REQ_DICT.copy() @@ -809,7 +808,7 @@ def test_new_refresh_token(self, conf): assert first_refresh_token != second_refresh_token def test_revoke_on_issue_refresh_token(self, conf): - self.endpoint_context.cdb["client_1"] = { + self.context.cdb["client_1"] = { "client_secret": "hemligt", "redirect_uris": [("https://example.com/cb", None)], "client_salt": "salted", @@ -822,7 +821,7 @@ def test_revoke_on_issue_refresh_token(self, conf): areq["scope"] = ["openid", "offline_access"] session_id = self._create_session(areq) - grant = self.endpoint_context.authz(session_id, areq) + grant = self.context.authz(session_id, areq) code = self._mint_code(grant, areq["client_id"]) _token_request = TOKEN_REQ_DICT.copy() @@ -849,7 +848,7 @@ def test_revoke_on_issue_refresh_token(self, conf): assert second_refresh_token.revoked is False def test_revoke_on_issue_refresh_token_per_client(self, conf): - self.endpoint_context.cdb["client_1"] = { + self.context.cdb["client_1"] = { "client_secret": "hemligt", "redirect_uris": [("https://example.com/cb", None)], "client_salt": "salted", @@ -857,12 +856,12 @@ def test_revoke_on_issue_refresh_token_per_client(self, conf): "response_types": ["code", "token", "code id_token", "id_token"], "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"] } - self.endpoint_context.cdb[AUTH_REQ["client_id"]]["revoke_refresh_on_issue"] = True + self.context.cdb[AUTH_REQ["client_id"]]["revoke_refresh_on_issue"] = True areq = AUTH_REQ.copy() areq["scope"] = ["openid", "offline_access"] session_id = self._create_session(areq) - grant = self.endpoint_context.authz(session_id, areq) + grant = self.context.authz(session_id, areq) code = self._mint_code(grant, areq["client_id"]) _token_request = TOKEN_REQ_DICT.copy() @@ -893,10 +892,10 @@ def test_do_refresh_access_token_not_allowed(self): areq["scope"] = ["openid", "offline_access"] session_id = self._create_session(areq) - grant = self.endpoint_context.authz(session_id, areq) + grant = self.context.authz(session_id, areq) code = self._mint_code(grant, areq["client_id"]) - _cntx = self.token_endpoint.upstream_get("endpoint_context") + _cntx = self.token_endpoint.upstream_get("context") _token_request = TOKEN_REQ_DICT.copy() _token_request["code"] = code.value @@ -917,10 +916,10 @@ def test_do_refresh_access_token_revoked(self): areq["scope"] = ["openid", "offline_access"] session_id = self._create_session(areq) - grant = self.endpoint_context.authz(session_id, areq) + grant = self.context.authz(session_id, areq) code = self._mint_code(grant, areq["client_id"]) - _cntx = self.token_endpoint.upstream_get("endpoint_context") + _cntx = self.token_endpoint.upstream_get("context") _token_request = TOKEN_REQ_DICT.copy() _token_request["code"] = code.value @@ -966,7 +965,7 @@ def test_access_token_lifetime(self): assert access_token["exp"] - access_token["iat"] == lifetime def test_token_request_other_client(self): - _context = self.endpoint_context + _context = self.context _context.cdb["client_2"] = _context.cdb["client_1"] session_id = self._create_session(AUTH_REQ) grant = self.session_manager[session_id] @@ -983,7 +982,7 @@ def test_token_request_other_client(self): assert _resp.to_dict() == {"error": "invalid_grant", "error_description": "Wrong client"} def test_refresh_token_request_other_client(self): - _context = self.endpoint_context + _context = self.context _context.cdb["client_2"] = _context.cdb["client_1"] session_id = self._create_session(AUTH_REQ) grant = self.session_manager[session_id] @@ -1015,12 +1014,13 @@ def test_refresh_token_request_other_client(self): class TestOldTokens(object): + @pytest.fixture(autouse=True) def create_endpoint(self, conf): server = Server(OPConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) - endpoint_context = server.endpoint_context - endpoint_context.cdb["client_1"] = { + context = server.context + context.cdb["client_1"] = { "client_secret": "hemligt", "redirect_uris": [("https://example.com/cb", None)], "client_salt": "salted", @@ -1029,10 +1029,10 @@ def create_endpoint(self, conf): "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"] } server.keyjar.import_jwks(CLIENT_KEYJAR.export_jwks(), "client_1") - self.session_manager = endpoint_context.session_manager + self.session_manager = context.session_manager self.token_endpoint = server.get_endpoint("token") self.user_id = "diana" - self.endpoint_context = endpoint_context + self.context = context def _create_session(self, auth_req, sub_type="public", sector_identifier=""): if sector_identifier: @@ -1054,7 +1054,7 @@ def _mint_code(self, grant, client_id): # Constructing an authorization code is now done _code = grant.mint_token( session_id=session_id, - context=self.endpoint_context, + context=self.context, token_class="authorization_code", token_handler=self.session_manager.token_handler["authorization_code"], usage_rules=usage_rules, @@ -1118,7 +1118,7 @@ def test_old_jwt_token(self): payload = _handler.load_custom_claims(payload) # payload.update(kwargs) - _context = _handler.upstream_get("endpoint_context") + _context = _handler.upstream_get("context") signer = JWT( key_jar=_handler.upstream_get('attribute', 'keyjar'), iss=_handler.issuer, diff --git a/tests/test_server_36_oauth2_token_exchange.py b/tests/test_server_36_oauth2_token_exchange.py index d6729313..87d84b92 100644 --- a/tests/test_server_36_oauth2_token_exchange.py +++ b/tests/test_server_36_oauth2_token_exchange.py @@ -179,8 +179,8 @@ def create_endpoint(self): "session_params": SESSION_PARAMS, } server = Server(ASConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) - self.endpoint_context = server.endpoint_context - self.endpoint_context.cdb["client_1"] = { + self.context = server.context + self.context.cdb["client_1"] = { "client_secret": "hemligt", "redirect_uris": [("https://example.com/cb", None)], "client_salt": "salted", @@ -195,7 +195,7 @@ def create_endpoint(self): "response_types": ["code", "token", "code id_token", "id_token"], "allowed_scopes": ["openid", "profile", "offline_access"], } - self.endpoint_context.cdb["client_2"] = { + self.context.cdb["client_2"] = { "client_secret": "hemligt", "redirect_uris": [("https://example.com/cb", None)], "client_salt": "salted", @@ -206,7 +206,7 @@ def create_endpoint(self): server.keyjar.import_jwks(CLIENT_KEYJAR.export_jwks(), "client_1") self.endpoint = server.get_endpoint("token") self.introspection_endpoint = server.get_endpoint("introspection") - self.session_manager = self.endpoint_context.session_manager + self.session_manager = self.context.session_manager self.user_id = "diana" def _create_session(self, auth_req, sub_type="public", sector_identifier=""): @@ -259,7 +259,7 @@ def test_token_exchange1(self, token): areq["scope"] = ["openid", "offline_access"] session_id = self._create_session(areq) - grant = self.endpoint_context.authz(session_id, areq) + grant = self.context.authz(session_id, areq) code = self._mint_code(grant, areq["client_id"]) _token_request = TOKEN_REQ_DICT.copy() @@ -303,7 +303,7 @@ def test_token_exchange2(self, token): areq["scope"] = ["openid", "offline_access"] session_id = self._create_session(areq) - grant = self.endpoint_context.authz(session_id, areq) + grant = self.context.authz(session_id, areq) code = self._mint_code(grant, areq["client_id"]) _token_request = TOKEN_REQ_DICT.copy() @@ -343,7 +343,7 @@ def test_token_exchange_per_client(self, token): """ Test that per-client token exchange configuration works correctly """ - self.endpoint_context.cdb["client_1"]["token_exchange"] = { + self.context.cdb["client_1"]["token_exchange"] = { "subject_token_types_supported": [ "urn:ietf:params:oauth:token-type:access_token", "urn:ietf:params:oauth:token-type:refresh_token", @@ -366,7 +366,7 @@ def test_token_exchange_per_client(self, token): areq["scope"] = ["openid", "offline_access"] session_id = self._create_session(areq) - grant = self.endpoint_context.authz(session_id, areq) + grant = self.context.authz(session_id, areq) code = self._mint_code(grant, areq["client_id"]) _token_request = TOKEN_REQ_DICT.copy() @@ -575,7 +575,7 @@ def test_additional_parameters(self): areq = AUTH_REQ.copy() session_id = self._create_session(areq) - grant = self.endpoint_context.authz(session_id, areq) + grant = self.context.authz(session_id, areq) code = self._mint_code(grant, areq["client_id"]) _token_request = TOKEN_REQ_DICT.copy() @@ -624,7 +624,7 @@ def test_token_exchange_fails_if_disabled(self): areq = AUTH_REQ.copy() session_id = self._create_session(areq) - grant = self.endpoint_context.authz(session_id, areq) + grant = self.context.authz(session_id, areq) code = self._mint_code(grant, areq["client_id"]) _token_request = TOKEN_REQ_DICT.copy() @@ -662,7 +662,7 @@ def test_wrong_resource(self): areq = AUTH_REQ.copy() session_id = self._create_session(areq) - grant = self.endpoint_context.authz(session_id, areq) + grant = self.context.authz(session_id, areq) code = self._mint_code(grant, areq["client_id"]) _token_request = TOKEN_REQ_DICT.copy() @@ -696,7 +696,7 @@ def test_refresh_token_audience(self): areq["scope"] = ["openid", "offline_access"] session_id = self._create_session(areq) - grant = self.endpoint_context.authz(session_id, areq) + grant = self.context.authz(session_id, areq) code = self._mint_code(grant, areq["client_id"]) _token_request = TOKEN_REQ_DICT.copy() @@ -732,7 +732,7 @@ def test_wrong_audience(self): areq = AUTH_REQ.copy() session_id = self._create_session(areq) - grant = self.endpoint_context.authz(session_id, areq) + grant = self.context.authz(session_id, areq) code = self._mint_code(grant, areq["client_id"]) _token_request = TOKEN_REQ_DICT.copy() @@ -766,7 +766,7 @@ def test_exchange_refresh_token_to_refresh_token(self): areq["scope"] = ["openid", "offline_access"] session_id = self._create_session(areq) - grant = self.endpoint_context.authz(session_id, areq) + grant = self.context.authz(session_id, areq) code = self._mint_code(grant, areq["client_id"]) _token_request = TOKEN_REQ_DICT.copy() _token_request["scope"] = "openid" @@ -802,7 +802,7 @@ def test_exchange_access_token_to_refresh_token(self, scopes): areq["scope"] = scopes session_id = self._create_session(areq) - grant = self.endpoint_context.authz(session_id, areq) + grant = self.context.authz(session_id, areq) code = self._mint_code(grant, areq["client_id"]) _token_request = TOKEN_REQ_DICT.copy() @@ -842,7 +842,7 @@ def test_missing_parameters(self, missing_attribute): areq = AUTH_REQ.copy() session_id = self._create_session(areq) - grant = self.endpoint_context.authz(session_id, areq) + grant = self.context.authz(session_id, areq) code = self._mint_code(grant, areq["client_id"]) _token_request = TOKEN_REQ_DICT.copy() @@ -888,7 +888,7 @@ def test_unsupported_requested_token_type(self, unsupported_type): areq = AUTH_REQ.copy() session_id = self._create_session(areq) - grant = self.endpoint_context.authz(session_id, areq) + grant = self.context.authz(session_id, areq) code = self._mint_code(grant, areq["client_id"]) _token_request = TOKEN_REQ_DICT.copy() @@ -932,7 +932,7 @@ def test_unsupported_subject_token_type(self, unsupported_type): areq = AUTH_REQ.copy() session_id = self._create_session(areq) - grant = self.endpoint_context.authz(session_id, areq) + grant = self.context.authz(session_id, areq) code = self._mint_code(grant, areq["client_id"]) _token_request = TOKEN_REQ_DICT.copy() @@ -966,7 +966,7 @@ def test_unsupported_actor_token(self): areq = AUTH_REQ.copy() session_id = self._create_session(areq) - grant = self.endpoint_context.authz(session_id, areq) + grant = self.context.authz(session_id, areq) code = self._mint_code(grant, areq["client_id"]) _token_request = TOKEN_REQ_DICT.copy() @@ -999,7 +999,7 @@ def test_invalid_token(self): areq = AUTH_REQ.copy() session_id = self._create_session(areq) - grant = self.endpoint_context.authz(session_id, areq) + grant = self.context.authz(session_id, areq) code = self._mint_code(grant, areq["client_id"]) _token_request = TOKEN_REQ_DICT.copy() diff --git a/tests/test_server_40_oauth2_pushed_authorization.py b/tests/test_server_40_oauth2_pushed_authorization.py index 6a0aaffe..0664ed54 100644 --- a/tests/test_server_40_oauth2_pushed_authorization.py +++ b/tests/test_server_40_oauth2_pushed_authorization.py @@ -164,9 +164,9 @@ def create_endpoint(self): "session_params": SESSION_PARAMS, } server = Server(ASConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) - endpoint_context = server.endpoint_context + context = server.context _clients = yaml.safe_load(io.StringIO(client_yaml)) - endpoint_context.cdb = verify_oidc_client_information(_clients["oidc_clients"]) + context.cdb = verify_oidc_client_information(_clients["oidc_clients"]) server.keyjar.import_jwks( server.keyjar.export_jwks(True, ""), conf["issuer"] ) diff --git a/tests/test_server_50_persistence.py b/tests/test_server_50_persistence.py index 725510db..a0202cfa 100644 --- a/tests/test_server_50_persistence.py +++ b/tests/test_server_50_persistence.py @@ -218,7 +218,7 @@ def create_endpoint(self): ) # The top most part (Server class instance) is not - server1.endpoint_context.cdb["client_1"] = { + server1.context.cdb["client_1"] = { "client_secret": "hemligt", "redirect_uris": [("https://example.com/cb", None)], "client_salt": "salted", @@ -228,12 +228,12 @@ def create_endpoint(self): } # make server2 endpoint context a copy of server 1 endpoint context - _store = server1.endpoint_context.dump() - server2.endpoint_context.load( + _store = server1.context.dump() + server2.context.load( _store, init_args={ "upstream_get": server2.upstream_get, - "handler": server2.endpoint_context.session_manager.token_handler, + "handler": server2.context.session_manager.token_handler, }, ) @@ -243,8 +243,8 @@ def create_endpoint(self): } self.session_manager = { - 1: server1.endpoint_context.session_manager, - 2: server2.endpoint_context.session_manager, + 1: server1.context.session_manager, + 2: server2.context.session_manager, } self.user_id = "diana" diff --git a/tests/test_server_60_dpop.py b/tests/test_server_60_dpop.py index 156162bc..c93eb150 100644 --- a/tests/test_server_60_dpop.py +++ b/tests/test_server_60_dpop.py @@ -186,8 +186,8 @@ def create_endpoint(self): "session_params": SESSION_PARAMS, } server = Server(OPConfiguration(conf, base_path=BASEDIR), keyjar=KEYJAR) - self.endpoint_context = server.endpoint_context - self.endpoint_context.cdb["client_1"] = { + self.context = server.context + self.context.cdb["client_1"] = { "client_secret": "hemligt", "redirect_uris": [("https://example.com/cb", None)], "client_salt": "salted", @@ -197,7 +197,7 @@ def create_endpoint(self): } self.user_id = "diana" self.token_endpoint = server.get_endpoint("token") - self.session_manager = self.endpoint_context.session_manager + self.session_manager = self.context.session_manager def _create_session(self, auth_req, sub_type="public", sector_identifier=""): if sector_identifier: @@ -219,7 +219,7 @@ def _mint_code(self, grant, client_id): # Constructing an authorization code is now done _code = grant.mint_token( session_id=session_id, - context=self.endpoint_context, + context=self.context, token_class="authorization_code", token_handler=self.session_manager.token_handler["authorization_code"], usage_rules=usage_rules, @@ -236,7 +236,7 @@ def test_post_parse_request(self): auth_req = post_parse_request( AUTH_REQ, AUTH_REQ["client_id"], - self.endpoint_context, + self.context, http_info={ "headers": {"dpop": DPOP_HEADER}, "url": "https://server.example.com/token", @@ -252,7 +252,7 @@ def test_process_request(self): code = self._mint_code(grant, AUTH_REQ["client_id"]) _token_request = TOKEN_REQ.to_dict() - _context = self.endpoint_context + _context = self.context _token_request["code"] = code.value _req = self.token_endpoint.parse_request( _token_request, diff --git a/tests/test_server_61_add_on.py b/tests/test_server_61_add_on.py index b226af33..c9513cf7 100644 --- a/tests/test_server_61_add_on.py +++ b/tests/test_server_61_add_on.py @@ -136,8 +136,8 @@ def create_endpoint(self): "session_params": SESSION_PARAMS, } server = Server(OPConfiguration(conf, base_path=BASEDIR), keyjar=KEYJAR) - self.endpoint_context = server.endpoint_context - self.endpoint_context.cdb["client_1"] = { + self.context = server.context + self.context.cdb["client_1"] = { "client_secret": "hemligt", "redirect_uris": [("https://example.com/cb", None)], "client_salt": "salted", diff --git a/tests/test_tandem_10_oauth2_token_exchange.py b/tests/test_tandem_10_oauth2_token_exchange.py index b86b6da0..a7745066 100644 --- a/tests/test_tandem_10_oauth2_token_exchange.py +++ b/tests/test_tandem_10_oauth2_token_exchange.py @@ -75,7 +75,7 @@ def full_path(local_file): USERINFO = UserInfo(json.loads(open(full_path("users.json")).read())) _OAUTH2_SERVICES = { - "metadata": {"class": "idpyoidc.client.oauth2.server_metadata.ServerMetadata"}, + "claims": {"class": "idpyoidc.client.oauth2.server_metadata.ServerMetadata"}, "authorization": {"class": "idpyoidc.client.oauth2.authorization.Authorization"}, "access_token": {"class": "idpyoidc.client.oauth2.access_token.AccessToken"}, "refresh_access_token": { @@ -186,10 +186,10 @@ def create_endpoint(self): client_1_config = { "issuer": server_conf["issuer"], - "client_secret": "hemligt", + "client_secret": "hemligtlösenord", "client_id": "client_1", "redirect_uris": ["https://example.com/cb"], - "client_salt": "salted", + "client_salt": "salted_peanuts_cooking", "token_endpoint_auth_method": "client_secret_post", "response_types": ["code", "token", "code id_token", "id_token"], "allowed_scopes": ["openid", "profile", "offline_access"], @@ -197,9 +197,9 @@ def create_endpoint(self): client_2_config = { "issuer": server_conf["issuer"], "client_id": "client_2", - "client_secret": "hemligt", + "client_secret": "hemligtlösenord", "redirect_uris": ["https://example.com/cb"], - "client_salt": "salted", + "client_salt": "salted_peanuts_cooking", "token_endpoint_auth_method": "client_secret_post", "response_types": ["code", "token", "code id_token", "id_token"], "allowed_scopes": ["openid", "profile", "offline_access"], @@ -211,17 +211,19 @@ def create_endpoint(self): keyjar=build_keyjar(KEYDEFS), services=_OAUTH2_SERVICES) - self.endpoint_context = self.server.endpoint_context - self.endpoint_context.cdb["client_1"] = client_1_config - self.endpoint_context.cdb["client_2"] = client_2_config - self.endpoint_context.keyjar.import_jwks( - self.client_1.get_service_context().keyjar.export_jwks(), "client_1") - self.endpoint_context.keyjar.import_jwks( - self.client_2.get_service_context().keyjar.export_jwks(), "client_2") - - # self.endpoint = self.server.server_get("endpoint", "token") - # self.introspection_endpoint = self.server.server_get("endpoint", "introspection") - self.session_manager = self.endpoint_context.session_manager + self.context = self.server.context + self.context.cdb["client_1"] = client_1_config + self.context.cdb["client_2"] = client_2_config + self.context.keyjar.import_jwks( + self.client_1.keyjar.export_jwks(), "client_1") + self.context.keyjar.import_jwks( + self.client_2.keyjar.export_jwks(), "client_2") + + self.context.set_provider_info() + + # self.endpoint = self.server.upstream_get("endpoint", "token") + # self.introspection_endpoint = self.server.upstream_get("endpoint", "introspection") + self.session_manager = self.context.session_manager self.user_id = "diana" def do_query(self, service_type, endpoint_type, request_args, state): @@ -349,7 +351,7 @@ def test_token_exchange_per_client(self, token): """ Test that per-client token exchange configuration works correctly """ - self.endpoint_context.cdb["client_1"]["token_exchange"] = { + self.context.cdb["client_1"]["token_exchange"] = { "subject_token_types_supported": [ "urn:ietf:params:oauth:token-type:access_token", "urn:ietf:params:oauth:token-type:refresh_token", diff --git a/tests/x_test_ciba_01_backchannel_auth.py b/tests/x_test_ciba_01_backchannel_auth.py index d96fc859..62d79ac3 100644 --- a/tests/x_test_ciba_01_backchannel_auth.py +++ b/tests/x_test_ciba_01_backchannel_auth.py @@ -191,8 +191,8 @@ class TestBCAEndpoint(object): @pytest.fixture(autouse=True) def create_endpoint(self): self.server = Server(OPConfiguration(SERVER_CONF, base_path=BASEDIR)) - self.endpoint_context = self.server.endpoint_context - self.endpoint_context.cdb["client_1"] = { + self.context = self.server.context + self.context.cdb["client_1"] = { "client_secret": "hemligt", "redirect_uris": [("https://example.com/cb", None)], "client_salt": "salted", @@ -211,9 +211,9 @@ def create_endpoint(self): self.server.keyjar.add_symmetric(CLIENT_ID, CLIENT_SECRET, ["sig"]) self.server.keyjar.import_jwks(self.client_keyjar.export_jwks(), CLIENT_ID) - self.server.endpoint_context.cdb = {CLIENT_ID: {"client_secret": CLIENT_SECRET}} + self.server.context.cdb = {CLIENT_ID: {"client_secret": CLIENT_SECRET}} # login_hint - self.server.endpoint_context.login_hint_lookup = init_service( + self.server.context.login_hint_lookup = init_service( {"class": "idpyoidc.self.server.login_hint.LoginHintLookup"}, None ) # userinfo @@ -224,8 +224,8 @@ def create_endpoint(self): }, "", ) - self.server.endpoint_context.login_hint_lookup.userinfo = _userinfo - self.session_manager = self.server.endpoint_context.session_manager + self.server.context.login_hint_lookup.userinfo = _userinfo + self.session_manager = self.server.context.session_manager def test_login_hint_token(self): _jwt = JWT(self.client_keyjar, iss=CLIENT_ID, sign_alg="HS256") @@ -494,8 +494,8 @@ def create_endpoint(self): def _create_self.server(self): self.server = Server(OPConfiguration(SERVER_CONF, base_path=BASEDIR)) - endpoint_context = self.server.endpoint_context - endpoint_context.cdb["client_1"] = { + context = self.server.context + context.cdb["client_1"] = { "client_secret": "hemligt", "redirect_uris": [("https://example.com/cb", None)], "client_salt": "salted", @@ -512,9 +512,9 @@ def _create_self.server(self): self.server.keyjar.add_symmetric(CLIENT_ID, CLIENT_SECRET, ["sig"]) self.server.keyjar.import_jwks(client_keyjar.export_jwks(), CLIENT_ID) - self.server.endpoint_context.cdb = {CLIENT_ID: {"client_secret": CLIENT_SECRET}} + self.server.context.cdb = {CLIENT_ID: {"client_secret": CLIENT_SECRET}} # login_hint - self.server.endpoint_context.login_hint_lookup = init_service( + self.server.context.login_hint_lookup = init_service( {"class": "idpyoidc.self.server.login_hint.LoginHintLookup"}, None ) # userinfo @@ -525,7 +525,7 @@ def _create_self.server(self): }, "", ) - self.server.endpoint_context.login_hint_lookup.userinfo = _userinfo + self.server.context.login_hint_lookup.userinfo = _userinfo return self.server def _create_ciba_client(self): @@ -562,13 +562,13 @@ def _create_session(self, user_id, auth_req, sub_type="public", sector_identifie authz_req = auth_req client_id = authz_req["client_id"] ae = create_authn_event(user_id) - _session_manager = self.ciba["self.server"].endpoint_context.session_manager + _session_manager = self.ciba["self.server"].context.session_manager return _session_manager.create_session( ae, authz_req, user_id, client_id=client_id, sub_type=sub_type ) def test_client_notification(self): - _keyjar = self.ciba["self.server"].endpoint_context.keyjar + _keyjar = self.ciba["self.server"].context.keyjar _jwt = JWT(_keyjar, iss=CLIENT_ID, sign_alg="HS256") _jwt.with_jti = True _assertion = _jwt.pack({"aud": [ISSUER]}) @@ -589,7 +589,7 @@ def test_client_notification(self): _info = _authn_endpoint.process_request(req) assert _info - _session_manager = self.ciba["self.server"].endpoint_context.session_manager + _session_manager = self.ciba["self.server"].context.session_manager sid = _session_manager.auth_req_id_map[_info["response_args"]["auth_req_id"]] _user_id, _client_id, _grant_id = _session_manager.decrypt_session_id(sid) From d18e3d5293f3585eb0a2cd810d86249e3e9f7308 Mon Sep 17 00:00:00 2001 From: Roland Hedberg Date: Thu, 9 Feb 2023 09:32:18 +0100 Subject: [PATCH 53/76] Fixed Flake8 complains --- archdoc/docs/client/index.md | 59 +++++++++++++++++++ archdoc/docs/combo/index.md | 1 + archdoc/docs/server/index.md | 10 ++++ {src/idpyoidc => private}/actor/__init__.py | 0 .../actor/client/__init__.py | 0 .../actor/client/oidc/__init__.py | 0 .../actor/client/oidc/registration.py | 0 .../actor/server/__init__.py | 0 .../actor/server/oidc/__init__.py | 0 src/idpyoidc/__init__.py | 3 - src/idpyoidc/client/claims/__init__.py | 2 +- src/idpyoidc/client/claims/transform.py | 1 - src/idpyoidc/client/entity.py | 5 +- .../oauth2/add_on/pushed_authorization.py | 6 +- .../cc_refresh_access_token.py | 2 - src/idpyoidc/client/oidc/access_token.py | 4 +- src/idpyoidc/client/oidc/check_id.py | 2 +- src/idpyoidc/client/oidc/check_session.py | 2 +- src/idpyoidc/client/oidc/registration.py | 2 +- src/idpyoidc/client/rp_handler.py | 6 +- src/idpyoidc/client/service.py | 6 +- src/idpyoidc/client/service_context.py | 5 +- src/idpyoidc/configure.py | 2 +- src/idpyoidc/message/oauth2/__init__.py | 4 +- src/idpyoidc/metadata.py | 2 +- src/idpyoidc/server/__init__.py | 8 +-- src/idpyoidc/server/client_authn.py | 25 ++++---- src/idpyoidc/server/endpoint_context.py | 2 +- src/idpyoidc/server/oauth2/authorization.py | 2 +- src/idpyoidc/server/oauth2/token.py | 3 - src/idpyoidc/server/oauth2/token_helper.py | 15 +++-- .../server/oidc/backchannel_authentication.py | 31 +++++----- src/idpyoidc/server/oidc/registration.py | 1 - src/idpyoidc/server/oidc/userinfo.py | 2 +- src/idpyoidc/server/scopes.py | 2 - src/idpyoidc/server/session/grant.py | 4 +- src/idpyoidc/server/session/manager.py | 13 +--- src/idpyoidc/server/token/jwt_token.py | 31 +++++----- src/idpyoidc/server/user_authn/user.py | 12 ++-- src/idpyoidc/time_util.py | 2 +- 40 files changed, 162 insertions(+), 115 deletions(-) create mode 100644 archdoc/docs/client/index.md create mode 100644 archdoc/docs/combo/index.md create mode 100644 archdoc/docs/server/index.md rename {src/idpyoidc => private}/actor/__init__.py (100%) rename {src/idpyoidc => private}/actor/client/__init__.py (100%) rename {src/idpyoidc => private}/actor/client/oidc/__init__.py (100%) rename {src/idpyoidc => private}/actor/client/oidc/registration.py (100%) rename {src/idpyoidc => private}/actor/server/__init__.py (100%) rename {src/idpyoidc => private}/actor/server/oidc/__init__.py (100%) diff --git a/archdoc/docs/client/index.md b/archdoc/docs/client/index.md new file mode 100644 index 00000000..be03f230 --- /dev/null +++ b/archdoc/docs/client/index.md @@ -0,0 +1,59 @@ +# The IdpyOIDC client + +A client can send requests to an endpoint and deal with the response. + +IdpyOIDC assumes that there is one Relying Party(RP)/Client instance per +OpenID Connect Provider(OP)/Authorization Server (AS). + +If you have a service that expects to talk to several OPs/ASs +then you must use **idpyoidc.client.rp_handler.RPHandler** to manage the RPs. + +RPHandler has methods like: +- begin() +- finalize() +- refresh_access_token() +- logout() + +More about RPHandler at the end of this section. + +## Client + +A client is configured to talk to a set of services each of them represented by +a Service Instance. + +# Context + +# Service + +A Service instance is expected to be able to: + +1. Collect all the request arguments +2. If necessary collect and add authentication information to the request attributes or HTTP header +3. Formats the message +4. chooses HTTP method +5. Add HTTP headers + +and then after having received the response: + +1. Parses the response +2. Gather verification information and verify the response +3. Do any special post-processing. +3. Store information from the response + +Doesn't matter which service is considered they all have to be able to do this. + +## Request + +## Response + +# AddOn + +# Endpoints + +## OAuth2 + +- Access Token +- Authorization +- Refresh Access Token +- Server Metadata +- Token Exchange diff --git a/archdoc/docs/combo/index.md b/archdoc/docs/combo/index.md new file mode 100644 index 00000000..ab4a67a7 --- /dev/null +++ b/archdoc/docs/combo/index.md @@ -0,0 +1 @@ +# An entity that can act both as a server and a client diff --git a/archdoc/docs/server/index.md b/archdoc/docs/server/index.md new file mode 100644 index 00000000..415bfff3 --- /dev/null +++ b/archdoc/docs/server/index.md @@ -0,0 +1,10 @@ + +# Service + +## Request + +## Response + +# Context + +# AddOn diff --git a/src/idpyoidc/actor/__init__.py b/private/actor/__init__.py similarity index 100% rename from src/idpyoidc/actor/__init__.py rename to private/actor/__init__.py diff --git a/src/idpyoidc/actor/client/__init__.py b/private/actor/client/__init__.py similarity index 100% rename from src/idpyoidc/actor/client/__init__.py rename to private/actor/client/__init__.py diff --git a/src/idpyoidc/actor/client/oidc/__init__.py b/private/actor/client/oidc/__init__.py similarity index 100% rename from src/idpyoidc/actor/client/oidc/__init__.py rename to private/actor/client/oidc/__init__.py diff --git a/src/idpyoidc/actor/client/oidc/registration.py b/private/actor/client/oidc/registration.py similarity index 100% rename from src/idpyoidc/actor/client/oidc/registration.py rename to private/actor/client/oidc/registration.py diff --git a/src/idpyoidc/actor/server/__init__.py b/private/actor/server/__init__.py similarity index 100% rename from src/idpyoidc/actor/server/__init__.py rename to private/actor/server/__init__.py diff --git a/src/idpyoidc/actor/server/oidc/__init__.py b/private/actor/server/oidc/__init__.py similarity index 100% rename from src/idpyoidc/actor/server/oidc/__init__.py rename to private/actor/server/oidc/__init__.py diff --git a/src/idpyoidc/__init__.py b/src/idpyoidc/__init__.py index 1ca2d4a7..c7216254 100644 --- a/src/idpyoidc/__init__.py +++ b/src/idpyoidc/__init__.py @@ -1,9 +1,6 @@ __author__ = "Roland Hedberg" __version__ = "2.0.0" -import os -from typing import Dict - VERIFIED_CLAIM_PREFIX = "__verified" diff --git a/src/idpyoidc/client/claims/__init__.py b/src/idpyoidc/client/claims/__init__.py index 66365344..f303e9e2 100644 --- a/src/idpyoidc/client/claims/__init__.py +++ b/src/idpyoidc/client/claims/__init__.py @@ -53,7 +53,7 @@ def get_jwks(self, keyjar): # if only one key under the id == "", that key being a SYMKey I assume it's # and I have a client_secret then don't publish a JWKS if len(_own_keys) == 1 and isinstance(_own_keys[0], SYMKey) and self.prefer[ - 'client_secret']: + 'client_secret']: pass else: _jwks = keyjar.export_jwks() diff --git a/src/idpyoidc/client/claims/transform.py b/src/idpyoidc/client/claims/transform.py index ec67b790..87e4e2d0 100644 --- a/src/idpyoidc/client/claims/transform.py +++ b/src/idpyoidc/client/claims/transform.py @@ -1,7 +1,6 @@ import logging from typing import Optional -from idpyoidc.message import Message from idpyoidc.message.oidc import RegistrationRequest from idpyoidc.message.oidc import RegistrationResponse diff --git a/src/idpyoidc/client/entity.py b/src/idpyoidc/client/entity.py index 2b876581..d00212dd 100644 --- a/src/idpyoidc/client/entity.py +++ b/src/idpyoidc/client/entity.py @@ -1,4 +1,5 @@ import logging +from typing import Callable from typing import Optional from typing import Union @@ -16,7 +17,6 @@ from idpyoidc.client.service_context import ServiceContext from idpyoidc.context import OidcContext from idpyoidc.node import Unit -from idpyoidc.server.client_authn import client_auth_class logger = logging.getLogger(__name__) @@ -200,4 +200,5 @@ def import_keys(self, keyspec): return _keyjar def get_callback_uris(self): - return self.context.claims.callback_uri \ No newline at end of file + return self.context.claims.callback_uri + diff --git a/src/idpyoidc/client/oauth2/add_on/pushed_authorization.py b/src/idpyoidc/client/oauth2/add_on/pushed_authorization.py index 4c67bdcd..611a0008 100644 --- a/src/idpyoidc/client/oauth2/add_on/pushed_authorization.py +++ b/src/idpyoidc/client/oauth2/add_on/pushed_authorization.py @@ -1,7 +1,6 @@ import logging from cryptojwt import JWT - from requests import request from idpyoidc.message import Message @@ -26,7 +25,7 @@ def push_authorization(request_args, service, **kwargs): if method_args["body_format"] == "urlencoded": _body = request_args.to_urlencoded() else: - _jwt = JWT(key_jar=service.upstream_get('attribute','keyjar'), + _jwt = JWT(key_jar=service.upstream_get('attribute', 'keyjar'), iss=_context.base_url) _jws = _jwt.pack(request_args.to_dict()) @@ -56,7 +55,8 @@ def push_authorization(request_args, service, **kwargs): def add_support( - services, body_format="jws", signing_algorithm="RS256", http_client=None, merge_rule="strict" + services, body_format="jws", signing_algorithm="RS256", http_client=None, + merge_rule="strict" ): """ Add the necessary pieces to make Demonstration of proof of possession (DPOP). diff --git a/src/idpyoidc/client/oauth2/client_credentials/cc_refresh_access_token.py b/src/idpyoidc/client/oauth2/client_credentials/cc_refresh_access_token.py index 69ac5ff5..6ab144fc 100644 --- a/src/idpyoidc/client/oauth2/client_credentials/cc_refresh_access_token.py +++ b/src/idpyoidc/client/oauth2/client_credentials/cc_refresh_access_token.py @@ -1,5 +1,3 @@ -from typing import Optional - from idpyoidc.client.service import Service from idpyoidc.message import oauth2 from idpyoidc.message.oauth2 import ResponseMessage diff --git a/src/idpyoidc/client/oidc/access_token.py b/src/idpyoidc/client/oidc/access_token.py index 4fc8fb7d..547f0ed2 100644 --- a/src/idpyoidc/client/oidc/access_token.py +++ b/src/idpyoidc/client/oidc/access_token.py @@ -2,11 +2,11 @@ from typing import Optional from typing import Union +from idpyoidc.claims import get_signing_algs from idpyoidc.client.client_auth import get_client_authn_methods from idpyoidc.client.exception import ParameterError from idpyoidc.client.oauth2 import access_token from idpyoidc.client.oidc import IDT2REG -from idpyoidc.claims import get_signing_algs from idpyoidc.message import Message from idpyoidc.message import oidc from idpyoidc.message.oidc import verified_claim_name @@ -71,7 +71,7 @@ def gather_verify_arguments( return kwargs - def update_service_context(self, resp, key: Optional[str] ="", **kwargs): + def update_service_context(self, resp, key: Optional[str] = "", **kwargs): _cstate = self.upstream_get("context").cstate try: _idt = resp[verified_claim_name("id_token")] diff --git a/src/idpyoidc/client/oidc/check_id.py b/src/idpyoidc/client/oidc/check_id.py index 38e5897f..3e33e3c7 100644 --- a/src/idpyoidc/client/oidc/check_id.py +++ b/src/idpyoidc/client/oidc/check_id.py @@ -23,7 +23,7 @@ def __init__(self, upstream_get, conf=None): Service.__init__(self, upstream_get, conf=conf) self.pre_construct = [self.oidc_pre_construct] - def oidc_pre_construct(self, request_args: Optional[dict]=None, **kwargs): + def oidc_pre_construct(self, request_args: Optional[dict] = None, **kwargs): _args = self.upstream_get("context").cstate.get_set( kwargs["state"], claim=["id_token"] diff --git a/src/idpyoidc/client/oidc/check_session.py b/src/idpyoidc/client/oidc/check_session.py index 373f5242..b089e2d3 100644 --- a/src/idpyoidc/client/oidc/check_session.py +++ b/src/idpyoidc/client/oidc/check_session.py @@ -24,7 +24,7 @@ def __init__(self, upstream_get, conf=None): def oidc_pre_construct(self, request_args=None, **kwargs): _args = self.upstream_get("context").cstate.get_set(kwargs["state"], - claim=["id_token"]) + claim=["id_token"]) if request_args: request_args.update(_args) else: diff --git a/src/idpyoidc/client/oidc/registration.py b/src/idpyoidc/client/oidc/registration.py index 4f202a62..3c6ac713 100644 --- a/src/idpyoidc/client/oidc/registration.py +++ b/src/idpyoidc/client/oidc/registration.py @@ -108,4 +108,4 @@ def gather_request_args(self, **kwargs): req_args.update(self.conf["request_args"]) req_args.update(kwargs) - return req_args \ No newline at end of file + return req_args diff --git a/src/idpyoidc/client/rp_handler.py b/src/idpyoidc/client/rp_handler.py index 59589506..2ceb0e50 100644 --- a/src/idpyoidc/client/rp_handler.py +++ b/src/idpyoidc/client/rp_handler.py @@ -180,7 +180,7 @@ def init_client(self, issuer): except KeyError: _services = self.services - if not 'base_url' in _cnf: + if 'base_url' not in _cnf: _cnf['base_url'] = self.base_url if self.jwks_uri: @@ -574,7 +574,7 @@ def get_tokens(self, state, client: Optional[Client] = None): authn_method=self.get_client_authn_method(client, "token_endpoint"), state=state, ) - except Exception as err: + except Exception: message = traceback.format_exception(*sys.exc_info()) logger.error(message) raise @@ -613,7 +613,7 @@ def refresh_access_token(self, state, client=None, scope=""): state=state, request_args=req_args, ) - except Exception as err: + except Exception: message = traceback.format_exception(*sys.exc_info()) logger.error(message) raise diff --git a/src/idpyoidc/client/service.py b/src/idpyoidc/client/service.py index 762cd403..ece09d01 100644 --- a/src/idpyoidc/client/service.py +++ b/src/idpyoidc/client/service.py @@ -525,8 +525,7 @@ def _do_jwt(self, info): args["allowed_enc_algs"] = enc_algs["alg"] args["allowed_enc_encs"] = enc_algs["enc"] - - _jwt = JWT(key_jar=self.upstream_get('attribute','keyjar'), **args) + _jwt = JWT(key_jar=self.upstream_get('attribute', 'keyjar'), **args) _jwt.iss = _context.get_client_id() return _jwt.unpack(info) @@ -675,7 +674,8 @@ def construct_uris(self, else: _path = self._callback_path.get(uri) if isinstance(_path, str): - _callback_uris[uri] = self.get_uri(base_url, self._callback_path.get(_path), hex) + _callback_uris[uri] = self.get_uri(base_url, self._callback_path.get(_path), + hex) else: _callback_uris[uri] = [self.get_uri(base_url, self._callback_path.get(_var), hex) for _var in _path] diff --git a/src/idpyoidc/client/service_context.py b/src/idpyoidc/client/service_context.py index 1da9765e..b49fcd69 100644 --- a/src/idpyoidc/client/service_context.py +++ b/src/idpyoidc/client/service_context.py @@ -2,15 +2,14 @@ Implements a service context. A Service context is used to keep information that are common between all the services that are used by OAuth2 client or OpenID Connect Relying Party. """ -import copy import hashlib import logging from typing import Callable from typing import Optional from typing import Union -from cryptojwt.jwk.rsa import import_private_rsa_key_from_file from cryptojwt.jwk.rsa import RSAKey +from cryptojwt.jwk.rsa import import_private_rsa_key_from_file from cryptojwt.key_bundle import KeyBundle from cryptojwt.key_jar import KeyJar from cryptojwt.utils import as_bytes @@ -152,7 +151,7 @@ def __init__(self, self.client_secret_expires_at = 0 self.registration_response = {} - _def_value = copy.deepcopy(DEFAULT_VALUE) + # _def_value = copy.deepcopy(DEFAULT_VALUE) _issuer = config.get("issuer") if _issuer: diff --git a/src/idpyoidc/configure.py b/src/idpyoidc/configure.py index 9b74135d..5258b576 100644 --- a/src/idpyoidc/configure.py +++ b/src/idpyoidc/configure.py @@ -265,7 +265,7 @@ def __init__( if entity_conf: skip = [ec["path"] for ec in entity_conf if "path" in ec] - check = [l[0] for l in skip] + check = [word[0] for word in skip] self.extend( conf=self.conf, diff --git a/src/idpyoidc/message/oauth2/__init__.py b/src/idpyoidc/message/oauth2/__init__.py index 2660aec0..11852a2d 100644 --- a/src/idpyoidc/message/oauth2/__init__.py +++ b/src/idpyoidc/message/oauth2/__init__.py @@ -10,7 +10,6 @@ from idpyoidc.exception import MissingRequiredAttribute from idpyoidc.exception import VerificationError from idpyoidc.message import Message -from idpyoidc.message import msg_ser from idpyoidc.message import OPTIONAL_LIST_OF_SP_SEP_STRINGS from idpyoidc.message import OPTIONAL_LIST_OF_STRINGS from idpyoidc.message import REQUIRED_LIST_OF_SP_SEP_STRINGS @@ -21,6 +20,7 @@ from idpyoidc.message import SINGLE_REQUIRED_BOOLEAN from idpyoidc.message import SINGLE_REQUIRED_INT from idpyoidc.message import SINGLE_REQUIRED_STRING +from idpyoidc.message import msg_ser logger = logging.getLogger(__name__) @@ -542,6 +542,7 @@ class SecurityEventToken(Message): "toe": SINGLE_OPTIONAL_INT, } + class JWTAccessToken(Message): c_param = { "iss": SINGLE_REQUIRED_STRING, @@ -561,7 +562,6 @@ class JWTAccessToken(Message): } - class JSONWebToken(Message): # implements RFC 9068 c_param = { diff --git a/src/idpyoidc/metadata.py b/src/idpyoidc/metadata.py index 7104d82a..55a90fbc 100644 --- a/src/idpyoidc/metadata.py +++ b/src/idpyoidc/metadata.py @@ -244,7 +244,7 @@ def alg_cmp(a, b): def get_signing_algs(): # Assumes Cryptojwt - _algs = [l for l in list(SIGNER_ALGS.keys()) if l != 'none'] + _algs = [name for name in list(SIGNER_ALGS.keys()) if name != 'none'] return sorted(_algs, key=cmp_to_key(alg_cmp)) diff --git a/src/idpyoidc/server/__init__.py b/src/idpyoidc/server/__init__.py index 59b93e14..0a277ca8 100644 --- a/src/idpyoidc/server/__init__.py +++ b/src/idpyoidc/server/__init__.py @@ -8,14 +8,14 @@ from cryptojwt import KeyJar from idpyoidc.node import Unit -from idpyoidc.server import authz -from idpyoidc.server.client_authn import client_auth_setup +# from idpyoidc.server import authz +# from idpyoidc.server.client_authn import client_auth_setup from idpyoidc.server.configure import ASConfiguration from idpyoidc.server.configure import OPConfiguration from idpyoidc.server.endpoint import Endpoint from idpyoidc.server.endpoint_context import EndpointContext -from idpyoidc.server.session.manager import create_session_manager -from idpyoidc.server.user_authn.authn_context import populate_authn_broker +# from idpyoidc.server.session.manager import create_session_manager +# from idpyoidc.server.user_authn.authn_context import populate_authn_broker from idpyoidc.server.util import allow_refresh_token from idpyoidc.server.util import build_endpoints diff --git a/src/idpyoidc/server/client_authn.py b/src/idpyoidc/server/client_authn.py index a7b4c1d9..ec671beb 100755 --- a/src/idpyoidc/server/client_authn.py +++ b/src/idpyoidc/server/client_authn.py @@ -3,18 +3,13 @@ from typing import Callable from typing import Dict from typing import Optional -from typing import TYPE_CHECKING from typing import Union -if TYPE_CHECKING: - from idpyoidc.server.endpoint_context import EndpointContext - from cryptojwt.exception import BadSignature from cryptojwt.exception import Invalid from cryptojwt.exception import MissingKey from cryptojwt.jwt import JWT from cryptojwt.jwt import utc_time_sans_frac -from cryptojwt.key_jar import KeyJar from cryptojwt.utils import as_bytes from cryptojwt.utils import as_unicode @@ -27,7 +22,6 @@ from idpyoidc.server.exception import InvalidClient from idpyoidc.server.exception import InvalidToken from idpyoidc.server.exception import ToOld -from idpyoidc.server.exception import UnAuthorizedClient from idpyoidc.server.exception import UnknownClient from idpyoidc.util import importer from idpyoidc.util import sanitize @@ -259,6 +253,7 @@ def _verify( request: Optional[Union[dict, Message]] = None, authorization_token: Optional[str] = None, endpoint=None, # Optional[Endpoint] + get_client_id_from_token: Optional[Callable] = None, **kwargs, ): _token = request.get("access_token") @@ -266,7 +261,8 @@ def _verify( raise ClientAuthenticationError("No access token") res = {"token": _token} - _client_id = get_client_id_from_token(endpoint_context, _token, request) + _context = self.upstream_get('context') + _client_id = get_client_id_from_token(_context, _token, request) if _client_id: res["client_id"] = _client_id return res @@ -290,7 +286,7 @@ def _verify( **kwargs, ): _context = self.upstream_get('context') - _keyjar = self.upstream_get('attribute','keyjar') + _keyjar = self.upstream_get('attribute', 'keyjar') _jwt = JWT(_keyjar, msg_cls=JsonWebToken) try: ca_jwt = _jwt.unpack(request["client_assertion"]) @@ -477,7 +473,8 @@ def verify_client( authorization_token = None auth_info = {} - methods = context.client_authn_method + _context = endpoint.upstream_get('context') + methods = _context.client_authn_method client_id = None allowed_methods = getattr(endpoint, "client_authn_method") if not allowed_methods: @@ -490,7 +487,7 @@ def verify_client( try: logger.info(f"Verifying client authentication using {_method.tag}") auth_info = _method.verify( - keyjar=keyjar, + keyjar=endpoint.upstream_get('attribute', 'keyjar'), request=request, authorization_token=authorization_token, endpoint=endpoint, @@ -513,10 +510,10 @@ def verify_client( client_id = also_known_as[client_id] auth_info["client_id"] = client_id - if client_id not in context.cdb: + if client_id not in _context.cdb: raise UnknownClient("Unknown Client ID") - _cinfo = context.cdb[client_id] + _cinfo = _context.cdb[client_id] if not valid_client_info(_cinfo): logger.warning("Client registration has timed out or " "client secret is expired.") @@ -540,9 +537,9 @@ def verify_client( _request_type = request.__class__.__name__ _used_authn_method = _cinfo.get("auth_method") if _used_authn_method: - context.cdb[client_id]["auth_method"][_request_type] = auth_info["method"] + _context.cdb[client_id]["auth_method"][_request_type] = auth_info["method"] else: - context.cdb[client_id]["auth_method"] = {_request_type: auth_info["method"]} + _context.cdb[client_id]["auth_method"] = {_request_type: auth_info["method"]} return auth_info diff --git a/src/idpyoidc/server/endpoint_context.py b/src/idpyoidc/server/endpoint_context.py index fa881379..190771fd 100755 --- a/src/idpyoidc/server/endpoint_context.py +++ b/src/idpyoidc/server/endpoint_context.py @@ -280,7 +280,7 @@ def __init__( self.setup_client_authn_methods() - _id_token_handler = self.session_manager.token_handler.handler.get("id_token") + # _id_token_handler = self.session_manager.token_handler.handler.get("id_token") # if _id_token_handler: # self.provider_info.update(_id_token_handler.provider_info) diff --git a/src/idpyoidc/server/oauth2/authorization.py b/src/idpyoidc/server/oauth2/authorization.py index 0a74922a..74fdf213 100755 --- a/src/idpyoidc/server/oauth2/authorization.py +++ b/src/idpyoidc/server/oauth2/authorization.py @@ -517,7 +517,7 @@ def _post_parse_request(self, request, client_id, context, **kwargs): request["redirect_uri"] = redirect_uri if ("resource_indicators" in _cinfo - and "authorization_code" in _cinfo["resource_indicators"]): + and "authorization_code" in _cinfo["resource_indicators"]): resource_indicators_config = _cinfo["resource_indicators"]["authorization_code"] else: resource_indicators_config = self.resource_indicators_config diff --git a/src/idpyoidc/server/oauth2/token.py b/src/idpyoidc/server/oauth2/token.py index 2ba4ecf5..20034b84 100755 --- a/src/idpyoidc/server/oauth2/token.py +++ b/src/idpyoidc/server/oauth2/token.py @@ -7,16 +7,13 @@ from idpyoidc.message import Message from idpyoidc.message.oauth2 import AccessTokenResponse from idpyoidc.message.oauth2 import ResponseMessage -from idpyoidc.message.oauth2 import TokenExchangeRequest from idpyoidc.message.oidc import TokenErrorResponse -from idpyoidc.server.constant import DEFAULT_REQUESTED_TOKEN_TYPE from idpyoidc.server.endpoint import Endpoint from idpyoidc.server.exception import ProcessError from idpyoidc.server.oauth2.token_helper import AccessTokenHelper from idpyoidc.server.oauth2.token_helper import RefreshTokenHelper from idpyoidc.server.oauth2.token_helper import TokenExchangeHelper from idpyoidc.server.session import MintingNotAllowed -from idpyoidc.server.session.token import TOKEN_TYPES_MAPPING from idpyoidc.util import importer logger = logging.getLogger(__name__) diff --git a/src/idpyoidc/server/oauth2/token_helper.py b/src/idpyoidc/server/oauth2/token_helper.py index f55296df..8836697b 100755 --- a/src/idpyoidc/server/oauth2/token_helper.py +++ b/src/idpyoidc/server/oauth2/token_helper.py @@ -102,6 +102,7 @@ def _mint_token( return token + def validate_resource_indicators_policy(request, context, **kwargs): if "resource" not in request: return TokenErrorResponse( @@ -114,7 +115,8 @@ def validate_resource_indicators_policy(request, context, **kwargs): resource_servers_per_client = kwargs.get("resource_servers_per_client", None) - if isinstance(resource_servers_per_client, dict) and client_id not in resource_servers_per_client: + if isinstance(resource_servers_per_client, + dict) and client_id not in resource_servers_per_client: return TokenErrorResponse( error="invalid_target", error_description=f"Resources for client {client_id} not found", @@ -183,7 +185,7 @@ def process_request(self, req: Union[Message, dict], **kwargs): _cinfo = self.endpoint.server_get("endpoint_context").cdb.get(client_id) if ("resource_indicators" in _cinfo - and "access_token" in _cinfo["resource_indicators"]): + and "access_token" in _cinfo["resource_indicators"]): resource_indicators_config = _cinfo["resource_indicators"]["access_token"] else: resource_indicators_config = self.endpoint.kwargs.get("resource_indicators", None) @@ -198,10 +200,11 @@ def process_request(self, req: Union[Message, dict], **kwargs): if isinstance(req, TokenErrorResponse): return req - if "grant_types_supported" in _context.cdb[client_id]: - grant_types_supported = _context.cdb[client_id].get("grant_types_supported") - else: - grant_types_supported = _context.provider_info["grant_types_supported"] + # if "grant_types_supported" in _context.cdb[client_id]: + # grant_types_supported = _context.cdb[client_id].get("grant_types_supported") + # else: + # grant_types_supported = _context.provider_info["grant_types_supported"] + grant = _session_info["grant"] _based_on = grant.get_token(_access_code) diff --git a/src/idpyoidc/server/oidc/backchannel_authentication.py b/src/idpyoidc/server/oidc/backchannel_authentication.py index 60134e90..941010b6 100644 --- a/src/idpyoidc/server/oidc/backchannel_authentication.py +++ b/src/idpyoidc/server/oidc/backchannel_authentication.py @@ -1,8 +1,8 @@ import logging -import uuid from typing import Callable from typing import Optional from typing import Union +import uuid from cryptojwt.jwe.exception import JWEException from cryptojwt.jws.exception import NoSuitableSigningKeys @@ -14,7 +14,6 @@ from idpyoidc.message.oidc.backchannel_authentication import AuthenticationRequest from idpyoidc.message.oidc.backchannel_authentication import AuthenticationResponse from idpyoidc.server import Endpoint -from idpyoidc.server import EndpointContext from idpyoidc.server.client_authn import ClientSecretBasic from idpyoidc.server.exception import NoSuchAuthentication from idpyoidc.server.oidc.token_helper import AccessTokenHelper @@ -86,10 +85,10 @@ def allowed_target_uris(self): return set(res) def process_request( - self, - request: Optional[Union[Message, dict]] = None, - http_info: Optional[dict] = None, - **kwargs, + self, + request: Optional[Union[Message, dict]] = None, + http_info: Optional[dict] = None, + **kwargs, ): try: request_user = self.do_request_user(request) @@ -137,7 +136,7 @@ def _get_session_info(self, request, session_manager): return session_info, _grant def post_parse_request( - self, request: Union[Message, dict], client_id: Optional[str] = "", **kwargs + self, request: Union[Message, dict], client_id: Optional[str] = "", **kwargs ) -> Union[Message, dict]: _context = self.endpoint.upstream_get("context") _mngr = _context.session_manager @@ -303,10 +302,10 @@ def __init__(self, upstream_get: Callable, **kwargs): Endpoint.__init__(self, upstream_get, **kwargs) def process_request( - self, - request: Optional[Union[Message, dict]] = None, - http_info: Optional[dict] = None, - **kwargs, + self, + request: Optional[Union[Message, dict]] = None, + http_info: Optional[dict] = None, + **kwargs, ) -> Union[Message, dict]: return {} @@ -322,11 +321,11 @@ def is_usable(self, request=None, authorization_token=None): return False def _verify( - self, - authorization_token: Optional[str] = None, - endpoint=None, # Optional[Endpoint] - get_client_id_from_token: Optional[Callable] = None, - **kwargs, + self, + authorization_token: Optional[str] = None, + endpoint=None, # Optional[Endpoint] + get_client_id_from_token: Optional[Callable] = None, + **kwargs, ): ttype, token = authorization_token.split(" ", 1) if ttype != "Bearer": diff --git a/src/idpyoidc/server/oidc/registration.py b/src/idpyoidc/server/oidc/registration.py index 09709f71..9b1cdef7 100755 --- a/src/idpyoidc/server/oidc/registration.py +++ b/src/idpyoidc/server/oidc/registration.py @@ -167,7 +167,6 @@ def match_claim(self, claim, val): def filter_client_request(self, request: dict) -> dict: _args = {} _context = self.upstream_get("context") - _provider_info = _context.provider_info for key, val in request.items(): if key not in _context.claims.register2preferred: _args[key] = val diff --git a/src/idpyoidc/server/oidc/userinfo.py b/src/idpyoidc/server/oidc/userinfo.py index 31bbce7d..0beed342 100755 --- a/src/idpyoidc/server/oidc/userinfo.py +++ b/src/idpyoidc/server/oidc/userinfo.py @@ -180,7 +180,7 @@ def parse_request(self, request, http_info=None, **kwargs): # Verify that the client is allowed to do this try: auth_info = self.client_authentication(request, http_info, **kwargs) - except ClientAuthenticationError as e: + except ClientAuthenticationError: return self.error_cls(error="invalid_token", error_description="Invalid token") if isinstance(auth_info, ResponseMessage): diff --git a/src/idpyoidc/server/scopes.py b/src/idpyoidc/server/scopes.py index 2e2bed27..2f5bf76e 100644 --- a/src/idpyoidc/server/scopes.py +++ b/src/idpyoidc/server/scopes.py @@ -1,5 +1,3 @@ -from idpyoidc.server.exception import ConfigurationError - # default set can be changed by configuration SCOPE2CLAIMS = { diff --git a/src/idpyoidc/server/session/grant.py b/src/idpyoidc/server/session/grant.py index 3d9060ef..d09543c6 100644 --- a/src/idpyoidc/server/session/grant.py +++ b/src/idpyoidc/server/session/grant.py @@ -181,7 +181,7 @@ def add_acr_value(self, claims_release_point): def payload_arguments( self, session_id: str, - context: 'EndpointContext', + context: "EndpointContext", item: SessionToken, claims_release_point: str, scope: Optional[dict] = None, @@ -599,4 +599,4 @@ def payload_arguments( elif self.add_acr_value(secondary_identifier): payload["acr"] = self.authentication_event["authn_info"] - return payload \ No newline at end of file + return payload diff --git a/src/idpyoidc/server/session/manager.py b/src/idpyoidc/server/session/manager.py index 2e2338ca..6a33f8ac 100644 --- a/src/idpyoidc/server/session/manager.py +++ b/src/idpyoidc/server/session/manager.py @@ -7,7 +7,6 @@ import uuid from idpyoidc.encrypter import default_crypt_config -from idpyoidc.encrypter import get_crypt_config from idpyoidc.message.oauth2 import AuthorizationRequest from idpyoidc.message.oauth2 import TokenExchangeRequest from idpyoidc.server.authn_event import AuthnEvent @@ -195,14 +194,6 @@ def create_grant( if "resource" in auth_req: resources = auth_req["resource"] - if self.node_type[0] == "user": - kwargs = { - "sub": self.sub_func[sub_type]( - user_id, salt=self.get_salt(), sector_identifier=sector_identifier) - } - else: - kwargs = {} - return self.add_grant( path=self.make_path(user_id=user_id, client_id=client_id), token_usage_rules=token_usage_rules, @@ -216,7 +207,7 @@ def create_grant( claims=_claims, remember_token=self.remember_token, remove_inactive_token=self.remove_inactive_token, - resources=resources + resources=resources, ) def create_exchange_grant( @@ -499,7 +490,7 @@ def get_session_info_by_token( authorization_request: Optional[bool] = False, handler_key: Optional[str] = "", ) -> dict: - + if handler_key: _token_info = self.token_handler.handler[handler_key].info(token_value) else: diff --git a/src/idpyoidc/server/token/jwt_token.py b/src/idpyoidc/server/token/jwt_token.py index 6cb12d7a..ec125921 100644 --- a/src/idpyoidc/server/token/jwt_token.py +++ b/src/idpyoidc/server/token/jwt_token.py @@ -7,8 +7,8 @@ from cryptojwt.utils import importer from idpyoidc.server.exception import ToOld -from . import is_expired from . import Token +from . import is_expired from .exception import UnknownToken from .exception import WrongTokenClass from ..constant import DEFAULT_TOKEN_LIFETIME @@ -19,18 +19,18 @@ class JWTToken(Token): def __init__( - self, - token_class, - # keyjar: KeyJar = None, - issuer: str = None, - aud: Optional[list] = None, - alg: str = "ES256", - lifetime: int = DEFAULT_TOKEN_LIFETIME, - upstream_get: Callable = None, - token_type: str = "Bearer", - profile: Optional[Union[Message, str]] = JWTAccessToken, - with_jti: Optional[bool] = False, - **kwargs + self, + token_class, + # keyjar: KeyJar = None, + issuer: str = None, + aud: Optional[list] = None, + alg: str = "ES256", + lifetime: int = DEFAULT_TOKEN_LIFETIME, + upstream_get: Callable = None, + token_type: str = "Bearer", + profile: Optional[Union[Message, str]] = JWTAccessToken, + with_jti: Optional[bool] = False, + **kwargs ): Token.__init__(self, token_class, **kwargs) self.token_type = token_type @@ -85,13 +85,12 @@ def __call__( payload = self.load_custom_claims(payload) # payload.update(kwargs) - _context = self.upstream_get("context") if usage_rules and "expires_in" in usage_rules: lifetime = usage_rules.get("expires_in") else: lifetime = self.lifetime signer = JWT( - key_jar=self.upstream_get('attribute','keyjar'), + key_jar=self.upstream_get('attribute', 'keyjar'), iss=self.issuer, lifetime=lifetime, sign_alg=self.alg, @@ -112,7 +111,7 @@ def __call__( return signer.pack(payload) def get_payload(self, token): - verifier = JWT(key_jar=self.upstream_get('attribute','keyjar'), + verifier = JWT(key_jar=self.upstream_get('attribute', 'keyjar'), allowed_sign_algs=[self.alg]) try: _payload = verifier.unpack(token) diff --git a/src/idpyoidc/server/user_authn/user.py b/src/idpyoidc/server/user_authn/user.py index b1dd008c..c5a2fd47 100755 --- a/src/idpyoidc/server/user_authn/user.py +++ b/src/idpyoidc/server/user_authn/user.py @@ -153,13 +153,13 @@ class UserPassJinja2(UserAuthnMethod): url_endpoint = "/verify/user_pass_jinja" def __init__( - self, - db, - template_handler, - template="user_pass.jinja2", + self, + db, + template_handler, + template="user_pass.jinja2", upstream_get=None, - verify_endpoint="", - **kwargs, + verify_endpoint="", + **kwargs, ): super(UserPassJinja2, self).__init__(upstream_get=upstream_get) diff --git a/src/idpyoidc/time_util.py b/src/idpyoidc/time_util.py index 3ff0838b..faa48655 100644 --- a/src/idpyoidc/time_util.py +++ b/src/idpyoidc/time_util.py @@ -27,7 +27,7 @@ from datetime import timezone TIME_FORMAT = "%Y-%m-%dT%H:%M:%SZ" -TIME_FORMAT_WITH_FRAGMENT = re.compile("^(\d{4,4}-\d{2,2}-\d{2,2}T\d{2,2}:\d{2,2}:\d{2,2})\.\d*Z$") +qTIME_FORMAT_WITH_FRAGMENT = re.compile("^(\d{4,4}-\d{2,2}-\d{2,2}T\d{2,2}:\d{2,2}:\d{2,2})\.\d*Z$") logger = logging.getLogger(__name__) From 916c0c5d6a9307ccbb49d1455ad0d1e10bbd6fda Mon Sep 17 00:00:00 2001 From: Roland Hedberg Date: Thu, 9 Feb 2023 09:35:55 +0100 Subject: [PATCH 54/76] Fixed Flake8 complains --- src/idpyoidc/server/endpoint.py | 2 +- src/idpyoidc/server/session/grant.py | 2 +- src/idpyoidc/time_util.py | 2 ++ 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/idpyoidc/server/endpoint.py b/src/idpyoidc/server/endpoint.py index 1514a344..79dd4d6d 100755 --- a/src/idpyoidc/server/endpoint.py +++ b/src/idpyoidc/server/endpoint.py @@ -139,7 +139,7 @@ def set_client_authn_methods(self, **kwargs): self.client_authn_method = ["none"] # Ignore default value elif self.default_capabilities: self.client_authn_method = self.default_capabilities.get("client_authn_method") - self.endpoint_info = construct_provider_info(self.default_capabilities, **kwargs) + # self.endpoint_info = construct_provider_info(self.default_capabilities, **kwargs) return kwargs def process_verify_error(self, exception): diff --git a/src/idpyoidc/server/session/grant.py b/src/idpyoidc/server/session/grant.py index d09543c6..6f193adb 100644 --- a/src/idpyoidc/server/session/grant.py +++ b/src/idpyoidc/server/session/grant.py @@ -181,7 +181,7 @@ def add_acr_value(self, claims_release_point): def payload_arguments( self, session_id: str, - context: "EndpointContext", + context: object, item: SessionToken, claims_release_point: str, scope: Optional[dict] = None, diff --git a/src/idpyoidc/time_util.py b/src/idpyoidc/time_util.py index faa48655..0837e367 100644 --- a/src/idpyoidc/time_util.py +++ b/src/idpyoidc/time_util.py @@ -26,6 +26,8 @@ from datetime import timedelta from datetime import timezone +from idpyoidc.time_util import TIME_FORMAT_WITH_FRAGMENT + TIME_FORMAT = "%Y-%m-%dT%H:%M:%SZ" qTIME_FORMAT_WITH_FRAGMENT = re.compile("^(\d{4,4}-\d{2,2}-\d{2,2}T\d{2,2}:\d{2,2}:\d{2,2})\.\d*Z$") From efb63cce142d81d8a0ad2f53ac568b504c3d2cfa Mon Sep 17 00:00:00 2001 From: Roland Hedberg Date: Thu, 9 Feb 2023 13:28:32 +0100 Subject: [PATCH 55/76] Fixed tests --- example/flask_rp/views.py | 10 ++-- src/idpyoidc/client/service_context.py | 2 +- src/idpyoidc/server/client_authn.py | 2 +- src/idpyoidc/server/endpoint.py | 2 - src/idpyoidc/server/oauth2/authorization.py | 11 ++-- src/idpyoidc/server/oauth2/token.py | 1 + src/idpyoidc/server/oauth2/token_helper.py | 4 +- src/idpyoidc/time_util.py | 4 +- tests/private/token_jwks.json | 2 +- tests/test_client_01_service_context.py | 3 +- .../test_client_14_service_context_impexp.py | 8 +-- tests/test_client_23_pkce.py | 8 +-- tests/test_client_28_rp_handler_oidc.py | 1 - tests/test_server_10_session_manager.py | 6 +-- tests/test_server_17_client_authn.py | 11 ++-- tests/test_server_20d_client_authn.py | 51 ++----------------- tests/test_server_20f_userinfo.py | 6 +-- ...st_server_24_oauth2_resource_indicators.py | 34 ++++++------- tests/test_server_24_oauth2_token_endpoint.py | 19 ++++--- .../test_server_26_oidc_userinfo_endpoint.py | 16 +++--- tests/test_server_36_oauth2_token_exchange.py | 48 ++++++++--------- 21 files changed, 102 insertions(+), 147 deletions(-) diff --git a/example/flask_rp/views.py b/example/flask_rp/views.py index e3f7b64f..2cb47934 100644 --- a/example/flask_rp/views.py +++ b/example/flask_rp/views.py @@ -100,7 +100,7 @@ def finalize(op_identifier, request_args): logger.error(rp.response[0].decode()) return rp.response[0], rp.status_code - _context = rp.client_get("context") + _context = rp.get_context() session['client_id'] = _context.get('client_id') session['state'] = request_args.get('state') @@ -123,7 +123,7 @@ def finalize(op_identifier, request_args): raise excp if 'userinfo' in res: - _context = rp.client_get("context") + _context = rp.get_context() endpoints = {} for k, v in _context.provider_info.items(): if k.endswith('_endpoint'): @@ -197,7 +197,7 @@ def session_iframe(): # session management logger.debug('session_iframe request_args: {}'.format(request.args)) _rp = get_rp(session['op_identifier']) - _context = _rp.client_get("context") + _context = _rp.get_context() session_change_url = "{}/session_change".format(_context.base_url) _issuer = current_app.rph.hash2issuer[session['op_identifier']] @@ -237,7 +237,7 @@ def session_change(): def session_logout(op_identifier): _rp = get_rp(op_identifier) logger.debug('post_logout') - return "Post logout from {}".format(_rp.client_get("context").issuer) + return "Post logout from {}".format(_rp.get_context().issuer) # RP initiated logout @@ -267,7 +267,7 @@ def frontchannel_logout(op_identifier): _rp = get_rp(op_identifier) sid = request.args['sid'] _iss = request.args['iss'] - if _iss != _rp.client_get("context").get('issuer'): + if _iss != _rp.get_context().get('issuer'): return 'Bad request', 400 _state = _rp.session_interface.get_state_by_sid(sid) _rp.session_interface.remove_state(_state) diff --git a/src/idpyoidc/client/service_context.py b/src/idpyoidc/client/service_context.py index b49fcd69..e0ea87d0 100644 --- a/src/idpyoidc/client/service_context.py +++ b/src/idpyoidc/client/service_context.py @@ -139,7 +139,7 @@ def __init__(self, self.kid = {"sig": {}, "enc": {}} self.allow = config.conf.get('allow', {}) - self.base_url = base_url or config.get("base_url", "") + self.base_url = base_url or config.conf.get("base_url", "") self.provider_info = config.conf.get("provider_info", {}) # Below so my IDE won't complain diff --git a/src/idpyoidc/server/client_authn.py b/src/idpyoidc/server/client_authn.py index ec671beb..a414a1bc 100755 --- a/src/idpyoidc/server/client_authn.py +++ b/src/idpyoidc/server/client_authn.py @@ -474,7 +474,7 @@ def verify_client( auth_info = {} _context = endpoint.upstream_get('context') - methods = _context.client_authn_method + methods = _context.client_authn_methods client_id = None allowed_methods = getattr(endpoint, "client_authn_method") if not allowed_methods: diff --git a/src/idpyoidc/server/endpoint.py b/src/idpyoidc/server/endpoint.py index 79dd4d6d..e1c5f5bb 100755 --- a/src/idpyoidc/server/endpoint.py +++ b/src/idpyoidc/server/endpoint.py @@ -137,8 +137,6 @@ def set_client_authn_methods(self, **kwargs): kwargs[self.auth_method_attribute] = _methods elif _methods is not None: # [] or '' or something not None but regarded as nothing. self.client_authn_method = ["none"] # Ignore default value - elif self.default_capabilities: - self.client_authn_method = self.default_capabilities.get("client_authn_method") # self.endpoint_info = construct_provider_info(self.default_capabilities, **kwargs) return kwargs diff --git a/src/idpyoidc/server/oauth2/authorization.py b/src/idpyoidc/server/oauth2/authorization.py index 74fdf213..df2ca321 100755 --- a/src/idpyoidc/server/oauth2/authorization.py +++ b/src/idpyoidc/server/oauth2/authorization.py @@ -14,7 +14,7 @@ from cryptojwt.utils import as_bytes from cryptojwt.utils import b64e -from idpyoidc import work_environment +from idpyoidc import claims from idpyoidc.exception import ImproperlyConfigured from idpyoidc.exception import ParameterError from idpyoidc.exception import URIError @@ -344,9 +344,9 @@ class Authorization(Endpoint): "request_uri_parameter_supported": True, "response_types_supported": ["code", "token", "code token"], "response_modes_supported": ["query", "fragment", "form_post"], - "request_object_signing_alg_values_supported": work_environment.get_signing_algs, - "request_object_encryption_alg_values_supported": work_environment.get_encryption_algs, - "request_object_encryption_enc_values_supported": work_environment.get_encryption_encs, + "request_object_signing_alg_values_supported": claims.get_signing_algs, + "request_object_encryption_alg_values_supported": claims.get_encryption_algs, + "request_object_encryption_enc_values_supported": claims.get_encryption_encs, "grant_types_supported": ["authorization_code", "implicit"], "scopes_supported": [], } @@ -359,6 +359,7 @@ def __init__(self, upstream_get, **kwargs): self.post_parse_request.append(self._do_request_uri) self.post_parse_request.append(self._post_parse_request) self.allowed_request_algorithms = AllowedAlgorithms(ALG_PARAMS) + self.resource_indicators_config = kwargs.get('resource_indicators', None) def filter_request(self, context, req): return req @@ -531,7 +532,7 @@ def _post_parse_request(self, request, client_id, context, **kwargs): return request def _enforce_resource_indicators_policy(self, request, config): - _context = self.server_get("endpoint_context") + _context = self.upstream_get("context") policy = config["policy"] callable = policy["callable"] diff --git a/src/idpyoidc/server/oauth2/token.py b/src/idpyoidc/server/oauth2/token.py index 20034b84..e20dc4ac 100755 --- a/src/idpyoidc/server/oauth2/token.py +++ b/src/idpyoidc/server/oauth2/token.py @@ -44,6 +44,7 @@ def __init__(self, upstream_get, new_refresh_token=False, **kwargs): self.grant_types_supported = kwargs.get("grant_types_supported", list(self.helper_by_grant_type.keys())) self.revoke_refresh_on_issue = kwargs.get("revoke_refresh_on_issue", False) + self.resource_indicators_config = kwargs.get('resource_indicators', None) def configure_grant_types(self, grant_types_helpers): if grant_types_helpers is None: diff --git a/src/idpyoidc/server/oauth2/token_helper.py b/src/idpyoidc/server/oauth2/token_helper.py index 8836697b..a7edbcd3 100755 --- a/src/idpyoidc/server/oauth2/token_helper.py +++ b/src/idpyoidc/server/oauth2/token_helper.py @@ -182,7 +182,7 @@ def process_request(self, req: Union[Message, dict], **kwargs): logger.warning("Client using token it was not given") return self.error_cls(error="invalid_grant", error_description="Wrong client") - _cinfo = self.endpoint.server_get("endpoint_context").cdb.get(client_id) + _cinfo = self.endpoint.upstream_get("context").cdb.get(client_id) if ("resource_indicators" in _cinfo and "access_token" in _cinfo["resource_indicators"]): @@ -283,7 +283,7 @@ def process_request(self, req: Union[Message, dict], **kwargs): return _response def _enforce_resource_indicators_policy(self, request, config): - _context = self.endpoint.server_get("endpoint_context") + _context = self.endpoint.upstream_get('context') policy = config["policy"] callable = policy["callable"] diff --git a/src/idpyoidc/time_util.py b/src/idpyoidc/time_util.py index 0837e367..3ff0838b 100644 --- a/src/idpyoidc/time_util.py +++ b/src/idpyoidc/time_util.py @@ -26,10 +26,8 @@ from datetime import timedelta from datetime import timezone -from idpyoidc.time_util import TIME_FORMAT_WITH_FRAGMENT - TIME_FORMAT = "%Y-%m-%dT%H:%M:%SZ" -qTIME_FORMAT_WITH_FRAGMENT = re.compile("^(\d{4,4}-\d{2,2}-\d{2,2}T\d{2,2}:\d{2,2}:\d{2,2})\.\d*Z$") +TIME_FORMAT_WITH_FRAGMENT = re.compile("^(\d{4,4}-\d{2,2}-\d{2,2}T\d{2,2}:\d{2,2}:\d{2,2})\.\d*Z$") logger = logging.getLogger(__name__) diff --git a/tests/private/token_jwks.json b/tests/private/token_jwks.json index 5375bc5f..f8537bbb 100644 --- a/tests/private/token_jwks.json +++ b/tests/private/token_jwks.json @@ -1 +1 @@ -{"keys": [{"kty": "oct", "use": "enc", "kid": "code", "k": "vSHDkLBHhDStkR0NWu8519rmV5zmnm5_"}, {"kty": "oct", "use": "enc", "kid": "refresh", "k": "XKrr1hBNC6l5na2jxwVbksUmtGzcRrJF"}]} \ No newline at end of file +{"keys": [{"kty": "oct", "use": "enc", "kid": "code", "k": "vSHDkLBHhDStkR0NWu8519rmV5zmnm5_"}, {"kty": "oct", "use": "enc", "kid": "refresh", "k": "Cy6KTXiLPPvNj33-4kuAsk2diPuoZkCC"}]} \ No newline at end of file diff --git a/tests/test_client_01_service_context.py b/tests/test_client_01_service_context.py index a4637a42..0143be2f 100644 --- a/tests/test_client_01_service_context.py +++ b/tests/test_client_01_service_context.py @@ -26,7 +26,8 @@ class TestServiceContext: @pytest.fixture(autouse=True) def setup(self): self.unit = Unit() - self.service_context = ServiceContext(config=MINI_CONFIG, upstream_get=self.unit.unit_get) + self.service_context = ServiceContext(config=MINI_CONFIG, upstream_get=self.unit.unit_get, + base_url="https://example.com/cli") def test_init(self): assert self.service_context diff --git a/tests/test_client_14_service_context_impexp.py b/tests/test_client_14_service_context_impexp.py index 27e3f1ad..f0ec76e3 100644 --- a/tests/test_client_14_service_context_impexp.py +++ b/tests/test_client_14_service_context_impexp.py @@ -19,12 +19,12 @@ def test_client_info_init(): "base_url": BASE_URL, "requests_dir": "requests", } - ci = ServiceContext(config=config, client_type='oidc') + ci = ServiceContext(config=config, client_type='oidc', base_url=BASE_URL) ci.claims.load_conf(config, supports=ci.supports()) ci.map_supported_to_preferred() ci.map_preferred_to_registered() - srvcnx = ServiceContext(base_url=BASE_URL).load(ci.dump()) + srvcnx = ServiceContext().load(ci.dump()) for attr in config.keys(): if attr == "client_id": @@ -62,8 +62,8 @@ def test_client_filename(): "base_url": "https://example.com", "requests_dir": "requests", } - service_context = ServiceContext(config=config) - srvcnx2 = ServiceContext(base_url=BASE_URL).load(service_context.dump()) + service_context = ServiceContext(config=config, base_url=BASE_URL) + srvcnx2 = ServiceContext().load(service_context.dump()) fname = srvcnx2.filename_from_webname("https://example.com/rq12345") assert fname == "rq12345" diff --git a/tests/test_client_23_pkce.py b/tests/test_client_23_pkce.py index 55c189d1..e7882822 100644 --- a/tests/test_client_23_pkce.py +++ b/tests/test_client_23_pkce.py @@ -85,8 +85,8 @@ def test_add_code_challenge_default_values(self): assert len(request_args["code_verifier"]) == 64 def test_authorization_and_pkce(self): - auth_serv = self.entity.client_get("service", "authorization") - _state = self.entity.get_context().state.create_state(iss="Issuer") + auth_serv = self.entity.get_service("authorization") + _state = self.entity.get_context().cstate.create_state(iss="Issuer") request = auth_serv.construct_request({"state": _state, "response_type": "code"}) assert set(request.keys()) == { @@ -105,8 +105,8 @@ def test_access_token_and_pkce(self): auth_response = AuthorizationResponse(code="access code") _context = self.entity.get_context() _context.cstate.update(_state, auth_response) - auth_serv = self.entity.get_service("authorization") - _state = _context.cstate.create_state(iss="Issuer") + #auth_serv = self.entity.get_service("authorization") + #_state = _context.cstate.create_state(iss="Issuer") token_service = self.entity.get_service("accesstoken") request = token_service.construct_request(state=_state) diff --git a/tests/test_client_28_rp_handler_oidc.py b/tests/test_client_28_rp_handler_oidc.py index e480387c..f208d211 100644 --- a/tests/test_client_28_rp_handler_oidc.py +++ b/tests/test_client_28_rp_handler_oidc.py @@ -325,7 +325,6 @@ def test_create_callbacks(self): client = self.rph.init_client("https://op.example.com/") _srv = client.get_service("registration") _context = _srv.upstream_get("context") - cb = _context.get_preference('callback_uris') assert set(cb.keys()) == {"request_uris", "redirect_uris"} diff --git a/tests/test_server_10_session_manager.py b/tests/test_server_10_session_manager.py index 5dca75a7..1518e7bb 100644 --- a/tests/test_server_10_session_manager.py +++ b/tests/test_server_10_session_manager.py @@ -92,7 +92,7 @@ def create_session_manager(self): } server = Server(conf) self.server = server - self.endpoint_context = server.endpoint_context + self.endpoint_context = server.context self.endpoint_context.cdb = { "client_1": { "client_secret": "hemligt", @@ -111,7 +111,7 @@ def create_session_manager(self): } } - self.session_manager = server.endpoint_context.session_manager + self.session_manager = server.context.session_manager self.authn_event = AuthnEvent( uid="uid", valid_until=utc_time_sans_frac() + 1, authn_info="authn_class_ref" ) @@ -128,7 +128,7 @@ def _create_session(self, auth_req, sub_type="public", sector_identifier=""): client_id = authz_req["client_id"] ae = create_authn_event(USER_ID) - return self.server.endpoint_context.session_manager.create_session( + return self.server.context.session_manager.create_session( ae, authz_req, USER_ID, client_id=client_id, sub_type=sub_type ) diff --git a/tests/test_server_17_client_authn.py b/tests/test_server_17_client_authn.py index 135ffee1..0fe2d533 100644 --- a/tests/test_server_17_client_authn.py +++ b/tests/test_server_17_client_authn.py @@ -502,12 +502,11 @@ def test_verify_per_client_per_endpoint(self): ) assert res == {"method": "public", "client_id": client_id} - with pytest.raises(ClientAuthenticationError) as e: - verify_client( - request=request, - endpoint=self.server.get_endpoint("endpoint_1"), - ) - assert e.value.args[0] == "Failed to verify client" + res = verify_client( + request=request, + endpoint=self.server.get_endpoint("endpoint_1"), + ) + assert res == {} request = {"client_id": client_id, "client_secret": client_secret} res = verify_client( diff --git a/tests/test_server_20d_client_authn.py b/tests/test_server_20d_client_authn.py index 9af23ca0..0c392c5b 100755 --- a/tests/test_server_20d_client_authn.py +++ b/tests/test_server_20d_client_authn.py @@ -1,13 +1,13 @@ import base64 from unittest.mock import MagicMock -import pytest from cryptojwt.jws.exception import NoSuitableSigningKeys from cryptojwt.jwt import JWT from cryptojwt.key_jar import KeyJar from cryptojwt.key_jar import build_keyjar from cryptojwt.utils import as_bytes from cryptojwt.utils import as_unicode +import pytest from idpyoidc.defaults import JWT_BEARER from idpyoidc.server import Server @@ -292,7 +292,8 @@ def create_method(self): def test_bearer_body(self): request = {"access_token": "1234567890"} - assert self.method.verify(request, get_client_id_from_token=get_client_id_from_token) == {"token": "1234567890", "method": "bearer_body"} + assert self.method.verify(request, get_client_id_from_token=get_client_id_from_token) == { + "token": "1234567890", "method": "bearer_body"} def test_bearer_body_no_token(self): request = {} @@ -435,8 +436,6 @@ def test_verify_per_client(self): request = {"client_id": client_id} res = verify_client( - self.endpoint_context, - keyjar=self.server.get_attribute('keyjar'), request=request, endpoint=self.server.get_endpoint("registration"), ) @@ -452,15 +451,12 @@ def test_verify_per_client_per_endpoint(self): request = {"client_id": client_id} res = verify_client( - self.endpoint_context, - keyjar=self.server.get_attribute('keyjar'), request=request, endpoint=self.server.get_endpoint("registration"), ) assert res == {"method": "public", "client_id": client_id} res = verify_client( - self.endpoint_context, request, endpoint=self.server.get_endpoint("token"), ) @@ -468,8 +464,6 @@ def test_verify_per_client_per_endpoint(self): request = {"client_id": client_id, "client_secret": client_secret} res = verify_client( - self.endpoint_context, - keyjar=self.server.get_attribute('keyjar'), request=request, endpoint=self.server.get_endpoint("token"), ) @@ -479,8 +473,6 @@ def test_verify_per_client_per_endpoint(self): def test_verify_client_client_secret_post(self): request = {"client_id": client_id, "client_secret": client_secret} res = verify_client( - self.endpoint_context, - keyjar=self.server.get_attribute('keyjar'), request=request, endpoint=self.server.get_endpoint("token"), ) @@ -501,8 +493,6 @@ def test_verify_client_jws_authn_method(self): request = {"client_assertion": _assertion, "client_assertion_type": JWT_BEARER} http_info = {"headers": {}} res = verify_client( - self.endpoint_context, - keyjar=self.server.get_attribute('keyjar'), request=request, http_info=http_info, endpoint=self.server.get_endpoint("token"), @@ -514,8 +504,6 @@ def test_verify_client_bearer_body(self): request = {"access_token": "1234567890", "client_id": client_id} self.context.registration_access_token["1234567890"] = client_id res = verify_client( - self.endpoint_context, - keyjar=self.server.get_attribute('keyjar'), request=request, get_client_id_from_token=get_client_id_from_token, endpoint=self.server.get_endpoint("userinfo"), @@ -526,8 +514,6 @@ def test_verify_client_bearer_body(self): def test_verify_client_client_secret_post(self): request = {"client_id": client_id, "client_secret": client_secret} res = verify_client( - self.endpoint_context, - keyjar=self.server.get_attribute('keyjar'), request=request, endpoint=self.server.get_endpoint("token"), ) @@ -541,8 +527,6 @@ def test_verify_client_client_secret_basic(self): http_info = {"headers": {"authorization": authz_token}} res = verify_client( - self.endpoint_context, - keyjar=self.server.get_attribute('keyjar'), request={}, http_info=http_info, endpoint=self.server.get_endpoint("token"), @@ -558,8 +542,6 @@ def test_verify_client_bearer_header(self): http_info = {"headers": {"authorization": token}} request = {"client_id": client_id} res = verify_client( - self.endpoint_context, - keyjar=self.server.get_attribute('keyjar'), request=request, http_info=http_info, get_client_id_from_token=get_client_id_from_token, @@ -590,8 +572,6 @@ def test_verify_client_jws_authn_method(self): request = {"client_assertion": _assertion, "client_assertion_type": JWT_BEARER} res = verify_client( - self.endpoint_context, - keyjar=self.server.get_attribute('keyjar'), request=request, endpoint=self.server.get_endpoint("token"), ) @@ -602,8 +582,6 @@ def test_verify_client_bearer_body(self): request = {"access_token": "1234567890", "client_id": client_id} self.context.registration_access_token["1234567890"] = client_id res = verify_client( - self.endpoint_context, - keyjar=self.server.get_attribute('keyjar'), request=request, get_client_id_from_token=get_client_id_from_token, endpoint=self.server.get_endpoint("userinfo"), @@ -611,17 +589,6 @@ def test_verify_client_bearer_body(self): assert set(res.keys()) == {"token", "method", "client_id"} assert res["method"] == "bearer_body" - def test_verify_client_client_secret_post(self): - request = {"client_id": client_id, "client_secret": client_secret} - res = verify_client( - self.endpoint_context, - keyjar=self.server.get_attribute('keyjar'), - request=request, - endpoint=self.server.get_endpoint("token"), - ) - assert set(res.keys()) == {"method", "client_id"} - assert res["method"] == "client_secret_post" - def test_verify_client_client_secret_basic(self): _token = "{}:{}".format(client_id, client_secret) token = as_unicode(base64.b64encode(as_bytes(_token))) @@ -629,8 +596,6 @@ def test_verify_client_client_secret_basic(self): http_info = {"headers": {"authorization": authz_token}} res = verify_client( - self.endpoint_context, - keyjar=self.server.get_attribute('keyjar'), request={}, http_info=http_info, endpoint=self.server.get_endpoint("token"), @@ -646,8 +611,6 @@ def test_verify_client_bearer_header(self): http_info = {"headers": {"authorization": token}} request = {"client_id": client_id} res = verify_client( - self.endpoint_context, - keyjar=self.server.get_attribute('keyjar'), request=request, http_info=http_info, get_client_id_from_token=get_client_id_from_token, @@ -660,8 +623,6 @@ def test_verify_client_authorization_none(self): # This is when it's explicitly said that no client auth method is allowed request = {"client_id": client_id} res = verify_client( - self.endpoint_context, - keyjar=self.server.get_attribute('keyjar'), request=request, endpoint=self.server.get_endpoint("authorization"), ) @@ -672,8 +633,6 @@ def test_verify_client_registration_public(self): # This is when no special auth method is configured request = {"redirect_uris": ["https://example.com/cb"], "client_id": "client_id"} res = verify_client( - self.endpoint_context, - keyjar=self.server.get_attribute('keyjar'), request=request, endpoint=self.server.get_endpoint("registration"), ) @@ -683,8 +642,6 @@ def test_verify_client_registration_none(self): # This is when no special auth method is configured request = {"redirect_uris": ["https://example.com/cb"]} res = verify_client( - self.endpoint_context, - keyjar=self.server.get_attribute('keyjar'), request=request, endpoint=self.server.get_endpoint("registration"), ) @@ -706,8 +663,6 @@ class Mock: request = {"redirect_uris": ["https://example.com/cb"]} res = verify_client( - server.endpoint_context, - keyjar=server.get_attribute('keyjar'), request=request, endpoint=server.get_endpoint("registration") ) diff --git a/tests/test_server_20f_userinfo.py b/tests/test_server_20f_userinfo.py index 161be87c..8a5a8a64 100644 --- a/tests/test_server_20f_userinfo.py +++ b/tests/test_server_20f_userinfo.py @@ -192,7 +192,7 @@ def create_endpoint_context(self): } server = Server(OPConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) - self.endpoint_context = server.endpoint_context + self.endpoint_context = server.context # Just has to be there self.endpoint_context.cdb["client1"] = { "add_claims": { @@ -422,7 +422,7 @@ def conf(self): @pytest.fixture(autouse=True) def create_endpoint_context(self, conf): self.server = Server(conf) - self.endpoint_context = self.server.endpoint_context + self.endpoint_context = self.server.context self.endpoint_context.cdb["client1"] = { "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access", "research_and_scholarship"] } @@ -476,7 +476,7 @@ def test_collect_user_info_custom_scope(self): def test_collect_user_info_scope_mapping_per_client(self, conf): conf["scopes_to_claims"] = SCOPE2CLAIMS server = Server(conf) - endpoint_context = server.endpoint_context + endpoint_context = server.context self.session_manager = endpoint_context.session_manager claims_interface = endpoint_context.claims_interface endpoint_context.cdb["client1"] = { diff --git a/tests/test_server_24_oauth2_resource_indicators.py b/tests/test_server_24_oauth2_resource_indicators.py index 3993e2d8..57848cbc 100644 --- a/tests/test_server_24_oauth2_resource_indicators.py +++ b/tests/test_server_24_oauth2_resource_indicators.py @@ -416,21 +416,21 @@ def create_endpoint_ri_disabled(self): conf = RESOURCE_INDICATORS_DISABLED server = Server(ASConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) - endpoint_context = server.endpoint_context + endpoint_context = server.context _clients = yaml.safe_load(io.StringIO(client_yaml)) endpoint_context.cdb = _clients["clients"] endpoint_context.keyjar.import_jwks( endpoint_context.keyjar.export_jwks(True, ""), conf["issuer"] ) self.endpoint_context = endpoint_context - self.endpoint = server.server_get("endpoint", "authorization") - self.token_endpoint = server.server_get("endpoint", "token") + self.endpoint = server.get_endpoint("authorization") + self.token_endpoint = server.get_endpoint("token") self.session_manager = endpoint_context.session_manager self.user_id = "diana" self.rp_keyjar = KeyJar() self.rp_keyjar.add_symmetric("client_1", "hemligtkodord1234567890") - self.endpoint.server_get("endpoint_context").keyjar.add_symmetric( + self.endpoint.upstream_get("endpoint_context").keyjar.add_symmetric( "client_1", "hemligtkodord1234567890" ) @@ -439,21 +439,21 @@ def create_endpoint_ri_enabled(self): conf = RESOURCE_INDICATORS_ENABLED server = Server(ASConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) - endpoint_context = server.endpoint_context + endpoint_context = server.context _clients = yaml.safe_load(io.StringIO(client_yaml)) endpoint_context.cdb = _clients["clients"] endpoint_context.keyjar.import_jwks( endpoint_context.keyjar.export_jwks(True, ""), conf["issuer"] ) self.endpoint_context = endpoint_context - self.endpoint = server.server_get("endpoint", "authorization") - self.token_endpoint = server.server_get("endpoint", "token") + self.endpoint = server.get_endpoint("authorization") + self.token_endpoint = server.get_endpoint("token") self.session_manager = endpoint_context.session_manager self.user_id = "diana" self.rp_keyjar = KeyJar() self.rp_keyjar.add_symmetric("client_1", "hemligtkodord1234567890") - self.endpoint.server_get("endpoint_context").keyjar.add_symmetric( + self.endpoint.upstream_get("context").keyjar.add_symmetric( "client_1", "hemligtkodord1234567890" ) @@ -478,7 +478,7 @@ def _mint_code(self, grant, client_id): # Constructing an authorization code is now done _code = grant.mint_token( session_id=session_id, - endpoint_context=self.endpoint_context, + context=self.endpoint_context, token_class="authorization_code", token_handler=self.session_manager.token_handler["authorization_code"], usage_rules=usage_rules, @@ -505,7 +505,7 @@ def test_authorization_code_req_no_resource(self, create_endpoint_ri_enabled): Test that appropriate error message is returned when resource indicators is enabled for the authorization endpoint and resource parameter is missing from request. """ - endpoint_context = self.endpoint.server_get("endpoint_context") + endpoint_context = self.endpoint.upstream_get("context") msg = self.endpoint._post_parse_request({}, "client_1", endpoint_context) assert "error" in msg @@ -525,7 +525,7 @@ def test_authorization_code_req_no_resource_indicators_disabled(self, create_end """ Test successful authorization request when resource indicators is disabled. """ - endpoint_context = self.endpoint.server_get("endpoint_context") + endpoint_context = self.endpoint.upstream_get("context") request = AUTH_REQ.copy() del request["resource"] @@ -536,7 +536,7 @@ def test_authorization_code_req(self, create_endpoint_ri_enabled): """ Test successful authorization request when resource indicators is enabled. """ - endpoint_context = self.endpoint.server_get("endpoint_context") + endpoint_context = self.endpoint.upstream_get("context") request = AUTH_REQ.copy() msg = self.endpoint._post_parse_request(request, "client_1", endpoint_context) @@ -547,7 +547,7 @@ def test_authorization_code_req_per_client(self, create_endpoint_ri_disabled): Test that appropriate error message is returned when resource indicators is enabled per client for the authorization endpoint and requested resource is not permitted for client. """ - endpoint_context = self.endpoint.server_get("endpoint_context") + endpoint_context = self.endpoint.upstream_get("context") endpoint_context.cdb["client_1"]["resource_indicators"] = { "authorization_code": { "policy": { @@ -572,7 +572,7 @@ def test_authorization_code_req_no_resource_client(self, create_endpoint_ri_enab """ request = AUTH_REQ.copy() client_id = request["client_id"] - endpoint_context = self.endpoint.server_get("endpoint_context") + endpoint_context = self.endpoint.upstream_get("context") self.endpoint.kwargs["resource_indicators"]["policy"]["kwargs"][ "resource_servers_per_client" ] = {"client_2": ["client_1"]} @@ -591,7 +591,7 @@ def test_authorization_code_req_invalid_resource_client(self, create_endpoint_ri request = AUTH_REQ.copy() request["resource"] = "client_2" client_id = request["client_id"] - endpoint_context = self.endpoint.server_get("endpoint_context") + endpoint_context = self.endpoint.upstream_get("context") msg = self.endpoint._post_parse_request(request, client_id, endpoint_context) @@ -603,7 +603,7 @@ def test_access_token_req(self, create_endpoint_ri_enabled): """ Test successful access_token request when resource indicators is enabled. """ - self.endpoint.server_get("endpoint_context").cdb["client_3"] = { + self.endpoint.upstream_get("context").cdb["client_3"] = { "client_id": "client_3", "redirect_uris": [("https://rp.example.com/cb", {})], "id_token_signed_response_alg": "ES256", @@ -657,7 +657,7 @@ def test_create_authn_response(self, create_endpoint_ri_enabled): Test that the requested access_token has the correct scopes based on the allowed scopes of the requested resources """ - self.endpoint.server_get("endpoint_context").cdb["client_3"] = { + self.endpoint.upstream_get("context").cdb["client_3"] = { "client_id": "client_3", "redirect_uris": [("https://rp.example.com/cb", {})], "id_token_signed_response_alg": "ES256", diff --git a/tests/test_server_24_oauth2_token_endpoint.py b/tests/test_server_24_oauth2_token_endpoint.py index 78edd336..ee37e852 100644 --- a/tests/test_server_24_oauth2_token_endpoint.py +++ b/tests/test_server_24_oauth2_token_endpoint.py @@ -838,17 +838,20 @@ def test_refresh_token_request_other_client(self): CONTEXT.cdb = { "client_1": {} } -CONTEXT.keyjar = KeyJar() -CONTEXT.keyjar.import_jwks(CLIENT_KEYJAR.export_jwks(private=True), "client_1") -CONTEXT.keyjar.import_jwks(CLIENT_KEYJAR.export_jwks(private=True), "") +KEYJAR = KeyJar() +KEYJAR.import_jwks(CLIENT_KEYJAR.export_jwks(private=True), "client_1") +KEYJAR.import_jwks(CLIENT_KEYJAR.export_jwks(private=True), "") -def server_get(what, *args): - if what == "endpoint_context": +def upstream_get(what, *args): + if what == "context": if not args: return CONTEXT + elif what == 'attribute': + if args[0] == 'keyjar': + return KEYJAR def test_def_jwttoken(): - _handler = handler.factory(server_get=server_get, **DEFAULT_TOKEN_HANDLER_ARGS) + _handler = handler.factory(upstream_get=upstream_get, **DEFAULT_TOKEN_HANDLER_ARGS) token_handler = _handler['access_token'] token_payload = { 'sub': 'subject_id', @@ -864,7 +867,7 @@ def test_def_jwttoken(): assert True def test_jwttoken(): - _handler = handler.factory(server_get=server_get, **TOKEN_HANDLER_ARGS) + _handler = handler.factory(upstream_get=upstream_get, **TOKEN_HANDLER_ARGS) token_handler = _handler['access_token'] token_payload = { 'sub': 'subject_id', @@ -890,7 +893,7 @@ class MyAccessToken(Message): } def test_jwttoken_2(): - _handler = handler.factory(server_get=server_get, **TOKEN_HANDLER_ARGS) + _handler = handler.factory(upstream_get=upstream_get, **TOKEN_HANDLER_ARGS) token_handler = _handler['access_token'] token_payload = { 'sub': 'subject_id', diff --git a/tests/test_server_26_oidc_userinfo_endpoint.py b/tests/test_server_26_oidc_userinfo_endpoint.py index 76477c6b..b349da37 100755 --- a/tests/test_server_26_oidc_userinfo_endpoint.py +++ b/tests/test_server_26_oidc_userinfo_endpoint.py @@ -206,8 +206,8 @@ def create_endpoint(self): } self.server = Server(OPConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) - self.endpoint_context = self.server.endpoint_context - self.endpoint_context.cdb["client_1"] = { + self.context = self.server.context + self.context.cdb["client_1"] = { "client_secret": "hemligt", "redirect_uris": [("https://example.com/cb", None)], "client_salt": "salted", @@ -216,7 +216,7 @@ def create_endpoint(self): "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access", "research_and_scholarship"] } self.endpoint = self.server.get_endpoint("userinfo") - self.session_manager = self.endpoint_context.session_manager + self.session_manager = self.context.session_manager self.user_id = "diana" def _create_session(self, auth_req, sub_type="public", sector_identifier="", authn_info=None): @@ -391,7 +391,7 @@ def test_scopes_to_claims(self): } def test_scopes_to_claims_per_client(self): - self.endpoint_context.cdb["client_1"]["scopes_to_claims"] = { + self.context.cdb["client_1"]["scopes_to_claims"] = { **SCOPE2CLAIMS, "research_and_scholarship_2": [ "name", @@ -403,8 +403,8 @@ def test_scopes_to_claims_per_client(self): "eduperson_scoped_affiliation", ], } - self.endpoint_context.cdb["client_1"]["allowed_scopes"] = list( - self.endpoint_context.cdb["client_1"]["scopes_to_claims"].keys() + self.context.cdb["client_1"]["allowed_scopes"] = list( + self.context.cdb["client_1"]["scopes_to_claims"].keys() ) + ["aba"] _auth_req = AUTH_REQ.copy() @@ -470,7 +470,7 @@ def test_allowed_scopes(self): } def test_allowed_scopes_per_client(self): - self.endpoint_context.cdb["client_1"]["scopes_to_claims"] = { + self.context.cdb["client_1"]["scopes_to_claims"] = { **SCOPE2CLAIMS, "research_and_scholarship_2": [ "name", @@ -482,7 +482,7 @@ def test_allowed_scopes_per_client(self): "eduperson_scoped_affiliation", ], } - self.endpoint_context.cdb["client_1"]["allowed_scopes"] = list(SCOPE2CLAIMS.keys()) + self.context.cdb["client_1"]["allowed_scopes"] = list(SCOPE2CLAIMS.keys()) _auth_req = AUTH_REQ.copy() _auth_req["scope"] = ["openid", "research_and_scholarship_2"] diff --git a/tests/test_server_36_oauth2_token_exchange.py b/tests/test_server_36_oauth2_token_exchange.py index 87d84b92..c3559868 100644 --- a/tests/test_server_36_oauth2_token_exchange.py +++ b/tests/test_server_36_oauth2_token_exchange.py @@ -401,7 +401,7 @@ def test_token_exchange_scopes_per_client(self): only get it if the subject token has it in its scope set, if it is permitted by the policy and if it is present in the clients allowed scopes. """ - self.endpoint_context.cdb["client_1"]["token_exchange"] = { + self.context.cdb["client_1"]["token_exchange"] = { "subject_token_types_supported": [ "urn:ietf:params:oauth:token-type:access_token", "urn:ietf:params:oauth:token-type:refresh_token", @@ -421,13 +421,13 @@ def test_token_exchange_scopes_per_client(self): }, } - self.endpoint_context.cdb["client_1"]["allowed_scopes"] = ["openid", "email", "profile", "offline_access"] + self.context.cdb["client_1"]["allowed_scopes"] = ["openid", "email", "profile", "offline_access"] areq = AUTH_REQ.copy() areq["scope"].append("profile") session_id = self._create_session(areq) - grant = self.endpoint_context.authz(session_id, areq) + grant = self.context.authz(session_id, areq) code = self._mint_code(grant, areq["client_id"]) @@ -459,7 +459,7 @@ def test_token_exchange_unsupported_scopes_per_client(self): """ Test that unsupported clients are handled appropriatelly """ - self.endpoint_context.cdb["client_1"]["token_exchange"] = { + self.context.cdb["client_1"]["token_exchange"] = { "subject_token_types_supported": [ "urn:ietf:params:oauth:token-type:access_token", "urn:ietf:params:oauth:token-type:refresh_token", @@ -484,7 +484,7 @@ def test_token_exchange_unsupported_scopes_per_client(self): areq["scope"].append("profile") session_id = self._create_session(areq) - grant = self.endpoint_context.authz(session_id, areq) + grant = self.context.authz(session_id, areq) code = self._mint_code(grant, areq["client_id"]) _token_request = TOKEN_REQ_DICT.copy() @@ -513,7 +513,7 @@ def test_token_exchange_no_scopes_requested(self): """ Test that the correct scopes are returned when no scopes requested by the client """ - self.endpoint_context.cdb["client_1"]["token_exchange"] = { + self.context.cdb["client_1"]["token_exchange"] = { "subject_token_types_supported": [ "urn:ietf:params:oauth:token-type:access_token", "urn:ietf:params:oauth:token-type:refresh_token", @@ -538,7 +538,7 @@ def test_token_exchange_no_scopes_requested(self): areq["scope"].append("profile") session_id = self._create_session(areq) - grant = self.endpoint_context.authz(session_id, areq) + grant = self.context.authz(session_id, areq) code = self._mint_code(grant, areq["client_id"]) _token_request = TOKEN_REQ_DICT.copy() @@ -614,7 +614,7 @@ def test_token_exchange_fails_if_disabled(self): Test that token exchange fails if it's not included in Token's grant_types_supported (that are set in its helper attribute). """ - self.endpoint_context.cdb["client_1"]["grant_types_supported"] = [ + self.context.cdb["client_1"]["grant_types_supported"] = [ 'authorization_code', 'implicit', 'urn:ietf:params:oauth:grant-type:jwt-bearer', @@ -1032,7 +1032,7 @@ def test_token_exchange_unsupported_scope_requested_1(self): Client1 has an access_token1 (with offline_access, openid and profile scope). Then, client1 exchanges access_token1 for a new access_token1_13 with scope offline_access """ - self.endpoint_context.cdb["client_1"]["token_exchange"] = { + self.context.cdb["client_1"]["token_exchange"] = { "subject_token_types_supported": [ "urn:ietf:params:oauth:token-type:access_token", "urn:ietf:params:oauth:token-type:refresh_token", @@ -1056,10 +1056,10 @@ def test_token_exchange_unsupported_scope_requested_1(self): areq["scope"].append("profile") areq["scope"].append("offline_access") - self.endpoint_context.cdb["client_1"]["allowed_scopes"] = ["offline_access", "profile"] + self.context.cdb["client_1"]["allowed_scopes"] = ["offline_access", "profile"] session_id = self._create_session(areq) - grant = self.endpoint_context.authz(session_id, areq) + grant = self.context.authz(session_id, areq) code = self._mint_code(grant, areq["client_id"]) _token_request = TOKEN_REQ_DICT.copy() @@ -1120,7 +1120,7 @@ def test_token_exchange_unsupported_scope_requested_2(self): Client1 has an access_token1 (with openid and profile scope). Then, client1 exchanges access_token1 for a new access_token1_13 with scope offline_access """ - self.endpoint_context.cdb["client_1"]["token_exchange"] = { + self.context.cdb["client_1"]["token_exchange"] = { "subject_token_types_supported": [ "urn:ietf:params:oauth:token-type:access_token", "urn:ietf:params:oauth:token-type:refresh_token", @@ -1139,14 +1139,14 @@ def test_token_exchange_unsupported_scope_requested_2(self): } }, } - self.endpoint_context.cdb["client_1"]["allowed_scopes"] = ["openid", "profile"] + self.context.cdb["client_1"]["allowed_scopes"] = ["openid", "profile"] areq = AUTH_REQ.copy() areq["scope"].append("profile") areq["scope"].append("offline_access") session_id = self._create_session(areq) - grant = self.endpoint_context.authz(session_id, areq) + grant = self.context.authz(session_id, areq) code = self._mint_code(grant, areq["client_id"]) _token_request = TOKEN_REQ_DICT.copy() @@ -1211,7 +1211,7 @@ def test_token_exchange_unsupported_scope_requested_3(self): Client1 has an access_token1 (with openid and profile scope). Then, client1 exchanges access_token1 for a new access_token1_13 with scope offline_access """ - self.endpoint_context.cdb["client_1"]["token_exchange"] = { + self.context.cdb["client_1"]["token_exchange"] = { "subject_token_types_supported": [ "urn:ietf:params:oauth:token-type:access_token", "urn:ietf:params:oauth:token-type:refresh_token", @@ -1230,7 +1230,7 @@ def test_token_exchange_unsupported_scope_requested_3(self): } }, } - self.endpoint_context.cdb["client_1"]["grant_types_supported"] = [ + self.context.cdb["client_1"]["grant_types_supported"] = [ 'authorization_code', 'implicit', 'urn:ietf:params:oauth:grant-type:jwt-bearer', @@ -1242,7 +1242,7 @@ def test_token_exchange_unsupported_scope_requested_3(self): areq["scope"].append("offline_access") session_id = self._create_session(areq) - grant = self.endpoint_context.authz(session_id, areq) + grant = self.context.authz(session_id, areq) code = self._mint_code(grant, areq["client_id"]) _token_request = TOKEN_REQ_DICT.copy() @@ -1283,7 +1283,7 @@ def test_token_exchange_unsupported_scope_requested_3(self): _resp = self.endpoint.process_request(request=_req) assert _resp["response_args"]["scope"] == ["offline_access"] - _c_interface = self.introspection_endpoint.server_get("endpoint_context").claims_interface + _c_interface = self.introspection_endpoint.upstream_get("context").claims_interface grant.claims = { "introspection": _c_interface.get_claims( session_id, scopes=AUTH_REQ["scope"], claims_release_point="introspection" @@ -1293,7 +1293,7 @@ def test_token_exchange_unsupported_scope_requested_3(self): { "token": _resp["response_args"]["access_token"], "client_id": "client_1", - "client_secret": self.endpoint_context.cdb["client_1"]["client_secret"], + "client_secret": self.context.cdb["client_1"]["client_secret"], } ) _resp_intro = self.introspection_endpoint.process_request(_req) @@ -1319,7 +1319,7 @@ def test_token_exchange_unsupported_scope_requested_4(self): Client1 has an access_token1 (with openid and profile scope). Then, client1 exchanges access_token1 for a new refresh token """ - self.endpoint_context.cdb["client_1"]["token_exchange"] = { + self.context.cdb["client_1"]["token_exchange"] = { "subject_token_types_supported": [ "urn:ietf:params:oauth:token-type:access_token", "urn:ietf:params:oauth:token-type:refresh_token", @@ -1338,7 +1338,7 @@ def test_token_exchange_unsupported_scope_requested_4(self): } }, } - self.endpoint_context.cdb["client_1"]["grant_types_supported"] = [ + self.context.cdb["client_1"]["grant_types_supported"] = [ 'authorization_code', 'implicit', 'urn:ietf:params:oauth:grant-type:jwt-bearer', @@ -1350,7 +1350,7 @@ def test_token_exchange_unsupported_scope_requested_4(self): areq["scope"].append("offline_access") session_id = self._create_session(areq) - grant = self.endpoint_context.authz(session_id, areq) + grant = self.context.authz(session_id, areq) code = self._mint_code(grant, areq["client_id"]) _token_request = TOKEN_REQ_DICT.copy() @@ -1417,7 +1417,7 @@ def test_token_exchange_unsupported_scope_requested_5(self): Client1 has an access_token1 (with openid and profile scope). Then, client1 exchanges access_token1 for a new refresh token """ - self.endpoint_context.cdb["client_1"]["token_exchange"] = { + self.context.cdb["client_1"]["token_exchange"] = { "subject_token_types_supported": [ "urn:ietf:params:oauth:token-type:access_token", "urn:ietf:params:oauth:token-type:refresh_token", @@ -1442,7 +1442,7 @@ def test_token_exchange_unsupported_scope_requested_5(self): areq["scope"].append("offline_access") session_id = self._create_session(areq) - grant = self.endpoint_context.authz(session_id, areq) + grant = self.context.authz(session_id, areq) code = self._mint_code(grant, areq["client_id"]) _token_request = TOKEN_REQ_DICT.copy() From d45ab4fe7fa3b0f6522462100e8c060f671aa69a Mon Sep 17 00:00:00 2001 From: roland Date: Fri, 10 Feb 2023 10:08:30 +0100 Subject: [PATCH 56/76] Fixed some remaining flake8 complains. Added TokenExchangeHelper which had somehow managed to go missing before. --- src/idpyoidc/client/entity.py | 1 - src/idpyoidc/defaults.py | 1 - src/idpyoidc/server/authz/__init__.py | 18 +++++++++++++----- src/idpyoidc/server/oauth2/token.py | 4 +++- src/idpyoidc/server/oauth2/token_helper.py | 3 +++ src/idpyoidc/server/oidc/token.py | 2 ++ src/idpyoidc/server/oidc/token_helper.py | 2 ++ src/idpyoidc/server/util.py | 1 - tests/test_server_36_oauth2_token_exchange.py | 10 ++++------ 9 files changed, 27 insertions(+), 15 deletions(-) diff --git a/src/idpyoidc/client/entity.py b/src/idpyoidc/client/entity.py index d00212dd..e362eccb 100644 --- a/src/idpyoidc/client/entity.py +++ b/src/idpyoidc/client/entity.py @@ -201,4 +201,3 @@ def import_keys(self, keyspec): def get_callback_uris(self): return self.context.claims.callback_uri - diff --git a/src/idpyoidc/defaults.py b/src/idpyoidc/defaults.py index 83341e84..5e8e65c3 100644 --- a/src/idpyoidc/defaults.py +++ b/src/idpyoidc/defaults.py @@ -12,4 +12,3 @@ JWT_BEARER = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" BASECHR = string.ascii_letters + string.digits - diff --git a/src/idpyoidc/server/authz/__init__.py b/src/idpyoidc/server/authz/__init__.py index b90e5ce3..f90094f6 100755 --- a/src/idpyoidc/server/authz/__init__.py +++ b/src/idpyoidc/server/authz/__init__.py @@ -61,10 +61,12 @@ def __call__( request: Union[dict, Message], resources: Optional[list] = None, ) -> Grant: - session_info = self.upstream_get("context").session_manager.get_session_info( + _context = self.upstream_get("context") + session_info = _context.session_manager.get_session_info( session_id=session_id, grant=True ) grant = session_info["grant"] + _client_id = session_info['client_id'] args = self.grant_config.copy() @@ -72,20 +74,26 @@ def __call__( if key == "expires_in": grant.set_expires_at(val) elif key == "usage_rules": - setattr(grant, key, self.usage_rules(request.get("client_id"))) + setattr(grant, key, self.usage_rules(_client_id)) else: setattr(grant, key, val) if resources is None: - grant.resources = [session_info["client_id"]] + grant.resources = [_client_id] else: grant.resources = resources - # After this is where user consent should be handled + # Scope handling. If allowed scopes are defined for the client filter using that scopes = grant.scope if not scopes: scopes = request.get("scope", []) - grant.scope = scopes + else: + _allowed = _context.cdb[_client_id].get('allowed_scopes', []) + if _allowed: + scopes = list(set(scopes).intersection(set(_allowed))) + grant.scope = scopes + + # After this is where user consent should be handled grant.claims = self.upstream_get("context").claims_interface.get_claims_all_usage( session_id=session_id, scopes=scopes ) diff --git a/src/idpyoidc/server/oauth2/token.py b/src/idpyoidc/server/oauth2/token.py index e20dc4ac..987bd39d 100755 --- a/src/idpyoidc/server/oauth2/token.py +++ b/src/idpyoidc/server/oauth2/token.py @@ -33,7 +33,9 @@ class Token(Endpoint): helper_by_grant_type = { "authorization_code": AccessTokenHelper, "refresh_token": RefreshTokenHelper, + "urn:ietf:params:oauth:grant-type:token-exchange": TokenExchangeHelper, } + token_exchange_helper = TokenExchangeHelper def __init__(self, upstream_get, new_refresh_token=False, **kwargs): Endpoint.__init__(self, upstream_get, **kwargs) @@ -132,7 +134,7 @@ def process_request(self, request: Optional[Union[Message, dict]] = None, **kwar _access_token = response_args["access_token"] _context = self.upstream_get("context") - if isinstance(_helper, TokenExchangeHelper): + if isinstance(_helper, self.token_exchange_helper): _handler_key = _helper.get_handler_key(request, _context) else: _handler_key = "access_token" diff --git a/src/idpyoidc/server/oauth2/token_helper.py b/src/idpyoidc/server/oauth2/token_helper.py index a7edbcd3..d6c3a2c1 100755 --- a/src/idpyoidc/server/oauth2/token_helper.py +++ b/src/idpyoidc/server/oauth2/token_helper.py @@ -33,6 +33,7 @@ class TokenEndpointHelper(object): + def __init__(self, endpoint, config=None): self.endpoint = endpoint self.config = config @@ -154,6 +155,7 @@ def validate_resource_indicators_policy(request, context, **kwargs): class AccessTokenHelper(TokenEndpointHelper): + def process_request(self, req: Union[Message, dict], **kwargs): """ @@ -341,6 +343,7 @@ def post_parse_request( class RefreshTokenHelper(TokenEndpointHelper): + def process_request(self, req: Union[Message, dict], **kwargs): _context = self.endpoint.upstream_get("context") _mngr = _context.session_manager diff --git a/src/idpyoidc/server/oidc/token.py b/src/idpyoidc/server/oidc/token.py index b4280901..045ce73b 100755 --- a/src/idpyoidc/server/oidc/token.py +++ b/src/idpyoidc/server/oidc/token.py @@ -42,3 +42,5 @@ class Token(token.Token): "urn:openid:params:grant-type:ciba": CIBATokenHelper, "urn:ietf:params:oauth:grant-type:token-exchange": TokenExchangeHelper, } + + token_exchange_helper = TokenExchangeHelper diff --git a/src/idpyoidc/server/oidc/token_helper.py b/src/idpyoidc/server/oidc/token_helper.py index 5b8e73f6..80972205 100755 --- a/src/idpyoidc/server/oidc/token_helper.py +++ b/src/idpyoidc/server/oidc/token_helper.py @@ -21,6 +21,7 @@ class AccessTokenHelper(TokenEndpointHelper): + def _get_session_info(self, request, session_manager): if request["grant_type"] != "authorization_code": return self.error_cls(error="invalid_request", error_description="Unknown grant_type") @@ -208,6 +209,7 @@ def post_parse_request( class RefreshTokenHelper(TokenEndpointHelper): + def process_request(self, req: Union[Message, dict], **kwargs): _context = self.endpoint.upstream_get("context") _mngr = _context.session_manager diff --git a/src/idpyoidc/server/util.py b/src/idpyoidc/server/util.py index ef481578..4ec0eaa9 100755 --- a/src/idpyoidc/server/util.py +++ b/src/idpyoidc/server/util.py @@ -169,4 +169,3 @@ def execute(spec, **kwargs): return _func(**kwargs) else: return kwargs - diff --git a/tests/test_server_36_oauth2_token_exchange.py b/tests/test_server_36_oauth2_token_exchange.py index c3559868..de9eabc0 100644 --- a/tests/test_server_36_oauth2_token_exchange.py +++ b/tests/test_server_36_oauth2_token_exchange.py @@ -1117,8 +1117,9 @@ def test_token_exchange_unsupported_scope_requested_2(self): - allowed_scopes: [profile] - requested_token_type: "...:access_token" Scenario: - Client1 has an access_token1 (with openid and profile scope). - Then, client1 exchanges access_token1 for a new access_token1_13 with scope offline_access + Client1 has an access_token1 (with scopes openid and profile). + Then, client1 wants to exchange access_token1 for a new access_token1_13 with scope + offline_access. This is not allowed. """ self.context.cdb["client_1"]["token_exchange"] = { "subject_token_types_supported": [ @@ -1187,10 +1188,7 @@ def test_token_exchange_unsupported_scope_requested_2(self): ) _resp = self.endpoint.process_request(request=_req) assert _resp["error"] == "invalid_scope" - assert ( - _resp["error_description"] - == "Invalid requested scopes" - ) + assert _resp["error_description"] == "Invalid requested scopes" token_exchange_req["scope"] = "offline_access profile" From 58852fb3d94cdfc4a95058389263242b87cdca6d Mon Sep 17 00:00:00 2001 From: roland Date: Fri, 10 Feb 2023 10:42:36 +0100 Subject: [PATCH 57/76] Added Kristos's Token Revocation. --- src/idpyoidc/message/oauth2/__init__.py | 28 +- .../server/oauth2/token_revocation.py | 134 ++++ ...st_server_38_oauth2_revocation_endpoint.py | 580 ++++++++++++++++++ 3 files changed, 741 insertions(+), 1 deletion(-) create mode 100644 src/idpyoidc/server/oauth2/token_revocation.py create mode 100644 tests/test_server_38_oauth2_revocation_endpoint.py diff --git a/src/idpyoidc/message/oauth2/__init__.py b/src/idpyoidc/message/oauth2/__init__.py index 11852a2d..b799152c 100644 --- a/src/idpyoidc/message/oauth2/__init__.py +++ b/src/idpyoidc/message/oauth2/__init__.py @@ -10,6 +10,7 @@ from idpyoidc.exception import MissingRequiredAttribute from idpyoidc.exception import VerificationError from idpyoidc.message import Message +from idpyoidc.message import msg_ser from idpyoidc.message import OPTIONAL_LIST_OF_SP_SEP_STRINGS from idpyoidc.message import OPTIONAL_LIST_OF_STRINGS from idpyoidc.message import REQUIRED_LIST_OF_SP_SEP_STRINGS @@ -20,7 +21,6 @@ from idpyoidc.message import SINGLE_REQUIRED_BOOLEAN from idpyoidc.message import SINGLE_REQUIRED_INT from idpyoidc.message import SINGLE_REQUIRED_STRING -from idpyoidc.message import msg_ser logger = logging.getLogger(__name__) @@ -581,6 +581,32 @@ class JSONWebToken(Message): 'entitlements': OPTIONAL_LIST_OF_STRINGS } +# RFC 7009 +class TokenRevocationRequest(Message): + c_param = { + "token": SINGLE_REQUIRED_STRING, + "token_type_hint": SINGLE_OPTIONAL_STRING, + # The ones below are part of authentication information + "client_id": SINGLE_OPTIONAL_STRING, + "client_secret": SINGLE_OPTIONAL_STRING, + } + + +class TokenRevocationResponse(Message): + pass + + +class TokenRevocationErrorResponse(ResponseMessage): + """ + Error response from the revocation endpoint + """ + c_allowed_values = ResponseMessage.c_allowed_values.copy() + c_allowed_values.update({ + "error": [ + "unsupported_token_type" + ] + }) + def factory(msgtype, **kwargs): """ diff --git a/src/idpyoidc/server/oauth2/token_revocation.py b/src/idpyoidc/server/oauth2/token_revocation.py new file mode 100644 index 00000000..1c8c1d54 --- /dev/null +++ b/src/idpyoidc/server/oauth2/token_revocation.py @@ -0,0 +1,134 @@ +"""Implements RFC7009""" + +import logging + +from idpyoidc.exception import ImproperlyConfigured +from idpyoidc.message import oauth2 +from idpyoidc.server.endpoint import Endpoint +from idpyoidc.server.token.exception import UnknownToken +from idpyoidc.server.token.exception import WrongTokenClass +from idpyoidc.util import importer + +logger = logging.getLogger(__name__) + + +class TokenRevocation(Endpoint): + """Implements RFC7009""" + + request_cls = oauth2.TokenRevocationRequest + response_cls = oauth2.TokenRevocationResponse + error_cls = oauth2.TokenRevocationErrorResponse + request_format = "urlencoded" + response_format = "json" + endpoint_name = "revocation_endpoint" + name = "token_revocation" + default_capabilities = { + "client_authn_method": [ + "client_secret_basic", + "client_secret_post", + "client_secret_jwt", + "bearer_header", + "private_key_jwt", + ] + } + + token_types_supported = ["authorization_code", "access_token", "refresh_token"] + + def __init__(self, upstream_get, **kwargs): + Endpoint.__init__(self, upstream_get, **kwargs) + self.token_revocation_kwargs = kwargs + + def get_client_id_from_token(self, endpoint_context, token, request=None): + _info = endpoint_context.session_manager.get_session_info_by_token( + token, handler_key="access_token" + ) + return _info["client_id"] + + def process_request(self, request=None, **kwargs): + """ + :param request: The revocation request as a dictionary + :param kwargs: + :return: + """ + _revoke_request = self.request_cls(**request) + if "error" in _revoke_request: + return _revoke_request + + request_token = _revoke_request["token"] + _resp = self.response_cls() + _context = self.upstream_get("endpoint_context") + logger.debug("Token Revocation") + + try: + _session_info = _context.session_manager.get_session_info_by_token( + request_token, grant=True + ) + except (UnknownToken, WrongTokenClass): + return {"response_args": _resp} + + client_id = _session_info["client_id"] + if client_id != _revoke_request["client_id"]: + logger.debug("{} owner of token".format(client_id)) + logger.warning("Client using token it was not given") + return self.error_cls(error="invalid_grant", error_description="Wrong client") + + grant = _session_info["grant"] + _token = grant.get_token(request_token) + + try: + self.token_types_supported = _context.cdb[client_id]["token_revocation"][ + "token_types_supported"] + except: + self.token_types_supported = self.token_revocation_kwargs.get("token_types_supported", + self.token_types_supported) + + try: + self.policy = _context.cdb[client_id]["token_revocation"]["policy"] + except: + self.policy = self.token_revocation_kwargs.get("policy", { + "": {"callable": validate_token_revocation_policy}}) + + if _token.token_class not in self.token_types_supported: + desc = ( + "The authorization server does not support the revocation of " + "the presented token type. That is, the client tried to revoke an access " + "token on a server not supporting this feature." + ) + return self.error_cls(error="unsupported_token_type", error_description=desc) + + return self._revoke(_revoke_request, _session_info) + + def _revoke(self, request, session_info): + _context = self.upstream_get("endpoint_context") + _mngr = _context.session_manager + _token = _mngr.find_token(session_info["branch_id"], request["token"]) + + _cls = _token.token_class + if _cls not in self.policy: + _cls = "" + + temp_policy = self.policy[_cls] + callable = temp_policy["callable"] + kwargs = temp_policy.get("kwargs", {}) + + if isinstance(callable, str): + try: + fn = importer(callable) + except Exception: + raise ImproperlyConfigured(f"Error importing {callable} policy callable") + else: + fn = callable + + try: + return fn(_token, session_info=session_info, **kwargs) + except Exception as e: + logger.error(f"Error while executing the {fn} policy callable: {e}") + return self.error_cls(error="server_error", error_description="Internal server error") + + +def validate_token_revocation_policy(token, session_info, **kwargs): + _token = token + _token.revoke() + + response_args = {"response_args": {}} + return oauth2.TokenRevocationResponse(**response_args) diff --git a/tests/test_server_38_oauth2_revocation_endpoint.py b/tests/test_server_38_oauth2_revocation_endpoint.py new file mode 100644 index 00000000..3860c760 --- /dev/null +++ b/tests/test_server_38_oauth2_revocation_endpoint.py @@ -0,0 +1,580 @@ +import base64 +import os + +import pytest +from cryptojwt import as_unicode +from cryptojwt.utils import as_bytes + +from idpyoidc.message.oauth2 import TokenRevocationRequest +from idpyoidc.message.oauth2 import TokenRevocationResponse +from idpyoidc.message.oidc import AccessTokenRequest +from idpyoidc.message.oidc import AuthorizationRequest +from idpyoidc.server import Server +from idpyoidc.server.authn_event import create_authn_event +from idpyoidc.server.authz import AuthzHandling +from idpyoidc.server.client_authn import verify_client +from idpyoidc.server.configure import ASConfiguration +from idpyoidc.server.exception import ClientAuthenticationError +from idpyoidc.server.oauth2.authorization import Authorization +from idpyoidc.server.oauth2.introspection import Introspection +from idpyoidc.server.oauth2.token_revocation import TokenRevocation +from idpyoidc.server.oauth2.token_revocation import validate_token_revocation_policy +from idpyoidc.server.oidc.token import Token +from idpyoidc.server.user_authn.authn_context import INTERNETPROTOCOLPASSWORD +from idpyoidc.server.user_info import UserInfo +from idpyoidc.time_util import utc_time_sans_frac +from tests import CRYPT_CONFIG +from tests import SESSION_PARAMS + +KEYDEFS = [ + {"type": "RSA", "key": "", "use": ["sig"]}, + {"type": "EC", "crv": "P-256", "use": ["sig"]}, +] + +RESPONSE_TYPES_SUPPORTED = [ + ["code"], + ["token"], + ["id_token"], + ["code", "token"], + ["code", "id_token"], + ["id_token", "token"], + ["code", "token", "id_token"], + ["none"], +] + +CAPABILITIES = { + "response_types_supported": [" ".join(x) for x in RESPONSE_TYPES_SUPPORTED], + "token_endpoint_auth_methods_supported": [ + "client_secret_post", + "client_secret_basic", + "client_secret_jwt", + "private_key_jwt", + ], + "response_modes_supported": ["query", "fragment", "form_post"], + "subject_types_supported": ["public", "pairwise", "ephemeral"], + "grant_types_supported": [ + "authorization_code", + "implicit", + "urn:ietf:params:oauth:grant-type:jwt-bearer", + "refresh_token", + ], + "claim_types_supported": ["normal", "aggregated", "distributed"], + "claims_parameter_supported": True, + "request_parameter_supported": True, + "request_uri_parameter_supported": True, +} + +AUTH_REQ = AuthorizationRequest( + client_id="client_1", + redirect_uri="https://example.com/cb", + scope=["openid", "offline_access"], + state="STATE", + response_type="code id_token", +) + +TOKEN_REQ = AccessTokenRequest( + client_id="client_1", + redirect_uri="https://example.com/cb", + state="STATE", + grant_type="authorization_code", + client_secret="hemligt", +) + +TOKEN_REQ_DICT = TOKEN_REQ.to_dict() + +BASEDIR = os.path.abspath(os.path.dirname(__file__)) + + +def full_path(local_file): + return os.path.join(BASEDIR, local_file) + + +@pytest.mark.parametrize("jwt_token", [True, False]) +class TestEndpoint: + + @pytest.fixture(autouse=True) + def create_endpoint(self, jwt_token): + conf = { + "issuer": "https://example.com/", + "httpc_params": {"verify": False, "timeout": 1}, + "capabilities": CAPABILITIES, + "keys": {"uri_path": "jwks.json", "key_defs": KEYDEFS}, + "token_handler_args": { + "jwks_file": "private/token_jwks.json", + "code": {"lifetime": 600, "kwargs": {"crypt_conf": CRYPT_CONFIG}}, + "token": { + "class": "idpyoidc.server.token.jwt_token.JWTToken", + "kwargs": { + "lifetime": 3600, + "add_claims_by_scope": True, + "aud": ["https://example.org/appl"], + }, + }, + "refresh": { + "class": "idpyoidc.server.token.jwt_token.JWTToken", + "kwargs": { + "lifetime": 3600, + "aud": ["https://example.org/appl"], + }, + }, + "id_token": { + "class": "idpyoidc.server.token.id_token.IDToken", + }, + }, + "endpoint": { + "authorization": { + "path": "{}/authorization", + "class": Authorization, + "kwargs": {}, + }, + "introspection": { + "path": "{}/intro", + "class": Introspection, + "kwargs": { + "client_authn_method": ["client_secret_post"], + "enable_claims_per_client": False, + }, + }, + "token_revocation": { + "path": "{}/revoke", + "class": TokenRevocation, + "kwargs": { + "client_authn_method": ["client_secret_post"], + }, + }, + "token": { + "path": "token", + "class": Token, + "kwargs": { + "client_authn_method": [ + "client_secret_basic", + "client_secret_post", + "client_secret_jwt", + "private_key_jwt", + ] + }, + }, + }, + "authentication": { + "anon": { + "acr": INTERNETPROTOCOLPASSWORD, + "class": "idpyoidc.server.user_authn.user.NoAuthn", + "kwargs": {"user": "diana"}, + } + }, + "userinfo": { + "path": "{}/userinfo", + "class": UserInfo, + "kwargs": {"db_file": full_path("users.json")}, + }, + "client_authn": verify_client, + "template_dir": "template", + "authz": { + "class": AuthzHandling, + "kwargs": { + "grant_config": { + "usage_rules": { + "authorization_code": { + "supports_minting": [ + "access_token", + "refresh_token", + "id_token", + ], + "max_usage": 1, + }, + "access_token": {}, + "refresh_token": { + "supports_minting": ["access_token", "refresh_token"], + }, + }, + "expires_in": 43200, + } + }, + }, + "session_params": SESSION_PARAMS, + } + if jwt_token: + conf["token_handler_args"]["token"] = { + "class": "idpyoidc.server.token.jwt_token.JWTToken", + "kwargs": {}, + } + server = Server(ASConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) + endpoint_context = server.context + endpoint_context.cdb["client_1"] = { + "client_secret": "hemligt", + "redirect_uris": [("https://example.com/cb", None)], + "client_salt": "salted", + "token_endpoint_auth_method": "client_secret_post", + "response_types": ["code", "token", "code id_token", "id_token"], + "add_claims": { + "always": { + "introspection": ["nickname", "eduperson_scoped_affiliation"], + }, + "by_scope": {}, + }, + "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access", + "research_and_scholarship"] + } + endpoint_context.keyjar.import_jwks_as_json( + endpoint_context.keyjar.export_jwks_as_json(private=True), + endpoint_context.issuer, + ) + self.revocation_endpoint = server.get_endpoint("token_revocation") + self.token_endpoint = server.get_endpoint("token") + self.session_manager = endpoint_context.session_manager + self.user_id = "diana" + + def _create_session(self, auth_req, sub_type="public", sector_identifier=""): + if sector_identifier: + authz_req = auth_req.copy() + authz_req["sector_identifier_uri"] = sector_identifier + else: + authz_req = auth_req + client_id = authz_req["client_id"] + ae = create_authn_event(self.user_id) + return self.session_manager.create_session( + ae, authz_req, self.user_id, client_id=client_id, sub_type=sub_type + ) + + def _mint_token(self, token_class, grant, session_id, based_on=None, **kwargs): + # Constructing an authorization code is now done + return grant.mint_token( + session_id=session_id, + context=self.token_endpoint.upstream_get("context"), + token_class=token_class, + token_handler=self.session_manager.token_handler.handler[token_class], + expires_at=utc_time_sans_frac() + 300, # 5 minutes from now + based_on=based_on, + **kwargs + ) + + def _get_access_token(self, areq): + session_id = self._create_session(areq) + # Consent handling + grant = self.token_endpoint.upstream_get("endpoint_context").authz(session_id, areq) + self.session_manager[session_id] = grant + # grant = self.session_manager[session_id] + code = self._mint_token("authorization_code", grant, session_id) + return self._mint_token("access_token", grant, session_id, code) + + def _get_refresh_token(self, areq): + session_id = self._create_session(areq) + # Consent handling + grant = self.token_endpoint.upstream_get("endpoint_context").authz(session_id, areq) + self.session_manager[session_id] = grant + # grant = self.session_manager[session_id] + code = self._mint_token("authorization_code", grant, session_id) + return self._mint_token("refresh_token", grant, session_id, code) + + def test_parse_no_authn(self): + access_token = self._get_access_token(AUTH_REQ) + with pytest.raises(ClientAuthenticationError): + self.revocation_endpoint.parse_request({"token": access_token.value}) + + def test_parse_with_client_auth_in_req(self): + access_token = self._get_access_token(AUTH_REQ) + + _context = self.revocation_endpoint.upstream_get("endpoint_context") + _req = self.revocation_endpoint.parse_request( + { + "token": access_token.value, + "client_id": "client_1", + "client_secret": _context.cdb["client_1"]["client_secret"], + } + ) + + assert isinstance(_req, TokenRevocationRequest) + assert set(_req.keys()) == {"token", "client_id", "client_secret"} + + def test_parse_with_wrong_client_authn(self): + access_token = self._get_access_token(AUTH_REQ) + + _basic_token = "{}:{}".format( + "client_1", + self.revocation_endpoint.upstream_get("endpoint_context").cdb["client_1"][ + "client_secret" + ], + ) + _basic_token = as_unicode(base64.b64encode(as_bytes(_basic_token))) + _basic_authz = "Basic {}".format(_basic_token) + http_info = {"headers": {"authorization": _basic_authz}} + + with pytest.raises(ClientAuthenticationError): + self.revocation_endpoint.parse_request( + {"token": access_token.value}, http_info=http_info + ) + + def test_process_request(self): + access_token = self._get_access_token(AUTH_REQ) + + _req = self.revocation_endpoint.parse_request( + { + "token": access_token.value, + "client_id": "client_1", + "client_secret": self.revocation_endpoint.upstream_get("endpoint_context").cdb[ + "client_1" + ]["client_secret"], + } + ) + _resp = self.revocation_endpoint.process_request(_req) + assert _resp + assert set(_resp.keys()) == {"response_args"} + + def test_do_response(self): + access_token = self._get_access_token(AUTH_REQ) + + _req = self.revocation_endpoint.parse_request( + { + "token": access_token.value, + "client_id": "client_1", + "client_secret": self.revocation_endpoint.upstream_get("endpoint_context").cdb[ + "client_1" + ]["client_secret"], + } + ) + _resp = self.revocation_endpoint.process_request(_req) + msg_info = self.revocation_endpoint.do_response(request=_req, **_resp) + assert isinstance(msg_info, dict) + assert set(msg_info.keys()) == {"response", "http_headers"} + assert msg_info["http_headers"] == [ + ("Content-type", "application/json; charset=utf-8"), + ("Pragma", "no-cache"), + ("Cache-Control", "no-store"), + ] + + def test_do_response_no_token(self): + # access_token = self._get_access_token(AUTH_REQ) + _context = self.revocation_endpoint.upstream_get("endpoint_context") + _req = self.revocation_endpoint.parse_request( + { + "client_id": "client_1", + "client_secret": _context.cdb["client_1"]["client_secret"], + } + ) + _resp = self.revocation_endpoint.process_request(_req) + assert "error" in _resp + + def test_access_token(self): + access_token = self._get_access_token(AUTH_REQ) + assert access_token.revoked is False + _context = self.revocation_endpoint.upstream_get("endpoint_context") + _req = self.revocation_endpoint.parse_request( + { + "token": access_token.value, + "client_id": "client_1", + "client_secret": _context.cdb["client_1"]["client_secret"], + } + ) + _resp = self.revocation_endpoint.process_request(_req) + assert "response_args" in _resp + assert access_token.revoked + + def test_access_token_per_client(self): + + def custom_token_revocation_policy(token, session_info, **kwargs): + _token = token + _token.revoke() + response_args = {"response_args": {"type": "custom"}} + return TokenRevocationResponse(**response_args) + + access_token = self._get_access_token(AUTH_REQ) + assert access_token.revoked is False + _context = self.revocation_endpoint.upstream_get("endpoint_context") + _context.cdb["client_1"]["token_revocation"] = { + "token_types_supported": [ + "access_token", + ], + "policy": { + "": { + "callable": validate_token_revocation_policy, + }, + "access_token": { + "callable": custom_token_revocation_policy, + } + }, + } + _req = self.revocation_endpoint.parse_request( + { + "token": access_token.value, + "client_id": "client_1", + "client_secret": _context.cdb["client_1"]["client_secret"], + } + ) + _resp = self.revocation_endpoint.process_request(_req) + assert "response_args" in _resp + assert "type" in _resp["response_args"] + assert _resp["response_args"]["type"] == "custom" + assert access_token.revoked + + def test_missing_token_policy_per_client(self): + + def custom_token_revocation_policy(token, session_info, **kwargs): + _token = token + _token.revoke() + response_args = {"response_args": {"type": "custom"}} + return TokenRevocationResponse(**response_args) + + access_token = self._get_access_token(AUTH_REQ) + assert access_token.revoked is False + _context = self.revocation_endpoint.upstream_get("endpoint_context") + _context.cdb["client_1"]["token_revocation"] = { + "token_types_supported": [ + "access_token", + ], + "policy": { + "": { + "callable": validate_token_revocation_policy, + }, + "refresh_token": { + "callable": custom_token_revocation_policy, + } + }, + } + _req = self.revocation_endpoint.parse_request( + { + "token": access_token.value, + "client_id": "client_1", + "client_secret": _context.cdb["client_1"]["client_secret"], + } + ) + _resp = self.revocation_endpoint.process_request(_req) + assert "response_args" in _resp + assert access_token.revoked + + def test_code(self): + session_id = self._create_session(AUTH_REQ) + + # Apply consent + grant = self.token_endpoint.upstream_get("endpoint_context").authz(session_id, AUTH_REQ) + self.session_manager[session_id] = grant + + code = self._mint_token("authorization_code", grant, session_id) + assert code.revoked is False + _context = self.revocation_endpoint.upstream_get("endpoint_context") + + _req = self.revocation_endpoint.parse_request( + { + "token": code.value, + "client_id": "client_1", + "client_secret": _context.cdb["client_1"]["client_secret"], + } + ) + _resp = self.revocation_endpoint.process_request(_req) + assert "response_args" in _resp + assert code.revoked + + def test_refresh_token(self): + refresh_token = self._get_refresh_token(AUTH_REQ) + assert refresh_token.revoked is False + _context = self.revocation_endpoint.upstream_get("endpoint_context") + _req = self.revocation_endpoint.parse_request( + { + "token": refresh_token.value, + "client_id": "client_1", + "client_secret": _context.cdb["client_1"]["client_secret"], + } + ) + _resp = self.revocation_endpoint.process_request(_req) + assert "response_args" in _resp + assert refresh_token.revoked + + def test_expired_access_token(self): + access_token = self._get_access_token(AUTH_REQ) + access_token.expires_at = utc_time_sans_frac() - 1000 + + _context = self.revocation_endpoint.upstream_get("endpoint_context") + + _req = self.revocation_endpoint.parse_request( + { + "token": access_token.value, + "client_id": "client_1", + "client_secret": _context.cdb["client_1"]["client_secret"], + } + ) + _resp = self.revocation_endpoint.process_request(_req) + assert "response_args" in _resp + + def test_revoked_access_token(self): + access_token = self._get_access_token(AUTH_REQ) + access_token.revoked = True + + _context = self.revocation_endpoint.upstream_get("endpoint_context") + + _req = self.revocation_endpoint.parse_request( + { + "token": access_token.value, + "client_id": "client_1", + "client_secret": _context.cdb["client_1"]["client_secret"], + } + ) + _resp = self.revocation_endpoint.process_request(_req) + assert "response_args" in _resp + + def test_unsupported_token_type(self): + self.revocation_endpoint.token_types_supported = ["access_token"] + session_id = self._create_session(AUTH_REQ) + + # Apply consent + grant = self.token_endpoint.upstream_get("endpoint_context").authz(session_id, AUTH_REQ) + self.session_manager[session_id] = grant + + code = self._mint_token("authorization_code", grant, session_id) + assert code.revoked is False + _context = self.revocation_endpoint.upstream_get("endpoint_context") + + _req = self.revocation_endpoint.parse_request( + { + "token": code.value, + "client_id": "client_1", + "client_secret": _context.cdb["client_1"]["client_secret"], + } + ) + _resp = self.revocation_endpoint.process_request(_req) + err_dscr = ( + "The authorization server does not support the revocation of " + "the presented token type. That is, the client tried to revoke an access " + "token on a server not supporting this feature." + ) + assert "error" in _resp + assert _resp.to_dict() == { + "error": "unsupported_token_type", + "error_description": err_dscr, + } + assert code.revoked is False + + def test_unsupported_token_type_per_client(self): + _context = self.revocation_endpoint.upstream_get("endpoint_context") + _context.cdb["client_1"]["token_revocation"] = { + "token_types_supported": [ + "refresh_token", + ], + } + session_id = self._create_session(AUTH_REQ) + + # Apply consent + grant = self.token_endpoint.upstream_get("endpoint_context").authz(session_id, AUTH_REQ) + self.session_manager[session_id] = grant + + code = self._mint_token("authorization_code", grant, session_id) + assert code.revoked is False + _context = self.revocation_endpoint.upstream_get("endpoint_context") + + _req = self.revocation_endpoint.parse_request( + { + "token": code.value, + "client_id": "client_1", + "client_secret": _context.cdb["client_1"]["client_secret"], + } + ) + _resp = self.revocation_endpoint.process_request(_req) + err_dscr = ( + "The authorization server does not support the revocation of " + "the presented token type. That is, the client tried to revoke an access " + "token on a server not supporting this feature." + ) + assert "error" in _resp + assert _resp.to_dict() == { + "error": "unsupported_token_type", + "error_description": err_dscr, + } + assert code.revoked is False From 6cf0d454172974d37b347ed3aa8840d03f3e7f30 Mon Sep 17 00:00:00 2001 From: roland Date: Fri, 10 Feb 2023 10:43:57 +0100 Subject: [PATCH 58/76] Added Kristos's Token Revocation. --- src/idpyoidc/message/oauth2/__init__.py | 1 + src/idpyoidc/server/oauth2/token_revocation.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/idpyoidc/message/oauth2/__init__.py b/src/idpyoidc/message/oauth2/__init__.py index b799152c..717c55e6 100644 --- a/src/idpyoidc/message/oauth2/__init__.py +++ b/src/idpyoidc/message/oauth2/__init__.py @@ -581,6 +581,7 @@ class JSONWebToken(Message): 'entitlements': OPTIONAL_LIST_OF_STRINGS } + # RFC 7009 class TokenRevocationRequest(Message): c_param = { diff --git a/src/idpyoidc/server/oauth2/token_revocation.py b/src/idpyoidc/server/oauth2/token_revocation.py index 1c8c1d54..8ac45f49 100644 --- a/src/idpyoidc/server/oauth2/token_revocation.py +++ b/src/idpyoidc/server/oauth2/token_revocation.py @@ -78,13 +78,13 @@ def process_request(self, request=None, **kwargs): try: self.token_types_supported = _context.cdb[client_id]["token_revocation"][ "token_types_supported"] - except: + except Exception: self.token_types_supported = self.token_revocation_kwargs.get("token_types_supported", self.token_types_supported) try: self.policy = _context.cdb[client_id]["token_revocation"]["policy"] - except: + except Exception: self.policy = self.token_revocation_kwargs.get("policy", { "": {"callable": validate_token_revocation_policy}}) From 4ae276269db7cccc502b24f56eb5a93f0db16278 Mon Sep 17 00:00:00 2001 From: Roland Hedberg Date: Mon, 13 Feb 2023 15:36:12 +0100 Subject: [PATCH 59/76] Bug discovered while attempting to use the package by another package. --- src/idpyoidc/claims.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/idpyoidc/claims.py b/src/idpyoidc/claims.py index 22198795..6e8d79c1 100644 --- a/src/idpyoidc/claims.py +++ b/src/idpyoidc/claims.py @@ -41,7 +41,7 @@ def __init__(self, ImpExp.__init__(self) if isinstance(prefer, dict): - self.prefer = {k: v for k, v in prefer.items() if k in self.supports} + self.prefer = {k: v for k, v in prefer.items() if k in self.supports()} else: self.prefer = {} @@ -76,7 +76,7 @@ def _callback_uris(self, base_url, hex): elif type in ["id_token", "id_token token"]: _uri.append('implicit') - if "form_post" in self.supports: + if "form_post" in self._supports: _uri.append("form_post") callback_uri = {} @@ -167,7 +167,10 @@ def handle_keys(self, configuration: dict, keyjar: Optional[KeyJar] = None): return {'keyjar': keyjar, 'jwks': _jwks, 'jwks_uri': _jwks_uri} - def load_conf(self, configuration, supports, keyjar: Optional[KeyJar] = None): + def load_conf(self, + configuration: dict, + supports: dict, + keyjar: Optional[KeyJar] = None) -> KeyJar: for attr, val in configuration.items(): if attr == "preference": for k, v in val.items(): From 3c322c4f884a6d447c7c79fc00a0dd9d3961ed56 Mon Sep 17 00:00:00 2001 From: roland Date: Mon, 13 Feb 2023 15:44:39 +0100 Subject: [PATCH 60/76] Bug discovered when attempting to use this package by another package. --- src/idpyoidc/claims.py | 2 +- src/idpyoidc/server/endpoint_context.py | 16 ---------------- 2 files changed, 1 insertion(+), 17 deletions(-) diff --git a/src/idpyoidc/claims.py b/src/idpyoidc/claims.py index 6e8d79c1..fa28bb17 100644 --- a/src/idpyoidc/claims.py +++ b/src/idpyoidc/claims.py @@ -41,7 +41,7 @@ def __init__(self, ImpExp.__init__(self) if isinstance(prefer, dict): - self.prefer = {k: v for k, v in prefer.items() if k in self.supports()} + self.prefer = {k: v for k, v in prefer.items() if k in self._supports} else: self.prefer = {} diff --git a/src/idpyoidc/server/endpoint_context.py b/src/idpyoidc/server/endpoint_context.py index 190771fd..742084c2 100755 --- a/src/idpyoidc/server/endpoint_context.py +++ b/src/idpyoidc/server/endpoint_context.py @@ -30,22 +30,6 @@ logger = logging.getLogger(__name__) -def get_provider_capabilities(conf, endpoints): - _cap = conf.get("capabilities", {}) - if _cap is None: - _cap = {} - - for endpoint, endpoint_instance in endpoints.items(): - if endpoint in ["webfinger", "provider_config"]: - continue - - for key, val in endpoint_instance.get_provider_info_attributes().items(): - if key not in _cap: - _cap[key] = val - - return _cap - - def init_user_info(conf, cwd: str): kwargs = conf.get("kwargs", {}) From 3269e08d8ef12ccf560d15db43c979ca67de0ea6 Mon Sep 17 00:00:00 2001 From: roland Date: Mon, 13 Feb 2023 16:57:14 +0100 Subject: [PATCH 61/76] Encrypting request parameter turned off by default but if turned on should actually work. --- src/idpyoidc/client/oauth2/authorization.py | 8 +++++--- src/idpyoidc/client/oidc/authorization.py | 5 ++++- src/idpyoidc/client/oidc/utils.py | 7 ++++--- src/idpyoidc/server/oidc/authorization.py | 2 +- 4 files changed, 14 insertions(+), 8 deletions(-) diff --git a/src/idpyoidc/client/oauth2/authorization.py b/src/idpyoidc/client/oauth2/authorization.py index ff308d3e..39f5ff7d 100644 --- a/src/idpyoidc/client/oauth2/authorization.py +++ b/src/idpyoidc/client/oauth2/authorization.py @@ -32,9 +32,11 @@ class Authorization(Service): _supports = { "response_types_supported": ["code", 'token'], "response_modes_supported": ['query', 'fragment'], - "request_object_signing_alg_values_supported": claims.get_signing_algs, - "request_object_encryption_alg_values_supported": claims.get_encryption_algs, - "request_object_encryption_enc_values_supported": claims.get_encryption_encs, + # Below not OAuth2 functionality + # "request_object_signing_alg_values_supported": claims.get_signing_algs, + # "request_object_encryption_alg_values_supported": claims.get_encryption_algs, + # "request_object_encryption_enc_values_supported": claims.get_encryption_encs, + # "encrypt_request_object_supported": False, } _callback_path = { diff --git a/src/idpyoidc/client/oidc/authorization.py b/src/idpyoidc/client/oidc/authorization.py index 204a8cd9..26ab2047 100644 --- a/src/idpyoidc/client/oidc/authorization.py +++ b/src/idpyoidc/client/oidc/authorization.py @@ -257,6 +257,9 @@ def construct_request_parameter( _req_jwt = make_openid_request(req, **_mor_args) + if 'target' not in kwargs: + kwargs['target'] = _context.provider_info["issuer"] + # Should the request be encrypted _req_jwte = request_object_encryption(_req_jwt, _context, self.upstream_get('attribute', 'keyjar'), @@ -300,7 +303,7 @@ def oidc_post_construct(self, req, **kwargs): _req = self.construct_request_parameter(req, _request_param, **kwargs) req["request_uri"] = self.store_request_on_file(_req, **kwargs) elif _request_param == "request": - _req = self.construct_request_parameter(req, _request_param) + _req = self.construct_request_parameter(req, _request_param, **kwargs) req["request"] = _req if _req: diff --git a/src/idpyoidc/client/oidc/utils.py b/src/idpyoidc/client/oidc/utils.py index ced4d6f4..4ccd9f1c 100644 --- a/src/idpyoidc/client/oidc/utils.py +++ b/src/idpyoidc/client/oidc/utils.py @@ -46,14 +46,15 @@ def request_object_encryption(msg, service_context, keyjar, **kwargs): except KeyError: _kid = "" - if "target" not in kwargs: + _target = kwargs.get('target', kwargs.get('recv', None)) + if _target is None: raise MissingRequiredAttribute("No target specified") if _kid: - _keys = keyjar.get_encrypt_key(_kty, issuer_id=kwargs["target"], kid=_kid) + _keys = keyjar.get_encrypt_key(_kty, issuer_id=_target, kid=_kid) _jwe["kid"] = _kid else: - _keys = keyjar.get_encrypt_key(_kty, issuer_id=kwargs["target"]) + _keys = keyjar.get_encrypt_key(_kty, issuer_id=_target) return _jwe.encrypt(_keys) diff --git a/src/idpyoidc/server/oidc/authorization.py b/src/idpyoidc/server/oidc/authorization.py index 3182cb22..eb6d06b7 100755 --- a/src/idpyoidc/server/oidc/authorization.py +++ b/src/idpyoidc/server/oidc/authorization.py @@ -78,7 +78,7 @@ class Authorization(authorization.Authorization): _supports = { "claims_parameter_supported": True, - "encrypt_request_object_supported": True, + "encrypt_request_object_supported": False, "request_object_signing_alg_values_supported": claims.get_signing_algs, "request_object_encryption_alg_values_supported": claims.get_encryption_algs, "request_object_encryption_enc_values_supported": claims.get_encryption_encs, From 04ff81e60f50ed81813371079f451b476d32406c Mon Sep 17 00:00:00 2001 From: roland Date: Mon, 13 Feb 2023 17:05:50 +0100 Subject: [PATCH 62/76] Oops, this should work. --- src/idpyoidc/client/oidc/authorization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/idpyoidc/client/oidc/authorization.py b/src/idpyoidc/client/oidc/authorization.py index 26ab2047..44a7ada9 100644 --- a/src/idpyoidc/client/oidc/authorization.py +++ b/src/idpyoidc/client/oidc/authorization.py @@ -258,7 +258,7 @@ def construct_request_parameter( _req_jwt = make_openid_request(req, **_mor_args) if 'target' not in kwargs: - kwargs['target'] = _context.provider_info["issuer"] + kwargs['target'] = _context.provider_info.get("issuer", _context.issuer) # Should the request be encrypted _req_jwte = request_object_encryption(_req_jwt, _context, From d2487f647fdb80ae8863483d8e028662417471c8 Mon Sep 17 00:00:00 2001 From: Roland Hedberg Date: Thu, 23 Feb 2023 08:26:51 +0100 Subject: [PATCH 63/76] Being more explicit on what type of client it is. Deal with policy expressed in two claims: metadata and metadata_policy. --- src/idpyoidc/client/claims/transform.py | 32 +++++++++++++++++-- src/idpyoidc/client/entity.py | 8 +++-- src/idpyoidc/client/oauth2/__init__.py | 3 +- src/idpyoidc/client/oauth2/server_metadata.py | 2 ++ src/idpyoidc/client/oidc/__init__.py | 6 +++- src/idpyoidc/client/oidc/access_token.py | 2 +- src/idpyoidc/client/provider/github.py | 2 +- src/idpyoidc/client/provider/linkedin.py | 2 +- src/idpyoidc/client/service.py | 6 ++-- src/idpyoidc/server/client_authn.py | 2 +- tests/private/token_jwks.json | 2 +- tests/test_client_27_conversation.py | 2 +- 12 files changed, 52 insertions(+), 17 deletions(-) diff --git a/src/idpyoidc/client/claims/transform.py b/src/idpyoidc/client/claims/transform.py index 87e4e2d0..2459774c 100644 --- a/src/idpyoidc/client/claims/transform.py +++ b/src/idpyoidc/client/claims/transform.py @@ -112,14 +112,31 @@ def array_or_singleton(claim_spec, values): def _is_subset(a, b): + # Is 'a' a subset of 'b' if isinstance(a, list): if isinstance(b, list): - return set(b).issubset(set(a)) + return set(a).issubset(set(b)) elif isinstance(b, list): return a in b else: return a == b +def _intersection(a, b): + res = None + if isinstance(a, list): + if isinstance(b, list): + res = list(set(a).intersection(set(b))) + else: + if b in a: + res = b + else: + res = [] + elif isinstance(b, list): + if a in b: + res = [a] + else: + res = [] + return res def preferred_to_registered(prefers: dict, supported: dict, registration_response: Optional[dict] = None): @@ -136,10 +153,19 @@ def preferred_to_registered(prefers: dict, supported: dict, if registration_response: for key, val in registration_response.items(): if key in REGISTER2PREFERRED: - if _is_subset(val, supported.get(REGISTER2PREFERRED[key])): + # Is the response value with in what this instance supports + _supports = supported.get(REGISTER2PREFERRED[key]) + if _is_subset(val, _supports): registered[key] = val else: - logger.warning(f'OP tells me to do something I do not support: {key} = {val}') + logger.warning( + f'OP tells me to do something I do not support: (key) = {val} not within ' + f'{_supports}') + _val = _intersection(val, _supports) + if _val: + registered[key] = _val + else: + raise ValueError('Not able to support the OPs choice') else: registered[key] = val # Should I just accept with the OP says ?? diff --git a/src/idpyoidc/client/entity.py b/src/idpyoidc/client/entity.py index e362eccb..9e8b7a8e 100644 --- a/src/idpyoidc/client/entity.py +++ b/src/idpyoidc/client/entity.py @@ -13,6 +13,7 @@ from idpyoidc.client.client_auth import method_to_item from idpyoidc.client.configure import Configuration from idpyoidc.client.defaults import DEFAULT_OAUTH2_SERVICES +from idpyoidc.client.defaults import DEFAULT_OIDC_SERVICES from idpyoidc.client.service import init_services from idpyoidc.client.service_context import ServiceContext from idpyoidc.context import OidcContext @@ -75,7 +76,7 @@ def redirect_uris_from_callback_uris(callback_uris): return res -class Entity(Unit): # This is a Client +class Entity(Unit): # This is a Client. What type is undefined here. parameter = { 'entity_id': None, 'jwks_uri': None, @@ -117,7 +118,10 @@ def __init__( _srvs = None if not _srvs: - _srvs = DEFAULT_OAUTH2_SERVICES + if client_type == 'oauth2': + _srvs = DEFAULT_OAUTH2_SERVICES + else: + _srvs = DEFAULT_OIDC_SERVICES self._service = init_services(service_definitions=_srvs, upstream_get=self.unit_get) diff --git a/src/idpyoidc/client/oauth2/__init__.py b/src/idpyoidc/client/oauth2/__init__.py index 4800a7b9..ab15c941 100755 --- a/src/idpyoidc/client/oauth2/__init__.py +++ b/src/idpyoidc/client/oauth2/__init__.py @@ -37,6 +37,7 @@ class ExpiredToken(Exception): class Client(Entity): + client_type = 'oauth2' def __init__( self, keyjar: Optional[KeyJar] = None, @@ -69,7 +70,7 @@ def __init__( """ if not client_type: - client_type = "oauth2" + client_type = self.client_type if verify_ssl is False: # just ignore verify_ssl until it goes away diff --git a/src/idpyoidc/client/oauth2/server_metadata.py b/src/idpyoidc/client/oauth2/server_metadata.py index 185da6e2..9bc868f4 100644 --- a/src/idpyoidc/client/oauth2/server_metadata.py +++ b/src/idpyoidc/client/oauth2/server_metadata.py @@ -118,6 +118,8 @@ def _update_service_context(self, resp): # that. Otherwise, a new Key Jar is minted try: _keyjar = self.upstream_get('attribute', 'keyjar') + if _keyjar is None: + _keyjar = KeyJar() except KeyError: _keyjar = KeyJar() diff --git a/src/idpyoidc/client/oidc/__init__.py b/src/idpyoidc/client/oidc/__init__.py index 05bcc894..7d171ef9 100755 --- a/src/idpyoidc/client/oidc/__init__.py +++ b/src/idpyoidc/client/oidc/__init__.py @@ -77,6 +77,7 @@ class FetchException(Exception): class RP(oauth2.Client): + client_type = 'oidc' def __init__( self, @@ -93,7 +94,10 @@ def __init__( **kwargs ): self.upstream_get = upstream_get - _srvs = services or DEFAULT_OIDC_SERVICES + if services: + _srvs = services + else: + _srvs = config.get("services", DEFAULT_OIDC_SERVICES) oauth2.Client.__init__( self, diff --git a/src/idpyoidc/client/oidc/access_token.py b/src/idpyoidc/client/oidc/access_token.py index 547f0ed2..9eb4329f 100644 --- a/src/idpyoidc/client/oidc/access_token.py +++ b/src/idpyoidc/client/oidc/access_token.py @@ -24,7 +24,7 @@ class AccessToken(access_token.AccessToken): default_authn_method = "client_secret_basic" _supports = { - "token_endpoint_auth_method": get_client_authn_methods, + "token_endpoint_auth_methods_supported": get_client_authn_methods, "token_endpoint_auth_signing_alg_values_supported": get_signing_algs } diff --git a/src/idpyoidc/client/provider/github.py b/src/idpyoidc/client/provider/github.py index 3c32b687..123b1191 100644 --- a/src/idpyoidc/client/provider/github.py +++ b/src/idpyoidc/client/provider/github.py @@ -28,7 +28,7 @@ class AccessToken(access_token.AccessToken): response_body_type = "urlencoded" _supports = { - "token_endpoint_auth_method": get_client_authn_methods, + "token_endpoint_auth_methods_supported": get_client_authn_methods, "token_endpoint_auth_signing_alg_values_supported": get_signing_algs } diff --git a/src/idpyoidc/client/provider/linkedin.py b/src/idpyoidc/client/provider/linkedin.py index 0d5db7ab..aec69216 100644 --- a/src/idpyoidc/client/provider/linkedin.py +++ b/src/idpyoidc/client/provider/linkedin.py @@ -34,7 +34,7 @@ class AccessToken(access_token.AccessToken): error_msg = oauth2.TokenErrorResponse _supports = { - "token_endpoint_auth_method": get_client_authn_methods, + "token_endpoint_auth_methods_supported": get_client_authn_methods, "token_endpoint_auth_signing_alg_values_supported": get_signing_algs } diff --git a/src/idpyoidc/client/service.py b/src/idpyoidc/client/service.py index ece09d01..0fc6f52f 100644 --- a/src/idpyoidc/client/service.py +++ b/src/idpyoidc/client/service.py @@ -674,11 +674,9 @@ def construct_uris(self, else: _path = self._callback_path.get(uri) if isinstance(_path, str): - _callback_uris[uri] = self.get_uri(base_url, self._callback_path.get(_path), - hex) + _callback_uris[uri] = self.get_uri(base_url, _path, hex) else: - _callback_uris[uri] = [self.get_uri(base_url, self._callback_path.get(_var), - hex) for _var in _path] + _callback_uris[uri] = [self.get_uri(base_url, _var, hex) for _var in _path] return _callback_uris diff --git a/src/idpyoidc/server/client_authn.py b/src/idpyoidc/server/client_authn.py index a414a1bc..1bcd95b4 100755 --- a/src/idpyoidc/server/client_authn.py +++ b/src/idpyoidc/server/client_authn.py @@ -451,7 +451,7 @@ def verify_client( get_client_id_from_token: Optional[Callable] = None, endpoint=None, # Optional[Endpoint] also_known_as: Optional[Dict[str, str]] = None, -): +) -> dict: """ Initiated Guessing ! diff --git a/tests/private/token_jwks.json b/tests/private/token_jwks.json index f8537bbb..1c591f5b 100644 --- a/tests/private/token_jwks.json +++ b/tests/private/token_jwks.json @@ -1 +1 @@ -{"keys": [{"kty": "oct", "use": "enc", "kid": "code", "k": "vSHDkLBHhDStkR0NWu8519rmV5zmnm5_"}, {"kty": "oct", "use": "enc", "kid": "refresh", "k": "Cy6KTXiLPPvNj33-4kuAsk2diPuoZkCC"}]} \ No newline at end of file +{"keys": [{"kty": "oct", "use": "enc", "kid": "code", "k": "vSHDkLBHhDStkR0NWu8519rmV5zmnm5_"}, {"kty": "oct", "use": "enc", "kid": "refresh", "k": "Nazjj3RrR-yyo33HAWFbMYrH7rhwD77V"}]} \ No newline at end of file diff --git a/tests/test_client_27_conversation.py b/tests/test_client_27_conversation.py index 06c25769..f1117ca9 100644 --- a/tests/test_client_27_conversation.py +++ b/tests/test_client_27_conversation.py @@ -209,7 +209,7 @@ def test_conversation(): provider_info_service = entity.get_service("provider_info") info = provider_info_service.get_request_parameters() - assert info["url"] == "https://example.org/op/.well-known/openid" "-configuration" + assert info["url"] == "https://example.org/op/.well-known/openid-configuration" provider_info_response = json.dumps( { From 55859f6f47c2f4d987458b8ec697b026208560f9 Mon Sep 17 00:00:00 2001 From: Roland Hedberg Date: Thu, 23 Feb 2023 09:21:07 +0100 Subject: [PATCH 64/76] Fixed tests ! Missed changing claim names. In configuration *_supported names must be used. --- src/idpyoidc/client/claims/oauth2.py | 4 ++-- src/idpyoidc/client/claims/transform.py | 4 ++-- tests/private/token_jwks.json | 2 +- tests/pub_client.jwks | 2 +- tests/pub_iss.jwks | 2 +- tests/test_08_transform.py | 14 ++++++-------- tests/test_09_work_condition.py | 1 + tests/test_client_02b_entity_metadata.py | 2 +- tests/test_client_21_oidc_service.py | 2 +- tests/test_client_26_read_registration.py | 12 +++++++----- tests/test_client_28_rp_handler_oidc.py | 6 +++--- tests/test_client_30_rph_defaults.py | 1 - tests/test_client_41_rp_handler_persistent.py | 6 +++--- 13 files changed, 29 insertions(+), 29 deletions(-) diff --git a/src/idpyoidc/client/claims/oauth2.py b/src/idpyoidc/client/claims/oauth2.py index 59536885..a979faa9 100644 --- a/src/idpyoidc/client/claims/oauth2.py +++ b/src/idpyoidc/client/claims/oauth2.py @@ -7,8 +7,8 @@ class Claims(claims.Claims): _supports = { "redirect_uris": None, - "grant_types": ["authorization_code", "implicit", "refresh_token"], - "response_types": ["code"], + "grant_types_supported": ["authorization_code", "implicit", "refresh_token"], + "response_types_supported": ["code"], "client_id": None, 'client_secret': None, "client_name": None, diff --git a/src/idpyoidc/client/claims/transform.py b/src/idpyoidc/client/claims/transform.py index 2459774c..744f1a77 100644 --- a/src/idpyoidc/client/claims/transform.py +++ b/src/idpyoidc/client/claims/transform.py @@ -159,13 +159,13 @@ def preferred_to_registered(prefers: dict, supported: dict, registered[key] = val else: logger.warning( - f'OP tells me to do something I do not support: (key) = {val} not within ' + f'OP tells me to do something I do not support: {key} = {val} not within ' f'{_supports}') _val = _intersection(val, _supports) if _val: registered[key] = _val else: - raise ValueError('Not able to support the OPs choice') + raise ValueError(f'Not able to support the OPs choice: {key}={val}') else: registered[key] = val # Should I just accept with the OP says ?? diff --git a/tests/private/token_jwks.json b/tests/private/token_jwks.json index 1c591f5b..1575a33f 100644 --- a/tests/private/token_jwks.json +++ b/tests/private/token_jwks.json @@ -1 +1 @@ -{"keys": [{"kty": "oct", "use": "enc", "kid": "code", "k": "vSHDkLBHhDStkR0NWu8519rmV5zmnm5_"}, {"kty": "oct", "use": "enc", "kid": "refresh", "k": "Nazjj3RrR-yyo33HAWFbMYrH7rhwD77V"}]} \ No newline at end of file +{"keys": [{"kty": "oct", "use": "enc", "kid": "code", "k": "vSHDkLBHhDStkR0NWu8519rmV5zmnm5_"}, {"kty": "oct", "use": "enc", "kid": "refresh", "k": "pwcNBtEhyGiqrg0OeikHmSnTRs8_LZrc"}]} \ No newline at end of file diff --git a/tests/pub_client.jwks b/tests/pub_client.jwks index d5ce25ed..84a27042 100644 --- a/tests/pub_client.jwks +++ b/tests/pub_client.jwks @@ -1 +1 @@ -{"keys": [{"kty": "EC", "use": "sig", "kid": "azZQQ2FEQjh3QnVZWVdrbHJkMEZSaWR6aVJ0LTBjeUFfeWRlbTRrRFZ5VQ", "crv": "P-256", "x": "2ADe18caWWGp6hpRbfa9HqQHDFNpid9xUmR56Wzm_wc", "y": "HnD_8QBanz4Y-UF8mKQFZXfqkGkXUSm34mLsdDKtSyk"}, {"kty": "RSA", "use": "sig", "kid": "SHEyYWcwNVk0LTdROTZzZ2FUWndIVXdack0xWUM5SEpwcS03dVUxWU4zRQ", "n": "rRz52ddyP9Y2ezSlRsnkt-sjXfV_Ii7vOFX-cStLE3IUlVeSJGEe_kAASLr2r3BE2unjntaxj67NP8D95h_rzG1SpCklTEn-aTe3FOwNyTzUH_oiDVeRoEcf04Y43ciRGYRB5PhI6ii-2lYuig6hyUr776Qxiu6-0zw-M_ay2MgGSy5CEj55dDSvcUyxStUObxGpPWnEvybO1vnE7iJEWGNe0L5uPe5nLidOiR-JwjxSWEx1xZYtIjxaf2Ulu-qu4hwgwBUQdx4bNZyBfljKj55skWuHqPMG3xMjnedQC6Ms5bR3rIkbBpvmgI3kJK-4CZikM6ruyLo94-Lk19aYQw", "e": "AQAB"}]} \ No newline at end of file +{"keys": [{"kty": "EC", "use": "sig", "kid": "azZQQ2FEQjh3QnVZWVdrbHJkMEZSaWR6aVJ0LTBjeUFfeWRlbTRrRFZ5VQ", "crv": "P-256", "x": "2ADe18caWWGp6hpRbfa9HqQHDFNpid9xUmR56Wzm_wc", "y": "HnD_8QBanz4Y-UF8mKQFZXfqkGkXUSm34mLsdDKtSyk"}, {"kty": "RSA", "use": "sig", "kid": "SHEyYWcwNVk0LTdROTZzZ2FUWndIVXdack0xWUM5SEpwcS03dVUxWU4zRQ", "e": "AQAB", "n": "rRz52ddyP9Y2ezSlRsnkt-sjXfV_Ii7vOFX-cStLE3IUlVeSJGEe_kAASLr2r3BE2unjntaxj67NP8D95h_rzG1SpCklTEn-aTe3FOwNyTzUH_oiDVeRoEcf04Y43ciRGYRB5PhI6ii-2lYuig6hyUr776Qxiu6-0zw-M_ay2MgGSy5CEj55dDSvcUyxStUObxGpPWnEvybO1vnE7iJEWGNe0L5uPe5nLidOiR-JwjxSWEx1xZYtIjxaf2Ulu-qu4hwgwBUQdx4bNZyBfljKj55skWuHqPMG3xMjnedQC6Ms5bR3rIkbBpvmgI3kJK-4CZikM6ruyLo94-Lk19aYQw"}]} \ No newline at end of file diff --git a/tests/pub_iss.jwks b/tests/pub_iss.jwks index 77081f40..9b062907 100644 --- a/tests/pub_iss.jwks +++ b/tests/pub_iss.jwks @@ -1 +1 @@ -{"keys": [{"kty": "EC", "use": "sig", "kid": "SmdKMlVGcG1zMnprdDdXZGpGWEczdHhlZVpGbkx1THpPdUY4d0w4bnZkSQ", "crv": "P-256", "x": "tRHJYm0fsOi0icpGEb33qiDVgt68ltMoYSWdLGhDGz4", "y": "fRpX0i6p5Jigf5I0qwW34PyStosMShwWAWS8x_w5o7E"}, {"kty": "RSA", "use": "sig", "kid": "R0FsaFdqREFaUFp1c0MwbUpsbHVSZ200blBJZWJVMTUtNGsyVlBmdHk5UQ", "n": "2ilgsKVqF92KfhwmosSVeZOaDgb3RF1mbg-pqkmLO6YpOO06LF4V4angF-GhP-ysAm2E75aSIU4tnHVThFlcxTgKFqjYKJQXyVzTVK2r-L2IbvFPaDtvoU6WteybpMlIUVk2po3cFDGObCWYKCm7CUOLlwH0uOpui66P9VSCqdKVKbJRAQBvTSbP10KWPxulfqjWGJtHO5fY7-JVWwOBkG-eHSJIT_uaoPjyvKCZjknq04bLUV9qP78KRQpRyYijBN60w2v8F79baN9CN10TIEjjWKGz0uX0M_YYQzTUoSY5l5ka9RkL3wT4o2iQ1t5nHphX6aA-gqwgCQmi-nvjaw", "e": "AQAB"}]} \ No newline at end of file +{"keys": [{"kty": "EC", "use": "sig", "kid": "SmdKMlVGcG1zMnprdDdXZGpGWEczdHhlZVpGbkx1THpPdUY4d0w4bnZkSQ", "crv": "P-256", "x": "tRHJYm0fsOi0icpGEb33qiDVgt68ltMoYSWdLGhDGz4", "y": "fRpX0i6p5Jigf5I0qwW34PyStosMShwWAWS8x_w5o7E"}, {"kty": "RSA", "use": "sig", "kid": "R0FsaFdqREFaUFp1c0MwbUpsbHVSZ200blBJZWJVMTUtNGsyVlBmdHk5UQ", "e": "AQAB", "n": "2ilgsKVqF92KfhwmosSVeZOaDgb3RF1mbg-pqkmLO6YpOO06LF4V4angF-GhP-ysAm2E75aSIU4tnHVThFlcxTgKFqjYKJQXyVzTVK2r-L2IbvFPaDtvoU6WteybpMlIUVk2po3cFDGObCWYKCm7CUOLlwH0uOpui66P9VSCqdKVKbJRAQBvTSbP10KWPxulfqjWGJtHO5fY7-JVWwOBkG-eHSJIT_uaoPjyvKCZjknq04bLUV9qP78KRQpRyYijBN60w2v8F79baN9CN10TIEjjWKGz0uX0M_YYQzTUoSY5l5ka9RkL3wT4o2iQ1t5nHphX6aA-gqwgCQmi-nvjaw"}]} \ No newline at end of file diff --git a/tests/test_08_transform.py b/tests/test_08_transform.py index dab3adcf..08588bd2 100644 --- a/tests/test_08_transform.py +++ b/tests/test_08_transform.py @@ -42,7 +42,6 @@ def setup(self): self.supported = supported def test_supported(self): - assert 'token_endpoint_auth_methods_supported' not in self.supported # These are all the available configuration parameters assert set(self.supported.keys()) == { 'acr_values_supported', @@ -88,8 +87,8 @@ def test_supported(self): 'scopes_supported', 'sector_identifier_uri', 'subject_types_supported', - 'token_endpoint_auth_method', - # 'token_endpoint_auth_methods_supported', + # 'token_endpoint_auth_method', + 'token_endpoint_auth_methods_supported', 'token_endpoint_auth_signing_alg_values_supported', 'tos_uri', 'userinfo_encryption_alg_values_supported', @@ -117,7 +116,6 @@ def test_oidc_setup(self): 'require_request_uri_registration', 'service_documentation', 'token_endpoint', - 'token_endpoint_auth_methods_supported', 'ui_locales_supported', 'userinfo_endpoint'} @@ -147,7 +145,6 @@ def test_oidc_setup(self): 'requests_dir', 'require_auth_time', 'sector_identifier_uri', - 'token_endpoint_auth_method', 'tos_uri'} claims = OIDC_Claims() @@ -173,7 +170,7 @@ def test_oidc_setup(self): 'response_types_supported', 'scopes_supported', 'subject_types_supported', - 'token_endpoint_auth_method', + 'token_endpoint_auth_methods_supported', 'token_endpoint_auth_signing_alg_values_supported', 'userinfo_encryption_alg_values_supported', 'userinfo_encryption_enc_values_supported', @@ -187,7 +184,7 @@ def test_oidc_setup(self): reg_claim.append(key) assert set(RegistrationRequest.c_param.keys()).difference(set(reg_claim)) == { - 'post_logout_redirect_uri', 'token_endpoint_auth_method'} + 'post_logout_redirect_uri'} # Which ones are list -> singletons @@ -250,7 +247,7 @@ def test_provider_info(self): 'response_types_supported', 'scopes_supported', 'subject_types_supported', - 'token_endpoint_auth_method', + 'token_endpoint_auth_methods_supported', 'token_endpoint_auth_signing_alg_values_supported', 'userinfo_encryption_alg_values_supported', 'userinfo_encryption_enc_values_supported', @@ -350,6 +347,7 @@ def test_registration_response(self): 'request_object_signing_alg', 'response_types', 'subject_type', + 'token_endpoint_auth_method', 'token_endpoint_auth_signing_alg', 'userinfo_signed_response_alg'} diff --git a/tests/test_09_work_condition.py b/tests/test_09_work_condition.py index 9fbb8b34..b6f41230 100644 --- a/tests/test_09_work_condition.py +++ b/tests/test_09_work_condition.py @@ -175,6 +175,7 @@ def test_registration_response(self): 'request_object_signing_alg', 'response_types', 'subject_type', + 'token_endpoint_auth_method', 'token_endpoint_auth_signing_alg', 'userinfo_signed_response_alg'} diff --git a/tests/test_client_02b_entity_metadata.py b/tests/test_client_02b_entity_metadata.py index 491c054f..fbc40ef8 100644 --- a/tests/test_client_02b_entity_metadata.py +++ b/tests/test_client_02b_entity_metadata.py @@ -87,7 +87,7 @@ def test_create_client(): 'response_types_supported', 'scopes_supported', 'subject_types_supported', - 'token_endpoint_auth_method', + 'token_endpoint_auth_methods_supported', 'token_endpoint_auth_signing_alg_values_supported', 'userinfo_signing_alg_values_supported'} diff --git a/tests/test_client_21_oidc_service.py b/tests/test_client_21_oidc_service.py index a2de0faa..729314a3 100644 --- a/tests/test_client_21_oidc_service.py +++ b/tests/test_client_21_oidc_service.py @@ -487,7 +487,7 @@ def create_service(self): "response_types_supported": ["code"], "request_object_signing_alg_values_supported": ["ES256"], "encrypt_id_token_supported": False, # default - "token_endpoint_auth_method": ["private_key_jwt"], + "token_endpoint_auth_methods_supported": ["private_key_jwt"], "token_endpoint_auth_signing_alg_values_supported": ["ES256"], "userinfo_signing_alg_values_supported": ["ES256"], "post_logout_redirect_uris": ["https://rp.example.com/post"], diff --git a/tests/test_client_26_read_registration.py b/tests/test_client_26_read_registration.py index dc53189e..cb8026f9 100644 --- a/tests/test_client_26_read_registration.py +++ b/tests/test_client_26_read_registration.py @@ -1,11 +1,11 @@ import json import time +from cryptojwt.utils import as_bytes import pytest +import requests import responses -from cryptojwt.utils import as_bytes -import requests from idpyoidc.client.entity import Entity from idpyoidc.message.oidc import RegistrationResponse @@ -22,18 +22,20 @@ def create_request(self): "requests_dir": "requests", "base_url": "https://example.com/cli/", "application_type": "web", - "response_types": ["code"], + "response_types_supported": ["code"], "contacts": ["ops@example.org"], "jwks_uri": "https://example.com/rp/static/jwks.json", "redirect_uris": ["{}/authz_cb".format(RP_BASEURL)], - "token_endpoint_auth_method": "client_secret_basic", - "grant_types": ["authorization_code"], + "token_endpoint_auth_methods_supported": ["client_secret_basic"], + "grant_types_supported": ["authorization_code"], } services = { "registration": {"class": "idpyoidc.client.oidc.registration.Registration"}, "read_registration": { "class": "idpyoidc.client.oidc.read_registration.RegistrationRead" }, + 'authorization': {'class': 'idpyoidc.client.oidc.authorization.Authorization'}, + 'accesstoken': {'class': 'idpyoidc.client.oidc.access_token.AccessToken'} } self.entity = Entity(config=client_config, services=services) diff --git a/tests/test_client_28_rp_handler_oidc.py b/tests/test_client_28_rp_handler_oidc.py index f208d211..e6d14553 100644 --- a/tests/test_client_28_rp_handler_oidc.py +++ b/tests/test_client_28_rp_handler_oidc.py @@ -66,7 +66,7 @@ "preference": { "response_types_supported": ["code"], "scopes_supported": ["r_basicprofile", "r_emailaddress"], - "token_endpoint_auth_method": ["client_secret_post"], + "token_endpoint_auth_methods_supported": ["client_secret_post"], }, "provider_info": { "authorization_endpoint": "https://www.linkedin.com/oauth/v2/authorization", @@ -87,7 +87,7 @@ "preference": { "response_types_supported": ["code"], "scopes_supported": ["email", "public_profile"], - "token_endpoint_auth_method": [], + "token_endpoint_auth_methods_supported": [], }, "redirect_uris": ["{}/authz_cb/facebook".format(BASE_URL)], "provider_info": { @@ -115,7 +115,7 @@ "preference": { "response_types_supported": ["code"], "scopes_supported": ["user", "public_repo", 'openid'], - "token_endpoint_auth_method": [], + "token_endpoint_auth_methods_supported": [], "verify_args": {"allow_sign_alg_none": True}, }, "provider_info": { diff --git a/tests/test_client_30_rph_defaults.py b/tests/test_client_30_rph_defaults.py index 1355bf22..dbef6550 100644 --- a/tests/test_client_30_rph_defaults.py +++ b/tests/test_client_30_rph_defaults.py @@ -46,7 +46,6 @@ def test_init_client(self): 'request_object_encryption_alg_values_supported', 'request_object_encryption_enc_values_supported', 'scopes_supported', - 'token_endpoint_auth_method', 'userinfo_encryption_alg_values_supported', 'userinfo_encryption_enc_values_supported'} diff --git a/tests/test_client_41_rp_handler_persistent.py b/tests/test_client_41_rp_handler_persistent.py index 965392d9..db08eafe 100644 --- a/tests/test_client_41_rp_handler_persistent.py +++ b/tests/test_client_41_rp_handler_persistent.py @@ -55,7 +55,7 @@ "preference": { "response_types": ["code"], "scope": ["r_basicprofile", "r_emailaddress"], - "token_endpoint_auth_method": "client_secret_post", + "token_endpoint_auth_methods_supported": ["client_secret_post"], }, "provider_info": { "authorization_endpoint": "https://www.linkedin.com/oauth/v2/authorization", @@ -76,7 +76,7 @@ "preference": { "response_types": ["code"], "scope": ["email", "public_profile"], - "token_endpoint_auth_method": "", + "token_endpoint_auth_methods_supported": [], }, "redirect_uris": ["{}/authz_cb/facebook".format(BASE_URL)], "provider_info": { @@ -104,7 +104,7 @@ "preference": { "response_types": ["code"], "scopes_supported": ["user", "public_repo"], - "token_endpoint_auth_method": "", + "token_endpoint_auth_methods_supported": [], "verify_args": {"allow_sign_alg_none": True}, }, "provider_info": { From 949ebd23300f80cc34c93f5e018123f807dae22b Mon Sep 17 00:00:00 2001 From: Roland Hedberg Date: Thu, 23 Feb 2023 09:40:25 +0100 Subject: [PATCH 65/76] One more claim name change. --- tests/test_tandem_10_oauth2_token_exchange.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_tandem_10_oauth2_token_exchange.py b/tests/test_tandem_10_oauth2_token_exchange.py index a7745066..d4a9014d 100644 --- a/tests/test_tandem_10_oauth2_token_exchange.py +++ b/tests/test_tandem_10_oauth2_token_exchange.py @@ -190,8 +190,8 @@ def create_endpoint(self): "client_id": "client_1", "redirect_uris": ["https://example.com/cb"], "client_salt": "salted_peanuts_cooking", - "token_endpoint_auth_method": "client_secret_post", - "response_types": ["code", "token", "code id_token", "id_token"], + "token_endpoint_auth_methods_supported": ["client_secret_post"], + "response_types_supported": ["code", "token", "code id_token", "id_token"], "allowed_scopes": ["openid", "profile", "offline_access"], } client_2_config = { @@ -200,8 +200,8 @@ def create_endpoint(self): "client_secret": "hemligtlösenord", "redirect_uris": ["https://example.com/cb"], "client_salt": "salted_peanuts_cooking", - "token_endpoint_auth_method": "client_secret_post", - "response_types": ["code", "token", "code id_token", "id_token"], + "token_endpoint_auth_methods_supported": ["client_secret_post"], + "response_types_supported": ["code", "token", "code id_token", "id_token"], "allowed_scopes": ["openid", "profile", "offline_access"], } self.client_1 = Client(client_type='oauth2', config=client_1_config, From 4adc0a42ba4ec6506f05b5d1af1236a91cb77a58 Mon Sep 17 00:00:00 2001 From: Kostis Triantafyllakis Date: Mon, 6 Mar 2023 21:11:21 +0200 Subject: [PATCH 66/76] Fix registration after fedservice refactor Signed-off-by: Kostis Triantafyllakis --- src/idpyoidc/server/oidc/registration.py | 15 +++++++++++++-- .../test_server_23_oidc_registration_endpoint.py | 9 +++++++++ 2 files changed, 22 insertions(+), 2 deletions(-) mode change 100755 => 100644 src/idpyoidc/server/oidc/registration.py diff --git a/src/idpyoidc/server/oidc/registration.py b/src/idpyoidc/server/oidc/registration.py old mode 100755 new mode 100644 index 9b1cdef7..7b9d4a7f --- a/src/idpyoidc/server/oidc/registration.py +++ b/src/idpyoidc/server/oidc/registration.py @@ -16,6 +16,7 @@ from idpyoidc.message.oidc import RegistrationRequest from idpyoidc.message.oidc import RegistrationResponse from idpyoidc.server.endpoint import Endpoint +from idpyoidc.server.exception import CapabilitiesMisMatch from idpyoidc.server.exception import InvalidRedirectURIError from idpyoidc.server.exception import InvalidSectorIdentifier from idpyoidc.time_util import utc_time_sans_frac @@ -155,7 +156,11 @@ def match_claim(self, claim, val): else: return None else: - return list(set(_val).intersection(set(val))) + _ret = list(set(_val).intersection(set(val))) + if len(_ret) > 0: + return _ret + else: + raise CapabilitiesMisMatch(_my_key) else: if val == _val: return val @@ -407,7 +412,13 @@ def client_registration_setup(self, request, new_id=True, set_secret=True): request.rm_blanks() _context = self.upstream_get("context") - request = self.filter_client_request(request) + try: + request = self.filter_client_request(request) + except CapabilitiesMisMatch as err: + return ResponseMessage( + error="invalid_request", + error_description="Don't support proposed %s" % err, + ) if new_id: if self.kwargs.get("client_id_generator"): diff --git a/tests/test_server_23_oidc_registration_endpoint.py b/tests/test_server_23_oidc_registration_endpoint.py index 04a74858..94fe0633 100755 --- a/tests/test_server_23_oidc_registration_endpoint.py +++ b/tests/test_server_23_oidc_registration_endpoint.py @@ -338,6 +338,15 @@ def test_register_initiate_login_uri_wrong_scheme(self): assert "error" in _resp assert _resp["error"] == "invalid_configuration_request" + def test_register_unsupported_response_type(self): + self.endpoint.upstream_get("context").provider_info["response_types_supported"] = ["token", "id_token"] + _msg = MSG.copy() + _msg["response_types"] = ["id_token token"] + _req = self.endpoint.parse_request(RegistrationRequest(**_msg).to_json()) + _resp = self.endpoint.process_request(request=_req) + assert _resp["error"] == "invalid_request" + assert "response_type" in _resp["error_description"] + def test_match_sp_sep(): assert match_sp_sep("foo bar", "bar foo") From f7fd768bac3a77a46e956e41bfa7e67ad2eb01dd Mon Sep 17 00:00:00 2001 From: roland Date: Tue, 14 Mar 2023 17:37:45 +0100 Subject: [PATCH 67/76] Refactored token endpoint helpers and added support for the two remaining flows: client credentials and resource owner password credentials. --- .../__init__.py | 0 .../cc_access_token.py | 0 .../cc_refresh_access_token.py | 0 .../client/oauth2/client_credentials.py | 43 + .../resource_owner_password_credentials.py | 43 + src/idpyoidc/message/oauth2/__init__.py | 15 +- src/idpyoidc/server/endpoint.py | 7 +- src/idpyoidc/server/oauth2/token.py | 89 +- src/idpyoidc/server/oauth2/token_helper.py | 806 ------------------ .../server/oauth2/token_helper/__init__.py | 176 ++++ .../oauth2/token_helper/access_token.py | 206 +++++ .../oauth2/token_helper/client_credentials.py | 77 ++ .../oauth2/token_helper/refresh_token.py | 147 ++++ .../resource_owner_password_credentials.py | 107 +++ .../oauth2/token_helper/token_exchange.py | 310 +++++++ .../server/oidc/backchannel_authentication.py | 2 +- src/idpyoidc/server/oidc/token.py | 6 +- .../server/oidc/token_helper/__init__.py | 0 .../access_token.py} | 172 ---- .../server/oidc/token_helper/refresh_token.py | 179 ++++ .../oidc/token_helper/token_exchange.py | 14 + src/idpyoidc/server/user_authn/user.py | 40 +- tests/private/token_jwks.json | 2 +- tests/pub_client.jwks | 2 +- tests/pub_iss.jwks | 2 +- tests/test_05_oauth2.py | 7 +- tests/test_client_25_cc_oauth2_service.py | 177 ---- tests/test_client_25_oauth2_cc_ropc.py | 120 +++ ...st_server_23_oidc_registration_endpoint.py | 2 +- tests/test_server_24_oauth2_token_endpoint.py | 107 ++- tests/test_server_25_oauth2_cc_ropc.py | 0 tests/test_server_31_oauth2_introspection.py | 7 +- .../test_server_32_oidc_read_registration.py | 3 +- tests/test_server_35_oidc_token_endpoint.py | 11 +- tests/test_server_36_oauth2_token_exchange.py | 6 +- ...st_server_38_oauth2_revocation_endpoint.py | 2 +- ...t_server_40_oauth2_pushed_authorization.py | 3 + tests/test_tandem_10_oauth2_token_exchange.py | 13 +- 38 files changed, 1658 insertions(+), 1245 deletions(-) rename src/idpyoidc/client/oauth2/{client_credentials => Xclient_credentials}/__init__.py (100%) rename src/idpyoidc/client/oauth2/{client_credentials => Xclient_credentials}/cc_access_token.py (100%) rename src/idpyoidc/client/oauth2/{client_credentials => Xclient_credentials}/cc_refresh_access_token.py (100%) create mode 100644 src/idpyoidc/client/oauth2/client_credentials.py create mode 100644 src/idpyoidc/client/oauth2/resource_owner_password_credentials.py delete mode 100755 src/idpyoidc/server/oauth2/token_helper.py create mode 100644 src/idpyoidc/server/oauth2/token_helper/__init__.py create mode 100755 src/idpyoidc/server/oauth2/token_helper/access_token.py create mode 100755 src/idpyoidc/server/oauth2/token_helper/client_credentials.py create mode 100755 src/idpyoidc/server/oauth2/token_helper/refresh_token.py create mode 100755 src/idpyoidc/server/oauth2/token_helper/resource_owner_password_credentials.py create mode 100755 src/idpyoidc/server/oauth2/token_helper/token_exchange.py create mode 100644 src/idpyoidc/server/oidc/token_helper/__init__.py rename src/idpyoidc/server/oidc/{token_helper.py => token_helper/access_token.py} (55%) create mode 100755 src/idpyoidc/server/oidc/token_helper/refresh_token.py create mode 100755 src/idpyoidc/server/oidc/token_helper/token_exchange.py delete mode 100644 tests/test_client_25_cc_oauth2_service.py create mode 100644 tests/test_client_25_oauth2_cc_ropc.py create mode 100644 tests/test_server_25_oauth2_cc_ropc.py diff --git a/src/idpyoidc/client/oauth2/client_credentials/__init__.py b/src/idpyoidc/client/oauth2/Xclient_credentials/__init__.py similarity index 100% rename from src/idpyoidc/client/oauth2/client_credentials/__init__.py rename to src/idpyoidc/client/oauth2/Xclient_credentials/__init__.py diff --git a/src/idpyoidc/client/oauth2/client_credentials/cc_access_token.py b/src/idpyoidc/client/oauth2/Xclient_credentials/cc_access_token.py similarity index 100% rename from src/idpyoidc/client/oauth2/client_credentials/cc_access_token.py rename to src/idpyoidc/client/oauth2/Xclient_credentials/cc_access_token.py diff --git a/src/idpyoidc/client/oauth2/client_credentials/cc_refresh_access_token.py b/src/idpyoidc/client/oauth2/Xclient_credentials/cc_refresh_access_token.py similarity index 100% rename from src/idpyoidc/client/oauth2/client_credentials/cc_refresh_access_token.py rename to src/idpyoidc/client/oauth2/Xclient_credentials/cc_refresh_access_token.py diff --git a/src/idpyoidc/client/oauth2/client_credentials.py b/src/idpyoidc/client/oauth2/client_credentials.py new file mode 100644 index 00000000..3c7459de --- /dev/null +++ b/src/idpyoidc/client/oauth2/client_credentials.py @@ -0,0 +1,43 @@ +import logging +from typing import Optional +from typing import Union + +from idpyoidc.client.service import Service +from idpyoidc.message import Message +from idpyoidc.message import oauth2 +from idpyoidc.time_util import time_sans_frac + + +class CCAccessTokenRequest(Service): + """The service that talks to the OAuth2 client credentials endpoint.""" + + msg_type = oauth2.CCAccessTokenRequest + response_cls = oauth2.AccessTokenResponse + error_msg = oauth2.ResponseMessage + endpoint_name = "token_endpoint" + synchronous = True + service_name = "client_credentials" + default_authn_method = "" + http_method = "POST" + + def __init__(self, upstream_get, conf=None): + Service.__init__(self, upstream_get, conf=conf) + self.pre_construct.append(self.cc_pre_construct) + + def cc_pre_construct(self, + request: Union[Message, dict], + service: Service, + post_args: Optional[dict], + **_args): + _grant_type = request.get('grant_type') + if not _grant_type: + request['grant_type'] = 'client_credentials' + elif _grant_type != 'client_credentials': + logging.error('Wrong grant_type') + + return request, post_args + + def update_service_context(self, resp, key: Optional[str] = "", **kwargs): + if "expires_in" in resp: + resp["__expires_at"] = time_sans_frac() + int(resp["expires_in"]) + self.upstream_get("context").cstate.update(key, resp) diff --git a/src/idpyoidc/client/oauth2/resource_owner_password_credentials.py b/src/idpyoidc/client/oauth2/resource_owner_password_credentials.py new file mode 100644 index 00000000..e2148035 --- /dev/null +++ b/src/idpyoidc/client/oauth2/resource_owner_password_credentials.py @@ -0,0 +1,43 @@ +import logging +from typing import Optional +from typing import Union + +from idpyoidc.client.service import Service +from idpyoidc.message import Message +from idpyoidc.message import oauth2 +from idpyoidc.time_util import time_sans_frac + + +class ROPCAccessTokenRequest(Service): + """The service uses the OAuth2 resource owner password credentials flow.""" + + msg_type = oauth2.ROPCAccessTokenRequest + response_cls = oauth2.AccessTokenResponse + error_msg = oauth2.ResponseMessage + endpoint_name = "token_endpoint" + synchronous = True + service_name = "resource_owner_password_credentials" + default_authn_method = "" + http_method = "POST" + + def __init__(self, upstream_get, conf=None): + Service.__init__(self, upstream_get, conf=conf) + self.pre_construct.append(self.ropc_pre_construct) + + def ropc_pre_construct(self, + request: Union[Message, dict], + service: Service, + post_args: Optional[dict], + **_args): + _grant_type = request.get('grant_type') + if not _grant_type: + request['grant_type'] = 'password' + elif _grant_type != 'password': + logging.error('Wrong grant_type') + + return request, post_args + + def update_service_context(self, resp, key: Optional[str] = "", **kwargs): + if "expires_in" in resp: + resp["__expires_at"] = time_sans_frac() + int(resp["expires_in"]) + self.upstream_get("context").cstate.update(key, resp) diff --git a/src/idpyoidc/message/oauth2/__init__.py b/src/idpyoidc/message/oauth2/__init__.py index 717c55e6..723526cf 100644 --- a/src/idpyoidc/message/oauth2/__init__.py +++ b/src/idpyoidc/message/oauth2/__init__.py @@ -287,6 +287,8 @@ class ROPCAccessTokenRequest(Message): "username": SINGLE_OPTIONAL_STRING, "password": SINGLE_OPTIONAL_STRING, "scope": OPTIONAL_LIST_OF_SP_SEP_STRINGS, + "client_id": SINGLE_OPTIONAL_STRING, + "client_secret": SINGLE_OPTIONAL_STRING, } @@ -295,9 +297,16 @@ class CCAccessTokenRequest(Message): Client Credential grant flow access token request """ - c_param = {"grant_type": SINGLE_REQUIRED_STRING, "scope": OPTIONAL_LIST_OF_SP_SEP_STRINGS} - c_default = {"grant_type": "client_credentials"} - c_allowed_values = {"grant_type": ["client_credentials"]} + c_param = { + "client_id": SINGLE_OPTIONAL_STRING, + "client_secret": SINGLE_OPTIONAL_STRING, + "grant_type": SINGLE_REQUIRED_STRING, + "scope": OPTIONAL_LIST_OF_SP_SEP_STRINGS + } + + def verify(self, **kwargs): + if self['grant_type'] != 'client_credentials': + raise ValueError('Grant type MUST be client_credentials') class RefreshAccessTokenRequest(Message): diff --git a/src/idpyoidc/server/endpoint.py b/src/idpyoidc/server/endpoint.py index e1c5f5bb..4ffd8bff 100755 --- a/src/idpyoidc/server/endpoint.py +++ b/src/idpyoidc/server/endpoint.py @@ -225,6 +225,11 @@ def parse_request( if "client_id" in auth_info: req["client_id"] = auth_info["client_id"] + + _auth_method = auth_info.get('method') + if _auth_method and _auth_method != 'public': + req['authenticated'] = True + _client_id = auth_info["client_id"] else: _client_id = req.get("client_id") @@ -239,7 +244,7 @@ def parse_request( # Do any endpoint specific parsing return self.do_post_parse_request( - request=req, client_id=_client_id, http_info=http_info, **kwargs + request=req, client_id=_client_id, http_info=http_info, auth_info=auth_info, **kwargs ) def client_authentication(self, request: Message, http_info: Optional[dict] = None, **kwargs): diff --git a/src/idpyoidc/server/oauth2/token.py b/src/idpyoidc/server/oauth2/token.py index 987bd39d..93081309 100755 --- a/src/idpyoidc/server/oauth2/token.py +++ b/src/idpyoidc/server/oauth2/token.py @@ -10,11 +10,14 @@ from idpyoidc.message.oidc import TokenErrorResponse from idpyoidc.server.endpoint import Endpoint from idpyoidc.server.exception import ProcessError -from idpyoidc.server.oauth2.token_helper import AccessTokenHelper -from idpyoidc.server.oauth2.token_helper import RefreshTokenHelper -from idpyoidc.server.oauth2.token_helper import TokenExchangeHelper +from idpyoidc.server.oauth2.token_helper import TokenEndpointHelper from idpyoidc.server.session import MintingNotAllowed from idpyoidc.util import importer +from .token_helper.access_token import AccessTokenHelper +from .token_helper.client_credentials import ClientCredentials +from .token_helper.refresh_token import RefreshTokenHelper +from .token_helper.resource_owner_password_credentials import ResourceOwnerPasswordCredentials +from .token_helper.token_exchange import TokenExchangeHelper logger = logging.getLogger(__name__) @@ -34,6 +37,8 @@ class Token(Endpoint): "authorization_code": AccessTokenHelper, "refresh_token": RefreshTokenHelper, "urn:ietf:params:oauth:grant-type:token-exchange": TokenExchangeHelper, + "client_credentials": ClientCredentials, + "resource_owner_password_credentials": ResourceOwnerPasswordCredentials, } token_exchange_helper = TokenExchangeHelper @@ -42,63 +47,81 @@ def __init__(self, upstream_get, new_refresh_token=False, **kwargs): self.post_parse_request.append(self._post_parse_request) self.allow_refresh = False self.new_refresh_token = new_refresh_token - self.configure_grant_types(kwargs.get("grant_types_helpers")) + self.grant_type_helper = self.configure_types(kwargs.get("grant_types_helpers"), + self.helper_by_grant_type) self.grant_types_supported = kwargs.get("grant_types_supported", - list(self.helper_by_grant_type.keys())) + list(self.grant_type_helper.keys())) self.revoke_refresh_on_issue = kwargs.get("revoke_refresh_on_issue", False) self.resource_indicators_config = kwargs.get('resource_indicators', None) - def configure_grant_types(self, grant_types_helpers): - if grant_types_helpers is None: - self.helper = {k: v(self) for k, v in self.helper_by_grant_type.items()} - return + def configure_types(self, helpers, default_helpers): + if helpers is None: + return {k: v(self) for k, v in default_helpers.items()} - self.helper = {} - # TODO: do we want to allow any grant_type? - for grant_type, grant_type_options in grant_types_helpers.items(): - _conf = grant_type_options.get("kwargs", {}) - if _conf is False: + _helper = {} + for type, args in helpers.items(): + _kwargs = args.get("kwargs", {}) + if _kwargs is False: continue try: - grant_class = grant_type_options["class"] + _class = args["class"] except (KeyError, TypeError): raise ProcessError( "Token Endpoint's grant types must be True, None or a dict with a" " 'class' key." ) - if isinstance(grant_class, str): + if isinstance(_class, str): try: - grant_class = importer(grant_class) + _class = importer(_class) except (ValueError, AttributeError): raise ProcessError( - f"Token Endpoint's grant type class {grant_class} can't" " be imported." + f"Token Endpoint's helper class {_class} can't" " be imported." ) try: - self.helper[grant_type] = grant_class(self, _conf) + _helper[type] = _class(self, _kwargs) except Exception as e: - raise ProcessError(f"Failed to initialize class {grant_class}: {e}") + raise ProcessError(f"Failed to initialize class {_class}: {e}") + + return _helper + + def _get_helper(self, + request: Union[Message, dict], + client_id: Optional[str] = "") -> Optional[Union[Message, TokenEndpointHelper]]: + grant_type = request.get('grant_type') + if grant_type: + _client_id = client_id or request.get('client_id') + if client_id: + client = self.upstream_get('context').cdb[client_id] + grant_types_supported = client.get("grant_types_supported", + self.grant_types_supported) + if grant_type not in grant_types_supported: + return self.error_cls( + error="invalid_request", + error_description=f"Unsupported grant_type: {grant_type}", + ) - def _post_parse_request( - self, request: Union[Message, dict], client_id: Optional[str] = "", **kwargs - ): - grant_type = request["grant_type"] - _helper = self.helper.get(grant_type) - client = kwargs["context"].cdb[client_id] - grant_types_supported = client.get("grant_types_supported", self.grant_types_supported) - if grant_type not in grant_types_supported: + return self.grant_type_helper.get(grant_type) + else: return self.error_cls( error="invalid_request", - error_description=f"Unsupported grant_type: {grant_type}", + error_description=f"Do not know how to handle this type of request", ) - if _helper: - return _helper.post_parse_request(request, client_id, **kwargs) + + def _post_parse_request( + self, request: Union[Message, dict], client_id: Optional[str] = "", **kwargs + ): + _resp = self._get_helper(request, client_id) + if isinstance(_resp, TokenEndpointHelper): + return _resp.post_parse_request(request, client_id, **kwargs) + elif _resp: + return _resp else: return self.error_cls( error="invalid_request", - error_description=f"Unsupported grant_type: {grant_type}", + error_description=f"Do not know how to handle this type of request", ) def process_request(self, request: Optional[Union[Message, dict]] = None, **kwargs): @@ -115,7 +138,7 @@ def process_request(self, request: Optional[Union[Message, dict]] = None, **kwar return self.error_cls(error="invalid_request") try: - _helper = self.helper.get(request["grant_type"]) + _helper = self._get_helper(request) if _helper: response_args = _helper.process_request(request, **kwargs) else: diff --git a/src/idpyoidc/server/oauth2/token_helper.py b/src/idpyoidc/server/oauth2/token_helper.py deleted file mode 100755 index d6c3a2c1..00000000 --- a/src/idpyoidc/server/oauth2/token_helper.py +++ /dev/null @@ -1,806 +0,0 @@ -import logging -from typing import Optional -from typing import Union - -from cryptojwt import BadSyntax -from cryptojwt.exception import JWKESTException - -from idpyoidc.exception import ImproperlyConfigured -from idpyoidc.exception import MissingRequiredAttribute -from idpyoidc.exception import MissingRequiredValue -from idpyoidc.message import Message -from idpyoidc.message.oauth2 import TokenExchangeRequest -from idpyoidc.message.oauth2 import TokenExchangeResponse -from idpyoidc.message.oidc import RefreshAccessTokenRequest -from idpyoidc.message.oidc import TokenErrorResponse -from idpyoidc.server.constant import DEFAULT_REQUESTED_TOKEN_TYPE -from idpyoidc.server.constant import DEFAULT_TOKEN_LIFETIME -from idpyoidc.server.exception import ToOld -from idpyoidc.server.exception import UnAuthorizedClientScope -from idpyoidc.server.oauth2.authorization import check_unknown_scopes_policy -from idpyoidc.server.session.grant import Grant -from idpyoidc.server.session.token import AuthorizationCode -from idpyoidc.server.session.token import MintingNotAllowed -from idpyoidc.server.session.token import RefreshToken -from idpyoidc.server.session.token import SessionToken -from idpyoidc.server.session.token import TOKEN_TYPES_MAPPING -from idpyoidc.server.token.exception import UnknownToken -from idpyoidc.time_util import utc_time_sans_frac -from idpyoidc.util import importer -from idpyoidc.util import sanitize - -logger = logging.getLogger(__name__) - - -class TokenEndpointHelper(object): - - def __init__(self, endpoint, config=None): - self.endpoint = endpoint - self.config = config - self.error_cls = self.endpoint.error_cls - - def post_parse_request( - self, request: Union[Message, dict], client_id: Optional[str] = "", **kwargs - ): - """Context specific parsing of the request. - This is done after general request parsing and before processing - the request. - """ - raise NotImplementedError - - def process_request(self, req: Union[Message, dict], **kwargs): - """Acts on a process request.""" - raise NotImplementedError - - def _mint_token( - self, - token_class: str, - grant: Grant, - session_id: str, - client_id: str, - based_on: Optional[SessionToken] = None, - scope: Optional[list] = None, - token_args: Optional[dict] = None, - token_type: Optional[str] = "", - ) -> SessionToken: - _context = self.endpoint.upstream_get("context") - _mngr = _context.session_manager - usage_rules = grant.usage_rules.get(token_class) - if usage_rules: - _exp_in = usage_rules.get("expires_in") - else: - _exp_in = DEFAULT_TOKEN_LIFETIME - - token_args = token_args or {} - for meth in _context.token_args_methods: - token_args = meth(_context, client_id, token_args) - - if token_args: - _args = token_args - else: - _args = {} - - token = grant.mint_token( - session_id, - context=_context, - token_class=token_class, - token_handler=_mngr.token_handler[token_class], - based_on=based_on, - usage_rules=usage_rules, - scope=scope, - token_type=token_type, - **_args, - ) - - if _exp_in: - if isinstance(_exp_in, str): - _exp_in = int(_exp_in) - - if _exp_in: - token.expires_at = utc_time_sans_frac() + _exp_in - - _context.session_manager.set(_context.session_manager.unpack_session_key(session_id), grant) - - return token - - -def validate_resource_indicators_policy(request, context, **kwargs): - if "resource" not in request: - return TokenErrorResponse( - error="invalid_target", - error_description="Missing resource parameter", - ) - - resource_servers_per_client = kwargs["resource_servers_per_client"] - client_id = request["client_id"] - - resource_servers_per_client = kwargs.get("resource_servers_per_client", None) - - if isinstance(resource_servers_per_client, - dict) and client_id not in resource_servers_per_client: - return TokenErrorResponse( - error="invalid_target", - error_description=f"Resources for client {client_id} not found", - ) - - if isinstance(resource_servers_per_client, dict): - permitted_resources = [res for res in resource_servers_per_client[client_id]] - else: - permitted_resources = [res for res in resource_servers_per_client] - - common_resources = list(set(request["resource"]).intersection(set(permitted_resources))) - if not common_resources: - return TokenErrorResponse( - error="invalid_target", - error_description=f"Invalid resource requested by client {client_id}", - ) - - common_resources = [r for r in common_resources if r in context.cdb.keys()] - if not common_resources: - return TokenErrorResponse( - error="invalid_target", - error_description=f"Invalid resource requested by client {client_id}", - ) - - if client_id not in common_resources: - common_resources.append(client_id) - - request["resource"] = common_resources - - permitted_scopes = [context.cdb[r]["allowed_scopes"] for r in common_resources] - permitted_scopes = [r for res in permitted_scopes for r in res] - scopes = list(set(request.get("scope", [])).intersection(set(permitted_scopes))) - request["scope"] = scopes - return request - - -class AccessTokenHelper(TokenEndpointHelper): - - def process_request(self, req: Union[Message, dict], **kwargs): - """ - - :param req: - :param kwargs: - :return: - """ - _context = self.endpoint.upstream_get("context") - _mngr = _context.session_manager - logger.debug("Access Token") - - if req["grant_type"] != "authorization_code": - return self.error_cls(error="invalid_request", error_description="Unknown grant_type") - - try: - _access_code = req["code"].replace(" ", "+") - except KeyError: # Missing code parameter - absolutely fatal - return self.error_cls(error="invalid_request", error_description="Missing code") - - _session_info = _mngr.get_session_info_by_token( - _access_code, grant=True, handler_key="authorization_code" - ) - client_id = _session_info["client_id"] - if client_id != req["client_id"]: - logger.debug("{} owner of token".format(client_id)) - logger.warning("Client using token it was not given") - return self.error_cls(error="invalid_grant", error_description="Wrong client") - - _cinfo = self.endpoint.upstream_get("context").cdb.get(client_id) - - if ("resource_indicators" in _cinfo - and "access_token" in _cinfo["resource_indicators"]): - resource_indicators_config = _cinfo["resource_indicators"]["access_token"] - else: - resource_indicators_config = self.endpoint.kwargs.get("resource_indicators", None) - - if resource_indicators_config is not None: - if "policy" not in resource_indicators_config: - policy = {"policy": {"callable": validate_resource_indicators_policy}} - resource_indicators_config.update(policy) - - req = self._enforce_resource_indicators_policy(req, resource_indicators_config) - - if isinstance(req, TokenErrorResponse): - return req - - # if "grant_types_supported" in _context.cdb[client_id]: - # grant_types_supported = _context.cdb[client_id].get("grant_types_supported") - # else: - # grant_types_supported = _context.provider_info["grant_types_supported"] - - grant = _session_info["grant"] - - _based_on = grant.get_token(_access_code) - _supports_minting = _based_on.usage_rules.get("supports_minting", []) - - _authn_req = grant.authorization_request - - # If redirect_uri was in the initial authorization request - # verify that the one given here is the correct one. - if "redirect_uri" in _authn_req: - if req["redirect_uri"] != _authn_req["redirect_uri"]: - return self.error_cls( - error="invalid_request", error_description="redirect_uri mismatch" - ) - - logger.debug("All checks OK") - - issue_refresh = kwargs.get("issue_refresh", False) - - if resource_indicators_config is not None: - scope = req["scope"] - else: - scope = grant.scope - - _response = { - "token_type": "Bearer", - "scope": scope, - } - - if "access_token" in _supports_minting: - - resources = req.get("resource", None) - if resources: - token_args = {"resources": resources} - else: - token_args = None - - try: - token = self._mint_token( - token_class="access_token", - grant=grant, - session_id=_session_info["branch_id"], - client_id=_session_info["client_id"], - based_on=_based_on, - token_args=token_args - ) - except MintingNotAllowed as err: - logger.warning(err) - else: - _response["access_token"] = token.value - if token.expires_at: - _response["expires_in"] = token.expires_at - utc_time_sans_frac() - - if ( - issue_refresh - and "refresh_token" in _supports_minting - ): - try: - refresh_token = self._mint_token( - token_class="refresh_token", - grant=grant, - session_id=_session_info["branch_id"], - client_id=_session_info["client_id"], - based_on=_based_on, - ) - except MintingNotAllowed as err: - logger.warning(err) - else: - _response["refresh_token"] = refresh_token.value - - # since the grant content has changed. Make sure it's stored - _mngr[_session_info["branch_id"]] = grant - - _based_on.register_usage() - - return _response - - def _enforce_resource_indicators_policy(self, request, config): - _context = self.endpoint.upstream_get('context') - - policy = config["policy"] - callable = policy["callable"] - kwargs = policy.get("kwargs", {}) - - if isinstance(callable, str): - try: - fn = importer(callable) - except Exception: - raise ImproperlyConfigured(f"Error importing {callable} policy callable") - else: - fn = callable - try: - return fn(request, context=_context, **kwargs) - except Exception as e: - logger.error(f"Error while executing the {fn} policy callable: {e}") - return self.error_cls(error="server_error", error_description="Internal server error") - - def post_parse_request( - self, request: Union[Message, dict], client_id: Optional[str] = "", **kwargs - ): - """ - This is where clients come to get their access tokens - - :param request: The request - :param client_id: Client identifier - :returns: - """ - - _mngr = self.endpoint.upstream_get("context").session_manager - try: - _session_info = _mngr.get_session_info_by_token( - request["code"], grant=True, handler_key="authorization_code" - ) - except (KeyError, UnknownToken): - logger.error("Access Code invalid") - return self.error_cls(error="invalid_grant", error_description="Unknown code") - - grant = _session_info["grant"] - code = grant.get_token(request["code"]) - if not isinstance(code, AuthorizationCode): - return self.error_cls(error="invalid_request", error_description="Wrong token type") - - if code.is_active() is False: - return self.error_cls(error="invalid_request", error_description="Code inactive") - - _auth_req = grant.authorization_request - - if "client_id" not in request: # Optional for access token request - request["client_id"] = _auth_req["client_id"] - - logger.debug("%s: %s" % (request.__class__.__name__, sanitize(request))) - - return request - - -class RefreshTokenHelper(TokenEndpointHelper): - - def process_request(self, req: Union[Message, dict], **kwargs): - _context = self.endpoint.upstream_get("context") - _mngr = _context.session_manager - logger.debug("Refresh Token") - - if req["grant_type"] != "refresh_token": - return self.error_cls(error="invalid_request", error_description="Wrong grant_type") - - token_value = req["refresh_token"] - _session_info = _mngr.get_session_info_by_token( - token_value, grant=True, handler_key="refresh_token" - ) - logger.debug("Session info: {}".format(_session_info)) - - if _session_info["client_id"] != req["client_id"]: - logger.debug("{} owner of token".format(_session_info["client_id"])) - logger.warning("Client using token it was not given") - return self.error_cls(error="invalid_grant", error_description="Wrong client") - - _grant = _session_info["grant"] - - token_type = "Bearer" - - # Is DPOP supported - if "dpop_signing_alg_values_supported" in _context.provider_info: - _dpop_jkt = req.get("dpop_jkt") - if _dpop_jkt: - _grant.extra["dpop_jkt"] = _dpop_jkt - token_type = "DPoP" - - token = _grant.get_token(token_value) - scope = _grant.find_scope(token) - if "scope" in req: - scope = req["scope"] - access_token = self._mint_token( - token_class="access_token", - grant=_grant, - session_id=_session_info["branch_id"], - client_id=_session_info["client_id"], - based_on=token, - scope=scope, - token_type=token_type, - ) - - _resp = { - "access_token": access_token.value, - "token_type": access_token.token_type, - "scope": scope, - } - - if access_token.expires_at: - _resp["expires_in"] = access_token.expires_at - utc_time_sans_frac() - - _mints = token.usage_rules.get("supports_minting") - issue_refresh = kwargs.get("issue_refresh", False) - if "refresh_token" in _mints and issue_refresh: - refresh_token = self._mint_token( - token_class="refresh_token", - grant=_grant, - session_id=_session_info["branch_id"], - client_id=_session_info["client_id"], - based_on=token, - scope=scope, - ) - refresh_token.usage_rules = token.usage_rules.copy() - _resp["refresh_token"] = refresh_token.value - - token.register_usage() - - if ( - "client_id" in req - and req["client_id"] in _context.cdb - and "revoke_refresh_on_issue" in _context.cdb[req["client_id"]] - ): - revoke_refresh = _context.cdb[req["client_id"]].get("revoke_refresh_on_issue") - else: - revoke_refresh = self.endpoint.revoke_refresh_on_issue - - if revoke_refresh: - token.revoke() - - return _resp - - def post_parse_request( - self, request: Union[Message, dict], client_id: Optional[str] = "", **kwargs - ): - """ - This is where clients come to refresh their access tokens - - :param request: The request - :param client_id: Client identifier - :returns: - """ - - request = RefreshAccessTokenRequest(**request.to_dict()) - _context = self.endpoint.upstream_get("context") - - request.verify( - keyjar=self.endpoint.upstream_get('sttribute', 'keyjar'), - opponent_id=client_id) - - _mngr = _context.session_manager - try: - _session_info = _mngr.get_session_info_by_token( - request["refresh_token"], grant=True, handler_key="refresh_token" - ) - except (KeyError, UnknownToken): - logger.error("Refresh token invalid") - return self.error_cls(error="invalid_grant", error_description="Invalid refresh token") - - grant = _session_info["grant"] - token = grant.get_token(request["refresh_token"]) - - if not isinstance(token, RefreshToken): - return self.error_cls(error="invalid_request", error_description="Wrong token type") - - if token.is_active() is False: - return self.error_cls( - error="invalid_request", error_description="Refresh token inactive" - ) - - if "scope" in request: - req_scopes = set(request["scope"]) - scopes = set(grant.find_scope(token.based_on)) - if not req_scopes.issubset(scopes): - return self.error_cls( - error="invalid_request", - error_description="Invalid refresh scopes", - ) - - return request - - -class TokenExchangeHelper(TokenEndpointHelper): - """Implements Token Exchange a.k.a. RFC8693""" - - token_types_mapping = { - "urn:ietf:params:oauth:token-type:access_token": "access_token", - "urn:ietf:params:oauth:token-type:refresh_token": "refresh_token", - } - - def __init__(self, endpoint, config=None): - TokenEndpointHelper.__init__(self, endpoint=endpoint, config=config) - if config is None: - self.config = { - "requested_token_types_supported": [ - "urn:ietf:params:oauth:token-type:access_token", - "urn:ietf:params:oauth:token-type:refresh_token", - ], - "default_requested_token_type": "urn:ietf:params:oauth:token-type:access_token", - "policy": {"": {"callable": validate_token_exchange_policy}}, - } - else: - self.config = config - - def post_parse_request(self, request, client_id="", **kwargs): - request = TokenExchangeRequest(**request.to_dict()) - - _context = self.endpoint.upstream_get("context") - if "token_exchange" in _context.cdb[request["client_id"]]: - config = _context.cdb[request["client_id"]]["token_exchange"] - else: - config = self.config - - try: - request.verify( - keyjar=self.endpoint.upstream_get('attribute', 'keyjar'), - opponent_id=client_id - ) - except ( - MissingRequiredAttribute, - ValueError, - MissingRequiredValue, - JWKESTException, - ) as err: - return self.endpoint.error_cls(error="invalid_request", error_description="%s" % err) - - self._validate_configuration(config) - - _mngr = _context.session_manager - try: - # token exchange is about minting one token based on another - _handler_key = self.token_types_mapping[request["subject_token_type"]] - _session_info = _mngr.get_session_info_by_token( - request["subject_token"], grant=True, handler_key=_handler_key - ) - except (KeyError, UnknownToken, BadSyntax) as err: - logger.error(f"Subject token invalid ({err}).") - return self.error_cls( - error="invalid_request", error_description="Subject token invalid" - ) - - # Find the token instance based on the token value - token = _mngr.find_token(_session_info["branch_id"], request["subject_token"]) - if token.is_active() is False: - return self.error_cls( - error="invalid_request", error_description="Subject token inactive" - ) - - resp = self._enforce_policy(request, token, config) - if isinstance(resp, TokenErrorResponse): - return resp - - scopes = resp.get("scope", []) - scopes = _context.scopes_handler.filter_scopes(scopes, client_id=resp["client_id"]) - - if not scopes: - logger.error("All requested scopes have been filtered out.") - return self.error_cls( - error="invalid_scope", error_description="Invalid requested scopes" - ) - - _requested_token_type = resp.get( - "requested_token_type", "urn:ietf:params:oauth:token-type:access_token" - ) - _token_class = self.token_types_mapping[_requested_token_type] - if _token_class == "refresh_token" and "offline_access" not in scopes: - return TokenErrorResponse( - error="invalid_request", - error_description="Exchanging this subject token to refresh token forbidden", - ) - - return resp - - def _enforce_policy(self, request, token, config): - _context = self.endpoint.upstream_get("context") - subject_token_types_supported = config.get( - "subject_token_types_supported", self.token_types_mapping.keys() - ) - subject_token_type = request["subject_token_type"] - if subject_token_type not in subject_token_types_supported: - return TokenErrorResponse( - error="invalid_request", - error_description="Unsupported subject token type", - ) - if self.token_types_mapping[subject_token_type] != token.token_class: - return TokenErrorResponse( - error="invalid_request", - error_description="Wrong token type", - ) - - if ( - "requested_token_type" in request - and request["requested_token_type"] not in config["requested_token_types_supported"] - ): - return TokenErrorResponse( - error="invalid_request", - error_description="Unsupported requested token type", - ) - - request_info = dict(scope=request.get("scope", token.scope)) - try: - check_unknown_scopes_policy(request_info, request["client_id"], _context) - except UnAuthorizedClientScope: - return self.error_cls( - error="invalid_grant", - error_description="Unauthorized scope requested", - ) - - if subject_token_type not in config["policy"]: - subject_token_type = "" - - policy = config["policy"][subject_token_type] - callable = policy["callable"] - kwargs = policy.get("kwargs", {}) - - if isinstance(callable, str): - try: - fn = importer(callable) - except Exception: - raise ImproperlyConfigured(f"Error importing {callable} policy callable") - else: - fn = callable - - try: - return fn(request, context=_context, subject_token=token, **kwargs) - except Exception as e: - logger.error(f"Error while executing the {fn} policy callable: {e}") - return self.error_cls(error="server_error", error_description="Internal server error") - - def token_exchange_response(self, token, issued_token_type): - response_args = {} - response_args["access_token"] = token.value - response_args["scope"] = token.scope - response_args["issued_token_type"] = issued_token_type - - if token.expires_at: - response_args["expires_in"] = token.expires_at - utc_time_sans_frac() - if hasattr(token, "token_type"): - response_args["token_type"] = token.token_type - else: - response_args["token_type"] = "N_A" - - return TokenExchangeResponse(**response_args) - - def process_request(self, request, **kwargs): - _context = self.endpoint.upstream_get("context") - _mngr = _context.session_manager - try: - _handler_key = self.token_types_mapping[request["subject_token_type"]] - _session_info = _mngr.get_session_info_by_token( - request["subject_token"], grant=True, handler_key=_handler_key - ) - except ToOld: - logger.error("Subject token has expired.") - return self.error_cls( - error="invalid_request", error_description="Subject token has expired" - ) - except (KeyError, UnknownToken): - logger.error("Subject token invalid.") - return self.error_cls( - error="invalid_request", error_description="Subject token invalid" - ) - - grant = _session_info["grant"] - token = _mngr.find_token(_session_info["branch_id"], request["subject_token"]) - _requested_token_type = request.get( - "requested_token_type", "urn:ietf:params:oauth:token-type:access_token" - ) - - _token_class = self.token_types_mapping[_requested_token_type] - - sid = _session_info["branch_id"] - - _token_type = "Bearer" - # Is DPOP supported - if "dpop_signing_alg_values_supported" in _context.provider_info: - if request.get("dpop_jkt"): - _token_type = "DPoP" - scopes = request.get("scope", []) - - if request["client_id"] != _session_info["client_id"]: - _token_usage_rules = _context.authz.usage_rules(request["client_id"]) - - sid = _mngr.create_exchange_session( - exchange_request=request, - original_grant=grant, - original_session_id=sid, - user_id=_session_info["user_id"], - client_id=request["client_id"], - token_usage_rules=_token_usage_rules, - scopes=scopes, - ) - - try: - _session_info = _mngr.get_session_info(session_id=sid, grant=True) - except Exception: - logger.error("Error retrieving token exchange session information") - return self.error_cls( - error="server_error", error_description="Internal server error" - ) - - resources = request.get("resource") - if resources and request.get("audience"): - resources = list(set(resources + request.get("audience"))) - else: - resources = request.get("audience") - - _token_args = None - if resources: - _token_args = {"resources": resources} - - try: - new_token = self._mint_token( - token_class=_token_class, - grant=_session_info["grant"], - session_id=sid, - client_id=request["client_id"], - based_on=token, - scope=scopes, - token_args=_token_args, - token_type=_token_type, - ) - new_token.expires_at = token.expires_at - except MintingNotAllowed: - logger.error(f"Minting not allowed for {_token_class}") - return self.error_cls( - error="invalid_grant", - error_description="Token Exchange not allowed with that token", - ) - - return self.token_exchange_response(new_token, _requested_token_type) - - def _validate_configuration(self, config): - if "requested_token_types_supported" not in config: - raise ImproperlyConfigured( - "Missing 'requested_token_types_supported' from Token Exchange configuration" - ) - if "policy" not in config: - raise ImproperlyConfigured("Missing 'policy' from Token Exchange configuration") - if "" not in config["policy"]: - raise ImproperlyConfigured( - "Default Token Exchange policy configuration is not defined" - ) - if "callable" not in config["policy"][""]: - raise ImproperlyConfigured( - "Missing 'callable' from default Token Exchange policy configuration" - ) - - _default_requested_token_type = config.get("default_requested_token_type", - DEFAULT_REQUESTED_TOKEN_TYPE) - if _default_requested_token_type not in config["requested_token_types_supported"]: - raise ImproperlyConfigured( - f"Unsupported default requested_token_type {_default_requested_token_type}" - ) - - def get_handler_key(self, request, endpoint_context): - client_info = endpoint_context.cdb.get(request["client_id"], {}) - - default_requested_token_type = ( - client_info.get("token_exchange", {}).get("default_requested_token_type", None) - or - self.config.get("default_requested_token_type", DEFAULT_REQUESTED_TOKEN_TYPE) - ) - - requested_token_type = request.get("requested_token_type", default_requested_token_type) - return TOKEN_TYPES_MAPPING[requested_token_type] - - -def validate_token_exchange_policy(request, context, subject_token, **kwargs): - if "resource" in request: - resource = kwargs.get("resource", []) - if not set(request["resource"]).issubset(set(resource)): - return TokenErrorResponse(error="invalid_target", error_description="Unknown resource") - - if "audience" in request: - if request["subject_token_type"] == "urn:ietf:params:oauth:token-type:refresh_token": - return TokenErrorResponse( - error="invalid_target", error_description="Refresh token has single owner" - ) - audience = kwargs.get("audience", []) - if audience and not set(request["audience"]).issubset(set(audience)): - return TokenErrorResponse(error="invalid_target", error_description="Unknown audience") - - if "actor_token" in request or "actor_token_type" in request: - return TokenErrorResponse( - error="invalid_request", error_description="Actor token not supported" - ) - - if ( - "requested_token_type" in request - and request["requested_token_type"] == "urn:ietf:params:oauth:token-type:refresh_token" - ): - if "offline_access" not in subject_token.scope: - return TokenErrorResponse( - error="invalid_request", - error_description=f"Exchange {request['subject_token_type']} to refresh token " - f"forbidden", - ) - - scopes = request.get("scope", subject_token.scope) - scopes = list(set(scopes).intersection(subject_token.scope)) - if kwargs.get("scope"): - scopes = list(set(scopes).intersection(kwargs.get("scope"))) - if scopes: - request["scope"] = scopes - else: - request.pop("scope") - - return request diff --git a/src/idpyoidc/server/oauth2/token_helper/__init__.py b/src/idpyoidc/server/oauth2/token_helper/__init__.py new file mode 100644 index 00000000..e9bbc96e --- /dev/null +++ b/src/idpyoidc/server/oauth2/token_helper/__init__.py @@ -0,0 +1,176 @@ +import logging +from typing import Optional +from typing import Union + +from idpyoidc.message import Message +from idpyoidc.message.oidc import TokenErrorResponse +from idpyoidc.server.constant import DEFAULT_TOKEN_LIFETIME +from idpyoidc.server.session.grant import Grant +from idpyoidc.server.session.token import SessionToken +from idpyoidc.time_util import utc_time_sans_frac + +logger = logging.getLogger(__name__) + + +class TokenEndpointHelper(object): + + def __init__(self, endpoint, config=None): + self.endpoint = endpoint + self.config = config + self.error_cls = self.endpoint.error_cls + + def post_parse_request( + self, request: Union[Message, dict], client_id: Optional[str] = "", **kwargs + ): + """Context specific parsing of the request. + This is done after general request parsing and before processing + the request. + """ + raise NotImplementedError + + def process_request(self, req: Union[Message, dict], **kwargs): + """Acts on a process request.""" + raise NotImplementedError + + def _mint_token( + self, + token_class: str, + grant: Grant, + session_id: str, + client_id: str, + based_on: Optional[SessionToken] = None, + scope: Optional[list] = None, + token_args: Optional[dict] = None, + token_type: Optional[str] = "", + ) -> SessionToken: + _context = self.endpoint.upstream_get("context") + _mngr = _context.session_manager + usage_rules = grant.usage_rules.get(token_class) + if usage_rules: + _exp_in = usage_rules.get("expires_in") + else: + _exp_in = DEFAULT_TOKEN_LIFETIME + + token_args = token_args or {} + for meth in _context.token_args_methods: + token_args = meth(_context, client_id, token_args) + + if token_args: + _args = token_args + else: + _args = {} + + token = grant.mint_token( + session_id, + context=_context, + token_class=token_class, + token_handler=_mngr.token_handler[token_class], + based_on=based_on, + usage_rules=usage_rules, + scope=scope, + token_type=token_type, + **_args, + ) + + if _exp_in: + if isinstance(_exp_in, str): + _exp_in = int(_exp_in) + + if _exp_in: + token.expires_at = utc_time_sans_frac() + _exp_in + + _context.session_manager.set(_context.session_manager.unpack_session_key(session_id), grant) + + return token + + +def validate_resource_indicators_policy(request, context, **kwargs): + if "resource" not in request: + return TokenErrorResponse( + error="invalid_target", + error_description="Missing resource parameter", + ) + + client_id = request["client_id"] + + resource_servers_per_client = kwargs.get("resource_servers_per_client", []) + + if isinstance(resource_servers_per_client, + dict) and client_id not in resource_servers_per_client: + return TokenErrorResponse( + error="invalid_target", + error_description=f"Resources for client {client_id} not found", + ) + + if isinstance(resource_servers_per_client, dict): + permitted_resources = [res for res in resource_servers_per_client[client_id]] + else: + permitted_resources = [res for res in resource_servers_per_client] + + common_resources = list(set(request["resource"]).intersection(set(permitted_resources))) + if not common_resources: + return TokenErrorResponse( + error="invalid_target", + error_description=f"Invalid resource requested by client {client_id}", + ) + + common_resources = [r for r in common_resources if r in context.cdb.keys()] + if not common_resources: + return TokenErrorResponse( + error="invalid_target", + error_description=f"Invalid resource requested by client {client_id}", + ) + + if client_id not in common_resources: + common_resources.append(client_id) + + request["resource"] = common_resources + + permitted_scopes = [context.cdb[r]["allowed_scopes"] for r in common_resources] + permitted_scopes = [r for res in permitted_scopes for r in res] + scopes = list(set(request.get("scope", [])).intersection(set(permitted_scopes))) + request["scope"] = scopes + return request + + +def validate_token_exchange_policy(request, context, subject_token, **kwargs): + if "resource" in request: + resource = kwargs.get("resource", []) + if not set(request["resource"]).issubset(set(resource)): + return TokenErrorResponse(error="invalid_target", error_description="Unknown resource") + + if "audience" in request: + if request["subject_token_type"] == "urn:ietf:params:oauth:token-type:refresh_token": + return TokenErrorResponse( + error="invalid_target", error_description="Refresh token has single owner" + ) + audience = kwargs.get("audience", []) + if audience and not set(request["audience"]).issubset(set(audience)): + return TokenErrorResponse(error="invalid_target", error_description="Unknown audience") + + if "actor_token" in request or "actor_token_type" in request: + return TokenErrorResponse( + error="invalid_request", error_description="Actor token not supported" + ) + + if ( + "requested_token_type" in request + and request["requested_token_type"] == "urn:ietf:params:oauth:token-type:refresh_token" + ): + if "offline_access" not in subject_token.scope: + return TokenErrorResponse( + error="invalid_request", + error_description=f"Exchange {request['subject_token_type']} to refresh token " + f"forbidden", + ) + + scopes = request.get("scope", subject_token.scope) + scopes = list(set(scopes).intersection(subject_token.scope)) + if kwargs.get("scope"): + scopes = list(set(scopes).intersection(kwargs.get("scope"))) + if scopes: + request["scope"] = scopes + else: + request.pop("scope") + + return request diff --git a/src/idpyoidc/server/oauth2/token_helper/access_token.py b/src/idpyoidc/server/oauth2/token_helper/access_token.py new file mode 100755 index 00000000..b7b917fe --- /dev/null +++ b/src/idpyoidc/server/oauth2/token_helper/access_token.py @@ -0,0 +1,206 @@ +import logging +from typing import Optional +from typing import Union + +from cryptojwt.jwt import utc_time_sans_frac +from cryptojwt.utils import importer + +from idpyoidc.exception import ImproperlyConfigured +from idpyoidc.message import Message +from idpyoidc.message.oauth2 import TokenErrorResponse +from idpyoidc.util import sanitize +from . import TokenEndpointHelper +from . import validate_resource_indicators_policy +from ...session import MintingNotAllowed +from ...session.token import AuthorizationCode +from ...token import UnknownToken + +logger = logging.getLogger(__name__) + + +class AccessTokenHelper(TokenEndpointHelper): + + def process_request(self, req: Union[Message, dict], **kwargs): + """ + + :param req: + :param kwargs: + :return: + """ + _context = self.endpoint.upstream_get("context") + _mngr = _context.session_manager + logger.debug("Access Token") + + if req["grant_type"] != "authorization_code": + return self.error_cls(error="invalid_request", error_description="Unknown grant_type") + + try: + _access_code = req["code"].replace(" ", "+") + except KeyError: # Missing code parameter - absolutely fatal + return self.error_cls(error="invalid_request", error_description="Missing code") + + _session_info = _mngr.get_session_info_by_token( + _access_code, grant=True, handler_key="authorization_code" + ) + client_id = _session_info["client_id"] + if client_id != req["client_id"]: + logger.debug("{} owner of token".format(client_id)) + logger.warning("Client using token it was not given") + return self.error_cls(error="invalid_grant", error_description="Wrong client") + + _cinfo = self.endpoint.upstream_get("context").cdb.get(client_id) + + if ("resource_indicators" in _cinfo + and "access_token" in _cinfo["resource_indicators"]): + resource_indicators_config = _cinfo["resource_indicators"]["access_token"] + else: + resource_indicators_config = self.endpoint.kwargs.get("resource_indicators", None) + + if resource_indicators_config is not None: + if "policy" not in resource_indicators_config: + policy = {"policy": {"callable": validate_resource_indicators_policy}} + resource_indicators_config.update(policy) + + req = self._enforce_resource_indicators_policy(req, resource_indicators_config) + + if isinstance(req, TokenErrorResponse): + return req + + # if "grant_types_supported" in _context.cdb[client_id]: + # grant_types_supported = _context.cdb[client_id].get("grant_types_supported") + # else: + # grant_types_supported = _context.provider_info["grant_types_supported"] + + grant = _session_info["grant"] + + _based_on = grant.get_token(_access_code) + _supports_minting = _based_on.usage_rules.get("supports_minting", []) + + _authn_req = grant.authorization_request + + # If redirect_uri was in the initial authorization request + # verify that the one given here is the correct one. + if "redirect_uri" in _authn_req: + if req["redirect_uri"] != _authn_req["redirect_uri"]: + return self.error_cls( + error="invalid_request", error_description="redirect_uri mismatch" + ) + + logger.debug("All checks OK") + + issue_refresh = kwargs.get("issue_refresh", False) + + if resource_indicators_config is not None: + scope = req["scope"] + else: + scope = grant.scope + + _response = { + "token_type": "Bearer", + "scope": scope, + } + + if "access_token" in _supports_minting: + + resources = req.get("resource", None) + if resources: + token_args = {"resources": resources} + else: + token_args = None + + try: + token = self._mint_token( + token_class="access_token", + grant=grant, + session_id=_session_info["branch_id"], + client_id=_session_info["client_id"], + based_on=_based_on, + token_args=token_args + ) + except MintingNotAllowed as err: + logger.warning(err) + else: + _response["access_token"] = token.value + if token.expires_at: + _response["expires_in"] = token.expires_at - utc_time_sans_frac() + + if ( + issue_refresh + and "refresh_token" in _supports_minting + ): + try: + refresh_token = self._mint_token( + token_class="refresh_token", + grant=grant, + session_id=_session_info["branch_id"], + client_id=_session_info["client_id"], + based_on=_based_on, + ) + except MintingNotAllowed as err: + logger.warning(err) + else: + _response["refresh_token"] = refresh_token.value + + # since the grant content has changed. Make sure it's stored + _mngr[_session_info["branch_id"]] = grant + + _based_on.register_usage() + + return _response + + def _enforce_resource_indicators_policy(self, request, config): + _context = self.endpoint.upstream_get('context') + + policy = config["policy"] + callable = policy["callable"] + kwargs = policy.get("kwargs", {}) + + if isinstance(callable, str): + try: + fn = importer(callable) + except Exception: + raise ImproperlyConfigured(f"Error importing {callable} policy callable") + else: + fn = callable + try: + return fn(request, context=_context, **kwargs) + except Exception as e: + logger.error(f"Error while executing the {fn} policy callable: {e}") + return self.error_cls(error="server_error", error_description="Internal server error") + + def post_parse_request( + self, request: Union[Message, dict], client_id: Optional[str] = "", **kwargs + ): + """ + This is where clients come to get their access tokens + + :param request: The request + :param client_id: Client identifier + :returns: + """ + + _mngr = self.endpoint.upstream_get("context").session_manager + try: + _session_info = _mngr.get_session_info_by_token( + request["code"], grant=True, handler_key="authorization_code" + ) + except (KeyError, UnknownToken): + logger.error("Access Code invalid") + return self.error_cls(error="invalid_grant", error_description="Unknown code") + + grant = _session_info["grant"] + code = grant.get_token(request["code"]) + if not isinstance(code, AuthorizationCode): + return self.error_cls(error="invalid_request", error_description="Wrong token type") + + if code.is_active() is False: + return self.error_cls(error="invalid_request", error_description="Code inactive") + + _auth_req = grant.authorization_request + + if "client_id" not in request: # Optional for access token request + request["client_id"] = _auth_req["client_id"] + + logger.debug("%s: %s" % (request.__class__.__name__, sanitize(request))) + + return request diff --git a/src/idpyoidc/server/oauth2/token_helper/client_credentials.py b/src/idpyoidc/server/oauth2/token_helper/client_credentials.py new file mode 100755 index 00000000..469eb8b6 --- /dev/null +++ b/src/idpyoidc/server/oauth2/token_helper/client_credentials.py @@ -0,0 +1,77 @@ +import logging +from typing import Optional +from typing import Union + +from idpyoidc.message import Message +from idpyoidc.time_util import utc_time_sans_frac +from . import TokenEndpointHelper + +logger = logging.getLogger(__name__) + + +class ClientCredentials(TokenEndpointHelper): + + def __init__(self, endpoint, config=None): + TokenEndpointHelper.__init__(self, endpoint, config) + + def process_request(self, req: Union[Message, dict], **kwargs): + _context = self.endpoint.upstream_get("context") + _mngr = _context.session_manager + logger.debug("Client credentials flow") + + # verify the client and the user + + client_id = req['client_id'] + _authenticated = req.get("authenticated", False) + if not _authenticated: + if _context.cdb[client_id] != req['client_secret']: + logger.warning("Client authentication failed") + return self.error_cls(error="invalid_request", error_description="Wrong client") + + _grant_types_supported = _context.cdb[client_id].get('grant_types_supported') + if _grant_types_supported and 'client_credentials' not in _grant_types_supported: + return self.error_cls(error="invalid_request", + error_description="Unsupported grant type") + + # Is there a previous session ? + try: + _session_info = _mngr.get(['client_credentials', client_id]) + _grant = _session_info["grant"] + except KeyError: + logger.debug('No previous session') + branch_id = _mngr.add_grant(['client_credentials', client_id]) + _session_info = _mngr.get_session_info(branch_id) + + _grant = _session_info["grant"] + + token_type = "Bearer" + + _allowed = _context.cdb[client_id].get('allowed_scopes', []) + access_token = self._mint_token( + token_class="access_token", + grant=_grant, + session_id=_session_info["branch_id"], + client_id=_session_info["client_id"], + based_on=None, + scope=_allowed, + token_type=token_type, + ) + + _resp = { + "access_token": access_token.value, + "token_type": access_token.token_class, + "scope": _allowed, + } + + if access_token.expires_at: + _resp["expires_in"] = access_token.expires_at - utc_time_sans_frac() + + return _resp + + def post_parse_request( + self, + request: Union[Message, dict], + client_id: Optional[str] = "", + **kwargs + ): + return request diff --git a/src/idpyoidc/server/oauth2/token_helper/refresh_token.py b/src/idpyoidc/server/oauth2/token_helper/refresh_token.py new file mode 100755 index 00000000..62341149 --- /dev/null +++ b/src/idpyoidc/server/oauth2/token_helper/refresh_token.py @@ -0,0 +1,147 @@ +import logging +from typing import Optional +from typing import Union + +from idpyoidc.message import Message +from idpyoidc.message.oidc import RefreshAccessTokenRequest +from idpyoidc.server.session.token import RefreshToken +from idpyoidc.server.token.exception import UnknownToken +from idpyoidc.time_util import utc_time_sans_frac +from . import TokenEndpointHelper + +logger = logging.getLogger(__name__) + + +class RefreshTokenHelper(TokenEndpointHelper): + + def process_request(self, req: Union[Message, dict], **kwargs): + _context = self.endpoint.upstream_get("context") + _mngr = _context.session_manager + logger.debug("Refresh Token") + + if req["grant_type"] != "refresh_token": + return self.error_cls(error="invalid_request", error_description="Wrong grant_type") + + token_value = req["refresh_token"] + _session_info = _mngr.get_session_info_by_token( + token_value, grant=True, handler_key="refresh_token" + ) + logger.debug("Session info: {}".format(_session_info)) + + if _session_info["client_id"] != req["client_id"]: + logger.debug("{} owner of token".format(_session_info["client_id"])) + logger.warning("Client using token it was not given") + return self.error_cls(error="invalid_grant", error_description="Wrong client") + + _grant = _session_info["grant"] + + token_type = "Bearer" + + # Is DPOP supported + if "dpop_signing_alg_values_supported" in _context.provider_info: + _dpop_jkt = req.get("dpop_jkt") + if _dpop_jkt: + _grant.extra["dpop_jkt"] = _dpop_jkt + token_type = "DPoP" + + token = _grant.get_token(token_value) + scope = _grant.find_scope(token) + if "scope" in req: + scope = req["scope"] + access_token = self._mint_token( + token_class="access_token", + grant=_grant, + session_id=_session_info["branch_id"], + client_id=_session_info["client_id"], + based_on=token, + scope=scope, + token_type=token_type, + ) + + _resp = { + "access_token": access_token.value, + "token_type": access_token.token_type, + "scope": scope, + } + + if access_token.expires_at: + _resp["expires_in"] = access_token.expires_at - utc_time_sans_frac() + + _mints = token.usage_rules.get("supports_minting") + issue_refresh = kwargs.get("issue_refresh", False) + if "refresh_token" in _mints and issue_refresh: + refresh_token = self._mint_token( + token_class="refresh_token", + grant=_grant, + session_id=_session_info["branch_id"], + client_id=_session_info["client_id"], + based_on=token, + scope=scope, + ) + refresh_token.usage_rules = token.usage_rules.copy() + _resp["refresh_token"] = refresh_token.value + + token.register_usage() + + if ( + "client_id" in req + and req["client_id"] in _context.cdb + and "revoke_refresh_on_issue" in _context.cdb[req["client_id"]] + ): + revoke_refresh = _context.cdb[req["client_id"]].get("revoke_refresh_on_issue") + else: + revoke_refresh = self.endpoint.revoke_refresh_on_issue + + if revoke_refresh: + token.revoke() + + return _resp + + def post_parse_request( + self, request: Union[Message, dict], client_id: Optional[str] = "", **kwargs + ): + """ + This is where clients come to refresh their access tokens + + :param request: The request + :param client_id: Client identifier + :returns: + """ + + request = RefreshAccessTokenRequest(**request.to_dict()) + _context = self.endpoint.upstream_get("context") + + request.verify( + keyjar=self.endpoint.upstream_get('sttribute', 'keyjar'), + opponent_id=client_id) + + _mngr = _context.session_manager + try: + _session_info = _mngr.get_session_info_by_token( + request["refresh_token"], grant=True, handler_key="refresh_token" + ) + except (KeyError, UnknownToken): + logger.error("Refresh token invalid") + return self.error_cls(error="invalid_grant", error_description="Invalid refresh token") + + grant = _session_info["grant"] + token = grant.get_token(request["refresh_token"]) + + if not isinstance(token, RefreshToken): + return self.error_cls(error="invalid_request", error_description="Wrong token type") + + if token.is_active() is False: + return self.error_cls( + error="invalid_request", error_description="Refresh token inactive" + ) + + if "scope" in request: + req_scopes = set(request["scope"]) + scopes = set(grant.find_scope(token.based_on)) + if not req_scopes.issubset(scopes): + return self.error_cls( + error="invalid_request", + error_description="Invalid refresh scopes", + ) + + return request diff --git a/src/idpyoidc/server/oauth2/token_helper/resource_owner_password_credentials.py b/src/idpyoidc/server/oauth2/token_helper/resource_owner_password_credentials.py new file mode 100755 index 00000000..75eee741 --- /dev/null +++ b/src/idpyoidc/server/oauth2/token_helper/resource_owner_password_credentials.py @@ -0,0 +1,107 @@ +import logging +from typing import Optional +from typing import Union + +from idpyoidc.exception import FailedAuthentication +from idpyoidc.message import Message +from idpyoidc.time_util import utc_time_sans_frac +from idpyoidc.util import instantiate +from . import TokenEndpointHelper +from ...user_authn.authn_context import pick_auth + +logger = logging.getLogger(__name__) + + +class ResourceOwnerPasswordCredentials(TokenEndpointHelper): + + def __init__(self, endpoint, config=None): + TokenEndpointHelper.__init__(self, endpoint, config) + self.user_db = {} + if config: + _db = config.get('db') + if _db: + _db_kwargs = _db.get("kwargs", {}) + self.user_db = instantiate(_db["class"], **_db_kwargs) + + def process_request(self, req: Union[Message, dict], **kwargs): + _context = self.endpoint.upstream_get("context") + _mngr = _context.session_manager + logger.debug("Client credentials flow") + + # verify the client and the user + + client_id = req['client_id'] + _cinfo = _context.cdb.get(client_id) + if not _cinfo: + logger.error('Unknown client') + return self.error_cls(error="invalid_grant", error_description="Unknown client") + + if _cinfo['client_secret'] != req['client_secret']: + logger.warning("Client secret mismatch") + return self.error_cls(error="invalid_grant", error_description="Wrong client") + + _auth_method = None + _acr = kwargs.get('acr') + if _acr: + _auth_method = _context.authn_broker.pick(_acr) + else: + try: + _auth_method = pick_auth(_context, req) + except Exception as exc: + logger.exception(f"An error occurred while picking the authN broker: {exc}") + + if not _auth_method: + return self.error_cls(error="invalid_request", + error_description="Can't authenticate user") + + authn = _auth_method["method"] + # authn_class_ref = _auth_method["acr"] + + try: + _username = authn.verify(username=req['username'], password=req['password']) + except FailedAuthentication: + logger.warning("User password did not match") + return self.error_cls(error="invalid_grant", error_description="Wrong user") + + # Is there a previous session ? + try: + _session_info = _mngr.get([_username, client_id]) + _grant = _session_info["grant"] + except KeyError: + logger.debug('No previous session') + branch_id = _mngr.add_grant([_username, client_id]) + _session_info = _mngr.get_session_info(branch_id) + + _grant = _session_info["grant"] + + token_type = "Bearer" + + _allowed = _context.cdb[client_id].get('allowed_scopes', []) + access_token = self._mint_token( + token_class="access_token", + grant=_grant, + session_id=_session_info["branch_id"], + client_id=_session_info["client_id"], + based_on=None, + scope=_allowed, + token_type=token_type, + ) + + _resp = { + "access_token": access_token.value, + "token_type": access_token.token_class, + "scope": _allowed + } + + if access_token.expires_at: + _resp["expires_in"] = access_token.expires_at - utc_time_sans_frac() + + return _resp + + def post_parse_request( + self, + request: Union[Message, dict], + client_id: Optional[str] = "", + **kwargs + ): + return request diff --git a/src/idpyoidc/server/oauth2/token_helper/token_exchange.py b/src/idpyoidc/server/oauth2/token_helper/token_exchange.py new file mode 100755 index 00000000..119407e9 --- /dev/null +++ b/src/idpyoidc/server/oauth2/token_helper/token_exchange.py @@ -0,0 +1,310 @@ +import logging + +from cryptojwt import BadSyntax +from cryptojwt.exception import JWKESTException + +from idpyoidc.exception import ImproperlyConfigured +from idpyoidc.exception import MissingRequiredAttribute +from idpyoidc.exception import MissingRequiredValue +from idpyoidc.message.oauth2 import TokenExchangeRequest +from idpyoidc.message.oauth2 import TokenExchangeResponse +from idpyoidc.message.oidc import TokenErrorResponse +from idpyoidc.server.constant import DEFAULT_REQUESTED_TOKEN_TYPE +from idpyoidc.server.exception import ToOld +from idpyoidc.server.exception import UnAuthorizedClientScope +from idpyoidc.server.oauth2.authorization import check_unknown_scopes_policy +from idpyoidc.server.session.token import MintingNotAllowed +from idpyoidc.server.session.token import TOKEN_TYPES_MAPPING +from idpyoidc.server.token.exception import UnknownToken +from idpyoidc.time_util import utc_time_sans_frac +from idpyoidc.util import importer +from . import TokenEndpointHelper +from . import validate_token_exchange_policy + +logger = logging.getLogger(__name__) + + +class TokenExchangeHelper(TokenEndpointHelper): + """Implements Token Exchange a.k.a. RFC8693""" + + token_types_mapping = { + "urn:ietf:params:oauth:token-type:access_token": "access_token", + "urn:ietf:params:oauth:token-type:refresh_token": "refresh_token", + } + + def __init__(self, endpoint, config=None): + TokenEndpointHelper.__init__(self, endpoint=endpoint, config=config) + if config is None: + self.config = { + "requested_token_types_supported": [ + "urn:ietf:params:oauth:token-type:access_token", + "urn:ietf:params:oauth:token-type:refresh_token", + ], + "default_requested_token_type": "urn:ietf:params:oauth:token-type:access_token", + "policy": {"": {"callable": validate_token_exchange_policy}}, + } + else: + self.config = config + + def post_parse_request(self, request, client_id="", **kwargs): + request = TokenExchangeRequest(**request.to_dict()) + + _context = self.endpoint.upstream_get("context") + if "token_exchange" in _context.cdb[request["client_id"]]: + config = _context.cdb[request["client_id"]]["token_exchange"] + else: + config = self.config + + try: + request.verify( + keyjar=self.endpoint.upstream_get('attribute', 'keyjar'), + opponent_id=client_id + ) + except ( + MissingRequiredAttribute, + ValueError, + MissingRequiredValue, + JWKESTException, + ) as err: + return self.endpoint.error_cls(error="invalid_request", error_description="%s" % err) + + self._validate_configuration(config) + + _mngr = _context.session_manager + try: + # token exchange is about minting one token based on another + _handler_key = self.token_types_mapping[request["subject_token_type"]] + _session_info = _mngr.get_session_info_by_token( + request["subject_token"], grant=True, handler_key=_handler_key + ) + except (KeyError, UnknownToken, BadSyntax) as err: + logger.error(f"Subject token invalid ({err}).") + return self.error_cls( + error="invalid_request", error_description="Subject token invalid" + ) + + # Find the token instance based on the token value + token = _mngr.find_token(_session_info["branch_id"], request["subject_token"]) + if token.is_active() is False: + return self.error_cls( + error="invalid_request", error_description="Subject token inactive" + ) + + resp = self._enforce_policy(request, token, config) + if isinstance(resp, TokenErrorResponse): + return resp + + scopes = resp.get("scope", []) + scopes = _context.scopes_handler.filter_scopes(scopes, client_id=resp["client_id"]) + + if not scopes: + logger.error("All requested scopes have been filtered out.") + return self.error_cls( + error="invalid_scope", error_description="Invalid requested scopes" + ) + + _requested_token_type = resp.get( + "requested_token_type", "urn:ietf:params:oauth:token-type:access_token" + ) + _token_class = self.token_types_mapping[_requested_token_type] + if _token_class == "refresh_token" and "offline_access" not in scopes: + return TokenErrorResponse( + error="invalid_request", + error_description="Exchanging this subject token to refresh token forbidden", + ) + + return resp + + def _enforce_policy(self, request, token, config): + _context = self.endpoint.upstream_get("context") + subject_token_types_supported = config.get( + "subject_token_types_supported", self.token_types_mapping.keys() + ) + subject_token_type = request["subject_token_type"] + if subject_token_type not in subject_token_types_supported: + return TokenErrorResponse( + error="invalid_request", + error_description="Unsupported subject token type", + ) + if self.token_types_mapping[subject_token_type] != token.token_class: + return TokenErrorResponse( + error="invalid_request", + error_description="Wrong token type", + ) + + if ( + "requested_token_type" in request + and request["requested_token_type"] not in config["requested_token_types_supported"] + ): + return TokenErrorResponse( + error="invalid_request", + error_description="Unsupported requested token type", + ) + + request_info = dict(scope=request.get("scope", token.scope)) + try: + check_unknown_scopes_policy(request_info, request["client_id"], _context) + except UnAuthorizedClientScope: + return self.error_cls( + error="invalid_grant", + error_description="Unauthorized scope requested", + ) + + if subject_token_type not in config["policy"]: + subject_token_type = "" + + policy = config["policy"][subject_token_type] + callable = policy["callable"] + kwargs = policy.get("kwargs", {}) + + if isinstance(callable, str): + try: + fn = importer(callable) + except Exception: + raise ImproperlyConfigured(f"Error importing {callable} policy callable") + else: + fn = callable + + try: + return fn(request, context=_context, subject_token=token, **kwargs) + except Exception as e: + logger.error(f"Error while executing the {fn} policy callable: {e}") + return self.error_cls(error="server_error", error_description="Internal server error") + + def token_exchange_response(self, token, issued_token_type): + response_args = {} + response_args["access_token"] = token.value + response_args["scope"] = token.scope + response_args["issued_token_type"] = issued_token_type + + if token.expires_at: + response_args["expires_in"] = token.expires_at - utc_time_sans_frac() + if hasattr(token, "token_type"): + response_args["token_type"] = token.token_type + else: + response_args["token_type"] = "N_A" + + return TokenExchangeResponse(**response_args) + + def process_request(self, request, **kwargs): + _context = self.endpoint.upstream_get("context") + _mngr = _context.session_manager + try: + _handler_key = self.token_types_mapping[request["subject_token_type"]] + _session_info = _mngr.get_session_info_by_token( + request["subject_token"], grant=True, handler_key=_handler_key + ) + except ToOld: + logger.error("Subject token has expired.") + return self.error_cls( + error="invalid_request", error_description="Subject token has expired" + ) + except (KeyError, UnknownToken): + logger.error("Subject token invalid.") + return self.error_cls( + error="invalid_request", error_description="Subject token invalid" + ) + + grant = _session_info["grant"] + token = _mngr.find_token(_session_info["branch_id"], request["subject_token"]) + _requested_token_type = request.get( + "requested_token_type", "urn:ietf:params:oauth:token-type:access_token" + ) + + _token_class = self.token_types_mapping[_requested_token_type] + + sid = _session_info["branch_id"] + + _token_type = "Bearer" + # Is DPOP supported + if "dpop_signing_alg_values_supported" in _context.provider_info: + if request.get("dpop_jkt"): + _token_type = "DPoP" + scopes = request.get("scope", []) + + if request["client_id"] != _session_info["client_id"]: + _token_usage_rules = _context.authz.usage_rules(request["client_id"]) + + sid = _mngr.create_exchange_session( + exchange_request=request, + original_grant=grant, + original_session_id=sid, + user_id=_session_info["user_id"], + client_id=request["client_id"], + token_usage_rules=_token_usage_rules, + scopes=scopes, + ) + + try: + _session_info = _mngr.get_session_info(session_id=sid, grant=True) + except Exception: + logger.error("Error retrieving token exchange session information") + return self.error_cls( + error="server_error", error_description="Internal server error" + ) + + resources = request.get("resource") + if resources and request.get("audience"): + resources = list(set(resources + request.get("audience"))) + else: + resources = request.get("audience") + + _token_args = None + if resources: + _token_args = {"resources": resources} + + try: + new_token = self._mint_token( + token_class=_token_class, + grant=_session_info["grant"], + session_id=sid, + client_id=request["client_id"], + based_on=token, + scope=scopes, + token_args=_token_args, + token_type=_token_type, + ) + new_token.expires_at = token.expires_at + except MintingNotAllowed: + logger.error(f"Minting not allowed for {_token_class}") + return self.error_cls( + error="invalid_grant", + error_description="Token Exchange not allowed with that token", + ) + + return self.token_exchange_response(new_token, _requested_token_type) + + def _validate_configuration(self, config): + if "requested_token_types_supported" not in config: + raise ImproperlyConfigured( + "Missing 'requested_token_types_supported' from Token Exchange configuration" + ) + if "policy" not in config: + raise ImproperlyConfigured("Missing 'policy' from Token Exchange configuration") + if "" not in config["policy"]: + raise ImproperlyConfigured( + "Default Token Exchange policy configuration is not defined" + ) + if "callable" not in config["policy"][""]: + raise ImproperlyConfigured( + "Missing 'callable' from default Token Exchange policy configuration" + ) + + _default_requested_token_type = config.get("default_requested_token_type", + DEFAULT_REQUESTED_TOKEN_TYPE) + if _default_requested_token_type not in config["requested_token_types_supported"]: + raise ImproperlyConfigured( + f"Unsupported default requested_token_type {_default_requested_token_type}" + ) + + def get_handler_key(self, request, endpoint_context): + client_info = endpoint_context.cdb.get(request["client_id"], {}) + + default_requested_token_type = ( + client_info.get("token_exchange", {}).get("default_requested_token_type", None) + or + self.config.get("default_requested_token_type", DEFAULT_REQUESTED_TOKEN_TYPE) + ) + + requested_token_type = request.get("requested_token_type", default_requested_token_type) + return TOKEN_TYPES_MAPPING[requested_token_type] diff --git a/src/idpyoidc/server/oidc/backchannel_authentication.py b/src/idpyoidc/server/oidc/backchannel_authentication.py index 941010b6..50350590 100644 --- a/src/idpyoidc/server/oidc/backchannel_authentication.py +++ b/src/idpyoidc/server/oidc/backchannel_authentication.py @@ -16,7 +16,7 @@ from idpyoidc.server import Endpoint from idpyoidc.server.client_authn import ClientSecretBasic from idpyoidc.server.exception import NoSuchAuthentication -from idpyoidc.server.oidc.token_helper import AccessTokenHelper +from idpyoidc.server.oidc.token_helper.access_token import AccessTokenHelper from idpyoidc.server.session.token import MintingNotAllowed from idpyoidc.server.util import execute diff --git a/src/idpyoidc/server/oidc/token.py b/src/idpyoidc/server/oidc/token.py index 045ce73b..67598713 100755 --- a/src/idpyoidc/server/oidc/token.py +++ b/src/idpyoidc/server/oidc/token.py @@ -7,9 +7,9 @@ from idpyoidc.message.oidc import TokenErrorResponse from idpyoidc.server.oauth2 import token from idpyoidc.server.oidc.backchannel_authentication import CIBATokenHelper -from idpyoidc.server.oidc.token_helper import AccessTokenHelper -from idpyoidc.server.oidc.token_helper import RefreshTokenHelper -from idpyoidc.server.oidc.token_helper import TokenExchangeHelper +from idpyoidc.server.oidc.token_helper.access_token import AccessTokenHelper +from idpyoidc.server.oidc.token_helper.refresh_token import RefreshTokenHelper +from idpyoidc.server.oidc.token_helper.token_exchange import TokenExchangeHelper logger = logging.getLogger(__name__) diff --git a/src/idpyoidc/server/oidc/token_helper/__init__.py b/src/idpyoidc/server/oidc/token_helper/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/idpyoidc/server/oidc/token_helper.py b/src/idpyoidc/server/oidc/token_helper/access_token.py similarity index 55% rename from src/idpyoidc/server/oidc/token_helper.py rename to src/idpyoidc/server/oidc/token_helper/access_token.py index 80972205..b83d3dc4 100755 --- a/src/idpyoidc/server/oidc/token_helper.py +++ b/src/idpyoidc/server/oidc/token_helper/access_token.py @@ -2,18 +2,14 @@ from typing import Optional from typing import Union -from cryptojwt import BadSyntax from cryptojwt.jwe.exception import JWEException from cryptojwt.jws.exception import NoSuitableSigningKeys from cryptojwt.jwt import utc_time_sans_frac from idpyoidc.message import Message -from idpyoidc.message.oidc import RefreshAccessTokenRequest -from idpyoidc.server import oauth2 from idpyoidc.server.oauth2.token_helper import TokenEndpointHelper from idpyoidc.server.session.token import AuthorizationCode from idpyoidc.server.session.token import MintingNotAllowed -from idpyoidc.server.session.token import RefreshToken from idpyoidc.server.token.exception import UnknownToken from idpyoidc.util import sanitize @@ -206,171 +202,3 @@ def post_parse_request( logger.debug("%s: %s" % (request.__class__.__name__, sanitize(request))) return request - - -class RefreshTokenHelper(TokenEndpointHelper): - - def process_request(self, req: Union[Message, dict], **kwargs): - _context = self.endpoint.upstream_get("context") - _mngr = _context.session_manager - - if req["grant_type"] != "refresh_token": - return self.error_cls(error="invalid_request", error_description="Wrong grant_type") - - token_value = req["refresh_token"] - - _session_info = _mngr.get_session_info_by_token( - token_value, handler_key="refresh_token", grant=True - ) - if _session_info["client_id"] != req["client_id"]: - logger.debug("{} owner of token".format(_session_info["client_id"])) - logger.warning("{} using token it was not given".format(req["client_id"])) - return self.error_cls(error="invalid_grant", error_description="Wrong client") - - _grant = _session_info["grant"] - - token_type = "Bearer" - - # Is DPOP supported - if "dpop_signing_alg_values_supported" in _context.provider_info: - _dpop_jkt = req.get("dpop_jkt") - if _dpop_jkt: - _grant.extra["dpop_jkt"] = _dpop_jkt - token_type = "DPoP" - - token = _grant.get_token(token_value) - scope = _grant.find_scope(token.based_on) - if "scope" in req: - scope = req["scope"] - access_token = self._mint_token( - token_class="access_token", - grant=_grant, - session_id=_session_info["branch_id"], - client_id=_session_info["client_id"], - based_on=token, - scope=scope, - token_type=token_type, - ) - - _resp = { - "access_token": access_token.value, - "token_type": token_type, - "scope": scope, - } - - if access_token.expires_at: - _resp["expires_in"] = access_token.expires_at - utc_time_sans_frac() - - _mints = token.usage_rules.get("supports_minting") - - issue_refresh = kwargs.get("issue_refresh", None) - # The existence of offline_access scope overwrites issue_refresh - if issue_refresh is None and "offline_access" in scope: - issue_refresh = True - - if "refresh_token" in _mints and issue_refresh: - refresh_token = self._mint_token( - token_class="refresh_token", - grant=_grant, - session_id=_session_info["branch_id"], - client_id=_session_info["client_id"], - based_on=token, - scope=scope, - ) - refresh_token.usage_rules = token.usage_rules.copy() - _resp["refresh_token"] = refresh_token.value - - if "id_token" in _mints and "openid" in scope: - try: - _idtoken = self._mint_token( - token_class="id_token", - grant=_grant, - session_id=_session_info["branch_id"], - client_id=_session_info["client_id"], - based_on=token, - scope=scope, - ) - except (JWEException, NoSuitableSigningKeys) as err: - logger.warning(str(err)) - resp = self.error_cls( - error="invalid_request", - error_description="Could not sign/encrypt id_token", - ) - return resp - - _resp["id_token"] = _idtoken.value - - token.register_usage() - - if ( - "client_id" in req - and req["client_id"] in _context.cdb - and "revoke_refresh_on_issue" in _context.cdb[req["client_id"]] - ): - revoke_refresh = _context.cdb[req["client_id"]].get("revoke_refresh_on_issue") - else: - revoke_refresh = self.endpoint.revoke_refresh_on_issue - - if revoke_refresh: - token.revoke() - - return _resp - - def post_parse_request( - self, - request: Union[Message, dict], - client_id: Optional[str] = "", - **kwargs - ): - """ - This is where clients come to refresh their access tokens - - :param request: The request - :param client_id: Client identifier - :returns: - """ - - request = RefreshAccessTokenRequest(**request.to_dict()) - _context = self.endpoint.upstream_get("context") - - request.verify(keyjar=self.endpoint.upstream_get('attribute', 'keyjar'), - opponent_id=client_id) - - _mngr = _context.session_manager - try: - _session_info = _mngr.get_session_info_by_token( - request["refresh_token"], handler_key="refresh_token", grant=True - ) - except (KeyError, UnknownToken, BadSyntax): - logger.error("Refresh token invalid") - return self.error_cls(error="invalid_grant", error_description="Invalid refresh token") - - grant = _session_info["grant"] - token = grant.get_token(request["refresh_token"]) - - if not isinstance(token, RefreshToken): - return self.error_cls(error="invalid_request", error_description="Wrong token type") - - if token.is_active() is False: - return self.error_cls( - error="invalid_request", error_description="Refresh token inactive" - ) - - if "scope" in request: - req_scopes = set(request["scope"]) - scopes = set(grant.find_scope(token.based_on)) - if not req_scopes.issubset(scopes): - return self.error_cls( - error="invalid_request", - error_description="Invalid refresh scopes", - ) - - return request - - -class TokenExchangeHelper(oauth2.token_helper.TokenExchangeHelper): - token_types_mapping = { - "urn:ietf:params:oauth:token-type:access_token": "access_token", - "urn:ietf:params:oauth:token-type:refresh_token": "refresh_token", - "urn:ietf:params:oauth:token-type:id_token": "id_token", - } diff --git a/src/idpyoidc/server/oidc/token_helper/refresh_token.py b/src/idpyoidc/server/oidc/token_helper/refresh_token.py new file mode 100755 index 00000000..534109a3 --- /dev/null +++ b/src/idpyoidc/server/oidc/token_helper/refresh_token.py @@ -0,0 +1,179 @@ +import logging +from typing import Optional +from typing import Union + +from cryptojwt import BadSyntax +from cryptojwt.jwe.exception import JWEException +from cryptojwt.jws.exception import NoSuitableSigningKeys +from cryptojwt.jwt import utc_time_sans_frac + +from idpyoidc.message import Message +from idpyoidc.message.oidc import RefreshAccessTokenRequest +from idpyoidc.server.oauth2.token_helper import TokenEndpointHelper +from idpyoidc.server.session.token import AuthorizationCode +from idpyoidc.server.session.token import MintingNotAllowed +from idpyoidc.server.session.token import RefreshToken +from idpyoidc.server.token.exception import UnknownToken +from idpyoidc.util import sanitize + +logger = logging.getLogger(__name__) + +class RefreshTokenHelper(TokenEndpointHelper): + + def process_request(self, req: Union[Message, dict], **kwargs): + _context = self.endpoint.upstream_get("context") + _mngr = _context.session_manager + + if req["grant_type"] != "refresh_token": + return self.error_cls(error="invalid_request", error_description="Wrong grant_type") + + token_value = req["refresh_token"] + + _session_info = _mngr.get_session_info_by_token( + token_value, handler_key="refresh_token", grant=True + ) + if _session_info["client_id"] != req["client_id"]: + logger.debug("{} owner of token".format(_session_info["client_id"])) + logger.warning("{} using token it was not given".format(req["client_id"])) + return self.error_cls(error="invalid_grant", error_description="Wrong client") + + _grant = _session_info["grant"] + + token_type = "Bearer" + + # Is DPOP supported + if "dpop_signing_alg_values_supported" in _context.provider_info: + _dpop_jkt = req.get("dpop_jkt") + if _dpop_jkt: + _grant.extra["dpop_jkt"] = _dpop_jkt + token_type = "DPoP" + + token = _grant.get_token(token_value) + scope = _grant.find_scope(token.based_on) + if "scope" in req: + scope = req["scope"] + access_token = self._mint_token( + token_class="access_token", + grant=_grant, + session_id=_session_info["branch_id"], + client_id=_session_info["client_id"], + based_on=token, + scope=scope, + token_type=token_type, + ) + + _resp = { + "access_token": access_token.value, + "token_type": token_type, + "scope": scope, + } + + if access_token.expires_at: + _resp["expires_in"] = access_token.expires_at - utc_time_sans_frac() + + _mints = token.usage_rules.get("supports_minting") + + issue_refresh = kwargs.get("issue_refresh", None) + # The existence of offline_access scope overwrites issue_refresh + if issue_refresh is None and "offline_access" in scope: + issue_refresh = True + + if "refresh_token" in _mints and issue_refresh: + refresh_token = self._mint_token( + token_class="refresh_token", + grant=_grant, + session_id=_session_info["branch_id"], + client_id=_session_info["client_id"], + based_on=token, + scope=scope, + ) + refresh_token.usage_rules = token.usage_rules.copy() + _resp["refresh_token"] = refresh_token.value + + if "id_token" in _mints and "openid" in scope: + try: + _idtoken = self._mint_token( + token_class="id_token", + grant=_grant, + session_id=_session_info["branch_id"], + client_id=_session_info["client_id"], + based_on=token, + scope=scope, + ) + except (JWEException, NoSuitableSigningKeys) as err: + logger.warning(str(err)) + resp = self.error_cls( + error="invalid_request", + error_description="Could not sign/encrypt id_token", + ) + return resp + + _resp["id_token"] = _idtoken.value + + token.register_usage() + + if ( + "client_id" in req + and req["client_id"] in _context.cdb + and "revoke_refresh_on_issue" in _context.cdb[req["client_id"]] + ): + revoke_refresh = _context.cdb[req["client_id"]].get("revoke_refresh_on_issue") + else: + revoke_refresh = self.endpoint.revoke_refresh_on_issue + + if revoke_refresh: + token.revoke() + + return _resp + + def post_parse_request( + self, + request: Union[Message, dict], + client_id: Optional[str] = "", + **kwargs + ): + """ + This is where clients come to refresh their access tokens + + :param request: The request + :param client_id: Client identifier + :returns: + """ + + request = RefreshAccessTokenRequest(**request.to_dict()) + _context = self.endpoint.upstream_get("context") + + request.verify(keyjar=self.endpoint.upstream_get('attribute', 'keyjar'), + opponent_id=client_id) + + _mngr = _context.session_manager + try: + _session_info = _mngr.get_session_info_by_token( + request["refresh_token"], handler_key="refresh_token", grant=True + ) + except (KeyError, UnknownToken, BadSyntax): + logger.error("Refresh token invalid") + return self.error_cls(error="invalid_grant", error_description="Invalid refresh token") + + grant = _session_info["grant"] + token = grant.get_token(request["refresh_token"]) + + if not isinstance(token, RefreshToken): + return self.error_cls(error="invalid_request", error_description="Wrong token type") + + if token.is_active() is False: + return self.error_cls( + error="invalid_request", error_description="Refresh token inactive" + ) + + if "scope" in request: + req_scopes = set(request["scope"]) + scopes = set(grant.find_scope(token.based_on)) + if not req_scopes.issubset(scopes): + return self.error_cls( + error="invalid_request", + error_description="Invalid refresh scopes", + ) + + return request + diff --git a/src/idpyoidc/server/oidc/token_helper/token_exchange.py b/src/idpyoidc/server/oidc/token_helper/token_exchange.py new file mode 100755 index 00000000..39025a56 --- /dev/null +++ b/src/idpyoidc/server/oidc/token_helper/token_exchange.py @@ -0,0 +1,14 @@ +import logging + +from idpyoidc.server.oauth2.token_helper.token_exchange import TokenExchangeHelper as \ + OAuth2TokenExchangeHelper + +logger = logging.getLogger(__name__) + + +class TokenExchangeHelper(OAuth2TokenExchangeHelper): + token_types_mapping = { + "urn:ietf:params:oauth:token-type:access_token": "access_token", + "urn:ietf:params:oauth:token-type:refresh_token": "refresh_token", + "urn:ietf:params:oauth:token-type:id_token": "id_token", + } diff --git a/src/idpyoidc/server/user_authn/user.py b/src/idpyoidc/server/user_authn/user.py index c5a2fd47..c0307dc9 100755 --- a/src/idpyoidc/server/user_authn/user.py +++ b/src/idpyoidc/server/user_authn/user.py @@ -153,13 +153,13 @@ class UserPassJinja2(UserAuthnMethod): url_endpoint = "/verify/user_pass_jinja" def __init__( - self, - db, - template_handler, - template="user_pass.jinja2", - upstream_get=None, - verify_endpoint="", - **kwargs, + self, + db, + template_handler, + template="user_pass.jinja2", + upstream_get=None, + verify_endpoint="", + **kwargs, ): super(UserPassJinja2, self).__init__(upstream_get=upstream_get) @@ -218,7 +218,31 @@ def verify(self, *args, **kwargs): raise FailedAuthentication() +class UserPass(UserAuthnMethod): + + def __init__( + self, + db_conf, + upstream_get=None, + **kwargs, + ): + + super(UserPass, self).__init__(upstream_get=upstream_get) + self.user_db = instantiate(db_conf["class"], **db_conf["kwargs"]) + + def __call__(self, **kwargs): + pass + + def verify(self, *args, **kwargs): + username = kwargs["username"] + if username in self.user_db and self.user_db[username] == kwargs["password"]: + return username + else: + raise FailedAuthentication() + + class BasicAuthn(UserAuthnMethod): + def __init__(self, pwd, ttl=5, upstream_get=None): UserAuthnMethod.__init__(self, upstream_get=upstream_get) self.passwd = pwd @@ -249,6 +273,7 @@ def authenticated_as(self, client_id, cookie=None, authorization="", **kwargs): class SymKeyAuthn(UserAuthnMethod): + # user authentication using a token def __init__(self, ttl, symkey, upstream_get=None): @@ -283,6 +308,7 @@ def authenticated_as(self, client_id, cookie=None, authorization="", **kwargs): class NoAuthn(UserAuthnMethod): + # Just for testing allows anyone it without authentication def __init__(self, user, upstream_get=None): diff --git a/tests/private/token_jwks.json b/tests/private/token_jwks.json index 1575a33f..b99aa4bb 100644 --- a/tests/private/token_jwks.json +++ b/tests/private/token_jwks.json @@ -1 +1 @@ -{"keys": [{"kty": "oct", "use": "enc", "kid": "code", "k": "vSHDkLBHhDStkR0NWu8519rmV5zmnm5_"}, {"kty": "oct", "use": "enc", "kid": "refresh", "k": "pwcNBtEhyGiqrg0OeikHmSnTRs8_LZrc"}]} \ No newline at end of file +{"keys": [{"kty": "oct", "use": "enc", "kid": "code", "k": "vSHDkLBHhDStkR0NWu8519rmV5zmnm5_"}, {"kty": "oct", "use": "enc", "kid": "refresh", "k": "zaBDhx4X98ZokBeA8X9hzoAIzIn1jpy3"}]} \ No newline at end of file diff --git a/tests/pub_client.jwks b/tests/pub_client.jwks index 84a27042..d5ce25ed 100644 --- a/tests/pub_client.jwks +++ b/tests/pub_client.jwks @@ -1 +1 @@ -{"keys": [{"kty": "EC", "use": "sig", "kid": "azZQQ2FEQjh3QnVZWVdrbHJkMEZSaWR6aVJ0LTBjeUFfeWRlbTRrRFZ5VQ", "crv": "P-256", "x": "2ADe18caWWGp6hpRbfa9HqQHDFNpid9xUmR56Wzm_wc", "y": "HnD_8QBanz4Y-UF8mKQFZXfqkGkXUSm34mLsdDKtSyk"}, {"kty": "RSA", "use": "sig", "kid": "SHEyYWcwNVk0LTdROTZzZ2FUWndIVXdack0xWUM5SEpwcS03dVUxWU4zRQ", "e": "AQAB", "n": "rRz52ddyP9Y2ezSlRsnkt-sjXfV_Ii7vOFX-cStLE3IUlVeSJGEe_kAASLr2r3BE2unjntaxj67NP8D95h_rzG1SpCklTEn-aTe3FOwNyTzUH_oiDVeRoEcf04Y43ciRGYRB5PhI6ii-2lYuig6hyUr776Qxiu6-0zw-M_ay2MgGSy5CEj55dDSvcUyxStUObxGpPWnEvybO1vnE7iJEWGNe0L5uPe5nLidOiR-JwjxSWEx1xZYtIjxaf2Ulu-qu4hwgwBUQdx4bNZyBfljKj55skWuHqPMG3xMjnedQC6Ms5bR3rIkbBpvmgI3kJK-4CZikM6ruyLo94-Lk19aYQw"}]} \ No newline at end of file +{"keys": [{"kty": "EC", "use": "sig", "kid": "azZQQ2FEQjh3QnVZWVdrbHJkMEZSaWR6aVJ0LTBjeUFfeWRlbTRrRFZ5VQ", "crv": "P-256", "x": "2ADe18caWWGp6hpRbfa9HqQHDFNpid9xUmR56Wzm_wc", "y": "HnD_8QBanz4Y-UF8mKQFZXfqkGkXUSm34mLsdDKtSyk"}, {"kty": "RSA", "use": "sig", "kid": "SHEyYWcwNVk0LTdROTZzZ2FUWndIVXdack0xWUM5SEpwcS03dVUxWU4zRQ", "n": "rRz52ddyP9Y2ezSlRsnkt-sjXfV_Ii7vOFX-cStLE3IUlVeSJGEe_kAASLr2r3BE2unjntaxj67NP8D95h_rzG1SpCklTEn-aTe3FOwNyTzUH_oiDVeRoEcf04Y43ciRGYRB5PhI6ii-2lYuig6hyUr776Qxiu6-0zw-M_ay2MgGSy5CEj55dDSvcUyxStUObxGpPWnEvybO1vnE7iJEWGNe0L5uPe5nLidOiR-JwjxSWEx1xZYtIjxaf2Ulu-qu4hwgwBUQdx4bNZyBfljKj55skWuHqPMG3xMjnedQC6Ms5bR3rIkbBpvmgI3kJK-4CZikM6ruyLo94-Lk19aYQw", "e": "AQAB"}]} \ No newline at end of file diff --git a/tests/pub_iss.jwks b/tests/pub_iss.jwks index 9b062907..77081f40 100644 --- a/tests/pub_iss.jwks +++ b/tests/pub_iss.jwks @@ -1 +1 @@ -{"keys": [{"kty": "EC", "use": "sig", "kid": "SmdKMlVGcG1zMnprdDdXZGpGWEczdHhlZVpGbkx1THpPdUY4d0w4bnZkSQ", "crv": "P-256", "x": "tRHJYm0fsOi0icpGEb33qiDVgt68ltMoYSWdLGhDGz4", "y": "fRpX0i6p5Jigf5I0qwW34PyStosMShwWAWS8x_w5o7E"}, {"kty": "RSA", "use": "sig", "kid": "R0FsaFdqREFaUFp1c0MwbUpsbHVSZ200blBJZWJVMTUtNGsyVlBmdHk5UQ", "e": "AQAB", "n": "2ilgsKVqF92KfhwmosSVeZOaDgb3RF1mbg-pqkmLO6YpOO06LF4V4angF-GhP-ysAm2E75aSIU4tnHVThFlcxTgKFqjYKJQXyVzTVK2r-L2IbvFPaDtvoU6WteybpMlIUVk2po3cFDGObCWYKCm7CUOLlwH0uOpui66P9VSCqdKVKbJRAQBvTSbP10KWPxulfqjWGJtHO5fY7-JVWwOBkG-eHSJIT_uaoPjyvKCZjknq04bLUV9qP78KRQpRyYijBN60w2v8F79baN9CN10TIEjjWKGz0uX0M_YYQzTUoSY5l5ka9RkL3wT4o2iQ1t5nHphX6aA-gqwgCQmi-nvjaw"}]} \ No newline at end of file +{"keys": [{"kty": "EC", "use": "sig", "kid": "SmdKMlVGcG1zMnprdDdXZGpGWEczdHhlZVpGbkx1THpPdUY4d0w4bnZkSQ", "crv": "P-256", "x": "tRHJYm0fsOi0icpGEb33qiDVgt68ltMoYSWdLGhDGz4", "y": "fRpX0i6p5Jigf5I0qwW34PyStosMShwWAWS8x_w5o7E"}, {"kty": "RSA", "use": "sig", "kid": "R0FsaFdqREFaUFp1c0MwbUpsbHVSZ200blBJZWJVMTUtNGsyVlBmdHk5UQ", "n": "2ilgsKVqF92KfhwmosSVeZOaDgb3RF1mbg-pqkmLO6YpOO06LF4V4angF-GhP-ysAm2E75aSIU4tnHVThFlcxTgKFqjYKJQXyVzTVK2r-L2IbvFPaDtvoU6WteybpMlIUVk2po3cFDGObCWYKCm7CUOLlwH0uOpui66P9VSCqdKVKbJRAQBvTSbP10KWPxulfqjWGJtHO5fY7-JVWwOBkG-eHSJIT_uaoPjyvKCZjknq04bLUV9qP78KRQpRyYijBN60w2v8F79baN9CN10TIEjjWKGz0uX0M_YYQzTUoSY5l5ka9RkL3wT4o2iQ1t5nHphX6aA-gqwgCQmi-nvjaw", "e": "AQAB"}]} \ No newline at end of file diff --git a/tests/test_05_oauth2.py b/tests/test_05_oauth2.py index c0cb90cb..fac4d7ad 100644 --- a/tests/test_05_oauth2.py +++ b/tests/test_05_oauth2.py @@ -571,16 +571,15 @@ def test_init(self): class TestROPCAccessTokenRequest(object): def test_init(self): ropc = ROPCAccessTokenRequest(grant_type="password", username="johndoe", password="A3ddj3w") - - assert ropc["grant_type"] == "password" + ropc.verify() assert ropc["username"] == "johndoe" assert ropc["password"] == "A3ddj3w" class TestCCAccessTokenRequest(object): def test_init(self): - cc = CCAccessTokenRequest(scope="/foo") - assert cc["grant_type"] == "client_credentials" + cc = CCAccessTokenRequest(scope="/foo", grant_type='client_credentials') + cc.verify() assert cc["scope"] == ["/foo"] diff --git a/tests/test_client_25_cc_oauth2_service.py b/tests/test_client_25_cc_oauth2_service.py deleted file mode 100644 index eb130b13..00000000 --- a/tests/test_client_25_cc_oauth2_service.py +++ /dev/null @@ -1,177 +0,0 @@ -import pytest - -from idpyoidc.client.entity import Entity -from idpyoidc.message.oauth2 import AccessTokenResponse -from idpyoidc.util import rndstr - -KEYDEF = [{"type": "EC", "crv": "P-256", "use": ["sig"]}] - - -BASE_URL = "https://example.com" - -class TestRP: - @pytest.fixture(autouse=True) - def create_service(self): - client_config = { - "client_id": "client_id", - "client_secret": "another password", - "base_url": BASE_URL, - "client_authn_methods": ['client_secret_basic', 'bearer_header'] - } - services = { - "token": { - "class": "idpyoidc.client.oauth2.client_credentials.cc_access_token.CCAccessToken" - }, - "refresh_token": { - "class": "idpyoidc.client.oauth2.client_credentials.cc_refresh_access_token" - ".CCRefreshAccessToken" - }, - } - - self.entity = Entity(config=client_config, services=services) - - self.entity.get_service("accesstoken").endpoint = "https://example.com/token" - self.entity.get_service("refresh_token").endpoint = "https://example.com/token" - - def test_token_get_request(self): - request_args = {"grant_type": "client_credentials"} - _srv = self.entity.get_service("accesstoken") - _info = _srv.get_request_parameters(request_args=request_args) - assert _info["method"] == "POST" - assert _info["url"] == "https://example.com/token" - assert _info["body"] == "grant_type=client_credentials" - assert _info["headers"] == { - "Authorization": "Basic Y2xpZW50X2lkOmFub3RoZXIgcGFzc3dvcmQ=", - "Content-Type": "application/x-www-form-urlencoded", - } - - def test_token_parse_response(self): - request_args = {"grant_type": "client_credentials"} - _srv = self.entity.get_service("accesstoken") - _request_info = _srv.get_request_parameters(request_args=request_args) - - response = AccessTokenResponse( - **{ - "access_token": "2YotnFZFEjr1zCsicMWpAA", - "token_type": "example", - "expires_in": 3600, - "refresh_token": "tGzv3JOkF0XG5Qx2TlKWIA", - "example_parameter": "example_value", - } - ) - - _response = _srv.parse_response(response.to_json(), sformat="json") - # since no state attribute is involved, a key is minted - _key = rndstr(16) - _srv.update_service_context(_response, key=_key) - info = _srv.upstream_get("context").cstate.get(_key) - assert "__expires_at" in info - - def test_refresh_token_get_request(self): - _srv = self.entity.get_service("accesstoken") - _srv.update_service_context( - { - "access_token": "2YotnFZFEjr1zCsicMWpAA", - "token_type": "example", - "expires_in": 3600, - "refresh_token": "tGzv3JOkF0XG5Qx2TlKWIA", - "example_parameter": "example_value", - } - ) - _srv = self.entity.get_service("refresh_token") - _info = _srv.get_request_parameters(state='cc') - assert _info["method"] == "POST" - assert _info["url"] == "https://example.com/token" - assert _info["body"] == "grant_type=refresh_token" - assert _info["headers"] == { - "Authorization": "Bearer tGzv3JOkF0XG5Qx2TlKWIA", - "Content-Type": "application/x-www-form-urlencoded", - } - - def test_refresh_token_parse_response(self): - request_args = {"grant_type": "client_credentials"} - _srv = self.entity.get_service("accesstoken") - _request_info = _srv.get_request_parameters(request_args=request_args) - - response = AccessTokenResponse( - **{ - "access_token": "2YotnFZFEjr1zCsicMWpAA", - "token_type": "example", - "expires_in": 3600, - "refresh_token": "tGzv3JOkF0XG5Qx2TlKWIA", - "example_parameter": "example_value", - } - ) - - _response = _srv.parse_response(response.to_json(), sformat="json") - # since no state attribute is involved, a key is minted - _key = rndstr(16) - _srv.update_service_context(_response, key=_key) - info = _srv.upstream_get("context").cstate.get(_key) - assert "__expires_at" in info - - # Move from token to refresh token service - - _srv = self.entity.get_service("refresh_token") - _request_info = _srv.get_request_parameters(request_args=request_args, state=_key) - - refresh_response = AccessTokenResponse( - **{ - "access_token": "wy4R01DmMoB5xkI65nNkVv1l", - "token_type": "example", - "expires_in": 3600, - "refresh_token": "lhNX9LSG8w1QuD6tSgc6CPfJ", - } - ) - - _response = _srv.parse_response(refresh_response.to_json(), sformat="json") - _srv.update_service_context(_response, key=_key) - info = _srv.upstream_get("context").cstate.get(_key) - assert "__expires_at" in info - - def test_2nd_refresh_token_parse_response(self): - request_args = {"grant_type": "client_credentials"} - _srv = self.entity.get_service("accesstoken") - _request_info = _srv.get_request_parameters(request_args=request_args) - - response = AccessTokenResponse( - **{ - "access_token": "2YotnFZFEjr1zCsicMWpAA", - "token_type": "example", - "expires_in": 3600, - "refresh_token": "tGzv3JOkF0XG5Qx2TlKWIA", - "example_parameter": "example_value", - } - ) - - _response = _srv.parse_response(response.to_json(), sformat="json") - # since no state attribute is involved, a key is minted - _key = rndstr(16) - _srv.update_service_context(_response, key=_key) - info = _srv.upstream_get("context").cstate.get(_key) - assert "__expires_at" in info - - # Move from token to refresh token service - - _srv = self.entity.get_service("refresh_token") - _request_info = _srv.get_request_parameters(request_args=request_args, state=_key) - - refresh_response = AccessTokenResponse( - **{ - "access_token": "wy4R01DmMoB5xkI65nNkVv1l", - "token_type": "example", - "expires_in": 3600, - "refresh_token": "lhNX9LSG8w1QuD6tSgc6CPfJ", - } - ) - - _response = _srv.parse_response(refresh_response.to_json(), sformat="json") - _srv.update_service_context(_response, key=_key) - info = _srv.upstream_get("context").cstate.get(_key) - assert "__expires_at" in info - - _request_info = _srv.get_request_parameters(request_args=request_args, state=_key) - assert _request_info["headers"] == { - "Authorization": "Bearer {}".format(refresh_response["refresh_token"]), - "Content-Type": "application/x-www-form-urlencoded", - } diff --git a/tests/test_client_25_oauth2_cc_ropc.py b/tests/test_client_25_oauth2_cc_ropc.py new file mode 100644 index 00000000..f2acb7a7 --- /dev/null +++ b/tests/test_client_25_oauth2_cc_ropc.py @@ -0,0 +1,120 @@ +import pytest + +from idpyoidc.client.entity import Entity +from idpyoidc.message.oauth2 import AccessTokenResponse +from idpyoidc.util import rndstr + +KEYDEF = [{"type": "EC", "crv": "P-256", "use": ["sig"]}] + +BASE_URL = "https://example.com" + + +class TestCC: + + @pytest.fixture(autouse=True) + def create_service(self): + client_config = { + "client_id": "client_id", + "client_secret": "another password", + "base_url": BASE_URL + } + services = { + "client_credentials": { + "class": "idpyoidc.client.oauth2.client_credentials.CCAccessTokenRequest" + } + } + + self.entity = Entity(config=client_config, services=services) + + self.entity.get_service("client_credentials").endpoint = "https://example.com/token" + + def test_token_get_request(self): + _srv = self.entity.get_service("client_credentials") + _info = _srv.get_request_parameters() + assert _info["method"] == "POST" + assert _info["url"] == "https://example.com/token" + assert _info[ + "body"] == "grant_type=client_credentials&client_id=client_id&client_secret=another+password" + + assert _info["headers"] == { + "Content-Type": "application/x-www-form-urlencoded", + } + + def test_token_parse_response(self): + _srv = self.entity.get_service("client_credentials") + _request_info = _srv.get_request_parameters() + + response = AccessTokenResponse( + **{ + "access_token": "2YotnFZFEjr1zCsicMWpAA", + "token_type": "access_token", + "expires_in": 3600, + "example_parameter": "example_value", + } + ) + + _response = _srv.parse_response(response.to_json(), sformat="json") + # since no state attribute is involved, a key is minted + _key = rndstr(16) + _srv.update_service_context(_response, key=_key) + info = _srv.upstream_get("context").cstate.get(_key) + assert "__expires_at" in info + + +class TestROPC: + + @pytest.fixture(autouse=True) + def create_service(self): + client_config = { + "client_id": "client_id", + "client_secret": "another password", + "base_url": BASE_URL + } + services = { + "resource_owner_password_credentials": { + "class": + "idpyoidc.client.oauth2.resource_owner_password_credentials" + ".ROPCAccessTokenRequest" + } + } + + self.entity = Entity(config=client_config, services=services) + + self.entity.get_service( + "resource_owner_password_credentials").endpoint = "https://example.com/token" + + def test_token_get_request(self): + _srv = self.entity.get_service("resource_owner_password_credentials") + _info = _srv.get_request_parameters({'username': 'diana', 'password': 'krall'}) + assert _info["method"] == "POST" + assert _info["url"] == "https://example.com/token" + assert _info["body"] == ( + "username=diana&" + "password=krall&" + "grant_type=password&" + "client_id=client_id&" + "client_secret=another+password") + + assert _info["headers"] == { + "Content-Type": "application/x-www-form-urlencoded", + } + + def test_token_parse_response(self): + _srv = self.entity.get_service("resource_owner_password_credentials") + _request_info = _srv.get_request_parameters() + + response = AccessTokenResponse( + **{ + "access_token": "2YotnFZFEjr1zCsicMWpAA", + "token_type": "access_token", + "expires_in": 3600, + "example_parameter": "example_value", + } + ) + + _response = _srv.parse_response(response.to_json(), sformat="json") + # since no state attribute is involved, a key is minted + _key = rndstr(16) + _srv.update_service_context(_response, key=_key) + info = _srv.upstream_get("context").cstate.get(_key) + assert "__expires_at" in info diff --git a/tests/test_server_23_oidc_registration_endpoint.py b/tests/test_server_23_oidc_registration_endpoint.py index 04a74858..cde81f60 100755 --- a/tests/test_server_23_oidc_registration_endpoint.py +++ b/tests/test_server_23_oidc_registration_endpoint.py @@ -170,7 +170,7 @@ def test_parse(self): _req = self.endpoint.parse_request(CLI_REQ.to_json()) assert isinstance(_req, RegistrationRequest) - assert set(_req.keys()) == set(CLI_REQ.keys()) + assert set(_req.keys()).difference(set(CLI_REQ.keys())) == {'authenticated'} def test_process_request(self): _req = self.endpoint.parse_request(CLI_REQ.to_json()) diff --git a/tests/test_server_24_oauth2_token_endpoint.py b/tests/test_server_24_oauth2_token_endpoint.py index ee37e852..53cccef7 100644 --- a/tests/test_server_24_oauth2_token_endpoint.py +++ b/tests/test_server_24_oauth2_token_endpoint.py @@ -6,17 +6,16 @@ from cryptojwt import KeyJar from cryptojwt.jws.jws import factory from cryptojwt.key_jar import build_keyjar -from idpyoidc.message import REQUIRED_LIST_OF_STRINGS - -from idpyoidc.message import SINGLE_REQUIRED_INT - -from idpyoidc.message import SINGLE_REQUIRED_STRING - -from idpyoidc.message import Message from idpyoidc.context import OidcContext from idpyoidc.defaults import JWT_BEARER +from idpyoidc.message import Message +from idpyoidc.message import REQUIRED_LIST_OF_STRINGS +from idpyoidc.message import SINGLE_REQUIRED_INT +from idpyoidc.message import SINGLE_REQUIRED_STRING +from idpyoidc.message.oauth2 import CCAccessTokenRequest from idpyoidc.message.oauth2 import JWTAccessToken +from idpyoidc.message.oauth2 import ROPCAccessTokenRequest from idpyoidc.message.oidc import AccessTokenRequest from idpyoidc.message.oidc import AuthorizationRequest from idpyoidc.message.oidc import RefreshAccessTokenRequest @@ -257,7 +256,7 @@ def test_parse(self): _token_request["code"] = code.value _req = self.token_endpoint.parse_request(_token_request) - assert set(_req.keys()) == set(_token_request.keys()) + assert set(_req.keys()).difference(set(_token_request.keys())) == {'authenticated'} def test_auth_code_grant_disallowed_per_client(self): areq = AUTH_REQ.copy() @@ -736,11 +735,12 @@ def test_do_refresh_access_token_revoked(self): def test_configure_grant_types(self): conf = {"access_token": {"class": "idpyoidc.server.oidc.token.AccessTokenHelper"}} - self.token_endpoint.configure_grant_types(conf) + _helper = self.token_endpoint.configure_types(conf, + self.token_endpoint.helper_by_grant_type) - assert len(self.token_endpoint.helper) == 1 - assert "access_token" in self.token_endpoint.helper - assert "refresh_token" not in self.token_endpoint.helper + assert len(_helper) == 1 + assert "access_token" in _helper + assert "refresh_token" not in _helper def test_token_request_other_client(self): _context = self.context @@ -842,6 +842,7 @@ def test_refresh_token_request_other_client(self): KEYJAR.import_jwks(CLIENT_KEYJAR.export_jwks(private=True), "client_1") KEYJAR.import_jwks(CLIENT_KEYJAR.export_jwks(private=True), "") + def upstream_get(what, *args): if what == "context": if not args: @@ -850,6 +851,7 @@ def upstream_get(what, *args): if args[0] == 'keyjar': return KEYJAR + def test_def_jwttoken(): _handler = handler.factory(upstream_get=upstream_get, **DEFAULT_TOKEN_HANDLER_ARGS) token_handler = _handler['access_token'] @@ -866,6 +868,7 @@ def test_def_jwttoken(): msg.verify() assert True + def test_jwttoken(): _handler = handler.factory(upstream_get=upstream_get, **TOKEN_HANDLER_ARGS) token_handler = _handler['access_token'] @@ -882,6 +885,7 @@ def test_jwttoken(): msg.verify() assert True + class MyAccessToken(Message): c_param = { "iss": SINGLE_REQUIRED_STRING, @@ -892,6 +896,7 @@ class MyAccessToken(Message): 'usage': SINGLE_REQUIRED_STRING } + def test_jwttoken_2(): _handler = handler.factory(upstream_get=upstream_get, **TOKEN_HANDLER_ARGS) token_handler = _handler['access_token'] @@ -906,4 +911,80 @@ def test_jwttoken_2(): msg = MyAccessToken(**_jws.jwt.payload()) # test if all required claims are there msg.verify() - assert True \ No newline at end of file + assert True + + +class TestClientCredentialsFlow(object): + + @pytest.fixture(autouse=True) + def create_endpoint(self, conf): + server = Server(ASConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) + context = server.context + context.cdb["client_1"] = { + "client_secret": "hemligt", + "redirect_uris": [("https://example.com/cb", None)], + "client_salt": "salted", + "endpoint_auth_method": "client_secret_post", + "response_types": ["code", "token", "code id_token", "id_token"], + "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"], + "allowed_flows": ['client_credentials', 'resource_owner_password_credentials'] + } + self.session_manager = context.session_manager + self.token_endpoint = server.get_endpoint("token") + self.user_id = "diana" + self.context = context + + def test_client_credentials(self): + request = CCAccessTokenRequest(client_id="client_1", client_secret='hemligt', + grant_type='client_credentials', scope="whatever") + request = self.token_endpoint.parse_request(request) + response = self.token_endpoint.process_request(request) + assert set(response.keys()) == {'response_args', 'cookie', 'http_headers'} + assert set(response["response_args"].keys()) == {'access_token', 'token_type', 'scope', + 'expires_in'} + + +class TestResourceOwnerPasswordCredentialsFlow(object): + + @pytest.fixture(autouse=True) + def create_endpoint(self, conf): + conf["authentication"] = { + "user": { + "acr": "urn:oasis:names:tc:SAML:2.0:ac:classes:InternetProtocolPassword", + "class": "idpyoidc.server.user_authn.user.UserPass", + "kwargs": { + "db_conf": { + "class": "idpyoidc.server.util.JSONDictDB", + "kwargs": {"filename": "passwd.json"} + } + } + } + } + + server = Server(ASConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) + context = server.context + context.cdb["client_1"] = { + "client_secret": "hemligt", + "redirect_uris": [("https://example.com/cb", None)], + "client_salt": "salted", + "endpoint_auth_method": "client_secret_post", + "response_types": ["code", "token", "code id_token", "id_token"], + "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"], + "grant_types_supported": ['client_credentials', 'resource_owner_password_credentials'], + } + self.session_manager = context.session_manager + self.token_endpoint = server.get_endpoint("token") + self.context = context + + def test_resource_owner_password_credentials(self): + request = ROPCAccessTokenRequest(client_id="client_1", + client_secret='hemligt', + grant_type='resource_owner_password_credentials', + username='diana', + password='krall', + scope="whatever") + request = self.token_endpoint.parse_request(request) + response = self.token_endpoint.process_request(request) + assert set(response.keys()) == {'response_args', 'cookie', 'http_headers'} + assert set(response["response_args"].keys()) == {'access_token', 'token_type', 'scope', + 'expires_in'} diff --git a/tests/test_server_25_oauth2_cc_ropc.py b/tests/test_server_25_oauth2_cc_ropc.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_server_31_oauth2_introspection.py b/tests/test_server_31_oauth2_introspection.py index e599c71f..547c0bc8 100644 --- a/tests/test_server_31_oauth2_introspection.py +++ b/tests/test_server_31_oauth2_introspection.py @@ -3,8 +3,8 @@ import os import pytest -from cryptojwt import JWT from cryptojwt import as_unicode +from cryptojwt import JWT from cryptojwt.key_jar import build_keyjar from cryptojwt.utils import as_bytes @@ -91,6 +91,7 @@ def full_path(local_file): @pytest.mark.parametrize("jwt_token", [True, False]) class TestEndpoint: + @pytest.fixture(autouse=True) def create_endpoint(self, jwt_token): conf = { @@ -207,7 +208,7 @@ def create_endpoint(self, jwt_token): "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"] } server.keyjar.import_jwks_as_json( - server.keyjar.export_jwks_as_json(private=True),context.issuer + server.keyjar.export_jwks_as_json(private=True), context.issuer ) self.introspection_endpoint = server.get_endpoint("introspection") self.token_endpoint = server.get_endpoint("token") @@ -265,7 +266,7 @@ def test_parse_with_client_auth_in_req(self): ) assert isinstance(_req, TokenIntrospectionRequest) - assert set(_req.keys()) == {"token", "client_id", "client_secret"} + assert set(_req.keys()) == {"token", "client_id", "client_secret", 'authenticated'} def test_parse_with_wrong_client_authn(self): access_token = self._get_access_token(AUTH_REQ) diff --git a/tests/test_server_32_oidc_read_registration.py b/tests/test_server_32_oidc_read_registration.py index 01783749..da6b19f2 100644 --- a/tests/test_server_32_oidc_read_registration.py +++ b/tests/test_server_32_oidc_read_registration.py @@ -75,6 +75,7 @@ class TestEndpoint(object): + @pytest.fixture(autouse=True) def create_endpoint(self): conf = { @@ -149,7 +150,7 @@ def test_do_response(self): "client_id={}".format(_resp["response_args"]["client_id"]), http_info=http_info, ) - assert set(_api_req.keys()) == {"client_id"} + assert set(_api_req.keys()) == {"client_id", 'authenticated'} _info = self.registration_api_endpoint.process_request(request=_api_req) assert set(_info.keys()) == {"response_args"} diff --git a/tests/test_server_35_oidc_token_endpoint.py b/tests/test_server_35_oidc_token_endpoint.py index 9ae4f83e..78ed8787 100755 --- a/tests/test_server_35_oidc_token_endpoint.py +++ b/tests/test_server_35_oidc_token_endpoint.py @@ -284,7 +284,7 @@ def test_parse(self): _token_request["code"] = code.value _req = self.token_endpoint.parse_request(_token_request) - assert set(_req.keys()) == set(_token_request.keys()) + assert set(_req.keys()).difference(set(_token_request.keys())) == {'authenticated'} def test_process_request(self): session_id = self._create_session(AUTH_REQ) @@ -938,11 +938,12 @@ def test_do_refresh_access_token_revoked(self): def test_configure_grant_types(self): conf = {"access_token": {"class": "idpyoidc.server.oidc.token.AccessTokenHelper"}} - self.token_endpoint.configure_grant_types(conf) + _helper = self.token_endpoint.configure_types(conf, + self.token_endpoint.helper_by_grant_type) - assert len(self.token_endpoint.helper) == 1 - assert "access_token" in self.token_endpoint.helper - assert "refresh_token" not in self.token_endpoint.helper + assert len(_helper) == 1 + assert "access_token" in _helper + assert "refresh_token" not in _helper def test_access_token_lifetime(self): lifetime = 100 diff --git a/tests/test_server_36_oauth2_token_exchange.py b/tests/test_server_36_oauth2_token_exchange.py index de9eabc0..5c957291 100644 --- a/tests/test_server_36_oauth2_token_exchange.py +++ b/tests/test_server_36_oauth2_token_exchange.py @@ -567,7 +567,7 @@ def test_additional_parameters(self): Test that a token exchange with additional parameters including scope, audience and subject_token_type works. """ - conf = self.endpoint.helper["urn:ietf:params:oauth:grant-type:token-exchange"].config + conf = self.endpoint.grant_type_helper["urn:ietf:params:oauth:grant-type:token-exchange"].config conf["policy"][""]["kwargs"] = {} conf["policy"][""]["kwargs"]["audience"] = ["https://example.com"] conf["policy"][""]["kwargs"]["resource"] = ["https://example.com"] @@ -656,7 +656,7 @@ def test_wrong_resource(self): """ Test that requesting a token for an unknown resource fails. """ - conf = self.endpoint.helper["urn:ietf:params:oauth:grant-type:token-exchange"].config + conf = self.endpoint.grant_type_helper["urn:ietf:params:oauth:grant-type:token-exchange"].config conf["policy"][""]["kwargs"] = {} conf["policy"][""]["kwargs"]["resource"] = ["https://example.com"] areq = AUTH_REQ.copy() @@ -726,7 +726,7 @@ def test_wrong_audience(self): """ Test that requesting a token for an unknown audience fails. """ - conf = self.endpoint.helper["urn:ietf:params:oauth:grant-type:token-exchange"].config + conf = self.endpoint.grant_type_helper["urn:ietf:params:oauth:grant-type:token-exchange"].config conf["policy"][""]["kwargs"] = {} conf["policy"][""]["kwargs"]["audience"] = ["https://example.com"] areq = AUTH_REQ.copy() diff --git a/tests/test_server_38_oauth2_revocation_endpoint.py b/tests/test_server_38_oauth2_revocation_endpoint.py index 3860c760..d2d69a79 100644 --- a/tests/test_server_38_oauth2_revocation_endpoint.py +++ b/tests/test_server_38_oauth2_revocation_endpoint.py @@ -284,7 +284,7 @@ def test_parse_with_client_auth_in_req(self): ) assert isinstance(_req, TokenRevocationRequest) - assert set(_req.keys()) == {"token", "client_id", "client_secret"} + assert set(_req.keys()) == {"token", "client_id", "client_secret", 'authenticated'} def test_parse_with_wrong_client_authn(self): access_token = self._get_access_token(AUTH_REQ) diff --git a/tests/test_server_40_oauth2_pushed_authorization.py b/tests/test_server_40_oauth2_pushed_authorization.py index 0664ed54..fa1a6acd 100644 --- a/tests/test_server_40_oauth2_pushed_authorization.py +++ b/tests/test_server_40_oauth2_pushed_authorization.py @@ -199,6 +199,7 @@ def test_pushed_auth_urlencoded(self): "code_challenge_method", "client_id", "code_challenge", + 'authenticated' } def test_pushed_auth_request(self): @@ -225,6 +226,7 @@ def test_pushed_auth_request(self): "code_challenge", "request", "__verified_request", + 'authenticated' } def test_pushed_auth_urlencoded_process(self): @@ -243,6 +245,7 @@ def test_pushed_auth_urlencoded_process(self): "code_challenge_method", "client_id", "code_challenge", + 'authenticated' } _resp = self.pushed_authorization_endpoint.process_request(_req) diff --git a/tests/test_tandem_10_oauth2_token_exchange.py b/tests/test_tandem_10_oauth2_token_exchange.py index d4a9014d..21fe1db0 100644 --- a/tests/test_tandem_10_oauth2_token_exchange.py +++ b/tests/test_tandem_10_oauth2_token_exchange.py @@ -401,7 +401,7 @@ def test_additional_parameters(self): scope, audience and subject_token_type works. """ endp = self.server.get_endpoint("token") - conf = endp.helper["urn:ietf:params:oauth:grant-type:token-exchange"].config + conf = endp.grant_type_helper["urn:ietf:params:oauth:grant-type:token-exchange"].config conf["policy"][""]["kwargs"] = {} conf["policy"][""]["kwargs"]["audience"] = ["https://example.com"] conf["policy"][""]["kwargs"]["resource"] = ["https://example.com"] @@ -436,7 +436,7 @@ def test_token_exchange_fails_if_disabled(self): grant_types_supported (that are set in its helper attribute). """ endpoint = self.server.get_endpoint("token") - del endpoint.helper["urn:ietf:params:oauth:grant-type:token-exchange"] + del endpoint.grant_type_helper["urn:ietf:params:oauth:grant-type:token-exchange"] resp, _state, _scope = self.process_setup() @@ -452,10 +452,7 @@ def test_token_exchange_fails_if_disabled(self): _te_request, _te_resp = self.do_query("token_exchange", "token", req_args, _state) assert _te_resp["error"] == "invalid_request" - assert ( - _te_resp["error_description"] - == "Unsupported grant_type: urn:ietf:params:oauth:grant-type:token-exchange" - ) + assert _te_resp["error_description"] == "Do not know how to handle this type of request" def test_wrong_resource(self): """ @@ -463,7 +460,7 @@ def test_wrong_resource(self): """ endpoint = self.server.get_endpoint("token") - conf = endpoint.helper["urn:ietf:params:oauth:grant-type:token-exchange"].config + conf = endpoint.grant_type_helper["urn:ietf:params:oauth:grant-type:token-exchange"].config conf["policy"][""]["kwargs"] = {} conf["policy"][""]["kwargs"]["resource"] = ["https://example.com"] @@ -512,7 +509,7 @@ def test_wrong_audience(self): Test that requesting a token for an unknown audience fails. """ endpoint = self.server.get_endpoint("token") - conf = endpoint.helper["urn:ietf:params:oauth:grant-type:token-exchange"].config + conf = endpoint.grant_type_helper["urn:ietf:params:oauth:grant-type:token-exchange"].config conf["policy"][""]["kwargs"] = {} conf["policy"][""]["kwargs"]["audience"] = ["https://example.com"] From 1098be4a5434104e4e62543a5ceb96cf358432bb Mon Sep 17 00:00:00 2001 From: roland Date: Thu, 16 Mar 2023 17:26:57 +0100 Subject: [PATCH 68/76] Refactored token endpoint helpers and added support for the two remaining flows: client credentials and resource owner password credentials. --- src/idpyoidc/claims.py | 9 ++ src/idpyoidc/client/claims/oidc.py | 2 +- src/idpyoidc/client/oauth2/access_token.py | 2 + .../client/oauth2/refresh_access_token.py | 2 + src/idpyoidc/client/oauth2/token_exchange.py | 2 + src/idpyoidc/client/oidc/access_token.py | 2 + .../client/oidc/provider_info_discovery.py | 1 + src/idpyoidc/client/service.py | 12 +- src/idpyoidc/client/service_context.py | 1 + src/idpyoidc/message/oidc/__init__.py | 2 +- src/idpyoidc/server/__init__.py | 2 + src/idpyoidc/server/claims/oauth2.py | 2 +- src/idpyoidc/server/claims/oidc.py | 2 +- src/idpyoidc/server/configure.py | 13 +- src/idpyoidc/server/endpoint.py | 2 +- src/idpyoidc/server/oauth2/authorization.py | 5 +- src/idpyoidc/server/oauth2/token.py | 25 ++- tests/private/token_jwks.json | 2 +- tests/pub_client.jwks | 2 +- tests/pub_iss.jwks | 2 +- tests/request123456.jwt | 2 +- tests/static/jwks.json | 2 +- tests/test_08_transform.py | 8 +- tests/test_09_work_condition.py | 2 - tests/test_client_04_service.py | 1 - tests/test_client_21_oidc_service.py | 5 +- tests/test_client_25_oauth2_cc_ropc.py | 3 +- tests/test_server_16_endpoint_context.py | 31 +--- ...st_server_23_oidc_registration_endpoint.py | 2 +- ...server_24_oauth2_authorization_endpoint.py | 2 +- tests/test_server_24_oauth2_token_endpoint.py | 6 +- tests/test_server_35_oidc_token_endpoint.py | 2 +- tests/test_server_36_oauth2_token_exchange.py | 13 +- tests/test_server_60_dpop.py | 5 - tests/test_tandem_08_oauth2_cc_ropc.py | 145 ++++++++++++++++++ tests/test_tandem_10_oauth2_token_exchange.py | 6 - 36 files changed, 230 insertions(+), 97 deletions(-) create mode 100644 tests/test_tandem_08_oauth2_cc_ropc.py diff --git a/src/idpyoidc/claims.py b/src/idpyoidc/claims.py index fa28bb17..05893a29 100644 --- a/src/idpyoidc/claims.py +++ b/src/idpyoidc/claims.py @@ -217,6 +217,15 @@ def supported(self, claim): def prefers(self): return self.prefer + def get_claim(self, key, default=None): + _val = self.get_usage(key) + if _val is None: + _val = self.get_preference(key) + + if _val is None: + return default + else: + return _val SIGNING_ALGORITHM_SORT_ORDER = ['RS', 'ES', 'PS', 'HS'] diff --git a/src/idpyoidc/client/claims/oidc.py b/src/idpyoidc/client/claims/oidc.py index d68fbeec..dfad0c17 100644 --- a/src/idpyoidc/client/claims/oidc.py +++ b/src/idpyoidc/client/claims/oidc.py @@ -75,7 +75,7 @@ class Claims(client_claims.Claims): "contacts": None, "default_max_age": 86400, "encrypt_id_token_supported": None, - "grant_types_supported": ["authorization_code", "implicit", "refresh_token"], + # "grant_types_supported": ["authorization_code", "refresh_token"], "logo_uri": None, "id_token_signing_alg_values_supported": claims.get_signing_algs, "id_token_encryption_alg_values_supported": claims.get_encryption_algs, diff --git a/src/idpyoidc/client/oauth2/access_token.py b/src/idpyoidc/client/oauth2/access_token.py index f2fe2a56..a100a830 100644 --- a/src/idpyoidc/client/oauth2/access_token.py +++ b/src/idpyoidc/client/oauth2/access_token.py @@ -27,6 +27,8 @@ class AccessToken(Service): request_body_type = "urlencoded" response_body_type = "json" + _include = {"grant_types_supported": ['authorization_code']} + _supports = { "token_endpoint_auth_methods_supported": get_client_authn_methods, "token_endpoint_auth_signing_alg": get_signing_algs, diff --git a/src/idpyoidc/client/oauth2/refresh_access_token.py b/src/idpyoidc/client/oauth2/refresh_access_token.py index 6dbc6d5a..69400787 100644 --- a/src/idpyoidc/client/oauth2/refresh_access_token.py +++ b/src/idpyoidc/client/oauth2/refresh_access_token.py @@ -23,6 +23,8 @@ class RefreshAccessToken(Service): default_authn_method = "bearer_header" http_method = "POST" + _include = {"grant_types_supported": ['refresh_token']} + def __init__(self, upstream_get, conf=None): Service.__init__(self, upstream_get, conf=conf) self.pre_construct.append(self.oauth_pre_construct) diff --git a/src/idpyoidc/client/oauth2/token_exchange.py b/src/idpyoidc/client/oauth2/token_exchange.py index 9bed32cd..36a3658a 100644 --- a/src/idpyoidc/client/oauth2/token_exchange.py +++ b/src/idpyoidc/client/oauth2/token_exchange.py @@ -27,6 +27,8 @@ class TokenExchange(Service): request_body_type = "urlencoded" response_body_type = "json" + _include = {'grant_types_supported': ['urn:ietf:params:oauth:grant-type:token-exchange']} + def __init__(self, upstream_get, conf=None): Service.__init__(self, upstream_get, conf=conf) self.pre_construct.append(self.oauth_pre_construct) diff --git a/src/idpyoidc/client/oidc/access_token.py b/src/idpyoidc/client/oidc/access_token.py index 9eb4329f..c39a404d 100644 --- a/src/idpyoidc/client/oidc/access_token.py +++ b/src/idpyoidc/client/oidc/access_token.py @@ -23,6 +23,8 @@ class AccessToken(access_token.AccessToken): error_msg = oidc.ResponseMessage default_authn_method = "client_secret_basic" + _include = {"grant_types_supported": ['authorization_code']} + _supports = { "token_endpoint_auth_methods_supported": get_client_authn_methods, "token_endpoint_auth_signing_alg_values_supported": get_signing_algs diff --git a/src/idpyoidc/client/oidc/provider_info_discovery.py b/src/idpyoidc/client/oidc/provider_info_discovery.py index 12a87b0b..a05fde77 100644 --- a/src/idpyoidc/client/oidc/provider_info_discovery.py +++ b/src/idpyoidc/client/oidc/provider_info_discovery.py @@ -46,6 +46,7 @@ class ProviderInfoDiscovery(server_metadata.ServerMetadata): error_msg = ResponseMessage service_name = "provider_info" + _include = {} _supports = {} def __init__(self, upstream_get, conf=None): diff --git a/src/idpyoidc/client/service.py b/src/idpyoidc/client/service.py index 0fc6f52f..e17bbf49 100644 --- a/src/idpyoidc/client/service.py +++ b/src/idpyoidc/client/service.py @@ -1,4 +1,5 @@ """ The basic Service class upon which all the specific services are built. """ +import copy import json import logging from typing import Callable @@ -13,8 +14,8 @@ from idpyoidc.impexp import ImpExp from idpyoidc.item import DLDict from idpyoidc.message import Message -from idpyoidc.message.oauth2 import ResponseMessage from idpyoidc.message.oauth2 import is_error_message +from idpyoidc.message.oauth2 import ResponseMessage from idpyoidc.util import importer from .client_auth import client_auth_setup from .client_auth import method_to_item @@ -68,6 +69,7 @@ class Service(ImpExp): init_args = ["upstream_get"] + _include = {} _supports = {} _callback_path = {} @@ -648,6 +650,14 @@ def supports(self): res[key] = val return res + def extends(self, info): + for claim, val in self._include.items(): + if claim in info: + info[claim].extend(val) + else: + info[claim] = copy.copy(val) + return info + def get_callback_path(self, callback): return self._callback_path.get(callback) diff --git a/src/idpyoidc/client/service_context.py b/src/idpyoidc/client/service_context.py index e0ea87d0..ae6e75d0 100644 --- a/src/idpyoidc/client/service_context.py +++ b/src/idpyoidc/client/service_context.py @@ -292,6 +292,7 @@ def supports(self): else: for service in services.values(): res.update(service.supports()) + res = service.extends(res) res.update(self.claims.supports()) return res diff --git a/src/idpyoidc/message/oidc/__init__.py b/src/idpyoidc/message/oidc/__init__.py index 9f8fd295..797eaa63 100644 --- a/src/idpyoidc/message/oidc/__init__.py +++ b/src/idpyoidc/message/oidc/__init__.py @@ -911,7 +911,7 @@ class ProviderConfigurationResponse(ResponseMessage): "request_parameter_supported": False, "request_uri_parameter_supported": True, "require_request_uri_registration": True, - "grant_types_supported": ["authorization_code", "implicit"], + "grant_types_supported": ["authorization_code"], } def verify(self, **kwargs): diff --git a/src/idpyoidc/server/__init__.py b/src/idpyoidc/server/__init__.py index 0a277ca8..7f3d7d94 100644 --- a/src/idpyoidc/server/__init__.py +++ b/src/idpyoidc/server/__init__.py @@ -78,6 +78,8 @@ def __init__( if _token_endp: _token_endp.allow_refresh = allow_refresh_token(self.context) + self.context.map_supported_to_preferred() + def get_endpoints(self, *arg): return self.endpoint diff --git a/src/idpyoidc/server/claims/oauth2.py b/src/idpyoidc/server/claims/oauth2.py index 8b259487..e0a1f418 100644 --- a/src/idpyoidc/server/claims/oauth2.py +++ b/src/idpyoidc/server/claims/oauth2.py @@ -19,7 +19,7 @@ class Claims(claims.Claims): _supports = { # "client_authn_methods": get_client_authn_methods, - "grant_types_supported": ["authorization_code", "implicit", "refresh_token"], + # "grant_types_supported": ["authorization_code", "implicit", "refresh_token"], "response_types_supported": ["code"], "response_modes_supported": ["code"], "jwks_uri": None, diff --git a/src/idpyoidc/server/claims/oidc.py b/src/idpyoidc/server/claims/oidc.py index 26434eaa..5f48dcd6 100644 --- a/src/idpyoidc/server/claims/oidc.py +++ b/src/idpyoidc/server/claims/oidc.py @@ -44,7 +44,7 @@ class Claims(server_claims.Claims): "default_max_age": 86400, "display_values_supported": None, "encrypt_id_token_supported": None, - "grant_types_supported": ["authorization_code", "implicit", "refresh_token"], + # "grant_types_supported": ["authorization_code", "implicit", "refresh_token"], "id_token_signing_alg_values_supported": claims.get_signing_algs, "id_token_encryption_alg_values_supported": claims.get_encryption_algs, "id_token_encryption_enc_values_supported": claims.get_encryption_encs, diff --git a/src/idpyoidc/server/configure.py b/src/idpyoidc/server/configure.py index 0daae06a..8dd8f215 100755 --- a/src/idpyoidc/server/configure.py +++ b/src/idpyoidc/server/configure.py @@ -14,14 +14,8 @@ logger = logging.getLogger(__name__) OP_DEFAULT_CONFIG = { - "capabilities": { + "preference": { "subject_types_supported": ["public", "pairwise"], - "grant_types_supported": [ - "authorization_code", - "implicit", - "urn:ietf:params:oauth:grant-type:jwt-bearer", - "refresh_token", - ], }, "cookie_handler": { "class": "idpyoidc.server.cookie_handler.CookieHandler", @@ -157,6 +151,7 @@ class EntityConfiguration(Base): "httpc_params": {}, "issuer": "", "key_conf": None, + 'preference': {}, "session_params": None, "template_dir": None, "token_handler_args": {}, @@ -344,11 +339,11 @@ def __init__( }, } }, - "capabilities": { + "preference": { "subject_types_supported": ["public", "pairwise"], "grant_types_supported": [ "authorization_code", - "implicit", + # "implicit", "urn:ietf:params:oauth:grant-type:jwt-bearer", "refresh_token", ], diff --git a/src/idpyoidc/server/endpoint.py b/src/idpyoidc/server/endpoint.py index 4ffd8bff..a0763ceb 100755 --- a/src/idpyoidc/server/endpoint.py +++ b/src/idpyoidc/server/endpoint.py @@ -227,7 +227,7 @@ def parse_request( req["client_id"] = auth_info["client_id"] _auth_method = auth_info.get('method') - if _auth_method and _auth_method != 'public': + if _auth_method and _auth_method not in ['public', 'none']: req['authenticated'] = True _client_id = auth_info["client_id"] diff --git a/src/idpyoidc/server/oauth2/authorization.py b/src/idpyoidc/server/oauth2/authorization.py index df2ca321..66e6989e 100755 --- a/src/idpyoidc/server/oauth2/authorization.py +++ b/src/idpyoidc/server/oauth2/authorization.py @@ -268,7 +268,7 @@ def authn_args_gather( def check_unknown_scopes_policy(request_info, client_id, context): - if not context.conf["capabilities"].get("deny_unknown_scopes"): + if not context.claims.get_preference("deny_unknown_scopes"): return scope = request_info["scope"] @@ -347,8 +347,9 @@ class Authorization(Endpoint): "request_object_signing_alg_values_supported": claims.get_signing_algs, "request_object_encryption_alg_values_supported": claims.get_encryption_algs, "request_object_encryption_enc_values_supported": claims.get_encryption_encs, - "grant_types_supported": ["authorization_code", "implicit"], + # "grant_types_supported": ["authorization_code", "implicit"], "scopes_supported": [], + "deny_unknown_scopes": False } default_capabilities = { "client_authn_method": ["request_param", "public"], diff --git a/src/idpyoidc/server/oauth2/token.py b/src/idpyoidc/server/oauth2/token.py index 93081309..98bc9fa8 100755 --- a/src/idpyoidc/server/oauth2/token.py +++ b/src/idpyoidc/server/oauth2/token.py @@ -22,6 +22,7 @@ logger = logging.getLogger(__name__) + class Token(Endpoint): request_cls = Message response_cls = AccessTokenResponse @@ -33,14 +34,19 @@ class Token(Endpoint): endpoint_name = "token_endpoint" name = "token" default_capabilities = {"token_endpoint_auth_signing_alg_values_supported": None} + token_exchange_helper = TokenExchangeHelper + helper_by_grant_type = { "authorization_code": AccessTokenHelper, "refresh_token": RefreshTokenHelper, "urn:ietf:params:oauth:grant-type:token-exchange": TokenExchangeHelper, "client_credentials": ClientCredentials, - "resource_owner_password_credentials": ResourceOwnerPasswordCredentials, + "password": ResourceOwnerPasswordCredentials, + } + + _supports = { + "grant_types_supported": list(helper_by_grant_type.keys()) } - token_exchange_helper = TokenExchangeHelper def __init__(self, upstream_get, new_refresh_token=False, **kwargs): Endpoint.__init__(self, upstream_get, **kwargs) @@ -49,8 +55,8 @@ def __init__(self, upstream_get, new_refresh_token=False, **kwargs): self.new_refresh_token = new_refresh_token self.grant_type_helper = self.configure_types(kwargs.get("grant_types_helpers"), self.helper_by_grant_type) - self.grant_types_supported = kwargs.get("grant_types_supported", - list(self.grant_type_helper.keys())) + # self.grant_types_supported = kwargs.get("grant_types_supported", + # list(self.grant_type_helper.keys())) self.revoke_refresh_on_issue = kwargs.get("revoke_refresh_on_issue", False) self.resource_indicators_config = kwargs.get('resource_indicators', None) @@ -95,9 +101,11 @@ def _get_helper(self, _client_id = client_id or request.get('client_id') if client_id: client = self.upstream_get('context').cdb[client_id] - grant_types_supported = client.get("grant_types_supported", - self.grant_types_supported) - if grant_type not in grant_types_supported: + _grant_types_supported = client.get("grant_types_supported", + self.upstream_get('context').claims.get_claim( + "grant_types_supported", []) + ) + if grant_type not in _grant_types_supported: return self.error_cls( error="invalid_request", error_description=f"Unsupported grant_type: {grant_type}", @@ -181,3 +189,6 @@ def process_request(self, request: Optional[Union[Message, dict]] = None, **kwar if _cookie: resp["cookie"] = [_cookie] return resp + + def supports(self): + return {'grant_types_supported': list(self.grant_type_helper.keys())} diff --git a/tests/private/token_jwks.json b/tests/private/token_jwks.json index b99aa4bb..9feee908 100644 --- a/tests/private/token_jwks.json +++ b/tests/private/token_jwks.json @@ -1 +1 @@ -{"keys": [{"kty": "oct", "use": "enc", "kid": "code", "k": "vSHDkLBHhDStkR0NWu8519rmV5zmnm5_"}, {"kty": "oct", "use": "enc", "kid": "refresh", "k": "zaBDhx4X98ZokBeA8X9hzoAIzIn1jpy3"}]} \ No newline at end of file +{"keys": [{"kty": "oct", "use": "enc", "kid": "code", "k": "vSHDkLBHhDStkR0NWu8519rmV5zmnm5_"}, {"kty": "oct", "use": "enc", "kid": "refresh", "k": "31Lm9Fi6Lt3NBm2djOvfV3k-j7kFVORm"}]} \ No newline at end of file diff --git a/tests/pub_client.jwks b/tests/pub_client.jwks index d5ce25ed..84a27042 100644 --- a/tests/pub_client.jwks +++ b/tests/pub_client.jwks @@ -1 +1 @@ -{"keys": [{"kty": "EC", "use": "sig", "kid": "azZQQ2FEQjh3QnVZWVdrbHJkMEZSaWR6aVJ0LTBjeUFfeWRlbTRrRFZ5VQ", "crv": "P-256", "x": "2ADe18caWWGp6hpRbfa9HqQHDFNpid9xUmR56Wzm_wc", "y": "HnD_8QBanz4Y-UF8mKQFZXfqkGkXUSm34mLsdDKtSyk"}, {"kty": "RSA", "use": "sig", "kid": "SHEyYWcwNVk0LTdROTZzZ2FUWndIVXdack0xWUM5SEpwcS03dVUxWU4zRQ", "n": "rRz52ddyP9Y2ezSlRsnkt-sjXfV_Ii7vOFX-cStLE3IUlVeSJGEe_kAASLr2r3BE2unjntaxj67NP8D95h_rzG1SpCklTEn-aTe3FOwNyTzUH_oiDVeRoEcf04Y43ciRGYRB5PhI6ii-2lYuig6hyUr776Qxiu6-0zw-M_ay2MgGSy5CEj55dDSvcUyxStUObxGpPWnEvybO1vnE7iJEWGNe0L5uPe5nLidOiR-JwjxSWEx1xZYtIjxaf2Ulu-qu4hwgwBUQdx4bNZyBfljKj55skWuHqPMG3xMjnedQC6Ms5bR3rIkbBpvmgI3kJK-4CZikM6ruyLo94-Lk19aYQw", "e": "AQAB"}]} \ No newline at end of file +{"keys": [{"kty": "EC", "use": "sig", "kid": "azZQQ2FEQjh3QnVZWVdrbHJkMEZSaWR6aVJ0LTBjeUFfeWRlbTRrRFZ5VQ", "crv": "P-256", "x": "2ADe18caWWGp6hpRbfa9HqQHDFNpid9xUmR56Wzm_wc", "y": "HnD_8QBanz4Y-UF8mKQFZXfqkGkXUSm34mLsdDKtSyk"}, {"kty": "RSA", "use": "sig", "kid": "SHEyYWcwNVk0LTdROTZzZ2FUWndIVXdack0xWUM5SEpwcS03dVUxWU4zRQ", "e": "AQAB", "n": "rRz52ddyP9Y2ezSlRsnkt-sjXfV_Ii7vOFX-cStLE3IUlVeSJGEe_kAASLr2r3BE2unjntaxj67NP8D95h_rzG1SpCklTEn-aTe3FOwNyTzUH_oiDVeRoEcf04Y43ciRGYRB5PhI6ii-2lYuig6hyUr776Qxiu6-0zw-M_ay2MgGSy5CEj55dDSvcUyxStUObxGpPWnEvybO1vnE7iJEWGNe0L5uPe5nLidOiR-JwjxSWEx1xZYtIjxaf2Ulu-qu4hwgwBUQdx4bNZyBfljKj55skWuHqPMG3xMjnedQC6Ms5bR3rIkbBpvmgI3kJK-4CZikM6ruyLo94-Lk19aYQw"}]} \ No newline at end of file diff --git a/tests/pub_iss.jwks b/tests/pub_iss.jwks index 77081f40..9b062907 100644 --- a/tests/pub_iss.jwks +++ b/tests/pub_iss.jwks @@ -1 +1 @@ -{"keys": [{"kty": "EC", "use": "sig", "kid": "SmdKMlVGcG1zMnprdDdXZGpGWEczdHhlZVpGbkx1THpPdUY4d0w4bnZkSQ", "crv": "P-256", "x": "tRHJYm0fsOi0icpGEb33qiDVgt68ltMoYSWdLGhDGz4", "y": "fRpX0i6p5Jigf5I0qwW34PyStosMShwWAWS8x_w5o7E"}, {"kty": "RSA", "use": "sig", "kid": "R0FsaFdqREFaUFp1c0MwbUpsbHVSZ200blBJZWJVMTUtNGsyVlBmdHk5UQ", "n": "2ilgsKVqF92KfhwmosSVeZOaDgb3RF1mbg-pqkmLO6YpOO06LF4V4angF-GhP-ysAm2E75aSIU4tnHVThFlcxTgKFqjYKJQXyVzTVK2r-L2IbvFPaDtvoU6WteybpMlIUVk2po3cFDGObCWYKCm7CUOLlwH0uOpui66P9VSCqdKVKbJRAQBvTSbP10KWPxulfqjWGJtHO5fY7-JVWwOBkG-eHSJIT_uaoPjyvKCZjknq04bLUV9qP78KRQpRyYijBN60w2v8F79baN9CN10TIEjjWKGz0uX0M_YYQzTUoSY5l5ka9RkL3wT4o2iQ1t5nHphX6aA-gqwgCQmi-nvjaw", "e": "AQAB"}]} \ No newline at end of file +{"keys": [{"kty": "EC", "use": "sig", "kid": "SmdKMlVGcG1zMnprdDdXZGpGWEczdHhlZVpGbkx1THpPdUY4d0w4bnZkSQ", "crv": "P-256", "x": "tRHJYm0fsOi0icpGEb33qiDVgt68ltMoYSWdLGhDGz4", "y": "fRpX0i6p5Jigf5I0qwW34PyStosMShwWAWS8x_w5o7E"}, {"kty": "RSA", "use": "sig", "kid": "R0FsaFdqREFaUFp1c0MwbUpsbHVSZ200blBJZWJVMTUtNGsyVlBmdHk5UQ", "e": "AQAB", "n": "2ilgsKVqF92KfhwmosSVeZOaDgb3RF1mbg-pqkmLO6YpOO06LF4V4angF-GhP-ysAm2E75aSIU4tnHVThFlcxTgKFqjYKJQXyVzTVK2r-L2IbvFPaDtvoU6WteybpMlIUVk2po3cFDGObCWYKCm7CUOLlwH0uOpui66P9VSCqdKVKbJRAQBvTSbP10KWPxulfqjWGJtHO5fY7-JVWwOBkG-eHSJIT_uaoPjyvKCZjknq04bLUV9qP78KRQpRyYijBN60w2v8F79baN9CN10TIEjjWKGz0uX0M_YYQzTUoSY5l5ka9RkL3wT4o2iQ1t5nHphX6aA-gqwgCQmi-nvjaw"}]} \ No newline at end of file diff --git a/tests/request123456.jwt b/tests/request123456.jwt index 28f863e1..1d5c9d1d 100644 --- a/tests/request123456.jwt +++ b/tests/request123456.jwt @@ -1 +1 @@ -eyJhbGciOiJSUzI1NiIsImtpZCI6IlNIRXlZV2N3TlZrMExUZFJPVFp6WjJGVVduZElWWGRhY2sweFdVTTVTRXB3Y1MwM2RWVXhXVTR6UlEifQ.eyJyZXNwb25zZV90eXBlIjogImNvZGUiLCAic3RhdGUiOiAic3RhdGUiLCAicmVkaXJlY3RfdXJpIjogImh0dHBzOi8vZXhhbXBsZS5jb20vY2xpL2F1dGh6X2NiIiwgInNjb3BlIjogIm9wZW5pZCIsICJub25jZSI6ICJJTVFraVkxckVhT2pncW5VZkpGSjN6dGV1MG9QMDJ2S1J5d0xyM0p1aHFjIiwgImNsaWVudF9pZCI6ICJjbGllbnRfaWQiLCAiaXNzIjogImNsaWVudF9pZCIsICJpYXQiOiAxNjc1NjczNjU1LCAiYXVkIjogWyJodHRwczovL2V4YW1wbGUuY29tIl19.Oj4q4UDeBTbkpI3oAGl6Bwt_DS1_rHJQxmpLwkKQTEgaTh08Fhr64iZoxUyyJYOZGkmMlgXz5nJZLt1uO5uotsA2wZaoAn6-EMXZ8lfm8vDxq5YdqoJX_8UfE3HSQDlmIsuHdtjOSYijYUP2FtSMutryzxkAW9Sp50GpaJ6QmpL_GE55lEfpHpR4A_2rf0SikwW2xHtMYU90XI0Jv_m-6rBf6sTlaqePge6ToNjCxpyYDOWKMa-qwrMeFe99JECzDdMbMYXQB2WPmRVuFkV7mJOoxFY7wviqjT_-eM3YI_jKDPB6M2-oVXQ7IjUv5t3WqYWasEoGEXgkMk-WcfBtTw \ No newline at end of file +eyJhbGciOiJSUzI1NiIsImtpZCI6IlNIRXlZV2N3TlZrMExUZFJPVFp6WjJGVVduZElWWGRhY2sweFdVTTVTRXB3Y1MwM2RWVXhXVTR6UlEifQ.eyJyZXNwb25zZV90eXBlIjogImNvZGUiLCAic3RhdGUiOiAic3RhdGUiLCAicmVkaXJlY3RfdXJpIjogImh0dHBzOi8vZXhhbXBsZS5jb20vY2xpL2F1dGh6X2NiIiwgInNjb3BlIjogIm9wZW5pZCIsICJub25jZSI6ICJBWGV0Wm1SVXFWT2NPX0NTMFZrNF9oM05vRjlJRHpzYUEwZHBWRFpZVS1BIiwgImNsaWVudF9pZCI6ICJjbGllbnRfaWQiLCAiaXNzIjogImNsaWVudF9pZCIsICJpYXQiOiAxNjc4OTU2Mzg1LCAiYXVkIjogWyJodHRwczovL2V4YW1wbGUuY29tIl19.axJ7C32rBbu5jWwnZAa04_3QSPwytuRtUjRTOpcHnSa1D_XsnPjVuVmRbYWFPepcaPeMN6GYuOn22_6quVSRktnMvVPfh-C1YttosfWOYavq60H3Hav3mLa357gGgCSRJJG1RGXQlSf5PU7P1hdiJoCaiejpVaA7efkBcQagTndlxFoE3oRoeKr9RqLKPRvRnlB-qv6FpanLwm4gY4NnAOjHo_1BOP6tvJTfad6aQwW5sRL-NaKLLrfkHgKnsTpyEUrBtl6-63O8_w9ckBsT1B9JBH1T6vhkjY-vGBptTnrAf_0giDi_Lw7jZMrETqJjnyMlQIDd88AOlnHV0IDvew \ No newline at end of file diff --git a/tests/static/jwks.json b/tests/static/jwks.json index 8322d976..161a407b 100644 --- a/tests/static/jwks.json +++ b/tests/static/jwks.json @@ -1 +1 @@ -{"keys": [{"kty": "RSA", "use": "sig", "kid": "YnNESFhyQjloMnYzV2VqRGR2a3VCblFLX2h4VGl3TDVlY3FUNkViUE90bw", "n": "2iMaDALTQolz4UaT--GhjriLMyNbrDGlIXxSmgRh17Cm3cuHiyPOIQv1pjZVg4ATU1aafxmFyTfrmtf56tPuJ8yqcNNZC8XadYPAw7PTW9g8GJgLtC8GURJ9GQZD6FYIE6YCou8fYo6yd4b99y2y_vsl06cm9xQnstfp6eyMkcgQyrmdmlbyeuXwvcxsxtGX61MTJtCp4VELmDctJiYP_bD7HNRPV7uqXDMNmWSY0TYL-tg0As4y8-w3wSwmtcfWhnQEraFT0-m4hBpEWHlouuFNXRQIrXbamKxeh6kJNO0wJN8fZ4Ovygf8sE4kEwBPfWO59wxDF7camTpDUqg29Q", "e": "AQAB"}, {"kty": "EC", "use": "sig", "kid": "aWhtalRSTDZmNmRTd1ZDNWZmY3ZGMTNqM1dnLVA2RjQyMi1CNGdOSUNKVQ", "crv": "P-256", "x": "Ww5XVT3CxYN88BpJDZGodRiar0qr8UvPFaRoqzyD1Io", "y": "w23EDFAvwe03NjL5NKtUXwxuVMFmEn3ecJOPbljiDkg"}]} \ No newline at end of file +{"keys": [{"kty": "RSA", "use": "sig", "kid": "YnNESFhyQjloMnYzV2VqRGR2a3VCblFLX2h4VGl3TDVlY3FUNkViUE90bw", "e": "AQAB", "n": "2iMaDALTQolz4UaT--GhjriLMyNbrDGlIXxSmgRh17Cm3cuHiyPOIQv1pjZVg4ATU1aafxmFyTfrmtf56tPuJ8yqcNNZC8XadYPAw7PTW9g8GJgLtC8GURJ9GQZD6FYIE6YCou8fYo6yd4b99y2y_vsl06cm9xQnstfp6eyMkcgQyrmdmlbyeuXwvcxsxtGX61MTJtCp4VELmDctJiYP_bD7HNRPV7uqXDMNmWSY0TYL-tg0As4y8-w3wSwmtcfWhnQEraFT0-m4hBpEWHlouuFNXRQIrXbamKxeh6kJNO0wJN8fZ4Ovygf8sE4kEwBPfWO59wxDF7camTpDUqg29Q"}, {"kty": "EC", "use": "sig", "kid": "aWhtalRSTDZmNmRTd1ZDNWZmY3ZGMTNqM1dnLVA2RjQyMi1CNGdOSUNKVQ", "crv": "P-256", "x": "Ww5XVT3CxYN88BpJDZGodRiar0qr8UvPFaRoqzyD1Io", "y": "w23EDFAvwe03NjL5NKtUXwxuVMFmEn3ecJOPbljiDkg"}]} \ No newline at end of file diff --git a/tests/test_08_transform.py b/tests/test_08_transform.py index 08588bd2..ac4d9aff 100644 --- a/tests/test_08_transform.py +++ b/tests/test_08_transform.py @@ -62,7 +62,6 @@ def test_supported(self): 'frontchannel_logout_session_required', 'frontchannel_logout_supported', 'frontchannel_logout_uri', - 'grant_types_supported', 'id_token_encryption_alg_values_supported', 'id_token_encryption_enc_values_supported', 'id_token_signing_alg_values_supported', @@ -109,6 +108,7 @@ def test_oidc_setup(self): 'error', 'error_description', 'error_uri', + 'grant_types_supported', 'issuer', 'op_policy_uri', 'op_tos_uri', @@ -159,7 +159,6 @@ def test_oidc_setup(self): 'default_max_age', 'encrypt_request_object_supported', 'encrypt_userinfo_supported', - 'grant_types_supported', 'id_token_encryption_alg_values_supported', 'id_token_encryption_enc_values_supported', 'id_token_signing_alg_values_supported', @@ -184,7 +183,7 @@ def test_oidc_setup(self): reg_claim.append(key) assert set(RegistrationRequest.c_param.keys()).difference(set(reg_claim)) == { - 'post_logout_redirect_uri'} + 'post_logout_redirect_uri', 'grant_types'} # Which ones are list -> singletons @@ -236,7 +235,6 @@ def test_provider_info(self): 'default_max_age', 'encrypt_request_object_supported', 'encrypt_userinfo_supported', - 'grant_types_supported', 'id_token_encryption_alg_values_supported', 'id_token_encryption_enc_values_supported', 'id_token_signing_alg_values_supported', @@ -340,7 +338,6 @@ def test_registration_response(self): 'client_name', 'contacts', 'default_max_age', - 'grant_types', 'id_token_signed_response_alg', 'logo_uri', 'redirect_uris', @@ -382,7 +379,6 @@ def test_registration_response(self): 'default_max_age', 'encrypt_request_object_supported', 'encrypt_userinfo_supported', - 'grant_types', 'id_token_signed_response_alg', 'jwks_uri', 'logo_uri', diff --git a/tests/test_09_work_condition.py b/tests/test_09_work_condition.py index b6f41230..6ebeb5e3 100644 --- a/tests/test_09_work_condition.py +++ b/tests/test_09_work_condition.py @@ -167,7 +167,6 @@ def test_registration_response(self): 'client_name', 'contacts', 'default_max_age', - 'grant_types', 'id_token_signed_response_alg', 'jwks', 'logo_uri', @@ -212,7 +211,6 @@ def test_registration_response(self): 'default_max_age', 'encrypt_request_object_supported', 'encrypt_userinfo_supported', - 'grant_types', 'id_token_signed_response_alg', 'jwks', 'jwks_uri', diff --git a/tests/test_client_04_service.py b/tests/test_client_04_service.py index 2873dd5c..d0ded3a6 100644 --- a/tests/test_client_04_service.py +++ b/tests/test_client_04_service.py @@ -59,7 +59,6 @@ def test_use(self): 'client_id', 'default_max_age', 'encrypt_request_object_supported', - 'grant_types', 'id_token_signed_response_alg', 'jwks', 'redirect_uris', diff --git a/tests/test_client_21_oidc_service.py b/tests/test_client_21_oidc_service.py index 729314a3..fb3ac1b2 100644 --- a/tests/test_client_21_oidc_service.py +++ b/tests/test_client_21_oidc_service.py @@ -552,7 +552,6 @@ def test_post_parse(self): # "require_request_uri_registration": True, "grant_types_supported": [ "authorization_code", - "token", "urn:ietf:params:oauth:grant-type:jwt-bearer", "refresh_token", ], @@ -754,7 +753,7 @@ def test_post_parse(self): 'encrypt_id_token_supported': False, 'encrypt_request_object_supported': False, 'encrypt_userinfo_supported': False, - 'grant_types': ['authorization_code', 'refresh_token'], + 'grant_types': ['authorization_code'], 'id_token_signed_response_alg': 'RS256', 'post_logout_redirect_uris': ['https://rp.example.com/post'], 'redirect_uris': ['https://example.com/cli/authz_cb'], @@ -815,7 +814,7 @@ def test_post_parse_2(self): 'encrypt_id_token_supported': False, 'encrypt_request_object_supported': False, 'encrypt_userinfo_supported': False, - 'grant_types': ['authorization_code', 'implicit', 'refresh_token'], + 'grant_types': ['authorization_code'], 'id_token_signed_response_alg': 'RS256', 'post_logout_redirect_uris': ['https://rp.example.com/post'], 'redirect_uris': ['https://example.com/cli/authz_cb'], diff --git a/tests/test_client_25_oauth2_cc_ropc.py b/tests/test_client_25_oauth2_cc_ropc.py index f2acb7a7..e03382fb 100644 --- a/tests/test_client_25_oauth2_cc_ropc.py +++ b/tests/test_client_25_oauth2_cc_ropc.py @@ -1,6 +1,7 @@ import pytest from idpyoidc.client.entity import Entity +from idpyoidc.client.oauth2 import Client from idpyoidc.message.oauth2 import AccessTokenResponse from idpyoidc.util import rndstr @@ -24,7 +25,7 @@ def create_service(self): } } - self.entity = Entity(config=client_config, services=services) + self.entity = Client(config=client_config, services=services) self.entity.get_service("client_credentials").endpoint = "https://example.com/token" diff --git a/tests/test_server_16_endpoint_context.py b/tests/test_server_16_endpoint_context.py index 751018c3..38d3dc2c 100644 --- a/tests/test_server_16_endpoint_context.py +++ b/tests/test_server_16_endpoint_context.py @@ -46,13 +46,8 @@ class Endpoint_1(Endpoint): "client_secret_basic", ], "subject_types_supported": ["public", "pairwise"], - "grant_types_supported": [ - "authorization_code", - "implicit", - "urn:ietf:params:oauth:grant-type:jwt-bearer", - "refresh_token", - ], "endpoint": { + "userinfo": { "path": "userinfo", "class": Endpoint_1, @@ -103,7 +98,6 @@ def create_endpoint_context(self): def test(self): self.context.set_provider_info() assert set(self.context.provider_info.keys()) == { - 'grant_types_supported', 'id_token_signing_alg_values_supported', 'issuer', 'jwks_uri', @@ -112,28 +106,6 @@ def test(self): 'userinfo_signing_alg_values_supported', 'version'} - def test_allow_refresh_token(self): - assert allow_refresh_token(self.context) - - # Have the software but is not expected to use it. - self.context.set_preference("grant_types_supported", [ - "authorization_code", - "implicit", - "urn:ietf:params:oauth:grant-type:jwt-bearer", - ]) - assert allow_refresh_token(self.context) is False - - # Don't have the software but are expected to use it. - self.context.set_preference("grant_types_supported", [ - "authorization_code", - "implicit", - "urn:ietf:params:oauth:grant-type:jwt-bearer", - "refresh_token", - ]) - del self.context.session_manager.token_handler.handler["refresh_token"] - with pytest.raises(OidcEndpointError): - assert allow_refresh_token(self.context) is False - class Tokenish(Endpoint): _supports = { @@ -200,7 +172,6 @@ def test_provider_configuration(kwargs): server.context.set_provider_info() pi = server.context.provider_info assert set(pi.keys()) == {'acr_values_supported', - 'grant_types_supported', 'id_token_signing_alg_values_supported', 'issuer', 'jwks_uri', diff --git a/tests/test_server_23_oidc_registration_endpoint.py b/tests/test_server_23_oidc_registration_endpoint.py index cde81f60..002f1073 100755 --- a/tests/test_server_23_oidc_registration_endpoint.py +++ b/tests/test_server_23_oidc_registration_endpoint.py @@ -170,7 +170,7 @@ def test_parse(self): _req = self.endpoint.parse_request(CLI_REQ.to_json()) assert isinstance(_req, RegistrationRequest) - assert set(_req.keys()).difference(set(CLI_REQ.keys())) == {'authenticated'} + assert set(_req.keys()).difference(set(CLI_REQ.keys())) == set() def test_process_request(self): _req = self.endpoint.parse_request(CLI_REQ.to_json()) diff --git a/tests/test_server_24_oauth2_authorization_endpoint.py b/tests/test_server_24_oauth2_authorization_endpoint.py index 29685a01..893e2194 100755 --- a/tests/test_server_24_oauth2_authorization_endpoint.py +++ b/tests/test_server_24_oauth2_authorization_endpoint.py @@ -579,7 +579,7 @@ def test_setup_auth_invalid_scope(self): kaka = _context.cookie_handler.make_cookie_content("value", "sso") # force to 400 Http Error message if the release scope policy is heavy! - _context.conf["capabilities"]["deny_unknown_scopes"] = True + _context.claims.set_preference("deny_unknown_scopes", True) excp = None try: res = self.endpoint.process_request(request, http_info={"headers": {"cookie": [kaka]}}) diff --git a/tests/test_server_24_oauth2_token_endpoint.py b/tests/test_server_24_oauth2_token_endpoint.py index 53cccef7..262f7331 100644 --- a/tests/test_server_24_oauth2_token_endpoint.py +++ b/tests/test_server_24_oauth2_token_endpoint.py @@ -927,7 +927,7 @@ def create_endpoint(self, conf): "endpoint_auth_method": "client_secret_post", "response_types": ["code", "token", "code id_token", "id_token"], "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"], - "allowed_flows": ['client_credentials', 'resource_owner_password_credentials'] + "grant_types_supported": ['client_credentials', 'password'] } self.session_manager = context.session_manager self.token_endpoint = server.get_endpoint("token") @@ -970,7 +970,7 @@ def create_endpoint(self, conf): "endpoint_auth_method": "client_secret_post", "response_types": ["code", "token", "code id_token", "id_token"], "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"], - "grant_types_supported": ['client_credentials', 'resource_owner_password_credentials'], + "grant_types_supported": ['client_credentials', 'password'], } self.session_manager = context.session_manager self.token_endpoint = server.get_endpoint("token") @@ -979,7 +979,7 @@ def create_endpoint(self, conf): def test_resource_owner_password_credentials(self): request = ROPCAccessTokenRequest(client_id="client_1", client_secret='hemligt', - grant_type='resource_owner_password_credentials', + grant_type='password', username='diana', password='krall', scope="whatever") diff --git a/tests/test_server_35_oidc_token_endpoint.py b/tests/test_server_35_oidc_token_endpoint.py index 78ed8787..c6141261 100755 --- a/tests/test_server_35_oidc_token_endpoint.py +++ b/tests/test_server_35_oidc_token_endpoint.py @@ -102,7 +102,7 @@ def conf(): return { "issuer": "https://example.com/", "httpc_params": {"verify": False, "timeout": 1}, - "capabilities": CAPABILITIES, + "preference": CAPABILITIES, "keys": {"uri_path": "jwks.json", "key_defs": KEYDEFS}, "token_handler_args": { "jwks_file": "private/token_jwks.json", diff --git a/tests/test_server_36_oauth2_token_exchange.py b/tests/test_server_36_oauth2_token_exchange.py index 5c957291..169aebe1 100644 --- a/tests/test_server_36_oauth2_token_exchange.py +++ b/tests/test_server_36_oauth2_token_exchange.py @@ -45,12 +45,6 @@ CAPABILITIES = { "subject_types_supported": ["public", "pairwise", "ephemeral"], - "grant_types_supported": [ - "authorization_code", - "implicit", - "urn:ietf:params:oauth:grant-type:jwt-bearer", - "refresh_token", - ], } AUTH_REQ = AuthorizationRequest( @@ -91,7 +85,7 @@ def create_endpoint(self): conf = { "issuer": "https://example.com/", "httpc_params": {"verify": False, "timeout": 1}, - "capabilities": CAPABILITIES, + "preference": CAPABILITIES, "cookie_handler": { "class": CookieHandler, "kwargs": {"keys": {"key_defs": COOKIE_KEYDEFS}}, @@ -180,6 +174,9 @@ def create_endpoint(self): } server = Server(ASConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) self.context = server.context + # Necessary to get grant_types_supported into preferred + self.context.map_supported_to_preferred() + self.context.cdb["client_1"] = { "client_secret": "hemligt", "redirect_uris": [("https://example.com/cb", None)], @@ -187,7 +184,6 @@ def create_endpoint(self): "token_endpoint_auth_method": "client_secret_post", "grant_types_supported": [ "authorization_code", - "implicit", "urn:ietf:params:oauth:grant-type:jwt-bearer", "refresh_token", "urn:ietf:params:oauth:grant-type:token-exchange" @@ -279,6 +275,7 @@ def test_token_exchange1(self, token): {"headers": {"authorization": "Basic {}".format("Y2xpZW50XzI6aGVtbGlndA==")}}, ) _resp = self.endpoint.process_request(request=_req) + print(_resp['response_args']) assert set(_resp["response_args"].keys()) == { "access_token", "token_type", diff --git a/tests/test_server_60_dpop.py b/tests/test_server_60_dpop.py index c93eb150..69eef704 100644 --- a/tests/test_server_60_dpop.py +++ b/tests/test_server_60_dpop.py @@ -85,11 +85,6 @@ def test_verify_header(): ], "response_modes_supported": ["query", "fragment", "form_post"], "subject_types_supported": ["public", "pairwise", "ephemeral"], - "grant_types_supported": [ - "authorization_code", - "implicit", - "urn:ietf:params:oauth:grant-type:jwt-bearer", - ], "claim_types_supported": ["normal", "aggregated", "distributed"], "claims_parameter_supported": True, "request_parameter_supported": True, diff --git a/tests/test_tandem_08_oauth2_cc_ropc.py b/tests/test_tandem_08_oauth2_cc_ropc.py new file mode 100644 index 00000000..30c2a967 --- /dev/null +++ b/tests/test_tandem_08_oauth2_cc_ropc.py @@ -0,0 +1,145 @@ +import os + +from idpyoidc.client.oauth2 import Client + +from idpyoidc.server import Server +from idpyoidc.server.authz import AuthzHandling +from idpyoidc.server.client_authn import verify_client +from idpyoidc.server.configure import ASConfiguration +from idpyoidc.server.oauth2.token import Token +from idpyoidc.server.user_info import UserInfo + +BASEDIR = os.path.abspath(os.path.dirname(__file__)) + + +def full_path(local_file): + return os.path.join(BASEDIR, local_file) + + +KEYDEFS = [ + {"type": "RSA", "key": "", "use": ["sig"]}, + {"type": "EC", "crv": "P-256", "use": ["sig"]}, +] +CRYPT_CONFIG = { + "kwargs": { + "keys": { + "key_defs": [ + {"type": "OCT", "use": ["enc"], "kid": "password"}, + {"type": "OCT", "use": ["enc"], "kid": "salt"}, + ] + }, + "iterations": 1, + } +} + +SESSION_PARAMS = {"encrypter": CRYPT_CONFIG} + +CONFIG = { + "issuer": "https://example.net/", + "httpc_params": {"verify": False}, + "preference": { + "grant_types_supported": ["client_credentials", "password"] + }, + "keys": {"uri_path": "jwks.json", "key_defs": KEYDEFS, 'read_only': False}, + "token_handler_args": { + "jwks_defs": {"key_defs": KEYDEFS}, + "token": { + "class": "idpyoidc.server.token.jwt_token.JWTToken", + "kwargs": { + "lifetime": 3600, + "add_claims_by_scope": True, + "aud": ["https://example.org/appl"], + } + } + }, + "endpoint": { + "token": { + "path": "token", + "class": Token, + "kwargs": { + "client_authn_method": ["client_secret_basic", "client_secret_post"], + # "grant_types_supported": ['client_credentials', 'password'] + }, + }, + }, + "client_authn": verify_client, + "claims_interface": { + "class": "idpyoidc.server.session.claims.OAuth2ClaimsInterface", + "kwargs": {}, + }, + "authz": { + "class": AuthzHandling, + "kwargs": { + "grant_config": { + "usage_rules": { + "authorization_code": { + "expires_in": 300, + "supports_minting": ["access_token", "refresh_token"], + "max_usage": 1, + }, + "access_token": {"expires_in": 600}, + "refresh_token": { + "expires_in": 86400, + "supports_minting": ["access_token", "refresh_token"], + }, + }, + "expires_in": 43200, + } + }, + }, + "session_params": {"encrypter": SESSION_PARAMS}, + "userinfo": {"class": UserInfo, "kwargs": {"db": {}}}, + "authentication": { + "user": { + "acr": "urn:oasis:names:tc:SAML:2.0:ac:classes:InternetProtocolPassword", + "class": "idpyoidc.server.user_authn.user.UserPass", + "kwargs": { + "db_conf": { + "class": "idpyoidc.server.util.JSONDictDB", + "kwargs": {"filename": full_path("passwd.json")} + } + } + } + } +} + +CLIENT_BASE_URL = "https://example.com" + +CLIENT_CONFIG = { + "client_id": "client_1", + "client_secret": "another password", + "base_url": CLIENT_BASE_URL +} +CLIENT_SERVICES = { + "resource_owner_password_credentials": { + "class": "idpyoidc.client.oauth2.resource_owner_password_credentials.ROPCAccessTokenRequest" + } +} + + +def test_ropc(): + # Client side + + client = Client(config=CLIENT_CONFIG, services=CLIENT_SERVICES) + client.get_service("resource_owner_password_credentials").endpoint = "https://example.com/token" + + service = client.get_service('resource_owner_password_credentials') + client_request_info = service.get_request_parameters( + request_args={'username': 'diana', 'password': 'krall'}) + + # Server side + + server = Server(ASConfiguration(conf=CONFIG, base_path=BASEDIR), cwd=BASEDIR) + server.context.cdb["client_1"] = { + "client_secret": "another password", + "redirect_uris": [("https://example.com/cb", None)], + "client_salt": "salted", + "endpoint_auth_method": "client_secret_post", + "response_types": ["code", "code id_token", "id_token"], + "allowed_scopes": ["resourceA"], + # "grant_types_supported": ['client_credentials', 'password'] + } + + token_endpoint = server.get_endpoint("token") + request = token_endpoint.parse_request(client_request_info['request']) + assert request diff --git a/tests/test_tandem_10_oauth2_token_exchange.py b/tests/test_tandem_10_oauth2_token_exchange.py index 21fe1db0..98b1a172 100644 --- a/tests/test_tandem_10_oauth2_token_exchange.py +++ b/tests/test_tandem_10_oauth2_token_exchange.py @@ -94,12 +94,6 @@ def create_endpoint(self): "issuer": "https://example.com/", "httpc_params": {"verify": False, "timeout": 1}, "subject_types_supported": ["public", "pairwise", "ephemeral"], - "grant_types_supported": [ - "authorization_code", - "implicit", - "urn:ietf:params:oauth:grant-type:jwt-bearer", - "refresh_token", - ], "client_authn_method": [ "client_secret_basic", "client_secret_post", From a678cf0675a0bc8a0c681a4ec0d3e5fb3be3f898 Mon Sep 17 00:00:00 2001 From: Roland Hedberg Date: Fri, 17 Mar 2023 09:28:51 +0100 Subject: [PATCH 69/76] Removed the old client_credentials implementationen. Moved display_values_supported to client.oauth2/oidc.claims to copy what ctriant did. --- .../oauth2/Xclient_credentials/__init__.py | 0 .../Xclient_credentials/cc_access_token.py | 27 ----------- .../cc_refresh_access_token.py | 48 ------------------- src/idpyoidc/server/claims/oauth2.py | 3 +- src/idpyoidc/server/claims/oidc.py | 1 + src/idpyoidc/server/oauth2/authorization.py | 3 +- tests/private/token_jwks.json | 2 +- ...server_24_oauth2_authorization_endpoint.py | 2 +- 8 files changed, 5 insertions(+), 81 deletions(-) delete mode 100644 src/idpyoidc/client/oauth2/Xclient_credentials/__init__.py delete mode 100644 src/idpyoidc/client/oauth2/Xclient_credentials/cc_access_token.py delete mode 100644 src/idpyoidc/client/oauth2/Xclient_credentials/cc_refresh_access_token.py diff --git a/src/idpyoidc/client/oauth2/Xclient_credentials/__init__.py b/src/idpyoidc/client/oauth2/Xclient_credentials/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/idpyoidc/client/oauth2/Xclient_credentials/cc_access_token.py b/src/idpyoidc/client/oauth2/Xclient_credentials/cc_access_token.py deleted file mode 100644 index af65573a..00000000 --- a/src/idpyoidc/client/oauth2/Xclient_credentials/cc_access_token.py +++ /dev/null @@ -1,27 +0,0 @@ -from typing import Optional - -from idpyoidc.client.service import Service -from idpyoidc.message import oauth2 -from idpyoidc.message.oauth2 import ResponseMessage -from idpyoidc.time_util import time_sans_frac - - -class CCAccessToken(Service): - msg_type = oauth2.CCAccessTokenRequest - response_cls = oauth2.AccessTokenResponse - error_msg = ResponseMessage - endpoint_name = "token_endpoint" - synchronous = True - service_name = "accesstoken" - default_authn_method = "client_secret_basic" - http_method = "POST" - request_body_type = "urlencoded" - response_body_type = "json" - - def __init__(self, upstream_get, conf=None): - Service.__init__(self, upstream_get, conf=conf) - - def update_service_context(self, resp, key: Optional[str] = "cc", **kwargs): - if "expires_in" in resp: - resp["__expires_at"] = time_sans_frac() + int(resp["expires_in"]) - self.upstream_get("context").cstate.update(key, resp) diff --git a/src/idpyoidc/client/oauth2/Xclient_credentials/cc_refresh_access_token.py b/src/idpyoidc/client/oauth2/Xclient_credentials/cc_refresh_access_token.py deleted file mode 100644 index 6ab144fc..00000000 --- a/src/idpyoidc/client/oauth2/Xclient_credentials/cc_refresh_access_token.py +++ /dev/null @@ -1,48 +0,0 @@ -from idpyoidc.client.service import Service -from idpyoidc.message import oauth2 -from idpyoidc.message.oauth2 import ResponseMessage -from idpyoidc.time_util import time_sans_frac - - -class CCRefreshAccessToken(Service): - msg_type = oauth2.RefreshAccessTokenRequest - response_cls = oauth2.AccessTokenResponse - error_msg = ResponseMessage - endpoint_name = "token_endpoint" - synchronous = True - service_name = "refresh_token" - default_authn_method = "bearer_header" - http_method = "POST" - - def __init__(self, upstream_get, conf=None): - Service.__init__(self, upstream_get, conf=conf) - self.pre_construct.append(self.cc_pre_construct) - self.post_construct.append(self.cc_post_construct) - - def cc_pre_construct(self, request_args=None, **kwargs): - _state_id = kwargs.get("state", "cc") - parameters = ["refresh_token"] - _current = self.upstream_get("context").cstate - _args = _current.get_set(_state_id, claim=parameters) - - if request_args is None: - request_args = _args - else: - _args.update(request_args) - request_args = _args - - return request_args, {} - - def cc_post_construct(self, request_args, **kwargs): - for attr in ["client_id", "client_secret"]: - try: - del request_args[attr] - except KeyError: - pass - - return request_args - - def update_service_context(self, resp, key="cc", **kwargs): - if "expires_in" in resp: - resp["__expires_at"] = time_sans_frac() + int(resp["expires_in"]) - self.upstream_get("context").cstate.update(key, resp) diff --git a/src/idpyoidc/server/claims/oauth2.py b/src/idpyoidc/server/claims/oauth2.py index e0a1f418..6b322baa 100644 --- a/src/idpyoidc/server/claims/oauth2.py +++ b/src/idpyoidc/server/claims/oauth2.py @@ -18,8 +18,7 @@ class Claims(claims.Claims): register2preferred = REGISTER2PREFERRED _supports = { - # "client_authn_methods": get_client_authn_methods, - # "grant_types_supported": ["authorization_code", "implicit", "refresh_token"], + "deny_unknown_scopes": False, "response_types_supported": ["code"], "response_modes_supported": ["code"], "jwks_uri": None, diff --git a/src/idpyoidc/server/claims/oidc.py b/src/idpyoidc/server/claims/oidc.py index 5f48dcd6..70e68768 100644 --- a/src/idpyoidc/server/claims/oidc.py +++ b/src/idpyoidc/server/claims/oidc.py @@ -42,6 +42,7 @@ class Claims(server_claims.Claims): # "client_authn_methods": get_client_authn_methods, "contacts": None, "default_max_age": 86400, + "deny_unknown_scopes": False, "display_values_supported": None, "encrypt_id_token_supported": None, # "grant_types_supported": ["authorization_code", "implicit", "refresh_token"], diff --git a/src/idpyoidc/server/oauth2/authorization.py b/src/idpyoidc/server/oauth2/authorization.py index 66e6989e..a9973770 100755 --- a/src/idpyoidc/server/oauth2/authorization.py +++ b/src/idpyoidc/server/oauth2/authorization.py @@ -268,7 +268,7 @@ def authn_args_gather( def check_unknown_scopes_policy(request_info, client_id, context): - if not context.claims.get_preference("deny_unknown_scopes"): + if not context.get_preference("deny_unknown_scopes"): return scope = request_info["scope"] @@ -349,7 +349,6 @@ class Authorization(Endpoint): "request_object_encryption_enc_values_supported": claims.get_encryption_encs, # "grant_types_supported": ["authorization_code", "implicit"], "scopes_supported": [], - "deny_unknown_scopes": False } default_capabilities = { "client_authn_method": ["request_param", "public"], diff --git a/tests/private/token_jwks.json b/tests/private/token_jwks.json index 9feee908..38cdb616 100644 --- a/tests/private/token_jwks.json +++ b/tests/private/token_jwks.json @@ -1 +1 @@ -{"keys": [{"kty": "oct", "use": "enc", "kid": "code", "k": "vSHDkLBHhDStkR0NWu8519rmV5zmnm5_"}, {"kty": "oct", "use": "enc", "kid": "refresh", "k": "31Lm9Fi6Lt3NBm2djOvfV3k-j7kFVORm"}]} \ No newline at end of file +{"keys": [{"kty": "oct", "use": "enc", "kid": "code", "k": "vSHDkLBHhDStkR0NWu8519rmV5zmnm5_"}, {"kty": "oct", "use": "enc", "kid": "refresh", "k": "nSZ0kdDYyJn4d0Oy67Z1okgykXRhCcKk"}]} \ No newline at end of file diff --git a/tests/test_server_24_oauth2_authorization_endpoint.py b/tests/test_server_24_oauth2_authorization_endpoint.py index 893e2194..e5f0a74d 100755 --- a/tests/test_server_24_oauth2_authorization_endpoint.py +++ b/tests/test_server_24_oauth2_authorization_endpoint.py @@ -579,7 +579,7 @@ def test_setup_auth_invalid_scope(self): kaka = _context.cookie_handler.make_cookie_content("value", "sso") # force to 400 Http Error message if the release scope policy is heavy! - _context.claims.set_preference("deny_unknown_scopes", True) + _context.set_preference("deny_unknown_scopes", True) excp = None try: res = self.endpoint.process_request(request, http_info={"headers": {"cookie": [kaka]}}) From b9fb7bae5318a85a79722528092090fb1cb69e22 Mon Sep 17 00:00:00 2001 From: Kostis Triantafyllakis Date: Mon, 20 Mar 2023 16:47:00 +0200 Subject: [PATCH 70/76] Fix scopes_handler after fedservice refactor Signed-off-by: Kostis Triantafyllakis --- src/idpyoidc/server/authz/__init__.py | 4 +--- src/idpyoidc/server/claims/oauth2.py | 1 + src/idpyoidc/server/claims/oidc.py | 1 + src/idpyoidc/server/configure.py | 4 ++++ 4 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/idpyoidc/server/authz/__init__.py b/src/idpyoidc/server/authz/__init__.py index f90094f6..8fdcb268 100755 --- a/src/idpyoidc/server/authz/__init__.py +++ b/src/idpyoidc/server/authz/__init__.py @@ -88,9 +88,7 @@ def __call__( if not scopes: scopes = request.get("scope", []) else: - _allowed = _context.cdb[_client_id].get('allowed_scopes', []) - if _allowed: - scopes = list(set(scopes).intersection(set(_allowed))) + scopes = _context.scopes_handler.filter_scopes(scopes, client_id=_client_id) grant.scope = scopes # After this is where user consent should be handled diff --git a/src/idpyoidc/server/claims/oauth2.py b/src/idpyoidc/server/claims/oauth2.py index 6b322baa..f0137543 100644 --- a/src/idpyoidc/server/claims/oauth2.py +++ b/src/idpyoidc/server/claims/oauth2.py @@ -19,6 +19,7 @@ class Claims(claims.Claims): _supports = { "deny_unknown_scopes": False, + "scopes_handler": None, "response_types_supported": ["code"], "response_modes_supported": ["code"], "jwks_uri": None, diff --git a/src/idpyoidc/server/claims/oidc.py b/src/idpyoidc/server/claims/oidc.py index 70e68768..f2b57506 100644 --- a/src/idpyoidc/server/claims/oidc.py +++ b/src/idpyoidc/server/claims/oidc.py @@ -43,6 +43,7 @@ class Claims(server_claims.Claims): "contacts": None, "default_max_age": 86400, "deny_unknown_scopes": False, + "scopes_handler": None, "display_values_supported": None, "encrypt_id_token_supported": None, # "grant_types_supported": ["authorization_code", "implicit", "refresh_token"], diff --git a/src/idpyoidc/server/configure.py b/src/idpyoidc/server/configure.py index 8dd8f215..3ba7449d 100755 --- a/src/idpyoidc/server/configure.py +++ b/src/idpyoidc/server/configure.py @@ -156,6 +156,7 @@ class EntityConfiguration(Base): "template_dir": None, "token_handler_args": {}, "userinfo": None, + "scopes_handler": None } def __init__( @@ -348,6 +349,9 @@ def __init__( "refresh_token", ], }, + "scopes_handler": { + "class": "idpyoidc.server.scopes.Scopes" + }, "claims_interface": {"class": "idpyoidc.server.session.claims.ClaimsInterface", "kwargs": {}}, "cookie_handler": { "class": "idpyoidc.server.cookie_handler.CookieHandler", From 26ae4fe534d563e06b0251b85b029b7be561d541 Mon Sep 17 00:00:00 2001 From: Kostis Triantafyllakis Date: Sun, 2 Apr 2023 22:51:26 +0300 Subject: [PATCH 71/76] Fix refresh grant on access token helper after fedservice Signed-off-by: Kostis Triantafyllakis --- src/idpyoidc/server/oidc/token_helper/access_token.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/idpyoidc/server/oidc/token_helper/access_token.py b/src/idpyoidc/server/oidc/token_helper/access_token.py index b83d3dc4..bad2873b 100755 --- a/src/idpyoidc/server/oidc/token_helper/access_token.py +++ b/src/idpyoidc/server/oidc/token_helper/access_token.py @@ -119,7 +119,6 @@ def process_request(self, req: Union[Message, dict], **kwargs): if ( issue_refresh and "refresh_token" in _supports_minting - and "refresh_token" in grant_types_supported ): try: refresh_token = self._mint_token( From b809758068bba4f91703406f7ace782cf93275f4 Mon Sep 17 00:00:00 2001 From: Kostis Triantafyllakis Date: Mon, 3 Apr 2023 10:40:34 +0300 Subject: [PATCH 72/76] Fix typo on introspection endpoint Signed-off-by: Kostis Triantafyllakis --- tests/test_server_31_oauth2_introspection.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_server_31_oauth2_introspection.py b/tests/test_server_31_oauth2_introspection.py index 3d707917..ab5e6985 100644 --- a/tests/test_server_31_oauth2_introspection.py +++ b/tests/test_server_31_oauth2_introspection.py @@ -498,8 +498,7 @@ def test_wrong_aud(self): auth_req = AUTH_REQ.copy() auth_req["client_id"] = "client_2" access_token = self._get_access_token(auth_req) - - _context = self.introspection_endpoint.server_get("endpoint_context") + _context = self.introspection_endpoint.upstream_get("endpoint_context") _req = self.introspection_endpoint.parse_request( { From bd9c9b96710386b52fe55e932ebe87f3329279f4 Mon Sep 17 00:00:00 2001 From: Kostis Triantafyllakis Date: Thu, 9 Feb 2023 12:06:20 +0200 Subject: [PATCH 73/76] Support code_challenge_methods_supported Signed-off-by: Kostis Triantafyllakis --- src/idpyoidc/message/oauth2/__init__.py | 1 + src/idpyoidc/message/oidc/__init__.py | 3 ++- src/idpyoidc/server/oauth2/authorization.py | 3 ++- src/idpyoidc/server/oidc/add_on/pkce.py | 14 ++++++++++-- src/idpyoidc/server/oidc/authorization.py | 25 ++++++++++++--------- tests/test_08_transform.py | 3 ++- tests/test_server_33_oauth2_pkce.py | 18 --------------- 7 files changed, 33 insertions(+), 34 deletions(-) mode change 100755 => 100644 src/idpyoidc/server/oidc/authorization.py diff --git a/src/idpyoidc/message/oauth2/__init__.py b/src/idpyoidc/message/oauth2/__init__.py index 723526cf..e0841847 100644 --- a/src/idpyoidc/message/oauth2/__init__.py +++ b/src/idpyoidc/message/oauth2/__init__.py @@ -352,6 +352,7 @@ class ASConfigurationResponse(Message): "ui_locales_supported": OPTIONAL_LIST_OF_STRINGS, "op_policy_uri": SINGLE_OPTIONAL_STRING, "op_tos_uri": SINGLE_OPTIONAL_STRING, + "code_challenge_methods_supported": OPTIONAL_LIST_OF_STRINGS, "revocation_endpoint": SINGLE_OPTIONAL_STRING, "introspection_endpoint": SINGLE_OPTIONAL_STRING, } diff --git a/src/idpyoidc/message/oidc/__init__.py b/src/idpyoidc/message/oidc/__init__.py index 797eaa63..98af8d66 100644 --- a/src/idpyoidc/message/oidc/__init__.py +++ b/src/idpyoidc/message/oidc/__init__.py @@ -898,7 +898,8 @@ class ProviderConfigurationResponse(ResponseMessage): "frontchannel_logout_supported": SINGLE_OPTIONAL_BOOLEAN, "frontchannel_logout_session_required": SINGLE_OPTIONAL_BOOLEAN, "backchannel_logout_supported": SINGLE_OPTIONAL_BOOLEAN, - "backchannel_logout_session_required": SINGLE_OPTIONAL_BOOLEAN + "backchannel_logout_session_required": SINGLE_OPTIONAL_BOOLEAN, + "code_challenge_methods_supported": OPTIONAL_LIST_OF_STRINGS, # "jwk_encryption_url": SINGLE_OPTIONAL_STRING, # "x509_url": SINGLE_REQUIRED_STRING, # "x509_encryption_url": SINGLE_OPTIONAL_STRING, diff --git a/src/idpyoidc/server/oauth2/authorization.py b/src/idpyoidc/server/oauth2/authorization.py index a9973770..6850ca60 100755 --- a/src/idpyoidc/server/oauth2/authorization.py +++ b/src/idpyoidc/server/oauth2/authorization.py @@ -348,6 +348,7 @@ class Authorization(Endpoint): "request_object_encryption_alg_values_supported": claims.get_encryption_algs, "request_object_encryption_enc_values_supported": claims.get_encryption_encs, # "grant_types_supported": ["authorization_code", "implicit"], + "code_challenge_methods_supported": ["S256"], "scopes_supported": [], } default_capabilities = { @@ -1087,7 +1088,7 @@ def process_request( :return: dictionary """ - if isinstance(request, self.error_cls): + if "error" in request: return request _cid = request["client_id"] diff --git a/src/idpyoidc/server/oidc/add_on/pkce.py b/src/idpyoidc/server/oidc/add_on/pkce.py index 01952ecb..296cc4fa 100644 --- a/src/idpyoidc/server/oidc/add_on/pkce.py +++ b/src/idpyoidc/server/oidc/add_on/pkce.py @@ -51,7 +51,7 @@ def post_authn_parse(request, client_id, context, **kwargs): ) if "code_challenge_method" not in request: - request["code_challenge_method"] = "plain" + request["code_challenge_method"] = "S256" if "code_challenge" in request and ( request["code_challenge_method"] @@ -140,7 +140,17 @@ def add_pkce_support(endpoint: Dict[str, Endpoint], **kwargs): token_endpoint.post_parse_request.append(post_token_parse) code_challenge_methods = kwargs.get("code_challenge_methods", CC_METHOD.keys()) - + code_challenge_methods = list( + set(code_challenge_methods).intersection( + authn_endpoint._supports["code_challenge_methods_supported"] + ) + ) + if not code_challenge_methods: + raise ValueError( + "Unsupported method: {}".format( + ", ".join(kwargs.get("code_challenge_methods", CC_METHOD.keys())) + ) + ) kwargs["code_challenge_methods"] = {} for method in code_challenge_methods: if method not in CC_METHOD: diff --git a/src/idpyoidc/server/oidc/authorization.py b/src/idpyoidc/server/oidc/authorization.py old mode 100755 new mode 100644 index eb6d06b7..ac14a754 --- a/src/idpyoidc/server/oidc/authorization.py +++ b/src/idpyoidc/server/oidc/authorization.py @@ -77,18 +77,21 @@ class Authorization(authorization.Authorization): name = "authorization" _supports = { - "claims_parameter_supported": True, - "encrypt_request_object_supported": False, - "request_object_signing_alg_values_supported": claims.get_signing_algs, - "request_object_encryption_alg_values_supported": claims.get_encryption_algs, - "request_object_encryption_enc_values_supported": claims.get_encryption_encs, - "request_parameter_supported": True, - "request_uri_parameter_supported": True, - "require_request_uri_registration": False, - "response_types_supported": ["code", "token", "code token", 'id_token', 'id_token token', + **authorization.Authorization._supports, + **{ + "claims_parameter_supported": True, + "encrypt_request_object_supported": False, + "request_object_signing_alg_values_supported": claims.get_signing_algs, + "request_object_encryption_alg_values_supported": claims.get_encryption_algs, + "request_object_encryption_enc_values_supported": claims.get_encryption_encs, + "request_parameter_supported": True, + "request_uri_parameter_supported": True, + "require_request_uri_registration": False, + "response_types_supported": ["code", "token", "code token", 'id_token', 'id_token token', 'code id_token', 'code id_token token'], - "response_modes_supported": ['query', 'fragment', 'form_post'], - "subject_types_supported": ["public", "pairwise", "ephemeral"], + "response_modes_supported": ['query', 'fragment', 'form_post'], + "subject_types_supported": ["public", "pairwise", "ephemeral"], + }, } def __init__(self, upstream_get: Callable, **kwargs): diff --git a/tests/test_08_transform.py b/tests/test_08_transform.py index ac4d9aff..52020451 100644 --- a/tests/test_08_transform.py +++ b/tests/test_08_transform.py @@ -117,7 +117,8 @@ def test_oidc_setup(self): 'service_documentation', 'token_endpoint', 'ui_locales_supported', - 'userinfo_endpoint'} + 'userinfo_endpoint', + 'code_challenge_methods_supported'} # parameters that are not mapped against what the OP's provider info says assert set(self.supported).difference( diff --git a/tests/test_server_33_oauth2_pkce.py b/tests/test_server_33_oauth2_pkce.py index 137668c5..fbb40d9d 100644 --- a/tests/test_server_33_oauth2_pkce.py +++ b/tests/test_server_33_oauth2_pkce.py @@ -367,24 +367,6 @@ def test_unknown_code_challenge_method(self): _authn_req["code_challenge_method"] ) - def test_unsupported_code_challenge_method(self, conf): - conf["add_on"]["pkce"]["kwargs"]["code_challenge_methods"] = ["plain"] - server = create_server(conf) - authn_endpoint = server.get_endpoint("authorization") - - _cc_info = _code_challenge() - _authn_req = AUTH_REQ.copy() - _authn_req["code_challenge"] = _cc_info["code_challenge"] - _authn_req["code_challenge_method"] = _cc_info["code_challenge_method"] - - _pr_resp = authn_endpoint.parse_request(_authn_req.to_dict()) - - assert isinstance(_pr_resp, AuthorizationErrorResponse) - assert _pr_resp["error"] == "invalid_request" - assert _pr_resp["error_description"] == "Unsupported code_challenge_method={}".format( - _authn_req["code_challenge_method"] - ) - def test_wrong_code_verifier(self): _cc_info = _code_challenge() _authn_req = AUTH_REQ.copy() From 48ecfc9a58355a2a60d32d2f0740baa79eec677d Mon Sep 17 00:00:00 2001 From: roland Date: Fri, 21 Apr 2023 13:17:02 +0200 Subject: [PATCH 74/76] Replace the name callable with function. --- src/idpyoidc/server/oauth2/authorization.py | 14 +++++++------- .../server/oauth2/token_helper/access_token.py | 14 +++++++------- .../oauth2/token_helper/token_exchange.py | 18 +++++++++--------- src/idpyoidc/server/oauth2/token_revocation.py | 14 +++++++------- tests/private/token_jwks.json | 2 +- tests/pub_client.jwks | 2 +- tests/pub_iss.jwks | 2 +- tests/static/jwks.json | 2 +- tests/test_server_00a_client_configure.py | 6 +++--- ...est_server_24_oauth2_resource_indicators.py | 6 +++--- tests/test_server_36_oauth2_token_exchange.py | 18 +++++++++--------- ...est_server_38_oauth2_revocation_endpoint.py | 8 ++++---- tests/test_tandem_10_oauth2_token_exchange.py | 2 +- 13 files changed, 54 insertions(+), 54 deletions(-) diff --git a/src/idpyoidc/server/oauth2/authorization.py b/src/idpyoidc/server/oauth2/authorization.py index 6850ca60..f6f60f99 100755 --- a/src/idpyoidc/server/oauth2/authorization.py +++ b/src/idpyoidc/server/oauth2/authorization.py @@ -526,7 +526,7 @@ def _post_parse_request(self, request, client_id, context, **kwargs): if resource_indicators_config is not None: if "policy" not in resource_indicators_config: - policy = {"policy": {"callable": validate_resource_indicators_policy}} + policy = {"policy": {"function": validate_resource_indicators_policy}} resource_indicators_config.update(policy) request = self._enforce_resource_indicators_policy(request, resource_indicators_config) @@ -536,7 +536,7 @@ def _enforce_resource_indicators_policy(self, request, config): _context = self.upstream_get("context") policy = config["policy"] - callable = policy["callable"] + function = policy["function"] kwargs = policy.get("kwargs", {}) if kwargs.get("resource_servers_per_client", None) is None: @@ -544,17 +544,17 @@ def _enforce_resource_indicators_policy(self, request, config): request["client_id"]: request["client_id"] } - if isinstance(callable, str): + if isinstance(function, str): try: - fn = importer(callable) + fn = importer(function) except Exception: - raise ImproperlyConfigured(f"Error importing {callable} policy callable") + raise ImproperlyConfigured(f"Error importing {function} policy function") else: - fn = callable + fn = function try: return fn(request, context=_context, **kwargs) except Exception as e: - logger.error(f"Error while executing the {fn} policy callable: {e}") + logger.error(f"Error while executing the {fn} policy function: {e}") return self.error_cls(error="server_error", error_description="Internal server error") def pick_authn_method(self, request, redirect_uri, acr=None, **kwargs): diff --git a/src/idpyoidc/server/oauth2/token_helper/access_token.py b/src/idpyoidc/server/oauth2/token_helper/access_token.py index b7b917fe..96e64c1c 100755 --- a/src/idpyoidc/server/oauth2/token_helper/access_token.py +++ b/src/idpyoidc/server/oauth2/token_helper/access_token.py @@ -58,7 +58,7 @@ def process_request(self, req: Union[Message, dict], **kwargs): if resource_indicators_config is not None: if "policy" not in resource_indicators_config: - policy = {"policy": {"callable": validate_resource_indicators_policy}} + policy = {"policy": {"function": validate_resource_indicators_policy}} resource_indicators_config.update(policy) req = self._enforce_resource_indicators_policy(req, resource_indicators_config) @@ -152,20 +152,20 @@ def _enforce_resource_indicators_policy(self, request, config): _context = self.endpoint.upstream_get('context') policy = config["policy"] - callable = policy["callable"] + function = policy["function"] kwargs = policy.get("kwargs", {}) - if isinstance(callable, str): + if isinstance(function, str): try: - fn = importer(callable) + fn = importer(function) except Exception: - raise ImproperlyConfigured(f"Error importing {callable} policy callable") + raise ImproperlyConfigured(f"Error importing {function} policy function") else: - fn = callable + fn = function try: return fn(request, context=_context, **kwargs) except Exception as e: - logger.error(f"Error while executing the {fn} policy callable: {e}") + logger.error(f"Error while executing the {fn} policy function: {e}") return self.error_cls(error="server_error", error_description="Internal server error") def post_parse_request( diff --git a/src/idpyoidc/server/oauth2/token_helper/token_exchange.py b/src/idpyoidc/server/oauth2/token_helper/token_exchange.py index 119407e9..0b5a0524 100755 --- a/src/idpyoidc/server/oauth2/token_helper/token_exchange.py +++ b/src/idpyoidc/server/oauth2/token_helper/token_exchange.py @@ -41,7 +41,7 @@ def __init__(self, endpoint, config=None): "urn:ietf:params:oauth:token-type:refresh_token", ], "default_requested_token_type": "urn:ietf:params:oauth:token-type:access_token", - "policy": {"": {"callable": validate_token_exchange_policy}}, + "policy": {"": {"function": validate_token_exchange_policy}}, } else: self.config = config @@ -154,21 +154,21 @@ def _enforce_policy(self, request, token, config): subject_token_type = "" policy = config["policy"][subject_token_type] - callable = policy["callable"] + function = policy["function"] kwargs = policy.get("kwargs", {}) - if isinstance(callable, str): + if isinstance(function, str): try: - fn = importer(callable) + fn = importer(function) except Exception: - raise ImproperlyConfigured(f"Error importing {callable} policy callable") + raise ImproperlyConfigured(f"Error importing {function} policy function") else: - fn = callable + fn = function try: return fn(request, context=_context, subject_token=token, **kwargs) except Exception as e: - logger.error(f"Error while executing the {fn} policy callable: {e}") + logger.error(f"Error while executing the {fn} policy function: {e}") return self.error_cls(error="server_error", error_description="Internal server error") def token_exchange_response(self, token, issued_token_type): @@ -285,9 +285,9 @@ def _validate_configuration(self, config): raise ImproperlyConfigured( "Default Token Exchange policy configuration is not defined" ) - if "callable" not in config["policy"][""]: + if "function" not in config["policy"][""]: raise ImproperlyConfigured( - "Missing 'callable' from default Token Exchange policy configuration" + "Missing 'function' from default Token Exchange policy configuration" ) _default_requested_token_type = config.get("default_requested_token_type", diff --git a/src/idpyoidc/server/oauth2/token_revocation.py b/src/idpyoidc/server/oauth2/token_revocation.py index 8ac45f49..7db5e184 100644 --- a/src/idpyoidc/server/oauth2/token_revocation.py +++ b/src/idpyoidc/server/oauth2/token_revocation.py @@ -86,7 +86,7 @@ def process_request(self, request=None, **kwargs): self.policy = _context.cdb[client_id]["token_revocation"]["policy"] except Exception: self.policy = self.token_revocation_kwargs.get("policy", { - "": {"callable": validate_token_revocation_policy}}) + "": {"function": validate_token_revocation_policy}}) if _token.token_class not in self.token_types_supported: desc = ( @@ -108,21 +108,21 @@ def _revoke(self, request, session_info): _cls = "" temp_policy = self.policy[_cls] - callable = temp_policy["callable"] + function = temp_policy["function"] kwargs = temp_policy.get("kwargs", {}) - if isinstance(callable, str): + if isinstance(function, str): try: - fn = importer(callable) + fn = importer(function) except Exception: - raise ImproperlyConfigured(f"Error importing {callable} policy callable") + raise ImproperlyConfigured(f"Error importing {function} policy function") else: - fn = callable + fn = function try: return fn(_token, session_info=session_info, **kwargs) except Exception as e: - logger.error(f"Error while executing the {fn} policy callable: {e}") + logger.error(f"Error while executing the {fn} policy function: {e}") return self.error_cls(error="server_error", error_description="Internal server error") diff --git a/tests/private/token_jwks.json b/tests/private/token_jwks.json index 38cdb616..9e18a977 100644 --- a/tests/private/token_jwks.json +++ b/tests/private/token_jwks.json @@ -1 +1 @@ -{"keys": [{"kty": "oct", "use": "enc", "kid": "code", "k": "vSHDkLBHhDStkR0NWu8519rmV5zmnm5_"}, {"kty": "oct", "use": "enc", "kid": "refresh", "k": "nSZ0kdDYyJn4d0Oy67Z1okgykXRhCcKk"}]} \ No newline at end of file +{"keys": [{"kty": "oct", "use": "enc", "kid": "code", "k": "vSHDkLBHhDStkR0NWu8519rmV5zmnm5_"}, {"kty": "oct", "use": "enc", "kid": "refresh", "k": "XeeoaV1P5eINXBFEDU2U_YBXqsjJE0uD"}]} \ No newline at end of file diff --git a/tests/pub_client.jwks b/tests/pub_client.jwks index 84a27042..d5ce25ed 100644 --- a/tests/pub_client.jwks +++ b/tests/pub_client.jwks @@ -1 +1 @@ -{"keys": [{"kty": "EC", "use": "sig", "kid": "azZQQ2FEQjh3QnVZWVdrbHJkMEZSaWR6aVJ0LTBjeUFfeWRlbTRrRFZ5VQ", "crv": "P-256", "x": "2ADe18caWWGp6hpRbfa9HqQHDFNpid9xUmR56Wzm_wc", "y": "HnD_8QBanz4Y-UF8mKQFZXfqkGkXUSm34mLsdDKtSyk"}, {"kty": "RSA", "use": "sig", "kid": "SHEyYWcwNVk0LTdROTZzZ2FUWndIVXdack0xWUM5SEpwcS03dVUxWU4zRQ", "e": "AQAB", "n": "rRz52ddyP9Y2ezSlRsnkt-sjXfV_Ii7vOFX-cStLE3IUlVeSJGEe_kAASLr2r3BE2unjntaxj67NP8D95h_rzG1SpCklTEn-aTe3FOwNyTzUH_oiDVeRoEcf04Y43ciRGYRB5PhI6ii-2lYuig6hyUr776Qxiu6-0zw-M_ay2MgGSy5CEj55dDSvcUyxStUObxGpPWnEvybO1vnE7iJEWGNe0L5uPe5nLidOiR-JwjxSWEx1xZYtIjxaf2Ulu-qu4hwgwBUQdx4bNZyBfljKj55skWuHqPMG3xMjnedQC6Ms5bR3rIkbBpvmgI3kJK-4CZikM6ruyLo94-Lk19aYQw"}]} \ No newline at end of file +{"keys": [{"kty": "EC", "use": "sig", "kid": "azZQQ2FEQjh3QnVZWVdrbHJkMEZSaWR6aVJ0LTBjeUFfeWRlbTRrRFZ5VQ", "crv": "P-256", "x": "2ADe18caWWGp6hpRbfa9HqQHDFNpid9xUmR56Wzm_wc", "y": "HnD_8QBanz4Y-UF8mKQFZXfqkGkXUSm34mLsdDKtSyk"}, {"kty": "RSA", "use": "sig", "kid": "SHEyYWcwNVk0LTdROTZzZ2FUWndIVXdack0xWUM5SEpwcS03dVUxWU4zRQ", "n": "rRz52ddyP9Y2ezSlRsnkt-sjXfV_Ii7vOFX-cStLE3IUlVeSJGEe_kAASLr2r3BE2unjntaxj67NP8D95h_rzG1SpCklTEn-aTe3FOwNyTzUH_oiDVeRoEcf04Y43ciRGYRB5PhI6ii-2lYuig6hyUr776Qxiu6-0zw-M_ay2MgGSy5CEj55dDSvcUyxStUObxGpPWnEvybO1vnE7iJEWGNe0L5uPe5nLidOiR-JwjxSWEx1xZYtIjxaf2Ulu-qu4hwgwBUQdx4bNZyBfljKj55skWuHqPMG3xMjnedQC6Ms5bR3rIkbBpvmgI3kJK-4CZikM6ruyLo94-Lk19aYQw", "e": "AQAB"}]} \ No newline at end of file diff --git a/tests/pub_iss.jwks b/tests/pub_iss.jwks index 9b062907..77081f40 100644 --- a/tests/pub_iss.jwks +++ b/tests/pub_iss.jwks @@ -1 +1 @@ -{"keys": [{"kty": "EC", "use": "sig", "kid": "SmdKMlVGcG1zMnprdDdXZGpGWEczdHhlZVpGbkx1THpPdUY4d0w4bnZkSQ", "crv": "P-256", "x": "tRHJYm0fsOi0icpGEb33qiDVgt68ltMoYSWdLGhDGz4", "y": "fRpX0i6p5Jigf5I0qwW34PyStosMShwWAWS8x_w5o7E"}, {"kty": "RSA", "use": "sig", "kid": "R0FsaFdqREFaUFp1c0MwbUpsbHVSZ200blBJZWJVMTUtNGsyVlBmdHk5UQ", "e": "AQAB", "n": "2ilgsKVqF92KfhwmosSVeZOaDgb3RF1mbg-pqkmLO6YpOO06LF4V4angF-GhP-ysAm2E75aSIU4tnHVThFlcxTgKFqjYKJQXyVzTVK2r-L2IbvFPaDtvoU6WteybpMlIUVk2po3cFDGObCWYKCm7CUOLlwH0uOpui66P9VSCqdKVKbJRAQBvTSbP10KWPxulfqjWGJtHO5fY7-JVWwOBkG-eHSJIT_uaoPjyvKCZjknq04bLUV9qP78KRQpRyYijBN60w2v8F79baN9CN10TIEjjWKGz0uX0M_YYQzTUoSY5l5ka9RkL3wT4o2iQ1t5nHphX6aA-gqwgCQmi-nvjaw"}]} \ No newline at end of file +{"keys": [{"kty": "EC", "use": "sig", "kid": "SmdKMlVGcG1zMnprdDdXZGpGWEczdHhlZVpGbkx1THpPdUY4d0w4bnZkSQ", "crv": "P-256", "x": "tRHJYm0fsOi0icpGEb33qiDVgt68ltMoYSWdLGhDGz4", "y": "fRpX0i6p5Jigf5I0qwW34PyStosMShwWAWS8x_w5o7E"}, {"kty": "RSA", "use": "sig", "kid": "R0FsaFdqREFaUFp1c0MwbUpsbHVSZ200blBJZWJVMTUtNGsyVlBmdHk5UQ", "n": "2ilgsKVqF92KfhwmosSVeZOaDgb3RF1mbg-pqkmLO6YpOO06LF4V4angF-GhP-ysAm2E75aSIU4tnHVThFlcxTgKFqjYKJQXyVzTVK2r-L2IbvFPaDtvoU6WteybpMlIUVk2po3cFDGObCWYKCm7CUOLlwH0uOpui66P9VSCqdKVKbJRAQBvTSbP10KWPxulfqjWGJtHO5fY7-JVWwOBkG-eHSJIT_uaoPjyvKCZjknq04bLUV9qP78KRQpRyYijBN60w2v8F79baN9CN10TIEjjWKGz0uX0M_YYQzTUoSY5l5ka9RkL3wT4o2iQ1t5nHphX6aA-gqwgCQmi-nvjaw", "e": "AQAB"}]} \ No newline at end of file diff --git a/tests/static/jwks.json b/tests/static/jwks.json index 161a407b..8322d976 100644 --- a/tests/static/jwks.json +++ b/tests/static/jwks.json @@ -1 +1 @@ -{"keys": [{"kty": "RSA", "use": "sig", "kid": "YnNESFhyQjloMnYzV2VqRGR2a3VCblFLX2h4VGl3TDVlY3FUNkViUE90bw", "e": "AQAB", "n": "2iMaDALTQolz4UaT--GhjriLMyNbrDGlIXxSmgRh17Cm3cuHiyPOIQv1pjZVg4ATU1aafxmFyTfrmtf56tPuJ8yqcNNZC8XadYPAw7PTW9g8GJgLtC8GURJ9GQZD6FYIE6YCou8fYo6yd4b99y2y_vsl06cm9xQnstfp6eyMkcgQyrmdmlbyeuXwvcxsxtGX61MTJtCp4VELmDctJiYP_bD7HNRPV7uqXDMNmWSY0TYL-tg0As4y8-w3wSwmtcfWhnQEraFT0-m4hBpEWHlouuFNXRQIrXbamKxeh6kJNO0wJN8fZ4Ovygf8sE4kEwBPfWO59wxDF7camTpDUqg29Q"}, {"kty": "EC", "use": "sig", "kid": "aWhtalRSTDZmNmRTd1ZDNWZmY3ZGMTNqM1dnLVA2RjQyMi1CNGdOSUNKVQ", "crv": "P-256", "x": "Ww5XVT3CxYN88BpJDZGodRiar0qr8UvPFaRoqzyD1Io", "y": "w23EDFAvwe03NjL5NKtUXwxuVMFmEn3ecJOPbljiDkg"}]} \ No newline at end of file +{"keys": [{"kty": "RSA", "use": "sig", "kid": "YnNESFhyQjloMnYzV2VqRGR2a3VCblFLX2h4VGl3TDVlY3FUNkViUE90bw", "n": "2iMaDALTQolz4UaT--GhjriLMyNbrDGlIXxSmgRh17Cm3cuHiyPOIQv1pjZVg4ATU1aafxmFyTfrmtf56tPuJ8yqcNNZC8XadYPAw7PTW9g8GJgLtC8GURJ9GQZD6FYIE6YCou8fYo6yd4b99y2y_vsl06cm9xQnstfp6eyMkcgQyrmdmlbyeuXwvcxsxtGX61MTJtCp4VELmDctJiYP_bD7HNRPV7uqXDMNmWSY0TYL-tg0As4y8-w3wSwmtcfWhnQEraFT0-m4hBpEWHlouuFNXRQIrXbamKxeh6kJNO0wJN8fZ4Ovygf8sE4kEwBPfWO59wxDF7camTpDUqg29Q", "e": "AQAB"}, {"kty": "EC", "use": "sig", "kid": "aWhtalRSTDZmNmRTd1ZDNWZmY3ZGMTNqM1dnLVA2RjQyMi1CNGdOSUNKVQ", "crv": "P-256", "x": "Ww5XVT3CxYN88BpJDZGodRiar0qr8UvPFaRoqzyD1Io", "y": "w23EDFAvwe03NjL5NKtUXwxuVMFmEn3ecJOPbljiDkg"}]} \ No newline at end of file diff --git a/tests/test_server_00a_client_configure.py b/tests/test_server_00a_client_configure.py index e618e1e9..2a61b6d3 100644 --- a/tests/test_server_00a_client_configure.py +++ b/tests/test_server_00a_client_configure.py @@ -34,14 +34,14 @@ ], "policy": { "urn:ietf:params:oauth:token-type:access_token": { - "callable": "/path/to/callable", + "function": "/path/to/function", "kwargs": {"audience": ["https://example.com"], "scopes": ["openid"]}, }, "urn:ietf:params:oauth:token-type:refresh_token": { - "callable": "/path/to/callable", + "function": "/path/to/function", "kwargs": {"resource": ["https://example.com"], "scopes": ["openid"]}, }, - "": {"callable": "/path/to/callable", "kwargs": {"scopes": ["openid"]}}, + "": {"function": "/path/to/function", "kwargs": {"scopes": ["openid"]}}, }, }, }, diff --git a/tests/test_server_24_oauth2_resource_indicators.py b/tests/test_server_24_oauth2_resource_indicators.py index 57848cbc..b991bfd2 100644 --- a/tests/test_server_24_oauth2_resource_indicators.py +++ b/tests/test_server_24_oauth2_resource_indicators.py @@ -328,7 +328,7 @@ def get_cookie_value(cookie=None, name=None): "request_uri_parameter_supported": True, "resource_indicators": { "policy": { - "callable": validate_authorization_resource_indicators_policy, + "function": validate_authorization_resource_indicators_policy, "kwargs": { "resource_servers_per_client": { "client_1": ["client_1", "client_2"], @@ -350,7 +350,7 @@ def get_cookie_value(cookie=None, name=None): ], "resource_indicators": { "policy": { - "callable": validate_token_resource_indicators_policy, + "function": validate_token_resource_indicators_policy, "kwargs": { "resource_servers_per_client": { "client_1": ["client_2", "client_3"] @@ -551,7 +551,7 @@ def test_authorization_code_req_per_client(self, create_endpoint_ri_disabled): endpoint_context.cdb["client_1"]["resource_indicators"] = { "authorization_code": { "policy": { - "callable": validate_authorization_resource_indicators_policy, + "function": validate_authorization_resource_indicators_policy, "kwargs": { "resource_servers_per_client":["client_3"] }, diff --git a/tests/test_server_36_oauth2_token_exchange.py b/tests/test_server_36_oauth2_token_exchange.py index 169aebe1..e1cf6615 100644 --- a/tests/test_server_36_oauth2_token_exchange.py +++ b/tests/test_server_36_oauth2_token_exchange.py @@ -352,7 +352,7 @@ def test_token_exchange_per_client(self, token): "default_requested_token_type": "urn:ietf:params:oauth:token-type:access_token", "policy": { "": { - "callable": "idpyoidc.server.oauth2.token_helper.validate_token_exchange_policy", + "function": "idpyoidc.server.oauth2.token_helper.validate_token_exchange_policy", "kwargs": {"scope": ["openid", "offline_access"]}, } }, @@ -410,7 +410,7 @@ def test_token_exchange_scopes_per_client(self): "default_requested_token_type": "urn:ietf:params:oauth:token-type:access_token", "policy": { "": { - "callable": "idpyoidc.server.oauth2.token_helper.validate_token_exchange_policy", + "function": "idpyoidc.server.oauth2.token_helper.validate_token_exchange_policy", "kwargs": { "scope": ["openid", "profile", "offline_access"] }, @@ -468,7 +468,7 @@ def test_token_exchange_unsupported_scopes_per_client(self): "default_requested_token_type": "urn:ietf:params:oauth:token-type:access_token", "policy": { "": { - "callable": "idpyoidc.server.oauth2.token_helper.validate_token_exchange_policy", + "function": "idpyoidc.server.oauth2.token_helper.validate_token_exchange_policy", "kwargs": { "scope": ["openid", "profile", "offline_access"] }, @@ -522,7 +522,7 @@ def test_token_exchange_no_scopes_requested(self): "default_requested_token_type": "urn:ietf:params:oauth:token-type:access_token", "policy": { "": { - "callable": "idpyoidc.server.oauth2.token_helper.validate_token_exchange_policy", + "function": "idpyoidc.server.oauth2.token_helper.validate_token_exchange_policy", "kwargs": { "scope": ["openid", "offline_access"] }, @@ -1041,7 +1041,7 @@ def test_token_exchange_unsupported_scope_requested_1(self): "default_requested_token_type": "urn:ietf:params:oauth:token-type:access_token", "policy": { "": { - "callable": "idpyoidc.server.oauth2.token_helper.validate_token_exchange_policy", + "function": "idpyoidc.server.oauth2.token_helper.validate_token_exchange_policy", "kwargs": { "scope": ["offline_access", "profile"] }, @@ -1130,7 +1130,7 @@ def test_token_exchange_unsupported_scope_requested_2(self): "default_requested_token_type": "urn:ietf:params:oauth:token-type:access_token", "policy": { "": { - "callable": "idpyoidc.server.oauth2.token_helper.validate_token_exchange_policy", + "function": "idpyoidc.server.oauth2.token_helper.validate_token_exchange_policy", "kwargs": { "scope": ["profile"] }, @@ -1218,7 +1218,7 @@ def test_token_exchange_unsupported_scope_requested_3(self): "default_requested_token_type": "urn:ietf:params:oauth:token-type:access_token", "policy": { "": { - "callable": "idpyoidc.server.oauth2.token_helper.validate_token_exchange_policy", + "function": "idpyoidc.server.oauth2.token_helper.validate_token_exchange_policy", "kwargs": { "scope": ["offline_access", "profile"] }, @@ -1326,7 +1326,7 @@ def test_token_exchange_unsupported_scope_requested_4(self): "default_requested_token_type": "urn:ietf:params:oauth:token-type:access_token", "policy": { "": { - "callable": "idpyoidc.server.oauth2.token_helper.validate_token_exchange_policy", + "function": "idpyoidc.server.oauth2.token_helper.validate_token_exchange_policy", "kwargs": { "scope": ["offline_access", "profile"] }, @@ -1424,7 +1424,7 @@ def test_token_exchange_unsupported_scope_requested_5(self): "default_requested_token_type": "urn:ietf:params:oauth:token-type:access_token", "policy": { "": { - "callable": "idpyoidc.server.oauth2.token_helper.validate_token_exchange_policy", + "function": "idpyoidc.server.oauth2.token_helper.validate_token_exchange_policy", "kwargs": { "scope": ["profile"] }, diff --git a/tests/test_server_38_oauth2_revocation_endpoint.py b/tests/test_server_38_oauth2_revocation_endpoint.py index d2d69a79..73a0b199 100644 --- a/tests/test_server_38_oauth2_revocation_endpoint.py +++ b/tests/test_server_38_oauth2_revocation_endpoint.py @@ -386,10 +386,10 @@ def custom_token_revocation_policy(token, session_info, **kwargs): ], "policy": { "": { - "callable": validate_token_revocation_policy, + "function": validate_token_revocation_policy, }, "access_token": { - "callable": custom_token_revocation_policy, + "function": custom_token_revocation_policy, } }, } @@ -423,10 +423,10 @@ def custom_token_revocation_policy(token, session_info, **kwargs): ], "policy": { "": { - "callable": validate_token_revocation_policy, + "function": validate_token_revocation_policy, }, "refresh_token": { - "callable": custom_token_revocation_policy, + "function": custom_token_revocation_policy, } }, } diff --git a/tests/test_tandem_10_oauth2_token_exchange.py b/tests/test_tandem_10_oauth2_token_exchange.py index 98b1a172..773fb218 100644 --- a/tests/test_tandem_10_oauth2_token_exchange.py +++ b/tests/test_tandem_10_oauth2_token_exchange.py @@ -356,7 +356,7 @@ def test_token_exchange_per_client(self, token): ], "policy": { "": { - "callable": + "function": "idpyoidc.server.oauth2.token_helper.validate_token_exchange_policy", "kwargs": {"scope": ["openid", "offline_access"]}, } From 981dc407edfe5e6749c8496bcedc92c6d590f8e4 Mon Sep 17 00:00:00 2001 From: Kostis Triantafyllakis Date: Mon, 30 Jan 2023 11:55:40 +0200 Subject: [PATCH 75/76] Introduce userinfo policy Signed-off-by: Kostis Triantafyllakis --- doc/server/contents/conf.rst | 21 ++++++- src/idpyoidc/server/oidc/userinfo.py | 55 ++++++++++++++++--- .../test_server_26_oidc_userinfo_endpoint.py | 51 +++++++++++++++++ 3 files changed, 119 insertions(+), 8 deletions(-) diff --git a/doc/server/contents/conf.rst b/doc/server/contents/conf.rst index 4138094a..e5c51f05 100644 --- a/doc/server/contents/conf.rst +++ b/doc/server/contents/conf.rst @@ -408,7 +408,11 @@ An example:: "normal", "aggregated", "distributed" - ] + ], + "policy": { + "function": "/path/to/callable", + "kwargs": {} + } } }, "revocation": { @@ -747,6 +751,10 @@ the following:: "userinfo": { "class": "oidc_provider.users.UserInfo", "kwargs": { + "policy": { + "function": "/path/to/callable", + "kwargs": {} + }, "claims_map": { "phone_number": "telephone", "family_name": "last_name", @@ -760,6 +768,17 @@ the following:: } } +The policy for userinfo endpoint is optional and can also be configured in a client's metadata, for example:: + + "userinfo": { + "kwargs": { + "policy": { + "function": "/path/to/callable", + "kwargs": {} + } + } + } + ================================ Special Configuration directives ================================ diff --git a/src/idpyoidc/server/oidc/userinfo.py b/src/idpyoidc/server/oidc/userinfo.py index 0beed342..58ffb107 100755 --- a/src/idpyoidc/server/oidc/userinfo.py +++ b/src/idpyoidc/server/oidc/userinfo.py @@ -10,11 +10,13 @@ from cryptojwt.jwt import utc_time_sans_frac from idpyoidc import claims +from idpyoidc.util import importer from idpyoidc.message import Message from idpyoidc.message import oidc from idpyoidc.message.oauth2 import ResponseMessage from idpyoidc.server.endpoint import Endpoint from idpyoidc.server.exception import ClientAuthenticationError +from idpyoidc.exception import ImproperlyConfigured from idpyoidc.server.util import OAUTH2_NOCACHE_HEADERS logger = logging.getLogger(__name__) @@ -46,18 +48,28 @@ def __init__(self, upstream_get: Callable, add_claims_by_scope: Optional[bool] = # Add the issuer ID as an allowed JWT target self.allowed_targets.append("") - def get_client_id_from_token(self, context, token, request=None): - _info = context.session_manager.get_session_info_by_token( + if kwargs is None: + self.config = { + "policy": { + "function": "/path/to/callable", + "kwargs": {} + }, + } + else: + self.config = kwargs + + def get_client_id_from_token(self, endpoint_context, token, request=None): + _info = endpoint_context.session_manager.get_session_info_by_token( token, handler_key="access_token" ) return _info["client_id"] def do_response( - self, - response_args: Optional[Union[Message, dict]] = None, - request: Optional[Union[Message, dict]] = None, - client_id: Optional[str] = "", - **kwargs + self, + response_args: Optional[Union[Message, dict]] = None, + request: Optional[Union[Message, dict]] = None, + client_id: Optional[str] = "", + **kwargs ) -> dict: if "error" in kwargs and kwargs["error"]: @@ -157,6 +169,12 @@ def process_request(self, request=None, **kwargs): info["sub"] = _grant.sub if _grant.add_acr_value("userinfo"): info["acr"] = _grant.authentication_event["authn_info"] + + if "userinfo" in _cntxt.cdb[request["client_id"]]: + self.config["policy"] = _cntxt.cdb[request["client_id"]]["userinfo"]["policy"] + + if "policy" in self.config: + info = self._enforce_policy(request, info, token, self.config) else: info = { "error": "invalid_request", @@ -190,3 +208,26 @@ def parse_request(self, request, http_info=None, **kwargs): request["access_token"] = auth_info["token"] return request + + def _enforce_policy(self, request, response_info, token, config): + policy = config["policy"] + callable = policy["function"] + kwargs = policy.get("kwargs", {}) + + if isinstance(callable, str): + try: + fn = importer(callable) + except Exception: + raise ImproperlyConfigured(f"Error importing {callable} policy callable") + else: + fn = callable + + try: + return fn(request, token, response_info, **kwargs) + except Exception as e: + logger.error(f"Error while executing the {fn} policy callable: {e}") + return self.error_cls(error="server_error", error_description="Internal server error") + + +def validate_userinfo_policy(request, token, response_info, **kwargs): + return response_info diff --git a/tests/test_server_26_oidc_userinfo_endpoint.py b/tests/test_server_26_oidc_userinfo_endpoint.py index b349da37..bef8921b 100755 --- a/tests/test_server_26_oidc_userinfo_endpoint.py +++ b/tests/test_server_26_oidc_userinfo_endpoint.py @@ -21,10 +21,12 @@ from idpyoidc.server.scopes import SCOPE2CLAIMS from idpyoidc.server.user_authn.authn_context import INTERNETPROTOCOLPASSWORD from idpyoidc.server.user_info import UserInfo +from idpyoidc.server.oidc.userinfo import validate_userinfo_policy from idpyoidc.time_util import utc_time_sans_frac from tests import CRYPT_CONFIG from tests import SESSION_PARAMS + KEYDEFS = [ {"type": "RSA", "key": "", "use": ["sig"]}, {"type": "EC", "crv": "P-256", "use": ["sig"]}, @@ -637,3 +639,52 @@ def test_process_request_absent_userinfo_conf(self): with pytest.raises(ImproperlyConfigured): code = self._mint_code(grant, session_id) + + def test_userinfo_policy(self): + _auth_req = AUTH_REQ.copy() + + session_id = self._create_session(_auth_req) + grant = self.session_manager[session_id] + access_token = self._mint_token("access_token", grant, session_id) + + http_info = {"headers": {"authorization": "Bearer {}".format(access_token.value)}} + + def _custom_validate_userinfo_policy(request, token, response_info, **kwargs): + return {"custom": "policy"} + + self.endpoint.config["policy"] = {} + self.endpoint.config["policy"]["function"] = _custom_validate_userinfo_policy + + _req = self.endpoint.parse_request({}, http_info=http_info) + args = self.endpoint.process_request(_req) + assert args + res = self.endpoint.do_response(request=_req, **args) + _response = json.loads(res["response"]) + assert "custom" in _response + + def test_userinfo_policy_per_client(self): + _auth_req = AUTH_REQ.copy() + + session_id = self._create_session(_auth_req) + grant = self.session_manager[session_id] + access_token = self._mint_token("access_token", grant, session_id) + + http_info = {"headers": {"authorization": "Bearer {}".format(access_token.value)}} + + def _custom_validate_userinfo_policy(request, token, response_info, **kwargs): + return {"custom": "policy"} + + self.context.cdb["client_1"]["userinfo"] = { + "policy": { + "function": _custom_validate_userinfo_policy, + "kwargs": {} + } + } + + _req = self.endpoint.parse_request({}, http_info=http_info) + args = self.endpoint.process_request(_req) + assert args + res = self.endpoint.do_response(request=_req, **args) + _response = json.loads(res["response"]) + assert "custom" in _response + From 24ca0a4fc9ddff13fd9950da9bc85f4a67c3c364 Mon Sep 17 00:00:00 2001 From: Kostis Triantafyllakis Date: Tue, 2 May 2023 10:16:55 +0300 Subject: [PATCH 76/76] Ignore PKCE for client credentials grant Signed-off-by: Kostis Triantafyllakis --- src/idpyoidc/server/oauth2/token_helper/client_credentials.py | 4 ++++ src/idpyoidc/server/oidc/add_on/pkce.py | 3 ++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/idpyoidc/server/oauth2/token_helper/client_credentials.py b/src/idpyoidc/server/oauth2/token_helper/client_credentials.py index 469eb8b6..2c37ba93 100755 --- a/src/idpyoidc/server/oauth2/token_helper/client_credentials.py +++ b/src/idpyoidc/server/oauth2/token_helper/client_credentials.py @@ -3,7 +3,9 @@ from typing import Union from idpyoidc.message import Message +from idpyoidc.message.oauth2 import CCAccessTokenRequest from idpyoidc.time_util import utc_time_sans_frac +from idpyoidc.util import sanitize from . import TokenEndpointHelper logger = logging.getLogger(__name__) @@ -74,4 +76,6 @@ def post_parse_request( client_id: Optional[str] = "", **kwargs ): + request = CCAccessTokenRequest(**request.to_dict()) + logger.debug("%s: %s" % (request.__class__.__name__, sanitize(request))) return request diff --git a/src/idpyoidc/server/oidc/add_on/pkce.py b/src/idpyoidc/server/oidc/add_on/pkce.py index 296cc4fa..ccd8d506 100644 --- a/src/idpyoidc/server/oidc/add_on/pkce.py +++ b/src/idpyoidc/server/oidc/add_on/pkce.py @@ -7,6 +7,7 @@ from idpyoidc.message.oauth2 import AuthorizationErrorResponse from idpyoidc.message.oauth2 import RefreshAccessTokenRequest from idpyoidc.message.oauth2 import TokenExchangeRequest +from idpyoidc.message.oauth2 import CCAccessTokenRequest from idpyoidc.message.oidc import TokenErrorResponse from idpyoidc.server.endpoint import Endpoint @@ -93,7 +94,7 @@ def post_token_parse(request, client_id, context, **kwargs): """ if isinstance( request, - (AuthorizationErrorResponse, RefreshAccessTokenRequest, TokenExchangeRequest), + (AuthorizationErrorResponse, RefreshAccessTokenRequest, TokenExchangeRequest, CCAccessTokenRequest), ): return request