diff --git a/example/flask_op/static/jwks.json b/example/flask_op/static/jwks.json index e6dc3e44..2ac9940a 100644 --- a/example/flask_op/static/jwks.json +++ b/example/flask_op/static/jwks.json @@ -1 +1 @@ -{"keys": [{"kty": "RSA", "use": "sig", "kid": "bXNmZXROQ3N2dDI2SWY5VlNWTG5yOXZqYlpLenVsalhwUWR5RW9BMHNCaw", "n": "uGVI-b6qr-OTc2knp7bpmDtiCQoWFXZ8mUV-SX0rCMtcc_IRmc_J7AfNEfnYk3dv0cKQK_Dgv3vicoeuf4KQ9ZZY-xI3bnRl9_HnhRpz_cJScDirkNKlsv8aQuYBO_gIiHp8B32YC0nx3BUQV5I6QGEiyG-lZT9PmXsUO1uKPPhny_vtQ6cUpvtuLySBu2ZYpaTDQqCv5Y6EKC49NYWhBB4B6f6TNKCoQTaxA8ZoM3lh7kFbu5DPEXKFAtuNiOtUNP7Ei9KfBtyBYSaZQBY8VkwAm1yKCA2sfv1mBwx0dT53MPJlNkoltf89mv1NM2OJPQAgGE6ygwGS2fyBLAn_bQ", "e": "AQAB"}, {"kty": "EC", "use": "sig", "kid": "U0pLNmFBRE4waDYyZG9ZdjNPb2pTZXAwZzdrbmpZdG0ya3lpaFJwZU9ncw", "crv": "P-256", "x": "DYUyBfiD53SEtUuKLjFCFpIkqyhbmBppAMjOat9qiY0", "y": "-SUSvVeOv7EA84qHLLEkDP24iZree-fomICuA4baeeA"}]} \ No newline at end of file +{"keys": [{"kty": "RSA", "use": "sig", "kid": "bXNmZXROQ3N2dDI2SWY5VlNWTG5yOXZqYlpLenVsalhwUWR5RW9BMHNCaw", "e": "AQAB", "n": "uGVI-b6qr-OTc2knp7bpmDtiCQoWFXZ8mUV-SX0rCMtcc_IRmc_J7AfNEfnYk3dv0cKQK_Dgv3vicoeuf4KQ9ZZY-xI3bnRl9_HnhRpz_cJScDirkNKlsv8aQuYBO_gIiHp8B32YC0nx3BUQV5I6QGEiyG-lZT9PmXsUO1uKPPhny_vtQ6cUpvtuLySBu2ZYpaTDQqCv5Y6EKC49NYWhBB4B6f6TNKCoQTaxA8ZoM3lh7kFbu5DPEXKFAtuNiOtUNP7Ei9KfBtyBYSaZQBY8VkwAm1yKCA2sfv1mBwx0dT53MPJlNkoltf89mv1NM2OJPQAgGE6ygwGS2fyBLAn_bQ"}, {"kty": "EC", "use": "sig", "kid": "U0pLNmFBRE4waDYyZG9ZdjNPb2pTZXAwZzdrbmpZdG0ya3lpaFJwZU9ncw", "crv": "P-256", "x": "DYUyBfiD53SEtUuKLjFCFpIkqyhbmBppAMjOat9qiY0", "y": "-SUSvVeOv7EA84qHLLEkDP24iZree-fomICuA4baeeA"}]} \ No newline at end of file diff --git a/src/idpyoidc/client/claims/__init__.py b/src/idpyoidc/client/claims/__init__.py index 59ff0687..1427005b 100644 --- a/src/idpyoidc/client/claims/__init__.py +++ b/src/idpyoidc/client/claims/__init__.py @@ -16,10 +16,7 @@ class Claims(claims.Claims): def get_base_url(self, configuration: dict, entity_id: Optional[str] = ""): _base = configuration.get("base_url") if not _base: - if entity_id: - _base = entity_id - else: - _base = configuration.get("client_id") + _base = configuration.get("client_id", configuration.get("entity_id")) return _base diff --git a/src/idpyoidc/client/client_auth.py b/src/idpyoidc/client/client_auth.py index 3624f324..41177d02 100755 --- a/src/idpyoidc/client/client_auth.py +++ b/src/idpyoidc/client/client_auth.py @@ -490,27 +490,29 @@ def _get_signing_key(self, algorithm, keyjar, key_types, kid=None): return signing_key def _get_audience_and_algorithm(self, context, keyjar, **kwargs): - algorithm = None - - # audience for the signed JWT depends on which endpoint - # we're talking to. - if "authn_endpoint" in kwargs and kwargs["authn_endpoint"] in ["token_endpoint"]: - algorithm = context.get_usage("token_endpoint_auth_signing_alg") - if algorithm is None: - _pi = context.provider_info - try: - algs = _pi["token_endpoint_auth_signing_alg_values_supported"] - except KeyError: - algorithm = "RS256" # default - else: - for alg in algs: # pick the first one I support and have keys for - if alg in SIGNER_ALGS and self.get_signing_key_from_keyjar(alg, keyjar): - algorithm = alg - break - - audience = context.provider_info.get("token_endpoint") - else: - audience = context.provider_info["issuer"] + algorithm = kwargs.get("algorithm", None) + audience = kwargs.get("audience", None) + + if not audience: + # audience for the signed JWT depends on which endpoint + # we're talking to. + if "authn_endpoint" in kwargs and kwargs["authn_endpoint"] in ["token_endpoint"]: + algorithm = context.get_usage("token_endpoint_auth_signing_alg") + if algorithm is None: + _pi = context.provider_info + try: + algs = _pi["token_endpoint_auth_signing_alg_values_supported"] + except KeyError: + algorithm = "RS256" # default + else: + for alg in algs: # pick the first one I support and have keys for + if alg in SIGNER_ALGS and self.get_signing_key_from_keyjar(alg, keyjar): + algorithm = alg + break + + audience = context.provider_info.get("token_endpoint") + else: + audience = context.provider_info["issuer"] if not algorithm: algorithm = self.choose_algorithm(**kwargs) @@ -519,6 +521,9 @@ def _get_audience_and_algorithm(self, context, keyjar, **kwargs): def _construct_client_assertion(self, service, **kwargs): _context = service.upstream_get("context") _entity = service.upstream_get("entity") + if _entity is None: + _entity = service.upstream_get("unit") + _keyjar = service.upstream_get("attribute", "keyjar") audience, algorithm = self._get_audience_and_algorithm(_context, _keyjar, **kwargs) @@ -527,7 +532,11 @@ def _construct_client_assertion(self, service, **kwargs): algorithm, _keyjar, _context.kid["sig"], kid=kwargs["kid"] ) else: - signing_key = self._get_signing_key(algorithm, _keyjar, _context.kid["sig"]) + _key_type = _context.kid.get("sig", None) + if _key_type: + signing_key = self._get_signing_key(algorithm, _keyjar, _key_type) + else: + signing_key = self.get_signing_key_from_keyjar(algorithm, _keyjar) if not signing_key: raise UnsupportedAlgorithm(algorithm) @@ -570,7 +579,8 @@ def modify_request(self, request, service, **kwargs): pass # If client_id is not required to be present, remove it. - if not request.c_param["client_id"][VREQUIRED]: + _cid_spec = request.c_param.get("client_id", None) + if _cid_spec and not _cid_spec[VREQUIRED]: try: del request["client_id"] except KeyError: diff --git a/src/idpyoidc/client/service.py b/src/idpyoidc/client/service.py index 231d56b0..d8684e90 100644 --- a/src/idpyoidc/client/service.py +++ b/src/idpyoidc/client/service.py @@ -8,23 +8,18 @@ from typing import Union from urllib.parse import urlparse -from cryptojwt.exception import IssuerNotFound from cryptojwt.jwe.jwe import factory as jwe_factory from cryptojwt.jws.jws import factory as jws_factory from cryptojwt.jwt import JWT -from idpyoidc.exception import MissingSigningKey from idpyoidc.client.exception import Unsupported +from idpyoidc.exception import MissingSigningKey from idpyoidc.impexp import ImpExp from idpyoidc.item import DLDict from idpyoidc.message import Message -from idpyoidc.message.oauth2 import ResponseMessage from idpyoidc.message.oauth2 import is_error_message +from idpyoidc.message.oauth2 import ResponseMessage from idpyoidc.util import importer - -from ..constant import JOSE_ENCODED -from ..constant import JSON_ENCODED -from ..constant import URL_ENCODED from .client_auth import client_auth_setup from .client_auth import method_to_item from .client_auth import single_authn_setup @@ -32,6 +27,9 @@ from .exception import ResponseError from .util import get_http_body from .util import get_http_url +from ..constant import JOSE_ENCODED +from ..constant import JSON_ENCODED +from ..constant import URL_ENCODED __author__ = "Roland Hedberg" @@ -79,7 +77,7 @@ class Service(ImpExp): _callback_path = {} def __init__( - self, upstream_get: Callable, conf: Optional[Union[dict, Configuration]] = None, **kwargs + self, upstream_get: Callable, conf: Optional[Union[dict, Configuration]] = None, **kwargs ): ImpExp.__init__(self) @@ -333,7 +331,7 @@ def get_endpoint(self): return self.upstream_get("context").provider_info[self.endpoint_name] def get_authn_header( - self, request: Union[dict, Message], authn_method: Optional[str] = "", **kwargs + self, request: Union[dict, Message], authn_method: Optional[str] = "", **kwargs ) -> dict: """ Construct an authorization specification to be sent in the @@ -364,12 +362,15 @@ def get_authn_method(self) -> str: """ return self.default_authn_method + def get_headers_args(self): + return {} + def get_headers( - self, - request: Union[dict, Message], - http_method: str, - authn_method: Optional[str] = "", - **kwargs, + self, + request: Union[dict, Message], + http_method: str, + authn_method: Optional[str] = "", + **kwargs, ) -> dict: """ @@ -404,7 +405,7 @@ def get_headers( return _headers def get_request_parameters( - self, request_args=None, method="", request_body_type="", authn_method="", **kwargs + self, request_args=None, method="", request_body_type="", authn_method="", **kwargs ) -> dict: """ Builds the request message and constructs the HTTP headers. @@ -445,6 +446,7 @@ def get_request_parameters( # Client authentication by usage of the Authorization HTTP header # or by modifying the request object + _args.update(self.get_headers_args()) _headers = self.get_headers(request, http_method=method, authn_method=authn_method, **_args) # Find out where to send this request @@ -506,7 +508,7 @@ def post_parse_response(self, response, **kwargs): return response def gather_verify_arguments( - self, response: Optional[Union[dict, Message]] = None, behaviour_args: Optional[dict] = None + self, response: Optional[Union[dict, Message]] = None, behaviour_args: Optional[dict] = None ): """ Need to add some information before running verify() @@ -542,7 +544,7 @@ def _do_jwt(self, info): def _do_response(self, info, sformat, **kwargs): _context = self.upstream_get("context") - if isinstance(info, list): # Don't have support for sformat=list + if isinstance(info, list): # Don't have support for sformat=list return info try: @@ -566,13 +568,13 @@ def _do_response(self, info, sformat, **kwargs): return resp def parse_response( - self, - info, - sformat: Optional[str] = "", - state: Optional[str] = "", - behaviour_args: Optional[dict] = None, - **kwargs, - ) : + self, + info, + sformat: Optional[str] = "", + state: Optional[str] = "", + behaviour_args: Optional[dict] = None, + **kwargs, + ): """ This the start of a pipeline that will: @@ -707,12 +709,12 @@ def get_uri(base_url, path, hex): return f"{base_url}/{path}/{hex}" def construct_uris( - self, - base_url: str, - hex: bytes, - context: OidcContext, - targets: Optional[List[str]] = None, - response_types: Optional[list] = None, + self, + base_url: str, + hex: bytes, + context: OidcContext, + targets: Optional[List[str]] = None, + response_types: Optional[list] = None, ): if not targets: targets = self._callback_path.keys() diff --git a/src/idpyoidc/server/client_authn.py b/src/idpyoidc/server/client_authn.py index f250c31b..b1d4bfda 100755 --- a/src/idpyoidc/server/client_authn.py +++ b/src/idpyoidc/server/client_authn.py @@ -41,11 +41,11 @@ def __init__(self, upstream_get): self.upstream_get = upstream_get def _verify( - self, - request: Optional[Union[dict, Message]] = None, - authorization_token: Optional[str] = None, - endpoint=None, # Optional[Endpoint] - **kwargs, + self, + request: Optional[Union[dict, Message]] = None, + authorization_token: Optional[str] = None, + endpoint=None, # Optional[Endpoint] + **kwargs, ): """ Verify authentication information in a request @@ -55,12 +55,12 @@ def _verify( raise NotImplementedError() def verify( - self, - request: Optional[Union[dict, Message]] = None, - authorization_token: Optional[str] = None, - endpoint=None, # Optional[Endpoint] - get_client_id_from_token: Optional[Callable] = None, - **kwargs, + self, + request: Optional[Union[dict, Message]] = None, + authorization_token: Optional[str] = None, + endpoint=None, # Optional[Endpoint] + get_client_id_from_token: Optional[Callable] = None, + **kwargs, ): """ Verify authentication information in a request @@ -78,9 +78,9 @@ def verify( return res def is_usable( - self, - request: Optional[Union[dict, Message]] = None, - authorization_token: Optional[str] = None, + self, + request: Optional[Union[dict, Message]] = None, + authorization_token: Optional[str] = None, ): """ Verify that this authentication method is applicable. @@ -117,11 +117,11 @@ def is_usable(self, request=None, authorization_token=None): return request is not None def _verify( - self, - request: Optional[Union[dict, Message]] = None, - authorization_token: Optional[str] = None, - endpoint=None, # Optional[Endpoint] - **kwargs, + self, + request: Optional[Union[dict, Message]] = None, + authorization_token: Optional[str] = None, + endpoint=None, # Optional[Endpoint] + **kwargs, ): return {"client_id": request.get("client_id")} @@ -138,11 +138,11 @@ def is_usable(self, request=None, authorization_token=None): return request and "client_id" in request def _verify( - self, - request: Optional[Union[dict, Message]] = None, - authorization_token: Optional[str] = None, - endpoint=None, # Optional[Endpoint] - **kwargs, + self, + request: Optional[Union[dict, Message]] = None, + authorization_token: Optional[str] = None, + endpoint=None, # Optional[Endpoint] + **kwargs, ): return {"client_id": request["client_id"]} @@ -162,11 +162,11 @@ def is_usable(self, request=None, authorization_token=None): return False def _verify( - self, - request: Optional[Union[dict, Message]] = None, - authorization_token: Optional[str] = None, - endpoint=None, # Optional[Endpoint] - **kwargs, + self, + request: Optional[Union[dict, Message]] = None, + authorization_token: Optional[str] = None, + endpoint=None, # Optional[Endpoint] + **kwargs, ): client_info = basic_authn(authorization_token) _context = self.upstream_get("context") @@ -194,11 +194,11 @@ def is_usable(self, request=None, authorization_token=None): return False def _verify( - self, - request: Optional[Union[dict, Message]] = None, - authorization_token: Optional[str] = None, - endpoint=None, # Optional[Endpoint] - **kwargs, + self, + request: Optional[Union[dict, Message]] = None, + authorization_token: Optional[str] = None, + endpoint=None, # Optional[Endpoint] + **kwargs, ): _context = self.upstream_get("context") if _context.cdb[request["client_id"]]["client_secret"] == request["client_secret"]: @@ -218,12 +218,12 @@ def is_usable(self, request=None, authorization_token=None): return False def _verify( - self, - request: Optional[Union[dict, Message]] = None, - authorization_token: Optional[str] = None, - endpoint=None, # Optional[Endpoint] - get_client_id_from_token: Optional[Callable] = None, - **kwargs, + self, + request: Optional[Union[dict, Message]] = None, + authorization_token: Optional[str] = None, + endpoint=None, # Optional[Endpoint] + get_client_id_from_token: Optional[Callable] = None, + **kwargs, ): logger.debug(f"Client Auth method: {self.tag}") token = authorization_token.split(" ", 1)[1] @@ -255,12 +255,12 @@ def is_usable(self, request=None, authorization_token=None): return False def _verify( - self, - request: Optional[Union[dict, Message]] = None, - authorization_token: Optional[str] = None, - endpoint=None, # Optional[Endpoint] - get_client_id_from_token: Optional[Callable] = None, - **kwargs, + self, + request: Optional[Union[dict, Message]] = None, + authorization_token: Optional[str] = None, + endpoint=None, # Optional[Endpoint] + get_client_id_from_token: Optional[Callable] = None, + **kwargs, ): _token = request.get("access_token") if _token is None: @@ -275,6 +275,7 @@ def _verify( class JWSAuthnMethod(ClientAuthnMethod): + def is_usable(self, request=None, authorization_token=None): if request is None: return False @@ -283,12 +284,12 @@ def is_usable(self, request=None, authorization_token=None): return False def _verify( - self, - request: Optional[Union[dict, Message]] = None, - authorization_token: Optional[str] = None, - endpoint=None, # Optional[Endpoint] - key_type: Optional[str] = None, - **kwargs, + self, + request: Optional[Union[dict, Message]] = None, + authorization_token: Optional[str] = None, + endpoint=None, # Optional[Endpoint] + key_type: Optional[str] = None, + **kwargs, ): _context = self.upstream_get("context") _keyjar = self.upstream_get("attribute", "keyjar") @@ -351,11 +352,11 @@ class ClientSecretJWT(JWSAuthnMethod): tag = "client_secret_jwt" def _verify( - self, - request: Optional[Union[dict, Message]] = None, - authorization_token: Optional[str] = None, - endpoint=None, # Optional[Endpoint] - **kwargs, + self, + request: Optional[Union[dict, Message]] = None, + authorization_token: Optional[str] = None, + endpoint=None, # Optional[Endpoint] + **kwargs, ): res = super()._verify( request=request, key_type="client_secret", endpoint=endpoint, **kwargs @@ -372,11 +373,11 @@ class PrivateKeyJWT(JWSAuthnMethod): tag = "private_key_jwt" def _verify( - self, - request: Optional[Union[dict, Message]] = None, - authorization_token: Optional[str] = None, - endpoint=None, # Optional[Endpoint] - **kwargs, + self, + request: Optional[Union[dict, Message]] = None, + authorization_token: Optional[str] = None, + endpoint=None, # Optional[Endpoint] + **kwargs, ): res = super()._verify( request=request, @@ -397,11 +398,11 @@ def is_usable(self, request=None, authorization_token=None): return True def _verify( - self, - request: Optional[Union[dict, Message]] = None, - authorization_token: Optional[str] = None, - endpoint=None, # Optional[Endpoint] - **kwargs, + self, + request: Optional[Union[dict, Message]] = None, + authorization_token: Optional[str] = None, + endpoint=None, # Optional[Endpoint] + **kwargs, ): _context = self.upstream_get("context") _jwt = JWT(self.upstream_get("attribute", "keyjar"), msg_cls=JsonWebToken) @@ -450,12 +451,12 @@ def valid_client_secret(cinfo): def verify_client( - request: Union[dict, Message], - http_info: Optional[dict] = None, - get_client_id_from_token: Optional[Callable] = None, - endpoint=None, # Optional[Endpoint] - also_known_as: Optional[Dict[str, str]] = None, - **kwargs, + request: Union[dict, Message], + http_info: Optional[dict] = None, + get_client_id_from_token: Optional[Callable] = None, + endpoint=None, # Optional[Endpoint] + also_known_as: Optional[Dict[str, str]] = None, + **kwargs, ) -> dict: """ Initiated Guessing ! @@ -480,7 +481,7 @@ def verify_client( auth_info = {} _context = endpoint.upstream_get("context") - methods = _context.client_authn_methods + methods = getattr(_context, "client_authn_methods", None) client_id = None allowed_methods = getattr(endpoint, "client_authn_method") @@ -488,6 +489,7 @@ def verify_client( allowed_methods = list(methods.keys()) # If not specific for this endpoint then all _method = None + _cdb = _cinfo = None for _method in (methods[meth] for meth in allowed_methods): if not _method.is_usable(request=request, authorization_token=authorization_token): continue @@ -519,12 +521,13 @@ def verify_client( client_id = also_known_as[client_id] auth_info["client_id"] = client_id - _get_client_info = kwargs.get("get_client_info") + _get_client_info = kwargs.get("get_client_info", None) if _get_client_info: - _cinfo = _get_client_info(client_id, _context) + _cinfo = _get_client_info(client_id, endpoint) else: + _cdb = getattr(_context, "cdb", None) try: - _cinfo = _context.cdb[client_id] + _cinfo = _cdb[client_id] except KeyError: raise UnknownClient("Unknown Client ID") @@ -537,7 +540,7 @@ def verify_client( # Validate that the used method is allowed for this client/endpoint client_allowed_methods = _cinfo.get( - f"{endpoint.endpoint_name}_client_authn_method", _cinfo.get("client_authn_method") + f"{endpoint.endpoint_name}_client_authn_method", _cinfo.get("client_authn_method", None) ) if client_allowed_methods is not None and auth_info["method"] not in client_allowed_methods: logger.info( @@ -551,13 +554,13 @@ def verify_client( logger.debug(f"Authn methods applied") # store what authn method was used - if "method" in auth_info and client_id: + if "method" in auth_info and client_id and _cdb: _request_type = request.__class__.__name__ _used_authn_method = _cinfo.get("auth_method") if _used_authn_method: - _context.cdb[client_id]["auth_method"][_request_type] = auth_info["method"] + _cdb[client_id]["auth_method"][_request_type] = auth_info["method"] else: - _context.cdb[client_id]["auth_method"] = {_request_type: auth_info["method"]} + _cdb[client_id]["auth_method"] = {_request_type: auth_info["method"]} return auth_info diff --git a/src/idpyoidc/server/endpoint.py b/src/idpyoidc/server/endpoint.py index 256f0029..07c0080b 100755 --- a/src/idpyoidc/server/endpoint.py +++ b/src/idpyoidc/server/endpoint.py @@ -238,7 +238,7 @@ def parse_request( if _auth_method and _auth_method not in ["public", "none"]: req["authenticated"] = True else: - _client_id = req.get("client_id") + _client_id = req.get("client_id", None) LOGGER.debug(f"parse_request:auth_info:{auth_info}") @@ -275,10 +275,11 @@ def client_authentication(self, request: Message, http_info: Optional[dict] = No authn_info = verify_client(request=request, http_info=http_info, **kwargs) LOGGER.debug("authn_info: %s", authn_info) - if authn_info == {} and self.client_authn_method and len(self.client_authn_method): - LOGGER.debug("client_authn_method: %s", self.client_authn_method) - raise UnAuthorizedClient("Authorization failed") - if "client_id" not in authn_info and authn_info.get("method") != "none": + if authn_info == {}: + if self.client_authn_method and len(self.client_authn_method): + LOGGER.debug("client_authn_method: %s", self.client_authn_method) + raise UnAuthorizedClient("Authorization failed") + elif "client_id" not in authn_info and authn_info.get("method") != "none": raise UnAuthorizedClient("Authorization failed") return authn_info diff --git a/src/idpyoidc/storage/listfile.py b/src/idpyoidc/storage/listfile.py new file mode 100644 index 00000000..1fdfd009 --- /dev/null +++ b/src/idpyoidc/storage/listfile.py @@ -0,0 +1,136 @@ +import logging +import os +import time + +from filelock import FileLock + +logger = logging.getLogger(__name__) + + +class ReadOnlyListFileMtime(object): + + def __init__(self, file_name): + self.file_name = file_name + self.fmtime = 0 + + if not os.path.exists(file_name): + fp = open(file_name, "x") + fp.close() + _lst = [] + else: + _lst = self._read_info(self.file_name) + + def __getitem__(self, item): + if self.is_changed(self.file_name): + _lst = self._read_info(self.file_name) + if _lst: + return _lst[item] + else: + return None + + def __len__(self): + if self.is_changed(self.file_name): + _lst = self._read_info(self.file_name) + if _lst is None or _lst == []: + return 0 + + return len(_lst) + + @staticmethod + def get_mtime(fname): + """ + Find the time this file was last modified. + + :param fname: File name + :return: The last time the file was modified. + """ + try: + mtime = os.path.getmtime(fname) + except OSError: + # The file might be right in the middle of being created + # so sleep + time.sleep(1) + mtime = os.path.getmtime(fname) + + return mtime + + def is_changed(self, fname): + """ + Find out if this file has been modified since last + + :param fname: A file name + :return: True/False + """ + if os.path.isfile(fname): + mtime = self.get_mtime(fname) + + if self.fmtime == 0: + self.fmtime = mtime + return True + + if mtime != self.fmtime: # has changed + self.fmtime = mtime + return True + else: + return False + else: + logger.error("Could not access {}".format(fname)) + raise FileNotFoundError() + + def _read_info(self, fname): + if os.path.isfile(fname): + try: + lock = FileLock(f"{fname}.lock") + with lock: + fp = open(fname, "r") + info = [x.strip() for x in fp.readlines()] + lock.release() + return info or None + except Exception as err: + logger.error(err) + raise + else: + _msg = f"No such file: '{fname}'" + logger.error(_msg) + return None + + +class ReadOnlyListFile(object): + + def __init__(self, file_name): + self.file_name = file_name + + if not os.path.exists(file_name): + fp = open(file_name, "x") + fp.close() + + def __getitem__(self, item): + _lst = self._read_info(self.file_name) + if _lst: + return _lst[item] + else: + return None + + def __len__(self): + _lst = self._read_info(self.file_name) + if _lst is None or _lst == []: + return 0 + + return len(_lst) + + def _read_info(self, fname): + if os.path.isfile(fname): + try: + lock = FileLock(f"{fname}.lock") + with lock: + fp = open(fname, "r") + info = [x.strip() for x in fp.readlines()] + lock.release() + return info or None + except Exception as err: + logger.error(err) + raise + else: + _msg = f"No such file: '{fname}'" + logger.error(_msg) + return None diff --git a/tests/pub_iss.jwks b/tests/pub_iss.jwks index 9b062907..77081f40 100644 --- a/tests/pub_iss.jwks +++ b/tests/pub_iss.jwks @@ -1 +1 @@ -{"keys": [{"kty": "EC", "use": "sig", "kid": "SmdKMlVGcG1zMnprdDdXZGpGWEczdHhlZVpGbkx1THpPdUY4d0w4bnZkSQ", "crv": "P-256", "x": "tRHJYm0fsOi0icpGEb33qiDVgt68ltMoYSWdLGhDGz4", "y": "fRpX0i6p5Jigf5I0qwW34PyStosMShwWAWS8x_w5o7E"}, {"kty": "RSA", "use": "sig", "kid": "R0FsaFdqREFaUFp1c0MwbUpsbHVSZ200blBJZWJVMTUtNGsyVlBmdHk5UQ", "e": "AQAB", "n": "2ilgsKVqF92KfhwmosSVeZOaDgb3RF1mbg-pqkmLO6YpOO06LF4V4angF-GhP-ysAm2E75aSIU4tnHVThFlcxTgKFqjYKJQXyVzTVK2r-L2IbvFPaDtvoU6WteybpMlIUVk2po3cFDGObCWYKCm7CUOLlwH0uOpui66P9VSCqdKVKbJRAQBvTSbP10KWPxulfqjWGJtHO5fY7-JVWwOBkG-eHSJIT_uaoPjyvKCZjknq04bLUV9qP78KRQpRyYijBN60w2v8F79baN9CN10TIEjjWKGz0uX0M_YYQzTUoSY5l5ka9RkL3wT4o2iQ1t5nHphX6aA-gqwgCQmi-nvjaw"}]} \ No newline at end of file +{"keys": [{"kty": "EC", "use": "sig", "kid": "SmdKMlVGcG1zMnprdDdXZGpGWEczdHhlZVpGbkx1THpPdUY4d0w4bnZkSQ", "crv": "P-256", "x": "tRHJYm0fsOi0icpGEb33qiDVgt68ltMoYSWdLGhDGz4", "y": "fRpX0i6p5Jigf5I0qwW34PyStosMShwWAWS8x_w5o7E"}, {"kty": "RSA", "use": "sig", "kid": "R0FsaFdqREFaUFp1c0MwbUpsbHVSZ200blBJZWJVMTUtNGsyVlBmdHk5UQ", "n": "2ilgsKVqF92KfhwmosSVeZOaDgb3RF1mbg-pqkmLO6YpOO06LF4V4angF-GhP-ysAm2E75aSIU4tnHVThFlcxTgKFqjYKJQXyVzTVK2r-L2IbvFPaDtvoU6WteybpMlIUVk2po3cFDGObCWYKCm7CUOLlwH0uOpui66P9VSCqdKVKbJRAQBvTSbP10KWPxulfqjWGJtHO5fY7-JVWwOBkG-eHSJIT_uaoPjyvKCZjknq04bLUV9qP78KRQpRyYijBN60w2v8F79baN9CN10TIEjjWKGz0uX0M_YYQzTUoSY5l5ka9RkL3wT4o2iQ1t5nHphX6aA-gqwgCQmi-nvjaw", "e": "AQAB"}]} \ No newline at end of file diff --git a/tests/test_14_read_only_list_file.py b/tests/test_14_read_only_list_file.py new file mode 100644 index 00000000..2abdf9e9 --- /dev/null +++ b/tests/test_14_read_only_list_file.py @@ -0,0 +1,29 @@ +import os +from time import sleep + +from idpyoidc.storage.listfile import ReadOnlyListFile + +BASEDIR = os.path.abspath(os.path.dirname(__file__)) + + +def full_path(local_file): + return os.path.join(BASEDIR, local_file) + +FILE_NAME = full_path("read_only") +def test_read_only_list_file(): + if os.path.exists(FILE_NAME): + os.unlink(FILE_NAME) + if os.path.exists(f"{FILE_NAME}.lock"): + os.unlink(f"{FILE_NAME}.lock") + + _read_only = ReadOnlyListFile(FILE_NAME) + assert len(_read_only) == 0 + + with open(FILE_NAME, "w") as fp: + for line in ["one", "two", "three"]: + fp.write(line + '\n') + + # sleep(2) + # assert _read_only.is_changed(FILE_NAME) is True + assert set(_read_only) == {"one", "two", "three"} + assert _read_only[-1] == "three" \ No newline at end of file