diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index aea92626b..f24fa76fe 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -157,3 +157,18 @@ jobs: steps: - run: bash <(curl -Ls https://coverage.codacy.com/get.sh) final if: ${{ env.SECRETS_AVAILABLE == 'true' }} + + pyright: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + + - name: Set up Python 3.11 + uses: actions/setup-python@v4 + with: + python-version: 3.11 + + - name: Install dependencies + run: python -m pip install -r requirements.txt psycopg2-binary psycopg + + - uses: jakebailey/pyright-action@v1 diff --git a/patroni/api.py b/patroni/api.py index 0fb39157f..f461811b0 100644 --- a/patroni/api.py +++ b/patroni/api.py @@ -24,7 +24,7 @@ from threading import Thread from urllib.parse import urlparse, parse_qs -from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, TYPE_CHECKING, Union from . import psycopg from .__main__ import Patroni @@ -86,7 +86,8 @@ def __init__(self, request: Any, :param client_address: address of the client connection. :param server: HTTP server that received the request. """ - assert isinstance(server, RestApiServer) + if TYPE_CHECKING: # pragma: no cover + assert isinstance(server, RestApiServer) super(RestApiHandler, self).__init__(request, client_address, server) self.server: 'RestApiServer' = server self.__start_time: float = 0.0 @@ -113,8 +114,8 @@ def _write_status_code_only(self, status_code: int) -> None: self.wfile.write('{0} {1} {2}\r\n\r\n'.format(self.protocol_version, status_code, message).encode('utf-8')) self.log_request(status_code) - def _write_response(self, status_code: int, body: str, content_type: str = 'text/html', - headers: Optional[Dict[str, str]] = None) -> None: + def write_response(self, status_code: int, body: str, content_type: str = 'text/html', + headers: Optional[Dict[str, str]] = None) -> None: """Write an HTTP response. .. note:: @@ -143,12 +144,12 @@ def _write_response(self, status_code: int, body: str, content_type: str = 'text def _write_json_response(self, status_code: int, response: Any) -> None: """Write an HTTP response with a JSON content type. - Call :func:`_write_response` with ``content_type`` as ``application/json``. + Call :func:`write_response` with ``content_type`` as ``application/json``. :param status_code: response HTTP status code. :param response: value to be dumped as a JSON string and to be used as the response body. """ - self._write_response(status_code, json.dumps(response, default=str), content_type='application/json') + self.write_response(status_code, json.dumps(response, default=str), content_type='application/json') def _write_status_response(self, status_code: int, response: Dict[str, Any]) -> None: """Write an HTTP response with Patroni/Postgres status in JSON format. @@ -457,7 +458,7 @@ def do_GET_metrics(self) -> None: * ``patroni_xlog_paused``: ``pg_is_wal_replay_paused()``; * ``patroni_postgres_server_version``: Postgres version without periods, e.g. ``150002`` for Postgres ``15.2``; * ``patroni_cluster_unlocked``: ``1`` if no one holds the leader lock, else ``0``; - * ``patroni_failsafe_mode_is_active``: ``1`` if ``failmode`` is currently active, else ``0``; + * ``patroni_failsafe_mode_is_active``: ``1`` if ``failsafe_mode`` is currently active, else ``0``; * ``patroni_postgres_timeline``: PostgreSQL timeline based on current WAL file name; * ``patroni_dcs_last_seen``: epoch timestamp when DCS was last contacted successfully; * ``patroni_pending_restart``: ``1`` if this PostgreSQL node is pending a restart, else ``0``; @@ -565,7 +566,7 @@ def do_GET_metrics(self) -> None: metrics.append("# TYPE patroni_is_paused gauge") metrics.append("patroni_is_paused{0} {1}".format(scope_label, int(postgres.get('pause', 0)))) - self._write_response(200, '\n'.join(metrics) + '\n', content_type='text/plain') + self.write_response(200, '\n'.join(metrics) + '\n', content_type='text/plain') def _read_json_content(self, body_is_optional: bool = False) -> Optional[Dict[Any, Any]]: """Read JSON from HTTP request body. @@ -615,12 +616,12 @@ def do_PATCH_config(self) -> None: request = self._read_json_content() if request: cluster = self.server.patroni.dcs.get_cluster(True) - if not (cluster.config and cluster.config.modify_index): + if not (cluster.config and cluster.config.modify_version): return self.send_error(503) data = cluster.config.data.copy() if patch_config(data, request): value = json.dumps(data, separators=(',', ':')) - if not self.server.patroni.dcs.set_config_value(value, cluster.config.index): + if not self.server.patroni.dcs.set_config_value(value, cluster.config.version): return self.send_error(409) self.server.patroni.ha.wakeup() self._write_json_response(200, data) @@ -651,7 +652,7 @@ def do_POST_reload(self) -> None: Schedules a reload to Patroni and writes a response with HTTP status `202`. """ self.server.patroni.sighup_handler() - self._write_response(202, 'reload scheduled') + self.write_response(202, 'reload scheduled') def do_GET_failsafe(self) -> None: """Handle a ``GET`` request to ``/failsafe`` path. @@ -684,7 +685,7 @@ def do_POST_failsafe(self) -> None: if request: message = self.server.patroni.ha.update_failsafe(request) or 'Accepted' code = 200 if message == 'Accepted' else 500 - self._write_response(code, message) + self.write_response(code, message) else: self.send_error(502) @@ -699,7 +700,7 @@ def do_POST_sigterm(self) -> None: """ if os.name == 'nt' and os.getenv('BEHAVE_DEBUG'): self.server.patroni.api_sigterm() - self._write_response(202, 'shutdown scheduled') + self.write_response(202, 'shutdown scheduled') @staticmethod def parse_schedule(schedule: str, @@ -779,7 +780,7 @@ def do_POST_restart(self) -> None: logger.debug("received restart request: {0}".format(request)) if self.server.patroni.config.get_global_config(cluster).is_paused and 'schedule' in request: - self._write_response(status_code, "Can't schedule restart in the paused state") + self.write_response(status_code, "Can't schedule restart in the paused state") return for k in request: @@ -827,8 +828,9 @@ def do_POST_restart(self) -> None: status_code = 409 # pyright thinks ``data`` can be ``None`` because ``parse_schedule`` call may return ``None``. However, if # that's the case, ``data`` will be overwritten when the ``for`` loop ends - assert isinstance(data, str) - self._write_response(status_code, data) + if TYPE_CHECKING: # pragma: no cover + assert isinstance(data, str) + self.write_response(status_code, data) @check_access def do_DELETE_restart(self) -> None: @@ -846,7 +848,7 @@ def do_DELETE_restart(self) -> None: else: data = "no restarts are scheduled" code = 404 - self._write_response(code, data) + self.write_response(code, data) @check_access def do_DELETE_switchover(self) -> None: @@ -861,7 +863,7 @@ def do_DELETE_switchover(self) -> None: """ failover = self.server.patroni.dcs.get_cluster().failover if failover and failover.scheduled_at: - if not self.server.patroni.dcs.manual_failover('', '', index=failover.index): + if not self.server.patroni.dcs.manual_failover('', '', version=failover.version): return self.send_error(409) else: data = "scheduled switchover deleted" @@ -869,7 +871,7 @@ def do_DELETE_switchover(self) -> None: else: data = "no switchover is scheduled" code = 404 - self._write_response(code, data) + self.write_response(code, data) @check_access def do_POST_reinitialize(self) -> None: @@ -895,7 +897,7 @@ def do_POST_reinitialize(self) -> None: data = 'reinitialize started' else: status_code = 503 - self._write_response(status_code, data) + self.write_response(status_code, data) def poll_failover_result(self, leader: Optional[str], candidate: Optional[str], action: str) -> Tuple[int, str]: """Poll failover/switchover operation until it finishes or times out. @@ -1033,9 +1035,10 @@ def do_POST_failover(self, action: str = 'failover') -> None: status_code = 503 # pyright thinks ``status_code`` can be ``None`` because ``parse_schedule`` call may return ``None``. However, # if that's the case, ``status_code`` will be overwritten somewhere between ``parse_schedule`` and - # ``_write_response`` calls. - assert isinstance(status_code, int) - self._write_response(status_code, data) + # ``write_response`` calls. + if TYPE_CHECKING: # pragma: no cover + assert isinstance(status_code, int) + self.write_response(status_code, data) def do_POST_switchover(self) -> None: """Handle a ``POST`` request to ``/switchover`` path. @@ -1062,7 +1065,7 @@ def do_POST_citus(self) -> None: if patroni.postgresql.citus_handler.is_coordinator() and patroni.ha.is_leader(): cluster = patroni.dcs.get_cluster(True) patroni.postgresql.citus_handler.handle_event(cluster, request) - self._write_response(200, 'OK') + self.write_response(200, 'OK') def parse_request(self) -> bool: """Override :func:`parse_request` method to enrich basic functionality of :class:`BaseHTTPRequestHandler`. @@ -1242,7 +1245,7 @@ def __init__(self, patroni: Patroni, config: Dict[str, Any]) -> None: :param patroni: Patroni daemon process. :param config: ``restapi`` section of Patroni configuration. """ - self.connection_string = None + self.connection_string: str self.__auth_key = None self.__allowlist_include_members: Optional[bool] = None self.__allowlist: Tuple[Union[IPv4Network, IPv6Network], ...] = () @@ -1300,7 +1303,8 @@ def check_basic_auth_key(self, key: str) -> bool: :returns: ``True`` if *key* matches the password configured for the REST API. """ # pyright -- ``__auth_key`` was already checked through the caller method (:func:`check_auth_header`). - assert self.__auth_key is not None + if TYPE_CHECKING: # pragma: no cover + assert self.__auth_key is not None return hmac.compare_digest(self.__auth_key, key.encode('utf-8')) def check_auth_header(self, auth_header: Optional[str]) -> Optional[str]: @@ -1371,16 +1375,16 @@ def check_access(self, rh: RestApiHandler) -> Optional[bool]: if self.__allowlist or self.__allowlist_include_members: incoming_ip = ip_address(rh.client_address[0]) if not any(incoming_ip in net for net in self.__allowlist + tuple(self.__members_ips())): - return rh._write_response(403, 'Access is denied') + return rh.write_response(403, 'Access is denied') if not hasattr(rh.request, 'getpeercert') or not rh.request.getpeercert(): # valid client cert isn't present if self.__protocol == 'https' and self.__ssl_options.get('verify_client') in ('required', 'optional'): - return rh._write_response(403, 'client certificate required') + return rh.write_response(403, 'client certificate required') reason = self.check_auth_header(rh.headers.get('Authorization')) if reason: headers = {'WWW-Authenticate': 'Basic realm="' + self.patroni.__class__.__name__ + '"'} - return rh._write_response(401, reason, headers=headers) + return rh.write_response(401, reason, headers=headers) return True @staticmethod @@ -1603,7 +1607,8 @@ def reload_config(self, config: Dict[str, Any]) -> None: self.__auth_key = base64.b64encode(config['auth'].encode('utf-8')) if 'auth' in config else None # pyright -- ``__listen`` is initially created as ``None``, but right after that it is replaced with a string # through :func:`__initialize`. - assert isinstance(self.__listen, str) + if TYPE_CHECKING: # pragma: no cover + assert isinstance(self.__listen, str) self.connection_string = uri(self.__protocol, config.get('connect_address') or self.__listen, 'patroni') def handle_error(self, request: Union[socket.socket, Tuple[bytes, socket.socket]], diff --git a/patroni/config.py b/patroni/config.py index 4d3eea8ed..abd2371d3 100644 --- a/patroni/config.py +++ b/patroni/config.py @@ -7,7 +7,7 @@ from collections import defaultdict from copy import deepcopy -from typing import Any, Callable, Collection, Dict, List, Optional, Union +from typing import Any, Callable, Collection, Dict, List, Optional, Union, TYPE_CHECKING from . import PATRONI_ENV_PREFIX from .collections import CaseInsensitiveDict @@ -151,7 +151,7 @@ def get_global_config(cluster: Union[Cluster, None], default: Optional[Dict[str, :returns: :class:`GlobalConfig` object """ # Try to protect from the case when DCS was wiped out - if cluster and cluster.config and cluster.config.modify_index: + if cluster and cluster.config and cluster.config.modify_version: config = cluster.config.data else: config = default or {} @@ -202,7 +202,7 @@ class Config(object): def __init__(self, configfile: str, validator: Optional[Callable[[Dict[str, Any]], List[str]]] = default_validator) -> None: - self._modify_index = -1 + self._modify_version = -1 self._dynamic_configuration = {} self.__environment_configuration = self._build_environment_configuration() @@ -257,7 +257,8 @@ def _load_config_path(self, path: str) -> Dict[str, Any]: def _load_config_file(self) -> Dict[str, Any]: """Loads config.yaml from filesystem and applies some values which were set via ENV""" - assert self._config_file is not None + if TYPE_CHECKING: # pragma: no cover + assert self._config_file is not None config = self._load_config_path(self._config_file) patch_config(config, self.__environment_configuration) return config @@ -296,9 +297,9 @@ def save_cache(self) -> None: # configuration could be either ClusterConfig or dict def set_dynamic_configuration(self, configuration: Union[ClusterConfig, Dict[str, Any]]) -> bool: if isinstance(configuration, ClusterConfig): - if self._modify_index == configuration.modify_index: - return False # If the index didn't changed there is nothing to do - self._modify_index = configuration.modify_index + if self._modify_version == configuration.modify_version: + return False # If the version didn't changed there is nothing to do + self._modify_version = configuration.modify_version configuration = configuration.data if not deep_compare(self._dynamic_configuration, configuration): diff --git a/patroni/ctl.py b/patroni/ctl.py index 786ea4ecc..612a97db3 100644 --- a/patroni/ctl.py +++ b/patroni/ctl.py @@ -32,9 +32,9 @@ from psycopg2 import cursor try: - from ydiff import markup_to_pager, PatchStream + from ydiff import markup_to_pager, PatchStream # pyright: ignore [reportMissingModuleSource] except ImportError: # pragma: no cover - from cdiff import markup_to_pager, PatchStream + from cdiff import markup_to_pager, PatchStream # pyright: ignore [reportMissingModuleSource] from .dcs import get_dcs as _get_dcs, AbstractDCS, Cluster, Member from .exceptions import PatroniException @@ -812,7 +812,8 @@ def _do_failover_or_switchover(obj: Dict[str, Any], action: str, cluster_name: s r = None try: member = cluster.leader.member if cluster.leader else candidate and cluster.get_member(candidate, False) - assert isinstance(member, Member) + if TYPE_CHECKING: # pragma: no cover + assert isinstance(member, Member) r = request_patroni(member, 'post', action, failover_value) # probably old patroni, which doesn't support switchover yet @@ -1052,7 +1053,7 @@ def flush(obj: Dict[str, Any], cluster_name: str, group: Optional[int], logging.warning('Failing over to DCS') click.echo('{0} Could not find any accessible member of cluster {1}'.format(timestamp(), cluster_name)) - dcs.manual_failover('', '', index=failover.index) + dcs.manual_failover('', '', version=failover.version) def wait_until_pause_is_applied(dcs: AbstractDCS, paused: bool, old_cluster: Cluster) -> None: @@ -1060,7 +1061,7 @@ def wait_until_pause_is_applied(dcs: AbstractDCS, paused: bool, old_cluster: Clu config = get_global_config(old_cluster) click.echo("'{0}' request sent, waiting until it is recognized by all nodes".format(paused and 'pause' or 'resume')) - old = {m.name: m.index for m in old_cluster.members if m.api_url} + old = {m.name: m.version for m in old_cluster.members if m.api_url} loop_wait = config.get('loop_wait') or dcs.loop_wait cluster = None @@ -1072,7 +1073,7 @@ def wait_until_pause_is_applied(dcs: AbstractDCS, paused: bool, old_cluster: Clu if TYPE_CHECKING: # pragma: no cover assert cluster is not None remaining = [m.name for m in cluster.members if m.data.get('pause', False) != paused - and m.name in old and old[m.name] != m.index] + and m.name in old and old[m.name] != m.version] if remaining: return click.echo("{0} members didn't recognized pause state after {1} seconds" .format(', '.join(remaining), loop_wait)) @@ -1169,7 +1170,7 @@ class opts: ( os.path.basename(p) for p in (os.environ.get('PAGER'), "less", "more") - if p is not None and shutil.which(p) + if p is not None and bool(shutil.which(p)) ), None, ) @@ -1347,7 +1348,7 @@ def edit_config(obj: Dict[str, Any], cluster_name: str, group: Optional[int], return if force or click.confirm('Apply these changes?'): - if not dcs.set_config_value(json.dumps(changed_data), cluster.config.index): + if not dcs.set_config_value(json.dumps(changed_data), cluster.config.version): raise PatroniCtlException("Config modification aborted due to concurrent changes") click.echo("Configuration changed") @@ -1396,7 +1397,8 @@ def version(obj: Dict[str, Any], cluster_name: str, group: Optional[int], member @click.pass_obj def history(obj: Dict[str, Any], cluster_name: str, group: Optional[int], fmt: str) -> None: cluster = get_dcs(obj, cluster_name, group).get_cluster() - history: List[List[Any]] = list(map(list, cluster.history and cluster.history.lines or [])) + cluster_history = cluster.history.lines if cluster.history else [] + history: List[List[Any]] = list(map(list, cluster_history)) table_header_row = ['TL', 'LSN', 'Reason', 'Timestamp', 'New Leader'] for line in history: if len(line) < len(table_header_row): diff --git a/patroni/dcs/__init__.py b/patroni/dcs/__init__.py index f1429c6f2..1ff4a60a1 100644 --- a/patroni/dcs/__init__.py +++ b/patroni/dcs/__init__.py @@ -126,7 +126,7 @@ def get_dcs(config: Union['Config', Dict[str, Any]]) -> 'AbstractDCS': class Member(NamedTuple): """Immutable object (namedtuple) which represents single member of PostgreSQL cluster. Consists of the following fields: - :param index: modification index of a given member key in a Configuration Store + :param version: modification version of a given member key in a Configuration Store :param name: name of PostgreSQL cluster member :param session: either session id or just ttl in seconds :param data: arbitrary data i.e. conn_url, api_url, xlog location, state, role, tags, etc... @@ -135,18 +135,18 @@ class Member(NamedTuple): conn_url: connection string containing host, user and password which could be used to access this member. api_url: REST API url of patroni instance """ - index: _Version + version: _Version name: str session: _Session data: Dict[str, Any] @staticmethod - def from_node(index: _Version, name: str, session: _Session, value: str) -> 'Member': + def from_node(version: _Version, name: str, session: _Session, value: str) -> 'Member': """ >>> Member.from_node(-1, '', '', '{"conn_url": "postgres://foo@bar/postgres"}') is not None True >>> Member.from_node(-1, '', '', '{') - Member(index=-1, name='', session='', data={}) + Member(version=-1, name='', session='', data={}) """ if value.startswith('postgres'): conn_url, api_url = parse_connection_string(value) @@ -157,7 +157,7 @@ def from_node(index: _Version, name: str, session: _Session, value: str) -> 'Mem assert isinstance(data, dict) except (AssertionError, TypeError, ValueError): data: Dict[str, Any] = {} - return Member(index, name, session, data) + return Member(version, name, session, data) @property def conn_url(self) -> Optional[str]: @@ -229,7 +229,7 @@ def is_running(self) -> bool: return self.state == 'running' @property - def version(self) -> Optional[Tuple[int, ...]]: + def patroni_version(self) -> Optional[Tuple[int, ...]]: version = self.data.get('version') if version: try: @@ -240,7 +240,9 @@ def version(self) -> Optional[Tuple[int, ...]]: class RemoteMember(Member): """Represents a remote member (typically a primary) for a standby cluster""" - def __new__(cls, name: str, data: Dict[str, Any]) -> 'RemoteMember': + + @classmethod + def from_name_and_data(cls, name: str, data: Dict[str, Any]) -> 'RemoteMember': return super(RemoteMember, cls).__new__(cls, -1, name, None, data) @staticmethod @@ -261,11 +263,11 @@ class Leader(NamedTuple): """Immutable object (namedtuple) which represents leader key. Consists of the following fields: - :param index: modification index of a leader key in a Configuration Store + :param version: modification version of a leader key in a Configuration Store :param session: either session id or just ttl in seconds :param member: reference to a `Member` object which represents current leader (see `Cluster.members`) """ - index: _Version + version: _Version session: _Session member: Member @@ -294,7 +296,7 @@ def checkpoint_after_promote(self) -> Optional[bool]: >>> Leader(1, '', Member.from_node(1, '', '', '{"version":"z"}')).checkpoint_after_promote """ - version = self.member.version + version = self.member.patroni_version # 1.5.6 is the last version which doesn't expose checkpoint_after_promote: false if version and version > (1, 5, 6): return self.data.get('role') in ('master', 'primary') and 'checkpoint_after_promote' not in self.data @@ -321,13 +323,13 @@ class Failover(NamedTuple): >>> 'abc' in Failover.from_node(1, 'abc:def') True """ - index: _Version + version: _Version leader: Optional[str] candidate: Optional[str] scheduled_at: Optional[datetime.datetime] @staticmethod - def from_node(index: _Version, value: Union[str, Dict[str, str]]) -> 'Failover': + def from_node(version: _Version, value: Union[str, Dict[str, str]]) -> 'Failover': if isinstance(value, dict): data: Dict[str, Any] = value elif value: @@ -340,26 +342,26 @@ def from_node(index: _Version, value: Union[str, Dict[str, str]]) -> 'Failover': t = [a.strip() for a in value.split(':')] leader = t[0] candidate = t[1] if len(t) > 1 else None - return Failover(index, leader, candidate, None) + return Failover(version, leader, candidate, None) else: data = {} if data.get('scheduled_at'): data['scheduled_at'] = dateutil.parser.parse(data['scheduled_at']) - return Failover(index, data.get('leader'), data.get('member'), data.get('scheduled_at')) + return Failover(version, data.get('leader'), data.get('member'), data.get('scheduled_at')) def __len__(self) -> int: return int(bool(self.leader)) + int(bool(self.candidate)) class ClusterConfig(NamedTuple): - index: _Version + version: _Version data: Dict[str, Any] - modify_index: _Version + modify_version: _Version @staticmethod - def from_node(index: _Version, value: str, modify_index: Optional[_Version] = None) -> 'ClusterConfig': + def from_node(version: _Version, value: str, modify_version: Optional[_Version] = None) -> 'ClusterConfig': """ >>> ClusterConfig.from_node(1, '{') is None False @@ -370,8 +372,8 @@ def from_node(index: _Version, value: str, modify_index: Optional[_Version] = No assert isinstance(data, dict) except (AssertionError, TypeError, ValueError): data: Dict[str, Any] = {} - modify_index = 0 - return ClusterConfig(index, data, index if modify_index is None else modify_index) + modify_version = 0 + return ClusterConfig(version, data, version if modify_version is None else modify_version) @property def permanent_slots(self) -> Dict[str, Any]: @@ -390,16 +392,16 @@ def max_timelines_history(self) -> int: class SyncState(NamedTuple): """Immutable object (namedtuple) which represents last observed synhcronous replication state - :param index: modification index of a synchronization key in a Configuration Store + :param version: modification version of a synchronization key in a Configuration Store :param leader: reference to member that was leader :param sync_standby: synchronous standby list (comma delimited) which are last synchronized to leader """ - index: Optional[_Version] + version: Optional[_Version] leader: Optional[str] sync_standby: Optional[str] @staticmethod - def from_node(index: Optional[_Version], value: Union[str, Dict[str, Any], None]) -> 'SyncState': + def from_node(version: Optional[_Version], value: Union[str, Dict[str, Any], None]) -> 'SyncState': """ >>> SyncState.from_node(1, None).leader is None True @@ -418,13 +420,13 @@ def from_node(index: Optional[_Version], value: Union[str, Dict[str, Any], None] if value and isinstance(value, str): value = json.loads(value) assert isinstance(value, dict) - return SyncState(index, value.get('leader'), value.get('sync_standby')) + return SyncState(version, value.get('leader'), value.get('sync_standby')) except (AssertionError, TypeError, ValueError): - return SyncState.empty(index) + return SyncState.empty(version) @staticmethod - def empty(index: Optional[_Version] = None) -> 'SyncState': - return SyncState(index, None, None) + def empty(version: Optional[_Version] = None) -> 'SyncState': + return SyncState(version, None, None) @property def is_empty(self) -> bool: @@ -484,12 +486,12 @@ def leader_matches(self, name: Optional[str]) -> bool: class TimelineHistory(NamedTuple): """Object representing timeline history file""" - index: _Version + version: _Version value: Any lines: List[_HistoryTuple] @staticmethod - def from_node(index: _Version, value: str) -> 'TimelineHistory': + def from_node(version: _Version, value: str) -> 'TimelineHistory': """ >>> h = TimelineHistory.from_node(1, 2) >>> h.lines @@ -500,7 +502,7 @@ def from_node(index: _Version, value: str) -> 'TimelineHistory': assert isinstance(lines, list) except (AssertionError, TypeError, ValueError): lines: List[_HistoryTuple] = [] - return TimelineHistory(index, value, lines) + return TimelineHistory(version, value, lines) class Cluster(NamedTuple): @@ -537,7 +539,7 @@ def empty() -> 'Cluster': def is_empty(self): return self.initialize is None and self.config is None and self.leader is None and self.last_lsn == 0\ - and self.members == [] and self.failover is None and self.sync.index is None\ + and self.members == [] and self.failover is None and self.sync.version is None\ and self.history is None and self.slots is None and self.failsafe is None and self.workers == {} def __len__(self) -> int: @@ -701,7 +703,7 @@ def timeline(self) -> int: @property def min_version(self) -> Optional[Tuple[int, ...]]: - return next(iter(sorted(m.version for m in self.members if m.version)), None) + return next(iter(sorted(m.patroni_version for m in self.members if m.patroni_version)), None) class ReturnFalseException(Exception): @@ -870,7 +872,8 @@ def __get_patroni_cluster(self, path: Optional[str] = None) -> Cluster: if path is None: path = self.client_path('') cluster = self._load_cluster(path, self._cluster_loader) - assert isinstance(cluster, Cluster) + if TYPE_CHECKING: # pragma: no cover + assert isinstance(cluster, Cluster) return cluster def is_citus_coordinator(self) -> bool: @@ -887,7 +890,6 @@ def _get_citus_cluster(self) -> Cluster: if isinstance(groups, Cluster): # Zookeeper could return a cached version cluster = groups else: - assert isinstance(groups, dict) cluster = groups.pop(CITUS_COORDINATOR_GROUP_ID, Cluster.empty()) cluster.workers.update(groups) return cluster @@ -1006,11 +1008,11 @@ def attempt_to_acquire_leader(self) -> bool: process requests (hopefuly temporary), the ~DCSError exception should be raised""" @abc.abstractmethod - def set_failover_value(self, value: str, index: Optional[Any] = None) -> bool: + def set_failover_value(self, value: str, version: Optional[Any] = None) -> bool: """Create or update `/failover` key""" def manual_failover(self, leader: Optional[str], candidate: Optional[str], - scheduled_at: Optional[datetime.datetime] = None, index: Optional[Any] = None) -> bool: + scheduled_at: Optional[datetime.datetime] = None, version: Optional[Any] = None) -> bool: failover_value = {} if leader: failover_value['leader'] = leader @@ -1020,10 +1022,10 @@ def manual_failover(self, leader: Optional[str], candidate: Optional[str], if scheduled_at: failover_value['scheduled_at'] = scheduled_at.isoformat() - return self.set_failover_value(json.dumps(failover_value, separators=(',', ':')), index) + return self.set_failover_value(json.dumps(failover_value, separators=(',', ':')), version) @abc.abstractmethod - def set_config_value(self, value: str, index: Optional[Any] = None) -> bool: + def set_config_value(self, value: str, version: Optional[Any] = None) -> bool: """Create or update `/config` key""" @abc.abstractmethod @@ -1087,16 +1089,16 @@ def sync_state(leader: Optional[str], sync_standby: Optional[Collection[str]]) - return {'leader': leader, 'sync_standby': ','.join(sorted(sync_standby)) if sync_standby else None} def write_sync_state(self, leader: Optional[str], sync_standby: Optional[Collection[str]], - index: Optional[Any] = None) -> Optional[SyncState]: + version: Optional[Any] = None) -> Optional[SyncState]: """Write the new synchronous state to DCS. Calls :func:`sync_state` method to build a dict and than calls DCS specific :func:`set_sync_state_value` method. :param leader: name of the leader node that manages /sync key :param sync_standby: collection of currently known synchronous standby node names - :param index: for conditional update of the key/object + :param version: for conditional update of the key/object :returns: the new :class:`SyncState` object or None """ sync_value = self.sync_state(leader, sync_standby) - ret = self.set_sync_state_value(json.dumps(sync_value, separators=(',', ':')), index) + ret = self.set_sync_state_value(json.dumps(sync_value, separators=(',', ':')), version) if not isinstance(ret, bool): return SyncState.from_node(ret, sync_value) @@ -1105,23 +1107,23 @@ def set_history_value(self, value: str) -> bool: """""" @abc.abstractmethod - def set_sync_state_value(self, value: str, index: Optional[Any] = None) -> Union[Any, bool]: + def set_sync_state_value(self, value: str, version: Optional[Any] = None) -> Union[Any, bool]: """Set synchronous state in DCS, should be implemented in the child class. :param value: the new value of /sync key - :param index: for conditional update of the key/object + :param version: for conditional update of the key/object :returns: version of the new object or `False` in case of error """ @abc.abstractmethod - def delete_sync_state(self, index: Optional[Any] = None) -> bool: + def delete_sync_state(self, version: Optional[Any] = None) -> bool: """""" - def watch(self, leader_index: Optional[Any], timeout: float) -> bool: + def watch(self, leader_version: Optional[Any], timeout: float) -> bool: """If the current node is a leader it should just sleep. Any other node should watch for changes of leader key with a given timeout - :param leader_index: index of a leader key + :param leader_version: version of a leader key :param timeout: timeout in seconds :returns: `!True` if you would like to reschedule the next run of ha cycle""" diff --git a/patroni/dcs/consul.py b/patroni/dcs/consul.py index 2d2d6708d..2383a9643 100644 --- a/patroni/dcs/consul.py +++ b/patroni/dcs/consul.py @@ -578,12 +578,12 @@ def take_leader(self) -> bool: return self.attempt_to_acquire_leader() @catch_consul_errors - def set_failover_value(self, value: str, index: Optional[int] = None) -> bool: - return self._client.kv.put(self.failover_path, value, cas=index) + def set_failover_value(self, value: str, version: Optional[int] = None) -> bool: + return self._client.kv.put(self.failover_path, value, cas=version) @catch_consul_errors - def set_config_value(self, value: str, index: Optional[int] = None) -> bool: - return self._client.kv.put(self.config_path, value, cas=index) + def set_config_value(self, value: str, version: Optional[int] = None) -> bool: + return self._client.kv.put(self.config_path, value, cas=version) @catch_consul_errors def _write_leader_optime(self, last_lsn: str) -> bool: @@ -622,7 +622,8 @@ def _update_leader(self) -> bool: raise ConsulError('update_leader timeout') logger.warning('Recreating the leader key due to session mismatch') if cluster and cluster.leader: - self._run_and_handle_exceptions(self._client.kv.delete, self.leader_path, cas=cluster.leader.index) + self._run_and_handle_exceptions(self._client.kv.delete, self.leader_path, + cas=cluster.leader.version) retry.deadline = retry.stoptime - time.time() if retry.deadline < 0.5: @@ -653,14 +654,14 @@ def set_history_value(self, value: str) -> bool: def _delete_leader(self) -> bool: cluster = self.cluster if cluster and isinstance(cluster.leader, Leader) and\ - cluster.leader.name == self._name and isinstance(cluster.leader.index, int): - return self._client.kv.delete(self.leader_path, cas=cluster.leader.index) + cluster.leader.name == self._name and isinstance(cluster.leader.version, int): + return self._client.kv.delete(self.leader_path, cas=cluster.leader.version) return True @catch_consul_errors - def set_sync_state_value(self, value: str, index: Optional[int] = None) -> Union[int, bool]: + def set_sync_state_value(self, value: str, version: Optional[int] = None) -> Union[int, bool]: retry = self._retry.copy() - ret = retry(self._client.kv.put, self.sync_path, value, cas=index) + ret = retry(self._client.kv.put, self.sync_path, value, cas=version) if ret: # We have no other choise, only read after write :( retry.deadline = retry.stoptime - time.time() if retry.deadline < 0.5: @@ -671,21 +672,21 @@ def set_sync_state_value(self, value: str, index: Optional[int] = None) -> Union return False @catch_consul_errors - def delete_sync_state(self, index: Optional[int] = None) -> bool: - return self.retry(self._client.kv.delete, self.sync_path, cas=index) + def delete_sync_state(self, version: Optional[int] = None) -> bool: + return self.retry(self._client.kv.delete, self.sync_path, cas=version) - def watch(self, leader_index: Optional[int], timeout: float) -> bool: + def watch(self, leader_version: Optional[int], timeout: float) -> bool: self._last_session_refresh = 0 if self.__do_not_watch: self.__do_not_watch = False return True - if leader_index: + if leader_version: end_time = time.time() + timeout while timeout >= 1: try: - idx, _ = self._client.kv.get(self.leader_path, index=leader_index, wait=str(timeout) + 's') - return str(idx) != str(leader_index) + idx, _ = self._client.kv.get(self.leader_path, index=leader_version, wait=str(timeout) + 's') + return str(idx) != str(leader_version) except (ConsulException, HTTPException, HTTPError, socket.error, socket.timeout): logger.exception('watch') diff --git a/patroni/dcs/etcd.py b/patroni/dcs/etcd.py index e53d8a671..190b79fce 100644 --- a/patroni/dcs/etcd.py +++ b/patroni/dcs/etcd.py @@ -290,7 +290,8 @@ def api_execute(self, path: str, method: str, params: Optional[Dict[str, Any]] = etcd_nodes = len(machines_cache) except Exception as e: logger.debug('Failed to update list of etcd nodes: %r', e) - assert isinstance(retry, Retry) # etcd.EtcdConnectionFailed is raised only if retry is not None! + if TYPE_CHECKING: # pragma: no cover + assert isinstance(retry, Retry) # etcd.EtcdConnectionFailed is raised only if retry is not None! sleeptime = retry.sleeptime remaining_time = retry.stoptime - sleeptime - time.time() nodes, timeout, retries = self._calculate_timeouts(etcd_nodes, remaining_time) @@ -502,6 +503,17 @@ def _handle_exception(self, e: Exception, name: str = '', do_sleep: bool = False if isinstance(raise_ex, Exception): raise raise_ex + def handle_etcd_exceptions(self, func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: + try: + retval = func(self, *args, **kwargs) + self._has_failed = False + return retval + except (RetryFailedError, etcd.EtcdException) as e: + self._handle_exception(e) + return False + except Exception as e: + self._handle_exception(e, raise_ex=self._client.ERROR_CLS('unexpected error')) + def _run_and_handle_exceptions(self, method: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: retry = kwargs.pop('retry', self.retry) try: @@ -624,16 +636,7 @@ def set_retry_timeout(self, retry_timeout: int) -> None: def catch_etcd_errors(func: Callable[..., Any]) -> Any: def wrapper(self: AbstractEtcd, *args: Any, **kwargs: Any) -> Any: - try: - retval = func(self, *args, **kwargs) - self._has_failed = False - return retval - except (RetryFailedError, etcd.EtcdException) as e: - self._handle_exception(e) - return False - except Exception as e: - self._handle_exception(e, raise_ex=self._client.ERROR_CLS('unexpected error')) - + return self.handle_etcd_exceptions(func, *args, **kwargs) return wrapper @@ -645,7 +648,8 @@ def __init__(self, config: Dict[str, Any]) -> None: @property def _client(self) -> EtcdClient: - assert isinstance(self._abstract_client, EtcdClient) + if TYPE_CHECKING: # pragma: no cover + assert isinstance(self._abstract_client, EtcdClient) return self._abstract_client def set_ttl(self, ttl: int) -> Optional[bool]: @@ -696,8 +700,8 @@ def _cluster_from_nodes(self, etcd_index: int, nodes: Dict[str, etcd.EtcdResult] if leader: member = Member(-1, leader.value, None, {}) member = ([m for m in members if m.name == leader.value] or [member])[0] - index = etcd_index if etcd_index > leader.modifiedIndex else leader.modifiedIndex + 1 - leader = Leader(index, leader.ttl, member) + version = etcd_index if etcd_index > leader.modifiedIndex else leader.modifiedIndex + 1 + leader = Leader(version, leader.ttl, member) # failover key failover = nodes.get(self._FAILOVER) @@ -742,7 +746,8 @@ def _load_cluster( except Exception as e: self._handle_exception(e, 'get_cluster', raise_ex=EtcdError('Etcd is not responding properly')) self._has_failed = False - assert cluster is not None + if TYPE_CHECKING: # pragma: no cover + assert cluster is not None return cluster @catch_etcd_errors @@ -766,12 +771,12 @@ def attempt_to_acquire_leader(self) -> bool: return self._run_and_handle_exceptions(self._do_attempt_to_acquire_leader, retry=None) @catch_etcd_errors - def set_failover_value(self, value: str, index: Optional[int] = None) -> bool: - return bool(self._client.write(self.failover_path, value, prevIndex=index or 0)) + def set_failover_value(self, value: str, version: Optional[int] = None) -> bool: + return bool(self._client.write(self.failover_path, value, prevIndex=version or 0)) @catch_etcd_errors - def set_config_value(self, value: str, index: Optional[int] = None) -> bool: - return bool(self._client.write(self.config_path, value, prevIndex=index or 0)) + def set_config_value(self, value: str, version: Optional[int] = None) -> bool: + return bool(self._client.write(self.config_path, value, prevIndex=version or 0)) @catch_etcd_errors def _write_leader_optime(self, last_lsn: str) -> bool: @@ -817,24 +822,24 @@ def set_history_value(self, value: str) -> bool: return bool(self._client.write(self.history_path, value)) @catch_etcd_errors - def set_sync_state_value(self, value: str, index: Optional[int] = None) -> Union[int, bool]: - return self.retry(self._client.write, self.sync_path, value, prevIndex=index or 0).modifiedIndex + def set_sync_state_value(self, value: str, version: Optional[int] = None) -> Union[int, bool]: + return self.retry(self._client.write, self.sync_path, value, prevIndex=version or 0).modifiedIndex @catch_etcd_errors - def delete_sync_state(self, index: Optional[int] = None) -> bool: - return bool(self.retry(self._client.delete, self.sync_path, prevIndex=index or 0)) + def delete_sync_state(self, version: Optional[int] = None) -> bool: + return bool(self.retry(self._client.delete, self.sync_path, prevIndex=version or 0)) - def watch(self, leader_index: Optional[int], timeout: float) -> bool: + def watch(self, leader_version: Optional[int], timeout: float) -> bool: if self.__do_not_watch: self.__do_not_watch = False return True - if leader_index: + if leader_version: end_time = time.time() + timeout while timeout >= 1: # when timeout is too small urllib3 doesn't have enough time to connect try: - result = self._client.watch(self.leader_path, index=leader_index, timeout=timeout + 0.5) + result = self._client.watch(self.leader_path, index=leader_version, timeout=timeout + 0.5) self._has_failed = False if result.action == 'compareAndSwap': time.sleep(0.01) diff --git a/patroni/dcs/etcd3.py b/patroni/dcs/etcd3.py index bdf5cc336..308bb2c3e 100644 --- a/patroni/dcs/etcd3.py +++ b/patroni/dcs/etcd3.py @@ -13,7 +13,7 @@ from enum import IntEnum from urllib3.exceptions import ReadTimeoutError, ProtocolError from threading import Condition, Lock, Thread -from typing import Any, Callable, Collection, Dict, Iterator, List, Optional, Tuple, Type, Union +from typing import Any, Callable, Collection, Dict, Iterator, List, Optional, Tuple, Type, TYPE_CHECKING, Union from . import ClusterConfig, Cluster, Failover, Leader, Member, SyncState,\ TimelineHistory, ReturnFalseException, catch_return_false_exception, citus_group_re @@ -145,7 +145,8 @@ class InvalidAuthToken(Etcd3ClientError): def _raise_for_data(data: Union[bytes, str, Dict[str, Union[Any, Dict[str, Any]]]], status_code: Optional[int] = None) -> Etcd3ClientError: try: - assert isinstance(data, dict) + if TYPE_CHECKING: # pragma: no cover + assert isinstance(data, dict) data_error: Optional[Dict[str, Any]] = data.get('error') or data.get('Error') if isinstance(data_error, dict): # streaming response status_code = data_error.get('http_code') @@ -153,7 +154,8 @@ def _raise_for_data(data: Union[bytes, str, Dict[str, Union[Any, Dict[str, Any]] error: str = data_error['message'] else: data_code = data.get('code') or data.get('Code') - assert not isinstance(data_code, dict) + if TYPE_CHECKING: # pragma: no cover + assert not isinstance(data_code, dict) code = data_code error = str(data_error) except Exception: @@ -193,28 +195,7 @@ def build_range_request(key: str, range_end: Union[bytes, str, None] = None) -> def _handle_auth_errors(func: Callable[..., Any]) -> Any: def wrapper(self: 'Etcd3Client', *args: Any, **kwargs: Any) -> Any: - def retry(ex: Exception) -> Any: - if self.username and self.password: - self.authenticate() - return func(self, *args, **kwargs) - else: - logger.fatal('Username or password not set, authentication is not possible') - raise ex - - try: - return func(self, *args, **kwargs) - except (UserEmpty, PermissionDenied) as e: # no token provided - # PermissionDenied is raised on 3.0 and 3.1 - if self._cluster_version < (3, 3) and (not isinstance(e, PermissionDenied) - or self._cluster_version < (3, 2)): - raise UnsupportedEtcdVersion('Authentication is required by Etcd cluster but not ' - 'supported on version lower than 3.3.0. Cluster version: ' - '{0}'.format('.'.join(map(str, self._cluster_version)))) - return retry(e) - except InvalidAuthToken as e: - logger.error('Invalid auth token: %s', self._token) - return retry(e) - + return self.handle_auth_errors(func, *args, **kwargs) return wrapper @@ -322,6 +303,29 @@ def authenticate(self) -> bool: self._token = response.get('token') return old_token != self._token + def handle_auth_errors(self: 'Etcd3Client', func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: + def retry(ex: Exception) -> Any: + if self.username and self.password: + self.authenticate() + return func(self, *args, **kwargs) + else: + logger.fatal('Username or password not set, authentication is not possible') + raise ex + + try: + return func(self, *args, **kwargs) + except (UserEmpty, PermissionDenied) as e: # no token provided + # PermissionDenied is raised on 3.0 and 3.1 + if self._cluster_version < (3, 3) and (not isinstance(e, PermissionDenied) + or self._cluster_version < (3, 2)): + raise UnsupportedEtcdVersion('Authentication is required by Etcd cluster but not ' + 'supported on version lower than 3.3.0. Cluster version: ' + '{0}'.format('.'.join(map(str, self._cluster_version)))) + return retry(e) + except InvalidAuthToken as e: + logger.error('Invalid auth token: %s', self._token) + return retry(e) + @_handle_auth_errors def range(self, key: str, range_end: Union[bytes, str, None] = None, retry: Optional[Retry] = None) -> Dict[str, Any]: @@ -401,7 +405,7 @@ def __init__(self, dcs: 'Etcd3', client: 'PatroniEtcd3Client') -> None: self._leader_key = base64_encode(dcs.leader_path) self._optime_key = base64_encode(dcs.leader_optime_path) self._status_key = base64_encode(dcs.status_path) - self._name = base64_encode(dcs._name) + self._name = base64_encode(getattr(dcs, '_name')) # pyright self._is_ready = False self._response = None self._response_lock = Lock() @@ -582,9 +586,9 @@ def _wait_cache(self, timeout: float) -> None: self._kv_cache.condition.wait(timeout) def get_cluster(self, path: str) -> List[Dict[str, Any]]: - if self._kv_cache and self._etcd3._retry.deadline is not None and path.startswith(self._etcd3.cluster_prefix): + if self._kv_cache and path.startswith(self._etcd3.cluster_prefix): with self._kv_cache.condition: - self._wait_cache(self._etcd3._retry.deadline) + self._wait_cache(self.read_timeout) ret = self._kv_cache.copy() else: ret = self._etcd3.retry(self.prefix, path).get('kvs', []) @@ -621,7 +625,6 @@ class Etcd3(AbstractEtcd): def __init__(self, config: Dict[str, Any]) -> None: super(Etcd3, self).__init__(config, PatroniEtcd3Client, (DeadlineExceeded, Unavailable, FailedPrecondition)) - assert isinstance(self._client, PatroniEtcd3Client) self.__do_not_watch = False self._lease = None self._last_lease_refresh = 0 @@ -633,12 +636,14 @@ def __init__(self, config: Dict[str, Any]) -> None: @property def _client(self) -> PatroniEtcd3Client: - assert isinstance(self._abstract_client, PatroniEtcd3Client) + if TYPE_CHECKING: # pragma: no cover + assert isinstance(self._abstract_client, PatroniEtcd3Client) return self._abstract_client def set_socket_options(self, sock: socket.socket, socket_options: Optional[Collection[Tuple[int, int, int]]]) -> None: - assert self._retry.deadline is not None + if TYPE_CHECKING: # pragma: no cover + assert self._retry.deadline is not None enable_keepalive(sock, self.ttl, int(self.loop_wait + self._retry.deadline)) def set_ttl(self, ttl: int) -> Optional[bool]: @@ -773,7 +778,8 @@ def _load_cluster( except Exception as e: self._handle_exception(e, 'get_cluster', raise_ex=Etcd3Error('Etcd is not responding properly')) self._has_failed = False - assert cluster is not None + if TYPE_CHECKING: # pragma: no cover + assert cluster is not None return cluster @catch_etcd_errors @@ -841,12 +847,12 @@ def _retry(*args: Any, **kwargs: Any) -> Any: return ret @catch_etcd_errors - def set_failover_value(self, value: str, index: Optional[str] = None) -> bool: - return bool(self._client.put(self.failover_path, value, mod_revision=index)) + def set_failover_value(self, value: str, version: Optional[str] = None) -> bool: + return bool(self._client.put(self.failover_path, value, mod_revision=version)) @catch_etcd_errors - def set_config_value(self, value: str, index: Optional[str] = None) -> bool: - return bool(self._client.put(self.config_path, value, mod_revision=index)) + def set_config_value(self, value: str, version: Optional[str] = None) -> bool: + return bool(self._client.put(self.config_path, value, mod_revision=version)) @catch_etcd_errors def _write_leader_optime(self, last_lsn: str) -> bool: @@ -893,7 +899,7 @@ def initialize(self, create_new: bool = True, sysid: str = ""): def _delete_leader(self) -> bool: cluster = self.cluster if cluster and isinstance(cluster.leader, Leader) and cluster.leader.name == self._name: - return self._client.deleterange(self.leader_path, mod_revision=cluster.leader.index) + return self._client.deleterange(self.leader_path, mod_revision=cluster.leader.version) return True @catch_etcd_errors @@ -909,15 +915,15 @@ def set_history_value(self, value: str) -> bool: return bool(self._client.put(self.history_path, value)) @catch_etcd_errors - def set_sync_state_value(self, value: str, index: Optional[str] = None) -> Union[str, bool]: - return self.retry(self._client.put, self.sync_path, value, mod_revision=index)\ + def set_sync_state_value(self, value: str, version: Optional[str] = None) -> Union[str, bool]: + return self.retry(self._client.put, self.sync_path, value, mod_revision=version)\ .get('header', {}).get('revision', False) @catch_etcd_errors - def delete_sync_state(self, index: Optional[str] = None) -> bool: - return self.retry(self._client.deleterange, self.sync_path, mod_revision=index) + def delete_sync_state(self, version: Optional[str] = None) -> bool: + return self.retry(self._client.deleterange, self.sync_path, mod_revision=version) - def watch(self, leader_index: Optional[str], timeout: float) -> bool: + def watch(self, leader_version: Optional[str], timeout: float) -> bool: if self.__do_not_watch: self.__do_not_watch = False return True diff --git a/patroni/dcs/kubernetes.py b/patroni/dcs/kubernetes.py index e33a8509f..3e334e9db 100644 --- a/patroni/dcs/kubernetes.py +++ b/patroni/dcs/kubernetes.py @@ -135,11 +135,14 @@ def load_kube_config(self, context: Optional[str] = None) -> None: context = context or config['current-context'] context_value = self._get_by_name(config, 'context', context) - assert isinstance(context_value, dict) + if TYPE_CHECKING: # pragma: no cover + assert isinstance(context_value, dict) cluster = self._get_by_name(config, 'cluster', context_value['cluster']) - assert isinstance(cluster, dict) + if TYPE_CHECKING: # pragma: no cover + assert isinstance(cluster, dict) user = self._get_by_name(config, 'user', context_value['user']) - assert isinstance(user, dict) + if TYPE_CHECKING: # pragma: no cover + assert isinstance(user, dict) self._server = cluster['server'].rstrip('/') if self._server.startswith('https'): @@ -281,7 +284,8 @@ def _get_api_servers(self, api_servers_cache: List[str]) -> List[str]: try: response = self.pool_manager.request('GET', base_uri + path, **kwargs) endpoint = self._handle_server_response(response, True) - assert isinstance(endpoint, K8sObject) + if TYPE_CHECKING: # pragma: no cover + assert isinstance(endpoint, K8sObject) for subset in endpoint.subsets: for port in subset.ports: if port.name == 'https' and port.protocol == 'TCP': @@ -412,7 +416,8 @@ def request( except Exception as e: logger.debug('Failed to update list of K8s master nodes: %r', e) - assert isinstance(retry, Retry) # K8sConnectionFailed is raised only if retry is not None! + if TYPE_CHECKING: # pragma: no cover + assert isinstance(retry, Retry) # K8sConnectionFailed is raised only if retry is not None! sleeptime = retry.sleeptime remaining_time = (retry.stoptime or time.time()) - sleeptime - time.time() nodes, timeout, retries = self._calculate_timeouts(api_servers, remaining_time) @@ -559,10 +564,23 @@ def use_endpoints(self) -> bool: return self._use_endpoints +def _run_and_handle_exceptions(method: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: + try: + return method(*args, **kwargs) + except k8s_client.rest.ApiException as e: + if e.status == 403: + logger.exception('Permission denied') + elif e.status != 409: # Object exists or conflict in resource_version + logger.exception('Unexpected error from Kubernetes API') + return False + except (RetryFailedError, K8sException) as e: + raise KubernetesError(e) + + def catch_kubernetes_errors(func: Callable[..., Any]) -> Callable[..., Any]: def wrapper(self: 'Kubernetes', *args: Any, **kwargs: Any) -> Any: try: - return self._run_and_handle_exceptions(func, self, *args, **kwargs) + return _run_and_handle_exceptions(func, self, *args, **kwargs) except KubernetesError: return False return wrapper @@ -584,7 +602,8 @@ def __init__(self, dcs: 'Kubernetes', func: Callable[..., Any], retry: Retry, self._response_lock = Lock() # protect the `self._response` from concurrent access self._object_cache: Dict[str, K8sObject] = {} self._object_cache_lock = Lock() - self._annotations_map = {self._dcs.leader_path: self._dcs._LEADER, self._dcs.config_path: self._dcs._CONFIG} + self._annotations_map = {self._dcs.leader_path: getattr(self._dcs, '_LEADER'), + self._dcs.config_path: getattr(self._dcs, '_CONFIG')} # pyright self.start() def _list(self) -> K8sObject: @@ -779,19 +798,6 @@ def retry(self, method: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: kwargs['_retry'] = retry return retry(method, *args, **kwargs) - @staticmethod - def _run_and_handle_exceptions(method: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: - try: - return method(*args, **kwargs) - except k8s_client.rest.ApiException as e: - if e.status == 403: - logger.exception('Permission denied') - elif e.status != 409: # Object exists or conflict in resource_version - logger.exception('Unexpected error from Kubernetes API') - return False - except (RetryFailedError, K8sException) as e: - raise KubernetesError(e) - def client_path(self, path: str) -> str: return super(Kubernetes, self).client_path(path)[1:].replace('/', '-') @@ -818,7 +824,8 @@ def reload_config(self, config: Union['Config', Dict[str, Any]]) -> None: Either cause by changes in the local configuration file + SIGHUP or by changes of dynamic configuration""" super(Kubernetes, self).reload_config(config) - assert self._retry.deadline is not None + if TYPE_CHECKING: # pragma: no cover + assert self._retry.deadline is not None self._api.configure_timeouts(self.loop_wait, self._retry.deadline, self.ttl) # retriable_http_codes supposed to be either int, list of integers or comma-separated string with integers. @@ -954,7 +961,8 @@ def _citus_cluster_loader(self, path: Dict[str, Any]) -> Dict[int, Cluster]: def __load_cluster( self, group: Optional[str], loader: Callable[[Dict[str, Any]], Union[Cluster, Dict[int, Cluster]]] ) -> Union[Cluster, Dict[int, Cluster]]: - assert self._retry.deadline is not None + if TYPE_CHECKING: # pragma: no cover + assert self._retry.deadline is not None stop_time = time.time() + self._retry.deadline self._api.refresh_api_servers_cache() try: @@ -978,7 +986,8 @@ def _load_cluster( def get_citus_coordinator(self) -> Optional[Cluster]: try: ret = self.__load_cluster(str(CITUS_COORDINATOR_GROUP_ID), self._cluster_loader) - assert isinstance(ret, Cluster) + if TYPE_CHECKING: # pragma: no cover + assert isinstance(ret, Cluster) return ret except Exception as e: logger.error('Failed to load Citus coordinator cluster from Kubernetes: %r', e) @@ -1170,8 +1179,8 @@ def _retry(*args: Any, **kwargs: Any) -> Any: if kind and (kind_annotations.get(self._LEADER) != self._name or kind_resource_version == resource_version): return False - return bool(self._run_and_handle_exceptions(self._patch_or_create, self.leader_path, annotations, - kind_resource_version, ips=ips, retry=_retry)) + return bool(_run_and_handle_exceptions(self._patch_or_create, self.leader_path, annotations, + kind_resource_version, ips=ips, retry=_retry)) def update_leader(self, last_lsn: Optional[int], slots: Optional[Dict[str, int]] = None, failsafe: Optional[Dict[str, str]] = None) -> bool: @@ -1232,24 +1241,24 @@ def attempt_to_acquire_leader(self) -> bool: def take_leader(self) -> bool: return self.attempt_to_acquire_leader() - def set_failover_value(self, value: str, index: Optional[str] = None) -> bool: + def set_failover_value(self, value: str, version: Optional[str] = None) -> bool: """Unused""" raise NotImplementedError # pragma: no cover def manual_failover(self, leader: Optional[str], candidate: Optional[str], - scheduled_at: Optional[datetime.datetime] = None, index: Optional[str] = None) -> bool: + scheduled_at: Optional[datetime.datetime] = None, version: Optional[str] = None) -> bool: annotations = {'leader': leader or None, 'member': candidate or None, 'scheduled_at': scheduled_at and scheduled_at.isoformat()} - patch = bool(self.cluster and isinstance(self.cluster.failover, Failover) and self.cluster.failover.index) - return bool(self.patch_or_create(self.failover_path, annotations, index, bool(index or patch), False)) + patch = bool(self.cluster and isinstance(self.cluster.failover, Failover) and self.cluster.failover.version) + return bool(self.patch_or_create(self.failover_path, annotations, version, bool(version or patch), False)) @property def _config_resource_version(self) -> Optional[str]: config = self._kinds.get(self.config_path) return config and config.metadata.resource_version - def set_config_value(self, value: str, index: Optional[str] = None) -> bool: - return self.patch_or_create_config({self._CONFIG: value}, index, bool(self._config_resource_version), False) + def set_config_value(self, value: str, version: Optional[str] = None) -> bool: + return self.patch_or_create_config({self._CONFIG: value}, version, bool(self._config_resource_version), False) @catch_kubernetes_errors def touch_member(self, data: Dict[str, Any]) -> bool: @@ -1279,7 +1288,8 @@ def touch_member(self, data: Dict[str, Any]) -> bool: def initialize(self, create_new: bool = True, sysid: str = "") -> bool: cluster = self.cluster - resource_version = str(cluster.config.index) if cluster and cluster.config and cluster.config.index else None + resource_version = str(cluster.config.version)\ + if cluster and cluster.config and cluster.config.version else None return self.patch_or_create_config({self._INITIALIZE: sysid}, resource_version) def _delete_leader(self) -> bool: @@ -1308,34 +1318,34 @@ def delete_cluster(self) -> bool: def set_history_value(self, value: str) -> bool: return self.patch_or_create_config({self._HISTORY: value}, None, bool(self._config_resource_version), False) - def set_sync_state_value(self, value: str, index: Optional[str] = None) -> bool: + def set_sync_state_value(self, value: str, version: Optional[str] = None) -> bool: """Unused""" raise NotImplementedError # pragma: no cover def write_sync_state(self, leader: Optional[str], sync_standby: Optional[Collection[str]], - index: Optional[str] = None) -> Optional[SyncState]: + version: Optional[str] = None) -> Optional[SyncState]: """Prepare and write annotations to $SCOPE-sync Endpoint or ConfigMap. :param leader: name of the leader node that manages /sync key :param sync_standby: collection of currently known synchronous standby node names - :param index: last known `resource_version` for conditional update of the object + :param version: last known `resource_version` for conditional update of the object :returns: the new :class:`SyncState` object or None """ sync_state = self.sync_state(leader, sync_standby) - ret = self.patch_or_create(self.sync_path, sync_state, index, False) + ret = self.patch_or_create(self.sync_path, sync_state, version, False) if not isinstance(ret, bool): return SyncState.from_node(ret.metadata.resource_version, sync_state) - def delete_sync_state(self, index: Optional[str] = None) -> bool: + def delete_sync_state(self, version: Optional[str] = None) -> bool: """Patch annotations of $SCOPE-sync Endpoint or ConfigMap with empty values. Effectively it removes "leader" and "sync_standby" annotations from the object. - :param index: last known `resource_version` for conditional update of the object + :param version: last known `resource_version` for conditional update of the object :returns: `True` if "delete" was successful """ - return self.write_sync_state(None, None, index=index) is not None + return self.write_sync_state(None, None, version=version) is not None - def watch(self, leader_index: Optional[str], timeout: float) -> bool: + def watch(self, leader_version: Optional[str], timeout: float) -> bool: if self.__do_not_watch: self.__do_not_watch = False return True diff --git a/patroni/dcs/raft.py b/patroni/dcs/raft.py index 101139311..9d9b5f6d1 100644 --- a/patroni/dcs/raft.py +++ b/patroni/dcs/raft.py @@ -430,11 +430,11 @@ def attempt_to_acquire_leader(self) -> bool: return self._sync_obj.set(self.leader_path, self._name, ttl=self._ttl, handle_raft_error=False, prevExist=False) is not False - def set_failover_value(self, value: str, index: Optional[int] = None) -> bool: - return self._sync_obj.set(self.failover_path, value, prevIndex=index) is not False + def set_failover_value(self, value: str, version: Optional[int] = None) -> bool: + return self._sync_obj.set(self.failover_path, value, prevIndex=version) is not False - def set_config_value(self, value: str, index: Optional[int] = None) -> bool: - return self._sync_obj.set(self.config_path, value, prevIndex=index) is not False + def set_config_value(self, value: str, version: Optional[int] = None) -> bool: + return self._sync_obj.set(self.config_path, value, prevIndex=version) is not False def touch_member(self, data: Dict[str, Any]) -> bool: value = json.dumps(data, separators=(',', ':')) @@ -458,17 +458,17 @@ def delete_cluster(self) -> bool: def set_history_value(self, value: str) -> bool: return self._sync_obj.set(self.history_path, value) is not False - def set_sync_state_value(self, value: str, index: Optional[int] = None) -> Union[int, bool]: - ret = self._sync_obj.set(self.sync_path, value, prevIndex=index) + def set_sync_state_value(self, value: str, version: Optional[int] = None) -> Union[int, bool]: + ret = self._sync_obj.set(self.sync_path, value, prevIndex=version) if isinstance(ret, dict): return ret['index'] return ret - def delete_sync_state(self, index: Optional[int] = None) -> bool: - return self._sync_obj.delete(self.sync_path, prevIndex=index) + def delete_sync_state(self, version: Optional[int] = None) -> bool: + return self._sync_obj.delete(self.sync_path, prevIndex=version) - def watch(self, leader_index: Optional[int], timeout: float) -> bool: + def watch(self, leader_version: Optional[int], timeout: float) -> bool: try: - return super(Raft, self).watch(leader_index, timeout) + return super(Raft, self).watch(leader_version, timeout) finally: self.event.clear() diff --git a/patroni/dcs/zookeeper.py b/patroni/dcs/zookeeper.py index 3d5f7b4b8..5ec9cd04d 100644 --- a/patroni/dcs/zookeeper.py +++ b/patroni/dcs/zookeeper.py @@ -273,7 +273,7 @@ def _cluster_loader(self, path: str) -> Cluster: member = Member(-1, leader[0], None, {}) member = ([m for m in members if m.name == leader[0]] or [member])[0] leader = Leader(leader[1].version, leader[1].ephemeralOwner, member) - self._fetch_cluster = member.index == -1 + self._fetch_cluster = member.version == -1 # get last known leader lsn and slots last_lsn, slots = self.get_status(path, leader) @@ -357,19 +357,19 @@ def attempt_to_acquire_leader(self) -> bool: logger.info('Could not take out TTL lock') return False - def _set_or_create(self, key: str, value: str, index: Optional[int] = None, + def _set_or_create(self, key: str, value: str, version: Optional[int] = None, retry: bool = False, do_not_create_empty: bool = False) -> Union[int, bool]: value_bytes = value.encode('utf-8') try: if retry: - ret = self._client.retry(self._client.set, key, value_bytes, version=index or -1) + ret = self._client.retry(self._client.set, key, value_bytes, version=version or -1) else: - ret = self._client.set_async(key, value_bytes, version=index or -1).get(timeout=1) + ret = self._client.set_async(key, value_bytes, version=version or -1).get(timeout=1) return ret.version except NoNodeError: if do_not_create_empty and not value_bytes: return True - elif index is None: + elif version is None: if self._create(key, value_bytes, retry): return 0 else: @@ -378,11 +378,11 @@ def _set_or_create(self, key: str, value: str, index: Optional[int] = None, logger.exception('Failed to update %s', key) return False - def set_failover_value(self, value: str, index: Optional[int] = None) -> bool: - return self._set_or_create(self.failover_path, value, index) is not False + def set_failover_value(self, value: str, version: Optional[int] = None) -> bool: + return self._set_or_create(self.failover_path, value, version) is not False - def set_config_value(self, value: str, index: Optional[int] = None) -> bool: - return self._set_or_create(self.config_path, value, index, retry=True) is not False + def set_config_value(self, value: str, version: Optional[int] = None) -> bool: + return self._set_or_create(self.config_path, value, version, retry=True) is not False def initialize(self, create_new: bool = True, sysid: str = "") -> bool: sysid_bytes = sysid.encode('utf-8') @@ -494,14 +494,14 @@ def delete_cluster(self) -> bool: def set_history_value(self, value: str) -> bool: return self._set_or_create(self.history_path, value) is not False - def set_sync_state_value(self, value: str, index: Optional[int] = None) -> Union[int, bool]: - return self._set_or_create(self.sync_path, value, index, retry=True, do_not_create_empty=True) + def set_sync_state_value(self, value: str, version: Optional[int] = None) -> Union[int, bool]: + return self._set_or_create(self.sync_path, value, version, retry=True, do_not_create_empty=True) - def delete_sync_state(self, index: Optional[int] = None) -> bool: - return self.set_sync_state_value("{}", index) is not False + def delete_sync_state(self, version: Optional[int] = None) -> bool: + return self.set_sync_state_value("{}", version) is not False - def watch(self, leader_index: Optional[int], timeout: float) -> bool: - ret = super(ZooKeeper, self).watch(leader_index, timeout + 0.5) + def watch(self, leader_version: Optional[int], timeout: float) -> bool: + ret = super(ZooKeeper, self).watch(leader_version, timeout + 0.5) if ret and not self._fetch_status: self._fetch_cluster = True return ret or self._fetch_cluster diff --git a/patroni/ha.py b/patroni/ha.py index e1559a279..1190ad1c2 100644 --- a/patroni/ha.py +++ b/patroni/ha.py @@ -100,10 +100,10 @@ def update(self, data: Dict[str, Any]) -> None: @property def leader(self) -> Optional[Leader]: with self._lock: - if self._last_update + self._dcs.ttl > time.time(): - return Leader('', '', RemoteMember(self._name, {'api_url': self._api_url, - 'conn_url': self._conn_url, - 'slots': self._slots})) + if self._last_update + self._dcs.ttl > time.time() and self._name: + return Leader('', '', RemoteMember.from_name_and_data(self._name, {'api_url': self._api_url, + 'conn_url': self._conn_url, + 'slots': self._slots})) def update_cluster(self, cluster: Cluster) -> Cluster: # Enreach cluster with the real leader if there was a ping from it @@ -596,7 +596,7 @@ def process_sync_replication(self) -> None: if sync_common != current: logger.info("Updating synchronous privilege temporarily from %s to %s", list(current), list(sync_common)) - sync = self.dcs.write_sync_state(self.state_handler.name, sync_common, index=sync.index) + sync = self.dcs.write_sync_state(self.state_handler.name, sync_common, version=sync.version) if not sync: return logger.info('Synchronous replication key updated by someone else.') @@ -614,11 +614,11 @@ def process_sync_replication(self) -> None: time.sleep(2) _, allow_promote = self.state_handler.sync_handler.current_state(self.cluster) if allow_promote and allow_promote != sync_common: - if not self.dcs.write_sync_state(self.state_handler.name, allow_promote, index=sync.index): + if not self.dcs.write_sync_state(self.state_handler.name, allow_promote, version=sync.version): return logger.info("Synchronous replication key updated by someone else") logger.info("Synchronous standby status assigned to %s", list(allow_promote)) else: - if not self.cluster.sync.is_empty and self.dcs.delete_sync_state(index=self.cluster.sync.index): + if not self.cluster.sync.is_empty and self.dcs.delete_sync_state(version=self.cluster.sync.version): logger.info("Disabled synchronous replication") self.state_handler.sync_handler.set_synchronous_standby_names(CaseInsensitiveSet()) @@ -727,7 +727,7 @@ def enforce_primary_role(self, message: str, promote_message: str) -> str: if self.is_synchronous_mode(): # Just set ourselves as the authoritative source of truth for now. We don't want to wait for standbys # to connect. We will try finding a synchronous standby in the next cycle. - if not self.dcs.write_sync_state(self.state_handler.name, None, index=self.cluster.sync.index): + if not self.dcs.write_sync_state(self.state_handler.name, None, version=self.cluster.sync.version): # Somebody else updated sync state, it may be due to us losing the lock. To be safe, postpone # promotion until next cycle. TODO: trigger immediate retry of run_cycle return 'Postponing promotion because synchronous replication state was updated by somebody else' @@ -800,9 +800,8 @@ def check_failsafe_topology(self) -> bool: data['slots'] = self.state_handler.slots() except Exception: logger.exception('Exception when called state_handler.slots()') - members = [RemoteMember(name, {'api_url': url}) - for name, url in failsafe.items() - if name != self.state_handler.name] + members = [RemoteMember.from_name_and_data(name, {'api_url': url}) + for name, url in failsafe.items() if name != self.state_handler.name] if not members: # A sinlge node cluster return True pool = ThreadPool(len(members)) @@ -907,7 +906,7 @@ def manual_failover_process_no_leader(self) -> Optional[bool]: if not self.cluster.get_member(failover.candidate, fallback_to_leader=False)\ and self.state_handler.is_leader(): logger.warning("manual failover: removing failover key because failover candidate is not running") - self.dcs.manual_failover('', '', index=failover.index) + self.dcs.manual_failover('', '', version=failover.version) return None return False @@ -998,7 +997,8 @@ def is_healthiest_node(self) -> bool: if failsafe_members and self.state_handler.name not in failsafe_members: return False # Race among not only existing cluster members, but also all known members from the failsafe config - all_known_members += [RemoteMember(name, {'api_url': url}) for name, url in failsafe_members.items()] + all_known_members += [RemoteMember.from_name_and_data(name, {'api_url': url}) + for name, url in failsafe_members.items()] all_known_members += self.cluster.members # When in sync mode, only last known primary and sync standby are allowed to promote automatically. @@ -1148,7 +1148,7 @@ def process_manual_failover_from_leader(self) -> Optional[str]: if (failover.scheduled_at and not self.should_run_scheduled_action("failover", failover.scheduled_at, lambda: - self.dcs.manual_failover('', '', index=failover.index))): + self.dcs.manual_failover('', '', version=failover.version))): return if not failover.leader or failover.leader == self.state_handler.name: @@ -1178,7 +1178,7 @@ def process_manual_failover_from_leader(self) -> Optional[str]: failover.leader, self.state_handler.name) logger.info('Cleaning up failover key') - self.dcs.manual_failover('', '', index=failover.index) + self.dcs.manual_failover('', '', version=failover.version) def process_unhealthy_cluster(self) -> str: """Cluster has no leader key""" @@ -1189,7 +1189,7 @@ def process_unhealthy_cluster(self) -> str: if failover: if self.is_paused() and failover.leader and failover.candidate: logger.info('Updating failover key after acquiring leader lock...') - self.dcs.manual_failover('', failover.candidate, failover.scheduled_at, failover.index) + self.dcs.manual_failover('', failover.candidate, failover.scheduled_at, failover.version) else: logger.info('Cleaning up failover key after acquiring leader lock...') self.dcs.manual_failover('', '') @@ -1839,11 +1839,11 @@ def _before_shutdown() -> None: def watch(self, timeout: float) -> bool: # watch on leader key changes if the postgres is running and leader is known and current node is not lock owner if self._async_executor.busy or not self.cluster or self.cluster.is_unlocked() or self.has_lock(False): - leader_index = None + leader_version = None else: - leader_index = self.cluster.leader.index if self.cluster.leader else None + leader_version = self.cluster.leader.version if self.cluster.leader else None - return self.dcs.watch(leader_index, timeout) + return self.dcs.watch(leader_version, timeout) def wakeup(self) -> None: """Call of this method will trigger the next run of HA loop if there is @@ -1868,4 +1868,4 @@ def get_remote_member(self, member: Union[Leader, Member, None] = None) -> Remot data['conn_kwargs'] = conn_kwargs name = member.name if member else 'remote_member:{}'.format(uuid.uuid1()) - return RemoteMember(name, data) + return RemoteMember.from_name_and_data(name, data) diff --git a/patroni/log.py b/patroni/log.py index 955594ffc..55b633867 100644 --- a/patroni/log.py +++ b/patroni/log.py @@ -13,7 +13,7 @@ from queue import Queue, Full from threading import Lock, Thread -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union, TYPE_CHECKING _LOGGER = logging.getLogger(__name__) @@ -249,8 +249,9 @@ def reload_config(self, config: Dict[str, Any]) -> None: if not isinstance(self.log_handler, RotatingFileHandler): new_handler = RotatingFileHandler(os.path.join(config['dir'], __name__)) handler = new_handler or self.log_handler - assert isinstance(handler, RotatingFileHandler) - handler.maxBytes = int(config.get('file_size', 25000000)) + if TYPE_CHECKING: # pragma: no cover + assert isinstance(handler, RotatingFileHandler) + handler.maxBytes = int(config.get('file_size', 25000000)) # pyright: ignore [reportGeneralTypeIssues] handler.backupCount = int(config.get('file_num', 4)) else: if self.log_handler is None or isinstance(self.log_handler, RotatingFileHandler): @@ -306,7 +307,8 @@ def run(self) -> None: while True: self._close_old_handlers() - assert self.log_handler is not None + if TYPE_CHECKING: # pragma: no cover + assert self.log_handler is not None record = self._queue_handler.queue.get(True) # special message that indicates Patroni is shutting down diff --git a/patroni/postgresql/__init__.py b/patroni/postgresql/__init__.py index 7f7928263..305aa4ceb 100644 --- a/patroni/postgresql/__init__.py +++ b/patroni/postgresql/__init__.py @@ -192,7 +192,7 @@ def cluster_info_query(self) -> str: "FROM pg_catalog.pg_stat_get_wal_senders() w," " pg_catalog.pg_stat_get_activity(w.pid)" " WHERE w.state = 'streaming') r)").format(self.wal_name, self.lsn_name) - if (not self._global_config or self._global_config.is_synchronous_mode) + if (not self.global_config or self.global_config.is_synchronous_mode) and self.role in ('master', 'primary', 'promoted') else "'on', '', NULL") if self._major_version >= 90600: @@ -375,6 +375,10 @@ def set_enforce_hot_standby_feedback(self, value: bool) -> None: self.config.write_postgresql_conf() self.reload() + @property + def global_config(self) -> Optional['GlobalConfig']: + return self._global_config + def reset_cluster_info_state(self, cluster: Union[Cluster, None], nofailover: bool = False, global_config: Optional['GlobalConfig'] = None) -> None: """Reset monitoring query cache. @@ -388,7 +392,7 @@ def reset_cluster_info_state(self, cluster: Union[Cluster, None], nofailover: bo :param global_config: last known :class:`GlobalConfig` object """ self._cluster_info_state = {} - if cluster and cluster.config and cluster.config.modify_index: + if cluster and cluster.config and cluster.config.modify_version: self._has_permanent_logical_slots =\ cluster.has_permanent_logical_slots(self.name, nofailover, self.major_version) diff --git a/patroni/postgresql/config.py b/patroni/postgresql/config.py index 4d6bacfba..76a3fb7bc 100644 --- a/patroni/postgresql/config.py +++ b/patroni/postgresql/config.py @@ -860,9 +860,9 @@ def get_server_parameters(self, config: Dict[str, Any]) -> CaseInsensitiveDict: parameters = config['parameters'].copy() listen_addresses, port = split_host_port(config['listen'], 5432) parameters.update(cluster_name=self._postgresql.scope, listen_addresses=listen_addresses, port=str(port)) - if not self._postgresql._global_config or self._postgresql._global_config.is_synchronous_mode: + if not self._postgresql.global_config or self._postgresql.global_config.is_synchronous_mode: if self._synchronous_standby_names is None: - if self._postgresql._global_config and self._postgresql._global_config.is_synchronous_mode_strict\ + if self._postgresql.global_config and self._postgresql.global_config.is_synchronous_mode_strict\ and self._postgresql.role in ('master', 'primary', 'promoted'): parameters['synchronous_standby_names'] = '*' else: diff --git a/patroni/postgresql/sync.py b/patroni/postgresql/sync.py index 78e7b4272..c56bdbcd6 100644 --- a/patroni/postgresql/sync.py +++ b/patroni/postgresql/sync.py @@ -240,10 +240,10 @@ def current_state(self, cluster: Cluster) -> Tuple[CaseInsensitiveSet, CaseInsen if len(replica_list) > 1 else self._postgresql.last_operation() if TYPE_CHECKING: # pragma: no cover - assert self._postgresql._global_config is not None - sync_node_count = self._postgresql._global_config.synchronous_node_count\ + assert self._postgresql.global_config is not None + sync_node_count = self._postgresql.global_config.synchronous_node_count\ if self._postgresql.supports_multiple_sync else 1 - sync_node_maxlag = self._postgresql._global_config.maximum_lag_on_syncnode + sync_node_maxlag = self._postgresql.global_config.maximum_lag_on_syncnode candidates = CaseInsensitiveSet() sync_nodes = CaseInsensitiveSet() diff --git a/patroni/scripts/wale_restore.py b/patroni/scripts/wale_restore.py index c2d3545c5..7ef5d2c96 100755 --- a/patroni/scripts/wale_restore.py +++ b/patroni/scripts/wale_restore.py @@ -32,7 +32,7 @@ import time from enum import IntEnum -from typing import Any, List, NamedTuple, Optional, Tuple +from typing import Any, List, NamedTuple, Optional, Tuple, TYPE_CHECKING from .. import psycopg @@ -365,7 +365,8 @@ def main() -> int: break time.sleep(RETRY_SLEEP_INTERVAL) - assert exit_code is not None + if TYPE_CHECKING: # pragma: no cover + assert exit_code is not None return exit_code diff --git a/patroni/validator.py b/patroni/validator.py index 569894d15..e867ce585 100644 --- a/patroni/validator.py +++ b/patroni/validator.py @@ -11,7 +11,7 @@ import socket import subprocess -from typing import Any, Dict, Union, Iterator, List, Optional as OptionalType +from typing import Any, Dict, Union, Iterator, List, Optional as OptionalType, TYPE_CHECKING from .utils import parse_int, split_host_port, data_directory_is_empty from .dcs import dcs_modules @@ -196,7 +196,8 @@ def get_major_version(bin_dir: OptionalType[str] = None) -> str: binary = os.path.join(bin_dir, 'postgres') version = subprocess.check_output([binary, '--version']).decode() version = re.match(r'^[^\s]+ [^\s]+ (\d+)(\.(\d+))?', version) - assert version is not None + if TYPE_CHECKING: # pragma: no cover + assert version is not None return '.'.join([version.group(1), version.group(3)]) if int(version.group(1)) < 10 else version.group(1) diff --git a/patroni/watchdog/base.py b/patroni/watchdog/base.py index bebdabe25..7d1cdca67 100644 --- a/patroni/watchdog/base.py +++ b/patroni/watchdog/base.py @@ -34,7 +34,7 @@ def parse_mode(mode: Union[bool, str]) -> str: def synchronized(func: Callable[..., Any]) -> Callable[..., Any]: def wrapped(self: 'Watchdog', *args: Any, **kwargs: Any) -> Any: - with self._lock: + with self.lock: return func(self, *args, **kwargs) return wrapped @@ -90,7 +90,7 @@ class Watchdog(object): def __init__(self, config: Config) -> None: self.config = WatchdogConfig(config) self.active_config: WatchdogConfig = self.config - self._lock = RLock() + self.lock = RLock() self.active = False if self.config.mode == MODE_OFF: diff --git a/patroni/watchdog/linux.py b/patroni/watchdog/linux.py index 3c96a268b..6404f3038 100644 --- a/patroni/watchdog/linux.py +++ b/patroni/watchdog/linux.py @@ -1,3 +1,4 @@ +# pyright: reportConstantRedefinition=false import ctypes import os import platform diff --git a/pyrightconfig.json b/pyrightconfig.json index 20980afb3..2394c6de9 100644 --- a/pyrightconfig.json +++ b/pyrightconfig.json @@ -19,7 +19,7 @@ "reportMissingImports": true, "reportMissingTypeStubs": false, - "pythonVersion": "3.6", + "pythonVersion": "3.11", "pythonPlatform": "All", "typeCheckingMode": "strict" diff --git a/tests/test_ha.py b/tests/test_ha.py index c97345173..8495a1f0c 100644 --- a/tests/test_ha.py +++ b/tests/test_ha.py @@ -1201,7 +1201,7 @@ def test_sync_replication_become_primary(self): # When we just became primary nobody is sync self.assertEqual(self.ha.enforce_primary_role('msg', 'promote msg'), 'promote msg') mock_set_sync.assert_called_once_with(CaseInsensitiveSet()) - mock_write_sync.assert_called_once_with('leader', None, index=0) + mock_write_sync.assert_called_once_with('leader', None, version=0) mock_set_sync.reset_mock() @@ -1239,7 +1239,7 @@ def test_unhealthy_sync_mode(self): mock_acquire.assert_called_once() mock_follow.assert_not_called() mock_promote.assert_called_once() - mock_write_sync.assert_called_once_with('other', None, index=0) + mock_write_sync.assert_called_once_with('other', None, version=0) def test_disable_sync_when_restarting(self): self.ha.is_synchronous_mode = true diff --git a/tests/test_postgresql.py b/tests/test_postgresql.py index 1a44b6716..f7c0c155e 100644 --- a/tests/test_postgresql.py +++ b/tests/test_postgresql.py @@ -339,7 +339,8 @@ def test_write_postgresql_and_sanitize_auto_conf(self): @patch.object(Postgresql, 'start', Mock()) def test_follow(self): self.p.call_nowait(CallbackAction.ON_START) - m = RemoteMember('1', {'restore_command': '2', 'primary_slot_name': 'foo', 'conn_kwargs': {'host': 'bar'}}) + m = RemoteMember.from_name_and_data('1', {'restore_command': '2', 'primary_slot_name': 'foo', + 'conn_kwargs': {'host': 'bar'}}) self.p.follow(m) with patch.object(Postgresql, 'ensure_major_version_is_known', Mock(return_value=False)): self.assertIsNone(self.p.follow(m))