diff --git a/demo/oauth2_add_on_dpop.py b/demo/oauth2_add_on_dpop.py index c4f935d1..e2980d0f 100755 --- a/demo/oauth2_add_on_dpop.py +++ b/demo/oauth2_add_on_dpop.py @@ -5,7 +5,7 @@ from common import KEYDEFS from common import full_path from flow import Flow -from idpyoidc.metadata import get_signing_algs +from idpyoidc.alg_info import get_signing_algs from idpyoidc.client.oauth2 import Client from idpyoidc.server import Server from idpyoidc.server.configure import ASConfiguration diff --git a/private/xmetadata/oidc.py b/private/xmetadata/oidc.py deleted file mode 100644 index 9f8db994..00000000 --- a/private/xmetadata/oidc.py +++ /dev/null @@ -1,127 +0,0 @@ -import logging -import os -from typing import Optional - -from idpyoidc import metadata -from idpyoidc.client import metadata as client_metadata -from idpyoidc.message.oidc import APPLICATION_TYPE_WEB -from idpyoidc.message.oidc import RegistrationRequest -from idpyoidc.message.oidc import RegistrationResponse - -logger = logging.getLogger(__name__) - -REGISTER2PREFERRED = { - # "require_signed_request_object": "request_object_algs_supported", - "request_object_signing_alg": "request_object_signing_alg_values_supported", - "request_object_encryption_alg": "request_object_encryption_alg_values_supported", - "request_object_encryption_enc": "request_object_encryption_enc_values_supported", - "userinfo_signed_response_alg": "userinfo_signing_alg_values_supported", - "userinfo_encrypted_response_alg": "userinfo_encryption_alg_values_supported", - "userinfo_encrypted_response_enc": "userinfo_encryption_enc_values_supported", - "id_token_signed_response_alg": "id_token_signing_alg_values_supported", - "id_token_encrypted_response_alg": "id_token_encryption_alg_values_supported", - "id_token_encrypted_response_enc": "id_token_encryption_enc_values_supported", - "default_acr_values": "acr_values_supported", - "subject_type": "subject_types_supported", - "token_endpoint_auth_method": "token_endpoint_auth_methods_supported", - "response_types": "response_types_supported", - "grant_types": "grant_types_supported", - # In OAuth2 but not in OIDC - "scope": "scopes_supported", - "token_endpoint_auth_signing_alg": "token_endpoint_auth_signing_alg_values_supported", - # "display": "display_values_supported", - # "claims": "claims_supported", - # "request": "request_parameter_supported", - # "request_uri": "request_uri_parameter_supported", - # 'claims_locales': 'claims_locales_supported', - # 'ui_locales': 'ui_locales_supported', -} - -PREFERRED2REGISTER = dict([(v, k) for k, v in REGISTER2PREFERRED.items()]) - -REQUEST2REGISTER = { - 'client_id': "client_id", - "client_secret": "client_secret", - # 'acr_values': "default_acr_values" , - # 'max_age': "default_max_age", - 'redirect_uri': "redirect_uris", - 'response_type': "response_types", - 'request_uri': "request_uris", - 'grant_type': "grant_types", - "scope": 'scopes_supported', - 'post_logout_redirect_uri': "post_logout_redirect_uris" -} - - -class Metadata(client_metadata.Metadata): - parameter = client_metadata.Metadata.parameter.copy() - parameter.update({ - "requests_dir": None - }) - - register2preferred = REGISTER2PREFERRED - registration_response = RegistrationResponse - registration_request = RegistrationRequest - - _supports = { - "acr_values_supported": None, - "application_type": APPLICATION_TYPE_WEB, - "callback_uris": None, - # "client_authn_methods": get_client_authn_methods, - "client_id": None, - "client_name": None, - "client_secret": None, - "client_uri": None, - "contacts": None, - "default_max_age": 86400, - "encrypt_id_token_supported": None, - "grant_types_supported": ["authorization_code", "implicit", "refresh_token"], - "logo_uri": None, - "id_token_signing_alg_values_supported": metadata.get_signing_algs, - "id_token_encryption_alg_values_supported": metadata.get_encryption_algs, - "id_token_encryption_enc_values_supported": metadata.get_encryption_encs, - "initiate_login_uri": None, - "jwks": None, - "jwks_uri": None, - "policy_uri": None, - "requests_dir": None, - "require_auth_time": None, - "sector_identifier_uri": None, - "scopes_supported": ["openid"], - "subject_types_supported": ["public", "pairwise", "ephemeral"], - "tos_uri": None, - } - - def __init__(self, - prefer: Optional[dict] = None, - callback_path: Optional[dict] = None - ): - client_metadata.Metadata.__init__(self, prefer=prefer, callback_path=callback_path) - - def verify_rules(self): - if self.get_preference("request_parameter_supported") and self.get_preference( - "request_uri_parameter_supported"): - raise ValueError( - "You have to chose one of 'request_parameter_supported' and " - "'request_uri_parameter_supported'. You can't have both.") - - if not self.get_preference('encrypt_userinfo_supported'): - self.set_preference('userinfo_encryption_alg_values_supported', []) - self.set_preference('userinfo_encryption_enc_values_supported', []) - - if not self.get_preference('encrypt_request_object_supported'): - self.set_preference('request_object_encryption_alg_values_supported', []) - self.set_preference('request_object_encryption_enc_values_supported', []) - - if not self.get_preference('encrypt_id_token_supported'): - self.set_preference('id_token_encryption_alg_values_supported', []) - self.set_preference('id_token_encryption_enc_values_supported', []) - - def locals(self, info): - requests_dir = info.get("requests_dir") - if requests_dir: - # make sure the path exists. If not, then create it. - if not os.path.isdir(requests_dir): - os.makedirs(requests_dir) - - self.set("requests_dir", requests_dir) diff --git a/src/idpyoidc/__init__.py b/src/idpyoidc/__init__.py index 76d83a3e..834b77ae 100644 --- a/src/idpyoidc/__init__.py +++ b/src/idpyoidc/__init__.py @@ -1,5 +1,5 @@ __author__ = "Roland Hedberg" -__version__ = "4.3.0" +__version__ = "5.0.0" VERIFIED_CLAIM_PREFIX = "__verified" @@ -10,7 +10,7 @@ def verified_claim_name(claim): def proper_path(path): """ - Clean up the path specification so it looks like something I could use. + Clean up the path specification such that it looks like something I could use. "./" "/" """ if path.startswith("./"): diff --git a/src/idpyoidc/alg_info.py b/src/idpyoidc/alg_info.py new file mode 100644 index 00000000..f3a96416 --- /dev/null +++ b/src/idpyoidc/alg_info.py @@ -0,0 +1,67 @@ +from functools import cmp_to_key +import logging + +from cryptojwt.jwe import DEPRECATED +from cryptojwt.jwe import SUPPORTED +from cryptojwt.jws.jws import SIGNER_ALGS + +logger = logging.getLogger(__name__) + +SIGNING_ALGORITHM_SORT_ORDER = ["RS", "ES", "PS", "HS", "Ed"] + + +def cmp(a, b): + return (a > b) - (a < b) + + +def alg_cmp(a, b): + if a == "none": + return 1 + elif b == "none": + return -1 + + _pos1 = SIGNING_ALGORITHM_SORT_ORDER.index(a[0:2]) + _pos2 = SIGNING_ALGORITHM_SORT_ORDER.index(b[0:2]) + if _pos1 == _pos2: + return (a > b) - (a < b) + elif _pos1 > _pos2: + return 1 + else: + return -1 + + +def get_signing_algs(): + # Assumes Cryptojwt + _algs = [name for name in list(SIGNER_ALGS.keys()) if name != "none" and name not in DEPRECATED["alg"]] + return sorted(_algs, key=cmp_to_key(alg_cmp)) + + +def get_encryption_algs(): + return SUPPORTED["alg"] + + +def get_encryption_encs(): + return SUPPORTED["enc"] + + +def array_or_singleton(claim_spec, values): + if isinstance(claim_spec[0], list): + if isinstance(values, list): + return values + else: + return [values] + else: + if isinstance(values, list): + return values[0] + else: # singleton + return values + + +def is_subset(a, b): + if isinstance(a, list): + if isinstance(b, list): + return set(b).issubset(set(a)) + elif isinstance(b, list): + return a in b + else: + return a == b diff --git a/src/idpyoidc/claims.py b/src/idpyoidc/claims.py index e684624f..afa6680f 100644 --- a/src/idpyoidc/claims.py +++ b/src/idpyoidc/claims.py @@ -1,4 +1,6 @@ +import logging from typing import Callable +from typing import List from typing import Optional from cryptojwt import KeyJar @@ -7,9 +9,14 @@ from idpyoidc.client.util import get_uri from idpyoidc.impexp import ImpExp +from idpyoidc.key_import import import_jwks +from idpyoidc.key_import import store_under_other_id +from idpyoidc.message import Message +from idpyoidc.transform import preferred_to_registered from idpyoidc.util import add_path from idpyoidc.util import qualified_name +logger = logging.getLogger(__name__) def claims_dump(info, exclude_attributes): return {qualified_name(info.__class__): info.dump(exclude_attributes=exclude_attributes)} @@ -85,7 +92,17 @@ def construct_redirect_uris(self, base_url: str, hex: str, callbacks: Optional[d self.callback = callbacks def verify_rules(self, supports): - return True + if self.get_preference("encrypt_userinfo_supported", False) is True: + self.set_preference("userinfo_encryption_alg_values_supported", []) + self.set_preference("userinfo_encryption_enc_values_supported", []) + + if self.get_preference("encrypt_request_object_supported", False) is True: + self.set_preference("request_object_encryption_alg_values_supported", []) + self.set_preference("request_object_encryption_enc_values_supported", []) + + if self.get_preference("encrypt_id_token_supported", False) is True: + self.set_preference("id_token_encryption_alg_values_supported", []) + self.set_preference("id_token_encryption_enc_values_supported", []) def locals(self, info): pass @@ -104,11 +121,11 @@ def _keyjar(self, keyjar=None, conf=None, entity_id=""): else: _keyjar = KeyJar() if "jwks" in conf: - _keyjar.import_jwks(conf["jwks"], "") + _keyjar = import_jwks(_keyjar, conf["jwks"], "") if "" in _keyjar and entity_id: # make sure I have the keys under my own name too (if I know it) - _keyjar.import_jwks_as_json(_keyjar.export_jwks_as_json(True, ""), entity_id) + _keyjar = store_under_other_id(_keyjar, "", entity_id, True) _httpc_params = conf.get("httpc_params") if _httpc_params: @@ -122,7 +139,7 @@ def _keyjar(self, keyjar=None, conf=None, entity_id=""): return keyjar, _uri_path - def get_base_url(self, configuration: dict, entity_id: Optional[str]=""): + def get_base_url(self, configuration: dict, entity_id: Optional[str] = ""): raise NotImplementedError() def get_id(self, configuration: dict): @@ -138,6 +155,7 @@ def handle_keys(self, configuration: dict, keyjar: Optional[KeyJar] = None, entity_id: Optional[str] = ""): + logger.debug(f"configuration: {configuration}") _jwks = _jwks_uri = None _id = self.get_id(configuration) keyjar, uri_path = self._keyjar(keyjar, configuration, entity_id=_id) @@ -180,6 +198,10 @@ def load_conf( elif val: self.set_preference(key, val) + for attr, val in supports.items(): + if attr not in self.prefer and val is not None: + self.set_preference(attr, val) + self.verify_rules(supports) return keyjar @@ -195,15 +217,21 @@ def set(self, key, val): def construct_uris(self, *args): pass - def supports(self): + def _expand(self, dictionary): res = {} - for key, val in self._supports.items(): + for key, val in dictionary.items(): if isinstance(val, Callable): res[key] = val() else: - res[key] = val + if isinstance(val, dict): + res[key] = self._expand(val) + else: + res[key] = val return res + def supports(self): + return self._expand(self._supports) + def supported(self, claim): return claim in self._supports @@ -219,3 +247,77 @@ def get_claim(self, key, default=None): return default else: return _val + + def get_endpoint_claims(self, endpoints): + _info = {} + for endp in endpoints: + if endp.endpoint_name: + _info[endp.endpoint_name] = endp.full_path + for arg, claim in [("client_authn_method", "auth_methods"), + ("auth_signing_alg_values", "auth_signing_alg_values")]: + _val = getattr(endp, arg, None) + if _val: + # trust_mark_status_endpoint_auth_methods_supported + md_param = f"{endp.endpoint_name}_{claim}" + _info[md_param] = _val + return _info + + def get_server_metadata(self, + entity_type: Optional[str] = "", + endpoints: Optional[list] = None, + metadata_schema: Optional[Message] = None, + extra_claims: Optional[List[str]] = None, + **kwargs): + + metadata = self.prefer + # the claims that can appear in the metadata + if metadata_schema: + attr = list(metadata_schema.c_param.keys()) + else: + attr = [] + + if extra_claims: + attr.extend(extra_claims) + + if attr: + metadata = {k: v for k, v in metadata.items() if k in attr and v != []} + + # collect endpoints + if endpoints: + metadata.update(self.get_endpoint_claims(endpoints)) + + if entity_type: + return {entity_type: metadata} + else: + return metadata + + def get_client_metadata(self, + entity_type: Optional[str] = "", + metadata_schema: Optional[Message] = None, + extra_claims: Optional[List[str]] = None, + supported: Optional[dict] = None, + **kwargs): + + if supported is None: + supported = self.supports() + + if not self.use: + self.use = preferred_to_registered(self.prefer, supported=supported) + + metadata = self.use + # the claims that can appear in the metadata + if metadata_schema: + attr = list(metadata_schema.c_param.keys()) + else: + attr = [] + + if extra_claims: + attr.extend(extra_claims) + + if attr: + metadata = {k: v for k, v in metadata.items() if k in attr} + + if entity_type: + return {entity_type: metadata} + else: + return metadata diff --git a/src/idpyoidc/client/claims/oauth2.py b/src/idpyoidc/client/claims/oauth2.py index 9d093d40..16e90475 100644 --- a/src/idpyoidc/client/claims/oauth2.py +++ b/src/idpyoidc/client/claims/oauth2.py @@ -1,7 +1,7 @@ from typing import Optional from idpyoidc.client import claims -from idpyoidc.client.claims.transform import create_registration_request +from idpyoidc.transform import create_registration_request class Claims(claims.Claims): diff --git a/src/idpyoidc/client/claims/oidc.py b/src/idpyoidc/client/claims/oidc.py index 0529f162..d2b1b0b0 100644 --- a/src/idpyoidc/client/claims/oidc.py +++ b/src/idpyoidc/client/claims/oidc.py @@ -2,9 +2,9 @@ import os from typing import Optional -from idpyoidc import metadata +from idpyoidc import alg_info from idpyoidc.client import claims as client_claims -from idpyoidc.client.claims.transform import create_registration_request +from idpyoidc.transform import create_registration_request from idpyoidc.message.oidc import APPLICATION_TYPE_WEB from idpyoidc.message.oidc import RegistrationRequest from idpyoidc.message.oidc import RegistrationResponse @@ -76,9 +76,9 @@ class Claims(client_claims.Claims): "encrypt_id_token_supported": None, # "grant_types_supported": ["authorization_code", "refresh_token"], "logo_uri": None, - "id_token_signing_alg_values_supported": metadata.get_signing_algs(), - "id_token_encryption_alg_values_supported": metadata.get_encryption_algs(), - "id_token_encryption_enc_values_supported": metadata.get_encryption_encs(), + "id_token_signing_alg_values_supported": alg_info.get_signing_algs(), + "id_token_encryption_alg_values_supported": alg_info.get_encryption_algs(), + "id_token_encryption_enc_values_supported": alg_info.get_encryption_encs(), "initiate_login_uri": None, "jwks": None, "jwks_uri": None, @@ -96,7 +96,7 @@ def __init__(self, prefer: Optional[dict] = None, callback_path: Optional[dict] def verify_rules(self, supports): if self.get_preference("request_parameter_supported") and self.get_preference( - "request_uri_parameter_supported" + "request_uri_parameter_supported" ): raise ValueError( "You have to chose one of 'request_parameter_supported' and " @@ -104,7 +104,7 @@ def verify_rules(self, supports): ) if self.get_preference("request_parameter_supported") or self.get_preference( - "request_uri_parameter_supported" + "request_uri_parameter_supported" ): if not self.get_preference("request_object_signing_alg_values_supported"): self.set_preference( diff --git a/src/idpyoidc/client/oauth2/access_token.py b/src/idpyoidc/client/oauth2/access_token.py index 6ccb6f49..5161aa06 100644 --- a/src/idpyoidc/client/oauth2/access_token.py +++ b/src/idpyoidc/client/oauth2/access_token.py @@ -7,7 +7,7 @@ from idpyoidc.client.service import Service from idpyoidc.message import oauth2 from idpyoidc.message.oauth2 import ResponseMessage -from idpyoidc.metadata import get_signing_algs +from idpyoidc.alg_info import get_signing_algs from idpyoidc.time_util import time_sans_frac LOGGER = logging.getLogger(__name__) diff --git a/src/idpyoidc/client/oauth2/add_on/dpop.py b/src/idpyoidc/client/oauth2/add_on/dpop.py index b8450500..d8a058ef 100644 --- a/src/idpyoidc/client/oauth2/add_on/dpop.py +++ b/src/idpyoidc/client/oauth2/add_on/dpop.py @@ -16,7 +16,7 @@ from idpyoidc.message import SINGLE_REQUIRED_INT from idpyoidc.message import SINGLE_REQUIRED_JSON from idpyoidc.message import SINGLE_REQUIRED_STRING -from idpyoidc.metadata import get_signing_algs +from idpyoidc.alg_info import get_signing_algs from idpyoidc.time_util import utc_time_sans_frac logger = logging.getLogger(__name__) diff --git a/src/idpyoidc/client/oauth2/add_on/jar.py b/src/idpyoidc/client/oauth2/add_on/jar.py index a775532b..209c0627 100644 --- a/src/idpyoidc/client/oauth2/add_on/jar.py +++ b/src/idpyoidc/client/oauth2/add_on/jar.py @@ -1,8 +1,7 @@ import logging from typing import Optional -from idpyoidc import claims -from idpyoidc import metadata +from idpyoidc import alg_info from idpyoidc.client.oidc.utils import construct_request_uri from idpyoidc.client.oidc.utils import request_object_encryption from idpyoidc.message.oidc import make_openid_request @@ -175,14 +174,14 @@ def jar_post_construct(request_args, service, **kwargs): def add_support( - service, - request_type: Optional[str] = "request_parameter", - request_dir: Optional[str] = "", - request_object_signing_alg: Optional[str] = "RS256", - expires_in: Optional[int] = DEFAULT_EXPIRES_IN, - with_jti: Optional[bool] = False, - request_object_encryption_alg: Optional[str] = "", - request_object_encryption_enc: Optional[str] = "", + service, + request_type: Optional[str] = "request_parameter", + request_dir: Optional[str] = "", + request_object_signing_alg: Optional[str] = "RS256", + expires_in: Optional[int] = DEFAULT_EXPIRES_IN, + with_jti: Optional[bool] = False, + request_object_encryption_alg: Optional[str] = "", + request_object_encryption_enc: Optional[str] = "", ): """ JAR support can only be considered if this client can access an authorization service. @@ -208,8 +207,8 @@ def add_support( args["request_dir"] = request_dir if request_object_encryption_enc and request_object_encryption_alg: - if request_object_encryption_enc in metadata.get_encryption_encs(): - if request_object_encryption_alg in metadata.get_encryption_algs(): + if request_object_encryption_enc in alg_info.get_encryption_encs(): + if request_object_encryption_alg in alg_info.get_encryption_algs(): args["request_object_encryption_enc"] = request_object_encryption_enc args["request_object_encryption_alg"] = request_object_encryption_alg else: diff --git a/src/idpyoidc/client/oidc/access_token.py b/src/idpyoidc/client/oidc/access_token.py index af629fab..2024612c 100644 --- a/src/idpyoidc/client/oidc/access_token.py +++ b/src/idpyoidc/client/oidc/access_token.py @@ -9,7 +9,7 @@ from idpyoidc.message import Message from idpyoidc.message import oidc from idpyoidc.message.oidc import verified_claim_name -from idpyoidc.metadata import get_signing_algs +from idpyoidc.alg_info import get_signing_algs from idpyoidc.time_util import time_sans_frac __author__ = "Roland Hedberg" diff --git a/src/idpyoidc/client/oidc/authorization.py b/src/idpyoidc/client/oidc/authorization.py index 03cde130..73c56929 100644 --- a/src/idpyoidc/client/oidc/authorization.py +++ b/src/idpyoidc/client/oidc/authorization.py @@ -3,7 +3,7 @@ from typing import Optional from typing import Union -from idpyoidc import metadata +from idpyoidc import alg_info from idpyoidc.client.oauth2 import authorization from idpyoidc.client.oauth2.utils import pre_construct_pick_redirect_uri from idpyoidc.client.oidc import IDT2REG @@ -32,9 +32,9 @@ class Authorization(authorization.Authorization): error_msg = oidc.ResponseMessage _supports = { - "request_object_signing_alg_values_supported": metadata.get_signing_algs(), - "request_object_encryption_alg_values_supported": metadata.get_encryption_algs(), - "request_object_encryption_enc_values_supported": metadata.get_encryption_encs(), + "request_object_signing_alg_values_supported": alg_info.get_signing_algs(), + "request_object_encryption_alg_values_supported": alg_info.get_encryption_algs(), + "request_object_encryption_enc_values_supported": alg_info.get_encryption_encs(), "response_types_supported": ["code", "id_token", "code id_token"], "request_parameter_supported": None, "request_uri_parameter_supported": None, @@ -213,7 +213,7 @@ def store_request_on_file(self, req, **kwargs): return _webname def construct_request_parameter( - self, req, request_param, audience=None, expires_in=0, **kwargs + self, req, request_param, audience=None, expires_in=0, **kwargs ): """Construct a request parameter""" alg = self.get_request_object_signing_alg(**kwargs) @@ -319,7 +319,7 @@ def oidc_post_construct(self, req, **kwargs): return req def gather_verify_arguments( - self, response: Optional[Union[dict, Message]] = None, behaviour_args: Optional[dict] = None + self, response: Optional[Union[dict, Message]] = None, behaviour_args: Optional[dict] = None ): """ Need to add some information before running verify() @@ -379,12 +379,12 @@ def _do_type(self, context, typ, response_types): return "" def construct_uris( - self, - base_url: str, - hex: bytes, - context: ServiceContext, - targets: Optional[List[str]] = None, - response_types: Optional[List[str]] = None, + self, + base_url: str, + hex: bytes, + context: ServiceContext, + targets: Optional[List[str]] = None, + response_types: Optional[List[str]] = None, ): _callback_uris = context.get_preference("callback_uris", {}) diff --git a/src/idpyoidc/client/oidc/userinfo.py b/src/idpyoidc/client/oidc/userinfo.py index 05fce76b..b92410d4 100644 --- a/src/idpyoidc/client/oidc/userinfo.py +++ b/src/idpyoidc/client/oidc/userinfo.py @@ -8,9 +8,9 @@ from idpyoidc.exception import MissingSigningKey from idpyoidc.message import Message from idpyoidc.message import oidc -from idpyoidc.metadata import get_encryption_algs -from idpyoidc.metadata import get_encryption_encs -from idpyoidc.metadata import get_signing_algs +from idpyoidc.alg_info import get_encryption_algs +from idpyoidc.alg_info import get_encryption_encs +from idpyoidc.alg_info import get_signing_algs logger = logging.getLogger(__name__) diff --git a/src/idpyoidc/client/provider/github.py b/src/idpyoidc/client/provider/github.py index 56c91031..674a78d0 100644 --- a/src/idpyoidc/client/provider/github.py +++ b/src/idpyoidc/client/provider/github.py @@ -1,12 +1,12 @@ +from idpyoidc.alg_info import get_signing_algs from idpyoidc.client.client_auth import get_client_authn_methods from idpyoidc.client.oauth2 import access_token from idpyoidc.client.oidc import userinfo +from idpyoidc.message import Message from idpyoidc.message import SINGLE_OPTIONAL_STRING from idpyoidc.message import SINGLE_REQUIRED_STRING -from idpyoidc.message import Message from idpyoidc.message import oauth2 from idpyoidc.message.oauth2 import ResponseMessage -from idpyoidc.metadata import get_signing_algs class AccessTokenResponse(Message): diff --git a/src/idpyoidc/client/provider/linkedin.py b/src/idpyoidc/client/provider/linkedin.py index 17c7e85b..e0bc430e 100644 --- a/src/idpyoidc/client/provider/linkedin.py +++ b/src/idpyoidc/client/provider/linkedin.py @@ -1,13 +1,13 @@ +from idpyoidc.alg_info import get_signing_algs from idpyoidc.client.client_auth import get_client_authn_methods from idpyoidc.client.oauth2 import access_token from idpyoidc.client.oidc import userinfo +from idpyoidc.message import Message from idpyoidc.message import SINGLE_OPTIONAL_JSON from idpyoidc.message import SINGLE_OPTIONAL_STRING from idpyoidc.message import SINGLE_REQUIRED_INT from idpyoidc.message import SINGLE_REQUIRED_STRING -from idpyoidc.message import Message from idpyoidc.message import oauth2 -from idpyoidc.metadata import get_signing_algs class AccessTokenResponse(Message): diff --git a/src/idpyoidc/client/rp_handler.py b/src/idpyoidc/client/rp_handler.py index eea94c0b..8a79fc65 100644 --- a/src/idpyoidc/client/rp_handler.py +++ b/src/idpyoidc/client/rp_handler.py @@ -3,19 +3,25 @@ import traceback from typing import List from typing import Optional +from typing import Union from cryptojwt import KeyJar +from cryptojwt.key_jar import build_keyjar from cryptojwt.key_jar import init_key_jar from cryptojwt.utils import as_bytes from cryptojwt.utils import importer +from idpyoidc.client.configure import RPHConfiguration from idpyoidc.client.defaults import DEFAULT_CLIENT_CONFIGS from idpyoidc.client.defaults import DEFAULT_OIDC_SERVICES -from idpyoidc.client.defaults import DEFAULT_RP_KEY_DEFS from idpyoidc.client.oauth2.stand_alone_client import StandAloneClient +from idpyoidc.configure import Base from idpyoidc.util import add_path from idpyoidc.util import rndstr +from .defaults import DEFAULT_KEY_DEFS from .oauth2 import Client +from ..key_import import import_jwks +from ..key_import import store_under_other_id from ..message import Message logger = logging.getLogger(__name__) @@ -34,50 +40,51 @@ def __init__( state_db=None, httpc=None, httpc_params=None, - config=None, + config: Optional[Union[dict, Base]] = None, **kwargs, ): - self.base_url = base_url - - if keyjar is None: - keyjar_defs = {} - if config: - keyjar_defs = getattr(config, "key_conf", None) - - if not keyjar_defs: - keyjar_defs = kwargs.get("key_conf", DEFAULT_RP_KEY_DEFS) - - _jwks_path = kwargs.get("jwks_path", keyjar_defs.get("uri_path", keyjar_defs.get("public_path", ""))) - if "uri_path" in keyjar_defs: - del keyjar_defs["uri_path"] - self.keyjar = init_key_jar(**keyjar_defs, issuer_id="") - self.keyjar.import_jwks_as_json(self.keyjar.export_jwks_as_json(True, ""), base_url) - else: + if config is None: + config = RPHConfiguration({}) + elif isinstance(config, dict): + config = RPHConfiguration(config) + + self.base_url = base_url or config.get("base_url", config.get("entity_id", "")) + self.entity_id = config.get("entity_id", config.conf.get("entity_id", self.base_url)) + self.entity_type = config.get("entity_type", config.conf.get("entity_type", "")) + self.client_type = config.get("client_type", config.conf.get("client_type", "")) + self.client_configs = client_configs or {} + + if keyjar: self.keyjar = keyjar _jwks_path = kwargs.get("jwks_path", "") - - if _jwks_path: - self.jwks_uri = add_path(base_url, _jwks_path) - else: - self.jwks_uri = "" - if len(self.keyjar): - self.jwks = self.keyjar.export_jwks() + if _jwks_path: + self.jwks_uri = add_path(base_url, _jwks_path) else: - self.jwks = {} + self.jwks_uri = "" + if len(self.keyjar): + self.jwks = self.keyjar.export_jwks() + else: + self.jwks = {} if config: if not hash_seed: self.hash_seed = config.hash_seed - if not keyjar: - self.keyjar = init_key_jar(**config.key_conf, issuer_id="") - if not client_configs: - self.client_configs = config.clients - - if "client_class" in config: - if isinstance(config["client_class"], str): - self.client_cls = importer(config["client_class"]) + + if not keyjar and config.key_conf: + _conf = {k: v for k, v in config.key_conf.items() if k != "uri_path"} + self.keyjar = init_key_jar(**_conf, issuer_id="") + _jwks_path = kwargs.get("jwks_path", + config.key_conf.get("uri_path", + config.key_conf.get("public_path", ""))) + if _jwks_path: + self.jwks_uri = add_path(self.base_url, _jwks_path) + + _c_class = config.get("client_class", config.conf.get("client_class")) + if _c_class: + if isinstance(_c_class, str): + self.client_cls = importer(_c_class) else: # assume it's a class - self.client_cls = config["client_class"] + self.client_cls = _c_class else: self.client_cls = StandAloneClient else: @@ -86,23 +93,22 @@ def __init__( else: self.hash_seed = as_bytes(rndstr(32)) - if client_configs is None: - self.client_configs = DEFAULT_CLIENT_CONFIGS - for param in ["client_type", "preference", "add_ons"]: - val = kwargs.get(param, None) - if val: - self.client_configs[""][param] = val - else: - self.client_configs = client_configs - _cc = kwargs.get("client_class", None) if _cc: if isinstance(_cc, str): _cc = importer(_cc) - self.client_cls =_cc + self.client_cls = _cc else: self.client_cls = StandAloneClient + if client_configs is None: + self.client_configs = DEFAULT_CLIENT_CONFIGS + for param in ["client_type", "preference", "add_ons"]: + val = kwargs.get(param, None) + if val: + self.client_configs[""][param] = val + else: + self.client_configs = client_configs if state_db: self.state_db = state_db @@ -111,6 +117,9 @@ def __init__( self.extra = kwargs + if services is None: + services = config.get("services", config.conf.get("services", None)) + if services is None: self.services = DEFAULT_OIDC_SERVICES else: @@ -122,15 +131,20 @@ def __init__( self.httpc = httpc if not httpc_params: - self.httpc_params = {"verify": verify_ssl} + self.httpc_params = config.get("httpc_params", {"verify": verify_ssl}) else: self.httpc_params = httpc_params - if not self.keyjar.httpc_params: - self.keyjar.httpc_params = self.httpc_params - self.upstream_get = kwargs.get("upstream_get", None) + _keyjar = getattr(self, "keyjar", None) + if _keyjar is not None: + if not _keyjar.httpc_params: + _keyjar.httpc_params = getattr(self, "httpc_params", {}) + else: + self.keyjar = build_keyjar(DEFAULT_KEY_DEFS) + self.keyjar.httpc_params = getattr(self, "httpc_params", {}) + def state2issuer(self, state): """ Given the state value find the Issuer ID of the OP/AS that state value @@ -159,7 +173,13 @@ def pick_config(self, issuer): :param issuer: Issuer ID :return: A client configuration """ - return self.client_configs[issuer] + _cnf = self.client_configs[issuer].copy() + for param in ["entity_id", "client_id", "base_url", "services", "jwks_uri", "entity_type", + "client_type"]: + if param not in _cnf and getattr(self, param, None): + _cnf[param] = getattr(self, param) + + return _cnf def get_session_information(self, key, client=None): """ @@ -192,16 +212,7 @@ def init_client(self, issuer): _cnf = self.pick_config("") _cnf["issuer"] = issuer - try: - _services = _cnf["services"] - except KeyError: - _services = self.services - - if "base_url" not in _cnf: - _cnf["base_url"] = self.base_url - - if self.jwks_uri: - _cnf["jwks_uri"] = self.jwks_uri + _services = _cnf["services"] logger.debug(f"config: {_cnf}") try: @@ -221,20 +232,32 @@ def init_client(self, issuer): _context = client.get_context() if _context.iss_hash: self.hash2issuer[_context.iss_hash] = issuer + # If non persistent _keyjar = client.keyjar - if not _keyjar: + if _keyjar is None: _keyjar = KeyJar() _keyjar.httpc_params.update(self.httpc_params) - for iss in self.keyjar.owners(): - _keyjar.import_jwks(self.keyjar.export_jwks(issuer_id=iss, private=True), iss) + if self.upstream_get: + _srv_keyjar = self.upstream_get("attribute", "keyjar") + else: + _srv_keyjar = getattr(self, "keyjar", None) + + if _srv_keyjar: + for iss in _srv_keyjar.owners(): + _keyjar = import_jwks(_keyjar, self.keyjar.export_jwks(issuer_id=iss, private=True), iss) + + if self.entity_id not in _keyjar: + _keyjar = store_under_other_id(_keyjar,"", self.entity_id, True) client.keyjar = _keyjar # If persistent nothing has to be copied - _context.base_url = self.base_url - _context.jwks_uri = self.jwks_uri + for item in ["jwks_uri", "base_url"]: + _val = getattr(self, item, None) + if _val: + setattr(_context, item, _val) return client def do_provider_info( @@ -639,7 +662,8 @@ def logout( return client.logout(state, post_logout_redirect_uri=post_logout_redirect_uri) def close( - self, state: str, issuer: Optional[str] = "", post_logout_redirect_uri: Optional[str] = "" + self, state: str, issuer: Optional[str] = "", + post_logout_redirect_uri: Optional[str] = "" ) -> dict: if issuer: diff --git a/src/idpyoidc/client/service_context.py b/src/idpyoidc/client/service_context.py index d101fae2..b4d391c2 100644 --- a/src/idpyoidc/client/service_context.py +++ b/src/idpyoidc/client/service_context.py @@ -8,8 +8,8 @@ from typing import Optional from typing import Union -from cryptojwt.jwk.rsa import RSAKey from cryptojwt.jwk.rsa import import_private_rsa_key_from_file +from cryptojwt.jwk.rsa import RSAKey from cryptojwt.key_bundle import KeyBundle from cryptojwt.key_jar import KeyJar from cryptojwt.utils import as_bytes @@ -21,13 +21,12 @@ from idpyoidc.client.claims.oauth2resource import Claims as OAUTH2RESOURCE_Specs from idpyoidc.client.claims.oidc import Claims as OIDC_Specs from idpyoidc.client.configure import Configuration +from idpyoidc.transform import preferred_to_registered +from idpyoidc.transform import supported_to_preferred from idpyoidc.util import rndstr - -from ..impexp import ImpExp -from .claims.transform import preferred_to_registered -from .claims.transform import supported_to_preferred from .configure import get_configuration from .current import Current +from ..impexp import ImpExp logger = logging.getLogger(__name__) @@ -116,14 +115,14 @@ class ServiceContext(ImpExp): init_args = ["upstream_get"] def __init__( - self, - upstream_get: Optional[Callable] = None, - base_url: Optional[str] = "", - keyjar: Optional[KeyJar] = None, - config: Optional[Union[dict, Configuration]] = None, - cstate: Optional[Current] = None, - client_type: Optional[str] = "oauth2", - **kwargs, + self, + upstream_get: Optional[Callable] = None, + base_url: Optional[str] = "", + keyjar: Optional[KeyJar] = None, + config: Optional[Union[dict, Configuration]] = None, + cstate: Optional[Current] = None, + client_type: Optional[str] = "oauth2", + **kwargs, ): ImpExp.__init__(self) config = get_configuration(config) @@ -212,7 +211,7 @@ def filename_from_webname(self, webname): if not webname.startswith(self.base_url): raise ValueError("Webname doesn't match base_url") - _name = webname[len(self.base_url) :] + _name = webname[len(self.base_url):] if _name.startswith("/"): return _name[1:] @@ -294,15 +293,15 @@ def get(self, key, default=None): def set(self, key, value): setattr(self, key, value) - def get_client_id(self): - res = self.claims.get_usage("client_id") - if not res: - res = self.entity_id - if not res and self.upstream_get: - res = self.upstream_get("unit").entity_id - + def get_entity_id(self): + res = self.entity_id + if not res and self.upstream_get: + res = self.upstream_get("unit").entity_id return res + def get_client_id(self): + return self.claims.get_usage("client_id") + def collect_usage(self): return self.claims.use diff --git a/src/idpyoidc/key_import.py b/src/idpyoidc/key_import.py new file mode 100644 index 00000000..9b33f50a --- /dev/null +++ b/src/idpyoidc/key_import.py @@ -0,0 +1,76 @@ +import json +from typing import List +from typing import Optional + +from cryptojwt import JWK +from cryptojwt import KeyBundle +from cryptojwt import KeyJar +from cryptojwt.jwk.hmac import SYMKey +from cryptojwt.jwk.jwk import key_from_jwk_dict + + +def issuer_keys(keyjar: KeyJar, entity_id: str, format: Optional[str] = "jwk"): + # sort of copying the functionality in KeyJar.get_issuer_keys() + key_issuer = keyjar.return_issuer(entity_id) + if format == "jwk": + return [k.serialize() for k in key_issuer.all_keys()] + else: + return [k for k in key_issuer.all_keys()] + + +def import_jwks(keyjar: KeyJar, jwks: dict, entity_id: Optional[str] = "") -> KeyJar: + keys = [] + jar = issuer_keys(keyjar, entity_id) + for jwk in jwks["keys"]: + if jwk not in jar: + jar.append(jwk) + key = key_from_jwk_dict(jwk) + keys.append(key) + if keys: + keyjar.add_keys(entity_id, keys) + return keyjar + + +def import_jwks_as_json(keyjar: KeyJar, jwks: str, entity_id: Optional[str] = "") -> KeyJar: + return import_jwks(keyjar, json.loads(jwks), entity_id) + + +def import_jwks_from_file(keyjar: KeyJar, filename: str, entity_id) -> KeyJar: + with open(filename) as jwks_file: + keyjar = import_jwks_as_json(keyjar, jwks_file.read(), entity_id) + return keyjar + + +def add_kb(keyjar: KeyJar, key_bundle: KeyBundle, entity_id: str) -> KeyJar: + return import_jwks(keyjar, json.loads(key_bundle.jwks()), entity_id) + + +def add_symmetric(keyjar: KeyJar, key: str, entity_id: Optional[str] = "") -> KeyJar: + jar = issuer_keys(keyjar, entity_id) + _sym_key = SYMKey(key=key) + + jwk = _sym_key.serialize() + if jwk not in jar: + keyjar.add_symmetric(entity_id, key) + return keyjar + + +def store_under_other_id(keyjar: KeyJar, fro: Optional[str] = "", to: Optional[str] = "", + private: Optional[bool] = False) -> KeyJar: + if fro == to: + return keyjar + else: + return import_jwks(keyjar, keyjar.export_jwks(private, fro), to) + + +def add_keys(keyjar:KeyJar, keys: List[JWK], entity_id) -> KeyJar: + _keys = [] + jar = issuer_keys(keyjar, entity_id) + for key in keys: + jwk = key.serialize() + if jwk not in jar: + jar.append(jwk) + _keys.append(key) + if _keys: + keyjar.add_keys(entity_id, _keys) + return keyjar diff --git a/src/idpyoidc/message/oidc/__init__.py b/src/idpyoidc/message/oidc/__init__.py index d266224d..fc5d114c 100644 --- a/src/idpyoidc/message/oidc/__init__.py +++ b/src/idpyoidc/message/oidc/__init__.py @@ -915,8 +915,8 @@ class ProviderConfigurationResponse(ResponseMessage): "token_endpoint_auth_methods_supported": ["client_secret_basic"], "claims_parameter_supported": False, "request_parameter_supported": False, - "request_uri_parameter_supported": True, - "require_request_uri_registration": True, + "request_uri_parameter_supported": None, + "require_request_uri_registration": None, "grant_types_supported": ["authorization_code"], } diff --git a/src/idpyoidc/metadata.py b/src/idpyoidc/metadata.py index 7561d483..e69de29b 100644 --- a/src/idpyoidc/metadata.py +++ b/src/idpyoidc/metadata.py @@ -1,274 +0,0 @@ -from functools import cmp_to_key -import logging -from typing import Callable -from typing import Optional - -from cryptojwt import KeyJar -from cryptojwt.jwe import SUPPORTED -from cryptojwt.jws.jws import SIGNER_ALGS -from cryptojwt.key_jar import init_key_jar -from cryptojwt.utils import importer - -from idpyoidc.client.util import get_uri -from idpyoidc.impexp import ImpExp -from idpyoidc.util import add_path -from idpyoidc.util import qualified_name - -logger = logging.getLogger(__name__) - - -def metadata_dump(info, exclude_attributes): - return {qualified_name(info.__class__): info.dump(exclude_attributes=exclude_attributes)} - - -def metadata_load(item: dict, **kwargs): - _class_name = list(item.keys())[0] # there is only one - _cls = importer(_class_name) - _cls = _cls().load(item[_class_name]) - return _cls - - -class Metadata(ImpExp): - parameter = {"prefer": None, "use": None, "callback_path": None, "_local": None} - - _supports = {} - - def __init__(self, prefer: Optional[dict] = None, callback_path: Optional[dict] = None): - - ImpExp.__init__(self) - if isinstance(prefer, dict): - self.prefer = {k: v for k, v in prefer.items() if k in self._supports} - else: - self.prefer = {} - - self.callback_path = callback_path or {} - self.use = {} - self._local = {} - - def get_use(self): - return self.use - - def set_usage(self, key, value): - self.use[key] = value - - def get_usage(self, key, default=None): - return self.use.get(key, default) - - def get_preference(self, key, default=None): - return self.prefer.get(key, default) - - def set_preference(self, key, value): - self.prefer[key] = value - - def remove_preference(self, key): - if key in self.prefer: - del self.prefer[key] - - def _callback_uris(self, base_url, hex): - _uri = [] - for type in self.get_usage("response_types", self._supports["response_types"]): - if "code" in type: - _uri.append("code") - elif type in ["id_token", "id_token token"]: - _uri.append("implicit") - - if "form_post" in self.supports: - _uri.append("form_post") - - callback_uri = {} - for key in _uri: - callback_uri[key] = get_uri(base_url, self.callback_path[key], hex) - return callback_uri - - def construct_redirect_uris(self, base_url: str, hex: str, callbacks: Optional[dict] = None): - if not callbacks: - callbacks = self._callback_uris(base_url, hex) - - if callbacks: - self.set_preference("callbacks", callbacks) - self.set_preference("redirect_uris", [v for k, v in callbacks.items()]) - - self.callback = callbacks - - def verify_rules(self, supports): - return True - - def locals(self, info): - pass - - def _keyjar(self, keyjar=None, conf=None, entity_id=""): - _uri_path = "" - if keyjar is None: - if "keys" in conf: - keys_args = {k: v for k, v in conf["keys"].items() if k != "uri_path"} - _keyjar = init_key_jar(**keys_args) - _uri_path = conf["keys"].get("uri_path") - elif "key_conf" in conf and conf["key_conf"]: - keys_args = {k: v for k, v in conf["key_conf"].items() if k != "uri_path"} - _keyjar = init_key_jar(**keys_args) - _uri_path = conf["key_conf"].get("uri_path") - else: - _keyjar = KeyJar() - if "jwks" in conf: - _keyjar.import_jwks(conf["jwks"], "") - - if "" in _keyjar and entity_id: - # make sure I have the keys under my own name too (if I know it) - _keyjar.import_jwks_as_json(_keyjar.export_jwks_as_json(True, ""), entity_id) - - _httpc_params = conf.get("httpc_params") - if _httpc_params: - _keyjar.httpc_params = _httpc_params - - return _keyjar, _uri_path - else: - if "keys" in conf: - _uri_path = conf["keys"].get("uri_path") - elif "key_conf" in conf and conf["key_conf"]: - _uri_path = conf["key_conf"].get("uri_path") - return keyjar, _uri_path - - def get_base_url(self, configuration: dict, entity_id: Optional[str] = ""): - raise NotImplementedError() - - def get_id(self, configuration: dict): - raise NotImplementedError() - - def add_extra_keys(self, keyjar, id): - return None - - def get_jwks(self, keyjar): - return None - - def handle_keys(self, - configuration: dict, - keyjar: Optional[KeyJar] = None, - base_url: Optional[str] = "", - entity_id: Optional[str] = ""): - _jwks = _jwks_uri = None - _id = self.get_id(configuration) - keyjar, uri_path = self._keyjar(keyjar, configuration, entity_id=_id) - - self.add_extra_keys(keyjar, _id) - - # now that keys are in the Key Jar, now for how to publish it - if "jwks_uri" in configuration: # simple - _jwks_uri = configuration.get("jwks_uri") - elif uri_path: - if not base_url: - base_url = self.get_base_url(configuration, entity_id=entity_id) - _jwks_uri = add_path(base_url, uri_path) - else: # jwks or nothing - _jwks = self.get_jwks(keyjar) - - return {"keyjar": keyjar, "jwks": _jwks, "jwks_uri": _jwks_uri} - - def load_conf( - self, configuration, supports, keyjar: Optional[KeyJar] = None, - base_url: Optional[str] = "" - ): - for attr, val in configuration.items(): - if attr == "preference": - for k, v in val.items(): - if k in supports: - self.set_preference(k, v) - elif attr in supports: - self.set_preference(attr, val) - - self.locals(configuration) - - for key, val in self.handle_keys(configuration, keyjar=keyjar, base_url=base_url).items(): - if key == "keyjar": - keyjar = val - elif val: - self.set_preference(key, val) - - self.verify_rules(supports) - return keyjar - - def get(self, key, default=None): - if key in self._local: - return self._local[key] - else: - return default - - def set(self, key, val): - self._local[key] = val - - def construct_uris(self, *args): - pass - - def supports(self): - res = {} - for key, val in self._supports.items(): - if isinstance(val, Callable): - res[key] = val() - else: - res[key] = val - return res - - def supported(self, claim): - return claim in self._supports - - def prefers(self): - return self.prefer - - -SIGNING_ALGORITHM_SORT_ORDER = ["RS", "ES", "PS", "HS", "Ed"] - - -def cmp(a, b): - return (a > b) - (a < b) - - -def alg_cmp(a, b): - if a == "none": - return 1 - elif b == "none": - return -1 - - _pos1 = SIGNING_ALGORITHM_SORT_ORDER.index(a[0:2]) - _pos2 = SIGNING_ALGORITHM_SORT_ORDER.index(b[0:2]) - if _pos1 == _pos2: - return (a > b) - (a < b) - elif _pos1 > _pos2: - return 1 - else: - return -1 - - -def get_signing_algs(): - # Assumes Cryptojwt - _algs = [name for name in list(SIGNER_ALGS.keys()) if name != "none"] - return sorted(_algs, key=cmp_to_key(alg_cmp)) - - -def get_encryption_algs(): - return SUPPORTED["alg"] - - -def get_encryption_encs(): - return SUPPORTED["enc"] - - -def array_or_singleton(claim_spec, values): - if isinstance(claim_spec[0], list): - if isinstance(values, list): - return values - else: - return [values] - else: - if isinstance(values, list): - return values[0] - else: # singleton - return values - - -def is_subset(a, b): - if isinstance(a, list): - if isinstance(b, list): - return set(b).issubset(set(a)) - elif isinstance(b, list): - return a in b - else: - return a == b diff --git a/src/idpyoidc/node.py b/src/idpyoidc/node.py index 0db622a7..2b64d6e9 100644 --- a/src/idpyoidc/node.py +++ b/src/idpyoidc/node.py @@ -7,14 +7,17 @@ from idpyoidc.configure import Configuration from idpyoidc.impexp import ImpExp +from idpyoidc.key_import import import_jwks +from idpyoidc.key_import import import_jwks_as_json +from idpyoidc.key_import import store_under_other_id from idpyoidc.util import instantiate def create_keyjar( - keyjar: Optional[KeyJar] = None, - conf: Optional[Union[dict, Configuration]] = None, - key_conf: Optional[dict] = None, - id: Optional[str] = "", + keyjar: Optional[KeyJar] = None, + conf: Optional[Union[dict, Configuration]] = None, + key_conf: Optional[dict] = None, + id: Optional[str] = "", ): if keyjar is None: if key_conf: @@ -30,13 +33,13 @@ def create_keyjar( else: _keyjar = KeyJar() if "jwks" in conf: - _keyjar.import_jwks(conf["jwks"], "") + _keyjar = import_jwks(_keyjar, conf["jwks"], "") else: _keyjar = None if _keyjar and "" in _keyjar and id: # make sure I have the keys under my own name too (if I know it) - _keyjar.import_jwks_as_json(_keyjar.export_jwks_as_json(True, ""), id) + _keyjar = store_under_other_id(_keyjar, "", id, True) return _keyjar else: @@ -60,7 +63,7 @@ def make_keyjar( keyjar = KeyJar() _jwks = config.get("jwks") if _jwks: - keyjar.import_jwks_as_json(_jwks, client_id) + keyjar = import_jwks_as_json(keyjar, _jwks, client_id) if keyjar or key_conf: # Should be either one @@ -78,15 +81,12 @@ def make_keyjar( keyjar = KeyJar() keyjar.add_symmetric(client_id, _key) keyjar.add_symmetric("", _key) - # else: - # keyjar = build_keyjar(DEFAULT_KEY_DEFS) - # if issuer_id: - # keyjar.import_jwks(keyjar.export_jwks(private=True), issuer_id) return keyjar class Node: + def __init__(self, upstream_get: Callable = None): self.upstream_get = upstream_get @@ -123,19 +123,20 @@ class Unit(ImpExp): init_args = ["upstream_get"] def __init__( - self, - upstream_get: Callable = None, - keyjar: Optional[Union[KeyJar, bool]] = None, - httpc: Optional[object] = None, - httpc_params: Optional[dict] = None, - config: Optional[Union[Configuration, dict]] = None, - key_conf: Optional[dict] = None, - issuer_id: Optional[str] = "", - client_id: Optional[str] = "", + self, + upstream_get: Callable = None, + keyjar: Optional[Union[KeyJar, bool]] = None, + httpc: Optional[object] = None, + httpc_params: Optional[dict] = None, + config: Optional[Union[Configuration, dict]] = None, + key_conf: Optional[dict] = None, + issuer_id: Optional[str] = "", + client_id: Optional[str] = "", ): ImpExp.__init__(self) self.upstream_get = upstream_get self.httpc = httpc + self.client_id = client_id if config is None: config = {} @@ -192,16 +193,16 @@ class ClientUnit(Unit): name = "" def __init__( - self, - upstream_get: Callable = None, - httpc: Optional[object] = None, - httpc_params: Optional[dict] = None, - keyjar: Optional[KeyJar] = None, - context: Optional[ImpExp] = None, - config: Optional[Union[Configuration, dict]] = None, - # jwks_uri: Optional[str] = "", - entity_id: Optional[str] = "", - key_conf: Optional[dict] = None, + self, + upstream_get: Callable = None, + httpc: Optional[object] = None, + httpc_params: Optional[dict] = None, + keyjar: Optional[KeyJar] = None, + context: Optional[ImpExp] = None, + config: Optional[Union[Configuration, dict]] = None, + # jwks_uri: Optional[str] = "", + entity_id: Optional[str] = "", + key_conf: Optional[dict] = None, ): if config is None: config = {} @@ -232,17 +233,18 @@ def get_context_attribute(self, attr, *args): # Neither client nor Server class Collection(Unit): + def __init__( - self, - upstream_get: Callable = None, - keyjar: Optional[KeyJar] = None, - httpc: Optional[object] = None, - httpc_params: Optional[dict] = None, - config: Optional[Union[Configuration, dict]] = None, - entity_id: Optional[str] = "", - key_conf: Optional[dict] = None, - functions: Optional[dict] = None, - claims: Optional[dict] = None, + self, + upstream_get: Callable = None, + keyjar: Optional[KeyJar] = None, + httpc: Optional[object] = None, + httpc_params: Optional[dict] = None, + config: Optional[Union[Configuration, dict]] = None, + entity_id: Optional[str] = "", + key_conf: Optional[dict] = None, + functions: Optional[dict] = None, + claims: Optional[dict] = None, ): if config is None: config = {} diff --git a/src/idpyoidc/server/claims/oidc.py b/src/idpyoidc/server/claims/oidc.py index 2c258ba2..6d5efd6a 100644 --- a/src/idpyoidc/server/claims/oidc.py +++ b/src/idpyoidc/server/claims/oidc.py @@ -1,6 +1,6 @@ from typing import Optional -from idpyoidc import metadata +from idpyoidc import alg_info from idpyoidc.message.oidc import ProviderConfigurationResponse from idpyoidc.message.oidc import RegistrationRequest from idpyoidc.message.oidc import RegistrationResponse @@ -48,9 +48,9 @@ class Claims(server_claims.Claims): "display_values_supported": None, "encrypt_id_token_supported": None, # "grant_types_supported": ["authorization_code", "implicit", "refresh_token"], - "id_token_signing_alg_values_supported": metadata.get_signing_algs(), - "id_token_encryption_alg_values_supported": metadata.get_encryption_algs(), - "id_token_encryption_enc_values_supported": metadata.get_encryption_encs(), + "id_token_signing_alg_values_supported": alg_info.get_signing_algs(), + "id_token_encryption_alg_values_supported": alg_info.get_encryption_algs(), + "id_token_encryption_enc_values_supported": alg_info.get_encryption_encs(), "initiate_login_uri": None, "jwks": None, "jwks_uri": None, diff --git a/src/idpyoidc/server/configure.py b/src/idpyoidc/server/configure.py index 3f304e7c..a6ecbe05 100755 --- a/src/idpyoidc/server/configure.py +++ b/src/idpyoidc/server/configure.py @@ -82,8 +82,8 @@ "kwargs": { "client_authn_method": None, "claims_parameter_supported": True, - "request_parameter_supported": True, - "request_uri_parameter_supported": True, + "request_parameter_supported": None, + "request_uri_parameter_supported": None, "response_types_supported": ["code"], "response_modes_supported": ["query", "fragment", "form_post"], }, @@ -151,8 +151,8 @@ "kwargs": { "client_authn_method": None, "claims_parameter_supported": True, - "request_parameter_supported": True, - "request_uri_parameter_supported": True, + "request_parameter_supported": None, + "request_uri_parameter_supported": None, "response_types_supported": [ "code", # "token", @@ -479,8 +479,8 @@ def __init__( "kwargs": { "client_authn_method": None, "claims_parameter_supported": True, - "request_parameter_supported": True, - "request_uri_parameter_supported": True, + "request_parameter_supported": None, + "request_uri_parameter_supported": None, "response_types_supported": [ "code", # "token", diff --git a/src/idpyoidc/server/oauth2/add_on/dpop.py b/src/idpyoidc/server/oauth2/add_on/dpop.py index 5148cfe0..2e7ae1e5 100644 --- a/src/idpyoidc/server/oauth2/add_on/dpop.py +++ b/src/idpyoidc/server/oauth2/add_on/dpop.py @@ -14,7 +14,7 @@ from idpyoidc.message import SINGLE_REQUIRED_INT from idpyoidc.message import SINGLE_REQUIRED_JSON from idpyoidc.message import SINGLE_REQUIRED_STRING -from idpyoidc.metadata import get_signing_algs +from idpyoidc.alg_info import get_signing_algs from idpyoidc.server.client_authn import BearerHeader logger = logging.getLogger(__name__) diff --git a/src/idpyoidc/server/oauth2/authorization.py b/src/idpyoidc/server/oauth2/authorization.py index 0766af57..e2cd4fa7 100755 --- a/src/idpyoidc/server/oauth2/authorization.py +++ b/src/idpyoidc/server/oauth2/authorization.py @@ -4,9 +4,9 @@ from typing import Optional from typing import TypeVar from typing import Union +from urllib.parse import parse_qs from urllib.parse import ParseResult from urllib.parse import SplitResult -from urllib.parse import parse_qs from urllib.parse import unquote from urllib.parse import urlencode from urllib.parse import urlparse @@ -18,7 +18,7 @@ from cryptojwt.utils import as_bytes from cryptojwt.utils import b64e -from idpyoidc import metadata +from idpyoidc import alg_info from idpyoidc.exception import ImproperlyConfigured from idpyoidc.exception import ParameterError from idpyoidc.exception import URIError @@ -48,7 +48,6 @@ from idpyoidc.util import importer from idpyoidc.util import rndstr - ParsedURI = TypeVar('ParsedURI', ParseResult, SplitResult) logger = logging.getLogger(__name__) @@ -392,13 +391,13 @@ class Authorization(Endpoint): _supports = { "claims_parameter_supported": True, - "request_parameter_supported": True, - "request_uri_parameter_supported": True, + "request_parameter_supported": None, + "request_uri_parameter_supported": None, "response_types_supported": ["code"], "response_modes_supported": ["query", "fragment", "form_post"], - "request_object_signing_alg_values_supported": metadata.get_signing_algs(), - "request_object_encryption_alg_values_supported": metadata.get_encryption_algs(), - "request_object_encryption_enc_values_supported": metadata.get_encryption_encs(), + "request_object_signing_alg_values_supported": alg_info.get_signing_algs(), + "request_object_encryption_alg_values_supported": alg_info.get_encryption_algs(), + "request_object_encryption_enc_values_supported": alg_info.get_encryption_encs(), # "grant_types_supported": ["authorization_code", "implicit"], "code_challenge_methods_supported": ["S256"], "scopes_supported": [], @@ -911,7 +910,7 @@ def create_authn_response(self, request: Union[dict, Message], sid: str) -> dict if isinstance(request["response_type"], list): rtype = set(request["response_type"][:]) - else: # assume it's a string + else: # assume it's a string rtype = set() rtype.add(request["response_type"]) diff --git a/src/idpyoidc/server/oidc/authorization.py b/src/idpyoidc/server/oidc/authorization.py index e6daad35..29e93886 100644 --- a/src/idpyoidc/server/oidc/authorization.py +++ b/src/idpyoidc/server/oidc/authorization.py @@ -2,7 +2,7 @@ from typing import Callable from urllib.parse import urlsplit -from idpyoidc import metadata +from idpyoidc import alg_info from idpyoidc.message import oidc from idpyoidc.message.oidc import Claims from idpyoidc.message.oidc import verified_claim_name @@ -82,12 +82,12 @@ class Authorization(authorization.Authorization): **{ "claims_parameter_supported": True, "encrypt_request_object_supported": False, - "request_object_signing_alg_values_supported": metadata.get_signing_algs(), - "request_object_encryption_alg_values_supported": metadata.get_encryption_algs(), - "request_object_encryption_enc_values_supported": metadata.get_encryption_encs(), - "request_parameter_supported": True, - "request_uri_parameter_supported": True, - "require_request_uri_registration": False, + "request_object_signing_alg_values_supported": alg_info.get_signing_algs(), + "request_object_encryption_alg_values_supported": alg_info.get_encryption_algs(), + "request_object_encryption_enc_values_supported": alg_info.get_encryption_encs(), + "request_parameter_supported": None, + "request_uri_parameter_supported": None, + "require_request_uri_registration": None, "response_types_supported": ["code", "id_token", "code id_token"], "response_modes_supported": ["query", "fragment", "form_post"], "subject_types_supported": ["public", "pairwise", "ephemeral"], diff --git a/src/idpyoidc/server/oidc/token.py b/src/idpyoidc/server/oidc/token.py index 3436df32..804fa591 100755 --- a/src/idpyoidc/server/oidc/token.py +++ b/src/idpyoidc/server/oidc/token.py @@ -1,6 +1,6 @@ import logging -from idpyoidc import metadata +from idpyoidc import alg_info from idpyoidc.message import Message from idpyoidc.message import oidc from idpyoidc.message.oidc import TokenErrorResponse @@ -40,7 +40,7 @@ class Token(token.Token): "client_secret_jwt", "private_key_jwt", ], - "token_endpoint_auth_signing_alg_values_supported": metadata.get_signing_algs(), + "token_endpoint_auth_signing_alg_values_supported": alg_info.get_signing_algs(), "grant_types_supported": list(helper_by_grant_type.keys()), } diff --git a/src/idpyoidc/server/oidc/userinfo.py b/src/idpyoidc/server/oidc/userinfo.py index 281b669d..27557047 100755 --- a/src/idpyoidc/server/oidc/userinfo.py +++ b/src/idpyoidc/server/oidc/userinfo.py @@ -1,15 +1,13 @@ import json import logging -from datetime import datetime from typing import Callable from typing import Optional from typing import Union from cryptojwt.exception import MissingValue from cryptojwt.jwt import JWT -from cryptojwt.jwt import utc_time_sans_frac -from idpyoidc import metadata +from idpyoidc import alg_info from idpyoidc.exception import ImproperlyConfigured from idpyoidc.message import Message from idpyoidc.message import oidc @@ -35,13 +33,13 @@ class UserInfo(Endpoint): _supports = { "claim_types_supported": ["normal", "aggregated", "distributed"], "encrypt_userinfo_supported": True, - "userinfo_signing_alg_values_supported": metadata.get_signing_algs(), - "userinfo_encryption_alg_values_supported": metadata.get_encryption_algs(), - "userinfo_encryption_enc_values_supported": metadata.get_encryption_encs(), + "userinfo_signing_alg_values_supported": alg_info.get_signing_algs(), + "userinfo_encryption_alg_values_supported": alg_info.get_encryption_algs(), + "userinfo_encryption_enc_values_supported": alg_info.get_encryption_encs(), } def __init__( - self, upstream_get: Callable, add_claims_by_scope: Optional[bool] = True, **kwargs + self, upstream_get: Callable, add_claims_by_scope: Optional[bool] = True, **kwargs ): Endpoint.__init__( self, @@ -58,11 +56,11 @@ def get_client_id_from_token(self, context, token, request=None): return _info["client_id"] def do_response( - self, - response_args: Optional[Union[Message, dict]] = None, - request: Optional[Union[Message, dict]] = None, - client_id: Optional[str] = "", - **kwargs, + self, + response_args: Optional[Union[Message, dict]] = None, + request: Optional[Union[Message, dict]] = None, + client_id: Optional[str] = "", + **kwargs, ) -> dict: if "error" in kwargs and kwargs["error"]: return Endpoint.do_response(self, response_args, request, **kwargs) diff --git a/src/idpyoidc/client/claims/transform.py b/src/idpyoidc/transform.py similarity index 94% rename from src/idpyoidc/client/claims/transform.py rename to src/idpyoidc/transform.py index 1ca40c6c..3834006c 100644 --- a/src/idpyoidc/client/claims/transform.py +++ b/src/idpyoidc/transform.py @@ -51,10 +51,10 @@ def supported_to_preferred( - supported: dict, - preference: dict, - base_url: str, - info: Optional[dict] = None, + supported: dict, + preference: dict, + base_url: str, + info: Optional[dict] = None, ): if info: # The provider info for key, val in supported.items(): @@ -83,7 +83,7 @@ def supported_to_preferred( preference[key] = [x for x in val if x in _info_val] else: pass - else: + elif val: preference[key] = val # special case -> must have a request_uris value @@ -148,7 +148,7 @@ def _intersection(a, b): def preferred_to_registered( - prefers: dict, supported: dict, registration_response: Optional[dict] = None + prefers: dict, supported: dict, registration_response: Optional[dict] = None ): """ The claims with values that are returned from the OP is what goes unless (!!) @@ -200,7 +200,7 @@ def preferred_to_registered( # be a singleton or an array. So just add it as is. registered[_reg_key] = val - logger.debug(f"Entity registered: {registered}") + logger.debug(f"preferred2registered: {registered}") return registered @@ -219,4 +219,10 @@ def create_registration_request(prefers: dict, supported: dict) -> dict: continue _request[key] = array_or_singleton(spec, value) + + for key, val in prefers.items(): + if key not in RegistrationRequest.c_param.keys(): + if key not in REGISTER2PREFERRED.values(): + _request[key] = val + return _request diff --git a/tests/test_client_20_oauth2.py b/tests/test_client_20_oauth2.py index 5e5df3f8..339a33a6 100644 --- a/tests/test_client_20_oauth2.py +++ b/tests/test_client_20_oauth2.py @@ -191,6 +191,6 @@ def create_client(self): def test_keyjar(self): _keyjar = self.client.get_attribute("keyjar") - assert len(_keyjar) == 2 # one issuer - assert len(_keyjar[""]) == 3 - assert len(_keyjar.get("sig")) == 3 + assert len(_keyjar) == 2 # "", and client_id + assert len(_keyjar[""]) == 2 + assert len(_keyjar.get("sig")) == 2 diff --git a/tests/test_client_21_oidc_service.py b/tests/test_client_21_oidc_service.py index 5eba310b..f0d83006 100644 --- a/tests/test_client_21_oidc_service.py +++ b/tests/test_client_21_oidc_service.py @@ -1,13 +1,13 @@ import os +import pytest +import responses from cryptojwt.exception import UnsupportedAlgorithm from cryptojwt.jws import jws from cryptojwt.jws.utils import left_hash from cryptojwt.jwt import JWT from cryptojwt.key_jar import build_keyjar from cryptojwt.key_jar import init_key_jar -import pytest -import responses from idpyoidc.client.defaults import DEFAULT_OIDC_SERVICES from idpyoidc.client.entity import Entity @@ -30,6 +30,7 @@ class Response(object): + def __init__(self, status_code, text, headers=None): self.status_code = status_code self.text = text @@ -70,6 +71,7 @@ def make_keyjar(): class TestAuthorization(object): + @pytest.fixture(autouse=True) def create_request(self): client_config = { @@ -312,6 +314,7 @@ def test_allow_unsigned_idtoken(self, allow_sign_alg_none): class TestAuthorizationCallback(object): + @pytest.fixture(autouse=True) def create_request(self): client_config = { @@ -397,6 +400,7 @@ def test_construct_form_post(self): class TestAccessTokenRequest(object): + @pytest.fixture(autouse=True) def create_request(self): client_config = { @@ -475,6 +479,7 @@ def test_id_token_nonce_match(self): class TestProviderInfo(object): + @pytest.fixture(autouse=True) def create_service(self): self._iss = ISS @@ -548,8 +553,6 @@ def test_post_parse(self): "private_key_jwt", ], "claims_parameter_supported": True, - "request_parameter_supported": True, - "request_uri_parameter_supported": True, # "require_request_uri_registration": True, "grant_types_supported": [ "authorization_code", @@ -877,6 +880,7 @@ def create_jws(val): class TestRegistration(object): + @pytest.fixture(autouse=True) def create_request(self): self._iss = ISS @@ -901,12 +905,16 @@ def test_construct(self): assert isinstance(_req, RegistrationRequest) assert set(_req.keys()) == { "application_type", + 'callback_uris', "default_max_age", + 'encrypt_request_object_supported', + 'encrypt_userinfo_supported', "grant_types", "id_token_signed_response_alg", "jwks", "redirect_uris", "request_object_signing_alg", + 'requests_dir', "response_modes", "response_types", "subject_type", @@ -924,13 +932,17 @@ def test_config_with_post_logout(self): assert isinstance(_req, RegistrationRequest) assert set(_req.keys()) == { "application_type", + 'callback_uris', "default_max_age", + 'encrypt_request_object_supported', + 'encrypt_userinfo_supported', "grant_types", "id_token_signed_response_alg", "jwks", "post_logout_redirect_uri", "redirect_uris", "request_object_signing_alg", + 'requests_dir', "response_modes", "response_types", "subject_type", @@ -966,22 +978,27 @@ def test_config_with_required_request_uri(): reg_service = entity.get_service("registration") _req = reg_service.construct() assert isinstance(_req, RegistrationRequest) - assert set(_req.keys()) == { - "application_type", - "response_modes", - "response_types", - "jwks", - "redirect_uris", - "grant_types", - "id_token_signed_response_alg", - "request_uris", - "default_max_age", - "request_object_signing_alg", - "subject_type", - "token_endpoint_auth_method", - "token_endpoint_auth_signing_alg", - "userinfo_signed_response_alg", - } + assert set(_req.keys()) == {'application_type', + 'callback_uris', + 'client_id', + 'client_secret', + 'default_max_age', + 'encrypt_request_object_supported', + 'encrypt_userinfo_supported', + 'grant_types', + 'id_token_signed_response_alg', + 'jwks', + 'redirect_uris', + 'request_object_signing_alg', + 'request_parameter', + 'request_uris', + 'requests_dir', + 'response_modes', + 'response_types', + 'subject_type', + 'token_endpoint_auth_method', + 'token_endpoint_auth_signing_alg', + 'userinfo_signed_response_alg'} def test_config_logout_uri(): @@ -1021,25 +1038,31 @@ def test_config_logout_uri(): reg_service = entity.get_service("registration") _req = reg_service.construct() assert isinstance(_req, RegistrationRequest) - assert set(_req.keys()) == { - "application_type", - "default_max_age", - "grant_types", - "id_token_signed_response_alg", - "jwks", - "redirect_uris", - "request_object_signing_alg", - "request_uris", - "response_modes", - "response_types", - "subject_type", - "token_endpoint_auth_method", - "token_endpoint_auth_signing_alg", - "userinfo_signed_response_alg", - } + assert set(_req.keys()) == {'application_type', + 'callback_uris', + 'client_id', + 'client_secret', + 'default_max_age', + 'encrypt_request_object_supported', + 'encrypt_userinfo_supported', + 'grant_types', + 'id_token_signed_response_alg', + 'jwks', + 'redirect_uris', + 'request_object_signing_alg', + 'request_parameter', + 'request_uris', + 'requests_dir', + 'response_modes', + 'response_types', + 'subject_type', + 'token_endpoint_auth_method', + 'token_endpoint_auth_signing_alg', + 'userinfo_signed_response_alg'} class TestUserInfo(object): + @pytest.fixture(autouse=True) def create_request(self): self._iss = ISS @@ -1186,6 +1209,7 @@ def test_unpack_encrypted_response(self): class TestCheckSession(object): + @pytest.fixture(autouse=True) def create_request(self): self._iss = ISS @@ -1213,6 +1237,7 @@ def test_construct(self): class TestCheckID(object): + @pytest.fixture(autouse=True) def create_request(self): self._iss = ISS @@ -1240,6 +1265,7 @@ def test_construct(self): class TestEndSession(object): + @pytest.fixture(autouse=True) def create_request(self): self._iss = ISS diff --git a/tests/test_client_27_conversation.py b/tests/test_client_27_conversation.py index fbb22399..99264876 100644 --- a/tests/test_client_27_conversation.py +++ b/tests/test_client_27_conversation.py @@ -202,8 +202,6 @@ def test_conversation(): "private_key_jwt", ], "claims_parameter_supported": True, - "request_parameter_supported": True, - "request_uri_parameter_supported": True, "require_request_uri_registration": True, "grant_types_supported": [ "authorization_code", @@ -403,8 +401,11 @@ def test_conversation(): "application_type", "backchannel_logout_session_required", "backchannel_logout_uri", + 'callback_uris', "contacts", "default_max_age", + 'encrypt_request_object_supported', + 'encrypt_userinfo_supported', "grant_types", "id_token_signed_response_alg", "jwks", diff --git a/tests/test_client_30_rp_handler_oidc.py b/tests/test_client_30_rp_handler_oidc.py index 3a3d75f9..b8d6430c 100644 --- a/tests/test_client_30_rp_handler_oidc.py +++ b/tests/test_client_30_rp_handler_oidc.py @@ -4,12 +4,13 @@ from urllib.parse import urlparse from urllib.parse import urlsplit -from cryptojwt.key_jar import init_key_jar import pytest import responses +from cryptojwt.key_jar import init_key_jar from idpyoidc.client.entity import Entity from idpyoidc.client.rp_handler import RPHandler +from idpyoidc.key_import import import_jwks from idpyoidc.message.oidc import AccessTokenResponse from idpyoidc.message.oidc import APPLICATION_TYPE_WEB from idpyoidc.message.oidc import AuthorizationResponse @@ -151,8 +152,6 @@ "authorization_endpoint": "https://github.com/login/oauth/authorize", "token_endpoint": "https://github.com/login/oauth/access_token", "userinfo_endpoint": "https://api.github.com/user", - "request_parameter_supported": True, - "request_uri_parameter_supported": True, }, "services": { "authorization": {"class": "idpyoidc.client.oidc.authorization.Authorization"}, @@ -217,6 +216,7 @@ def iss_id(iss): class TestRPHandler(object): + @pytest.fixture(autouse=True) def rphandler_setup(self): self.rph = RPHandler( @@ -279,13 +279,13 @@ def test_init_client(self): _github_id = iss_id("github") _keyjar = _context.upstream_get("attribute", "keyjar") - _keyjar.import_jwks(GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) + _keyjar = import_jwks(_keyjar, GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) # The key jar should only contain a symmetric key that is the clients # secret. 2 because one is marked for encryption and the other signing # usage. - assert set(_keyjar.owners()) == {"", "eeeeeeeee", _github_id} + assert set(_keyjar.owners()) == {"", _context.claims.prefer["client_id"], _github_id, self.rph.entity_id} keys = _keyjar.get_issuer_keys("") assert len(keys) == 3 @@ -329,9 +329,9 @@ def test_do_client_setup(self): assert _context.issuer == _github_id _keyjar = _context.upstream_get("attribute", "keyjar") - _keyjar.import_jwks(GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) + _keyjar = import_jwks(_keyjar, GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) - assert set(_keyjar.owners()) == {"", "eeeeeeeee", _github_id} + assert set(_keyjar.owners()) == {"", _context.claims.prefer["client_id"], _github_id, self.rph.entity_id} keys = _keyjar.get_issuer_keys("") assert len(keys) == 3 @@ -347,7 +347,7 @@ def test_create_callbacks(self): cb = _context.get_preference("callback_uris") assert set(cb.keys()) == {"request_uris", "redirect_uris"} - assert set(cb["redirect_uris"].keys()) == {"query", "fragment"} + assert set(cb["redirect_uris"].keys()) == {"query", "fragment", "form_post"} _hash = _context.iss_hash assert cb["redirect_uris"]["query"] == [f"https://example.com/rp/authz_cb/{_hash}"] @@ -449,7 +449,7 @@ def test_get_tokens(self): _github_id = iss_id("github") _context = client.get_context() _keyjar = _context.upstream_get("attribute", "keyjar") - _keyjar.import_jwks(GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) + _keyjar = import_jwks(_keyjar, GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) _nonce = _session["nonce"] _iss = _session["iss"] @@ -524,7 +524,7 @@ def test_access_and_id_token(self): _github_id = iss_id("github") _keyjar = _context.upstream_get("attribute", "keyjar") - _keyjar.import_jwks(GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) + _keyjar = import_jwks(_keyjar, GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) idts = IdToken(**idval) _signed_jwt = idts.to_jwt( @@ -571,7 +571,7 @@ def test_access_and_id_token_by_reference(self): _github_id = iss_id("github") _keyjar = _context.upstream_get("attribute", "keyjar") - _keyjar.import_jwks(GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) + _keyjar = import_jwks(_keyjar, GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) idts = IdToken(**idval) _signed_jwt = idts.to_jwt( @@ -618,7 +618,7 @@ def test_get_user_info(self): _github_id = iss_id("github") _keyjar = _context.upstream_get("attribute", "keyjar") - _keyjar.import_jwks(GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) + _keyjar = import_jwks(_keyjar, GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) idts = IdToken(**idval) _signed_jwt = idts.to_jwt( @@ -697,6 +697,7 @@ def test_get_provider_specific_service(): class TestRPHandlerTier2(object): + @pytest.fixture(autouse=True) def rphandler_setup(self): self.rph = RPHandler(BASE_URL, CLIENT_CONFIG, keyjar=CLI_KEY) @@ -712,7 +713,7 @@ def rphandler_setup(self): _github_id = iss_id("github") _keyjar = _context.upstream_get("attribute", "keyjar") - _keyjar.import_jwks(GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) + _keyjar = import_jwks(_keyjar, GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) idts = IdToken(**idval) _signed_jwt = idts.to_jwt( @@ -818,6 +819,7 @@ def test_get_valid_access_token(self): class MockResponse: + def __init__(self, status_code, text, headers=None): self.status_code = status_code self.text = text @@ -825,6 +827,7 @@ def __init__(self, status_code, text, headers=None): class MockOP(object): + def __init__(self, issuer, keyjar=None): self.keyjar = keyjar self.issuer = issuer @@ -913,6 +916,7 @@ def test_rphandler_request(): class TestRPHandlerWithMockOP(object): + @pytest.fixture(autouse=True) def rphandler_setup(self): self.issuer = "https://github.com/login/oauth/authorize" @@ -956,7 +960,7 @@ def test_finalize(self): ) _github_id = iss_id("github") _keyjar = client.get_attribute("keyjar") - _keyjar.import_jwks(GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) + _keyjar = import_jwks(_keyjar, GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) with responses.RequestsMock() as rsps: rsps.add( "POST", diff --git a/tests/test_client_30_rph_defaults.py b/tests/test_client_30_rph_defaults.py index 177b7721..c4389728 100644 --- a/tests/test_client_30_rph_defaults.py +++ b/tests/test_client_30_rph_defaults.py @@ -45,7 +45,6 @@ def test_init_client(self): 'id_token_encryption_alg_values_supported', 'id_token_encryption_enc_values_supported', 'id_token_signing_alg_values_supported', - 'jwks_uri', 'redirect_uris', 'request_object_encryption_alg_values_supported', 'request_object_encryption_enc_values_supported', @@ -61,7 +60,7 @@ def test_init_client(self): 'userinfo_signing_alg_values_supported'} _keyjar = client.get_attribute("keyjar") - assert list(_keyjar.owners()) == ["", BASE_URL] + assert list(_keyjar.owners()) == ["", "https://example.com"] keys = _keyjar.get_issuer_keys("") assert len(keys) == 2 @@ -116,7 +115,6 @@ def test_begin(self): "encrypt_request_object_supported", "grant_types", "id_token_signed_response_alg", - "jwks_uri", "redirect_uris", "request_object_signing_alg", "response_modes", @@ -180,4 +178,4 @@ def test_begin_2(self): rsps.add("POST", request_uri, body=_jws, status=200) self.rph.do_client_registration(client, ISS_ID) - assert "jwks_uri" in _context.get("registration_response") + assert "jwks_uri" not in _context.get("registration_response") diff --git a/tests/test_client_41_rp_handler_persistent.py b/tests/test_client_41_rp_handler_persistent.py index 8edce035..7cfb38e3 100644 --- a/tests/test_client_41_rp_handler_persistent.py +++ b/tests/test_client_41_rp_handler_persistent.py @@ -176,6 +176,7 @@ def get_state_from_url(url): class TestRPHandler(object): + def test_pick_config(self): rph_1 = RPHandler( BASE_URL, client_configs=CLIENT_CONFIG, keyjar=CLI_KEY, module_dirs=["oidc"] @@ -251,7 +252,7 @@ def test_do_client_setup(self): _keyjar = _context.upstream_get("attribute", "keyjar") _keyjar.import_jwks(GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) - assert set(_keyjar.owners()) == {"", "eeeeeeeee", _github_id} + assert set(_keyjar.owners()) == {"", "eeeeeeeee", _github_id, 'https://example.com/rp'} keys = _keyjar.get_issuer_keys("") assert len(keys) == 3 # one symmetric, one RSA and one EC diff --git a/tests/test_server_01_construct.py b/tests/test_server_01_construct.py index 1dbdc492..014201ea 100644 --- a/tests/test_server_01_construct.py +++ b/tests/test_server_01_construct.py @@ -7,7 +7,6 @@ def test_construct(): default_capabilities = { "claims_parameter_supported": True, "request_parameter_supported": True, - "request_uri_parameter_supported": True, "response_types_supported": ["code", "token", "code token"], "response_modes_supported": ["query", "fragment", "form_post"], "request_object_signing_alg_values_supported": None, diff --git a/tests/test_server_16_endpoint_context.py b/tests/test_server_16_endpoint_context.py index bf4b8286..f96b676c 100644 --- a/tests/test_server_16_endpoint_context.py +++ b/tests/test_server_16_endpoint_context.py @@ -4,6 +4,7 @@ import pytest from cryptojwt.key_jar import build_keyjar +from idpyoidc import alg_info from idpyoidc import metadata from idpyoidc.server import OPConfiguration from idpyoidc.server import Server @@ -28,9 +29,9 @@ class Endpoint_1(Endpoint): name = "userinfo" _supports = { "claim_types_supported": ["normal", "aggregated", "distributed"], - "userinfo_signing_alg_values_supported": metadata.get_signing_algs(), - "userinfo_encryption_alg_values_supported": metadata.get_encryption_algs(), - "userinfo_encryption_enc_values_supported": metadata.get_encryption_encs(), + "userinfo_signing_alg_values_supported": alg_info.get_signing_algs(), + "userinfo_encryption_alg_values_supported": alg_info.get_encryption_algs(), + "userinfo_encryption_enc_values_supported": alg_info.get_encryption_encs(), "client_authn_method": ["bearer_header", "bearer_body"], "encrypt_userinfo_supported": False, } diff --git a/tests/test_server_20a_server.py b/tests/test_server_20a_server.py index 8413b019..fd9fa2a5 100755 --- a/tests/test_server_20a_server.py +++ b/tests/test_server_20a_server.py @@ -122,7 +122,6 @@ def test_capabilities_default(): "id_token", "code id_token", } - assert server.context.provider_info["request_uri_parameter_supported"] is True assert server.context.get_preference("jwks_uri") == "https://127.0.0.1:443/static/jwks.json" diff --git a/tests/test_server_24_oauth2_authorization_endpoint.py b/tests/test_server_24_oauth2_authorization_endpoint.py index 925424c1..fb353fde 100755 --- a/tests/test_server_24_oauth2_authorization_endpoint.py +++ b/tests/test_server_24_oauth2_authorization_endpoint.py @@ -60,7 +60,10 @@ "implicit", "urn:ietf:params:oauth:grant-type:jwt-bearer", "refresh_token", - ] + ], + "response_types_supported": [" ".join(x) for x in RESPONSE_TYPES_SUPPORTED], + "response_modes_supported": ["query", "fragment", "form_post"], + "claims_parameter_supported": True } AUTH_REQ = AuthorizationRequest( @@ -165,7 +168,7 @@ def create_endpoint(self): "issuer": "https://example.com/", "password": "mycket hemligt zebra", "verify_ssl": False, - "capabilities": CAPABILITIES, + "preference": CAPABILITIES, "keys": {"uri_path": "static/jwks.json", "key_defs": KEYDEFS}, "token_handler_args": { "jwks_def": { @@ -203,13 +206,6 @@ def create_endpoint(self): "authorization": { "path": "{}/authorization", "class": Authorization, - "kwargs": { - "response_types_supported": [" ".join(x) for x in RESPONSE_TYPES_SUPPORTED], - "response_modes_supported": ["query", "fragment", "form_post"], - "claims_parameter_supported": True, - "request_parameter_supported": True, - "request_uri_parameter_supported": True, - }, } }, "authentication": {