diff --git a/patroni/api.py b/patroni/api.py index e65f24a53..21a1514ac 100644 --- a/patroni/api.py +++ b/patroni/api.py @@ -37,13 +37,13 @@ logger = logging.getLogger(__name__) -def check_access(func: Callable[..., None]) -> Callable[..., None]: +def check_access(*args: Any, **kwargs: Any) -> Callable[..., Any]: """Check the source ip, authorization header, or client certificates. .. note:: The actual logic to check access is implemented through :func:`RestApiServer.check_access`. - :param func: function to be decorated. + Optionally it is possible to skip source ip check by specifying ``allowlist_check_members=False``. :returns: a decorator that executes *func* only if :func:`RestApiServer.check_access` returns ``True``. @@ -60,19 +60,31 @@ def check_access(func: Callable[..., None]) -> Callable[..., None]: ... @check_access ... def do_PUT_foo(self): ... print('In do_PUT_foo') + ... @check_access(allowlist_check_members=False) + ... def do_POST_bar(self): + ... print('In do_POST_bar') >>> f = Foo() >>> f.do_PUT_foo() In FooServer: Foo In do_PUT_foo - """ + allowlist_check_members = kwargs.get('allowlist_check_members', True) + + def inner_decorator(func: Callable[..., None]) -> Callable[..., None]: + def wrapper(self: 'RestApiHandler', *args: Any, **kwargs: Any) -> None: + if self.server.check_access(self, allowlist_check_members=allowlist_check_members): + return func(self, *args, **kwargs) - def wrapper(self: 'RestApiHandler', *args: Any, **kwargs: Any) -> None: - if self.server.check_access(self): - return func(self, *args, **kwargs) + return wrapper - return wrapper + # A hacky way to have decorators that work with and without parameters. + if len(args) == 1 and callable(args[0]): + # The first parameter is a function, it means decorator is used as "@check_access" + return inner_decorator(args[0]) + else: + # @check_access(allowlist_check_members=False) case + return inner_decorator class RestApiHandler(BaseHTTPRequestHandler): @@ -747,7 +759,7 @@ def do_GET_failsafe(self) -> None: else: self.send_error(502) - @check_access + @check_access(allowlist_check_members=False) def do_POST_failsafe(self) -> None: """Handle a ``POST`` request to ``/failsafe`` path. @@ -1501,7 +1513,7 @@ def __members_ips(self) -> Iterator[Union[IPv4Network, IPv6Network]]: except Exception as e: logger.debug('Failed to parse url %s: %r', member.api_url, e) - def check_access(self, rh: RestApiHandler) -> Optional[bool]: + def check_access(self, rh: RestApiHandler, allowlist_check_members: bool = True) -> Optional[bool]: """Ensure client has enough privileges to perform a given request. Write a response back to the client if any issue is observed, and the HTTP status may be: @@ -1514,12 +1526,17 @@ def check_access(self, rh: RestApiHandler) -> Optional[bool]: * a client certificate is expected by the server, but is missing in the request. :param rh: the request which access should be checked. + :param allowlist_check_members: whether we should check the source ip against existing cluster members. :returns: ``True`` if client access verification succeeded, otherwise ``None``. """ - if self.__allowlist or self.__allowlist_include_members: + allowlist_check_members = allowlist_check_members and bool(self.__allowlist_include_members) + if self.__allowlist or allowlist_check_members: incoming_ip = ip_address(rh.client_address[0]) - if not any(incoming_ip in net for net in self.__allowlist + tuple(self.__members_ips())): + + members_ips = tuple(self.__members_ips()) if allowlist_check_members else () + + if not any(incoming_ip in net for net in self.__allowlist + members_ips): return rh.write_response(403, 'Access is denied') if not hasattr(rh.request, 'getpeercert') or not rh.request.getpeercert(): # valid client cert isn't present