Skip to content

Commit

Permalink
Enable pyright job for every commit (patroni#2675)
Browse files Browse the repository at this point in the history
And fix remaining issues that the job doesn't fail.
  • Loading branch information
CyberDem0n authored May 15, 2023
1 parent fdcf8b1 commit 66a0e44
Show file tree
Hide file tree
Showing 23 changed files with 337 additions and 280 deletions.
15 changes: 15 additions & 0 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,18 @@ jobs:
steps:
- run: bash <(curl -Ls https://coverage.codacy.com/get.sh) final
if: ${{ env.SECRETS_AVAILABLE == 'true' }}

pyright:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3

- name: Set up Python 3.11
uses: actions/setup-python@v4
with:
python-version: 3.11

- name: Install dependencies
run: python -m pip install -r requirements.txt psycopg2-binary psycopg

- uses: jakebailey/pyright-action@v1
65 changes: 35 additions & 30 deletions patroni/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from threading import Thread
from urllib.parse import urlparse, parse_qs

from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, TYPE_CHECKING, Union

from . import psycopg
from .__main__ import Patroni
Expand Down Expand Up @@ -86,7 +86,8 @@ def __init__(self, request: Any,
:param client_address: address of the client connection.
:param server: HTTP server that received the request.
"""
assert isinstance(server, RestApiServer)
if TYPE_CHECKING: # pragma: no cover
assert isinstance(server, RestApiServer)
super(RestApiHandler, self).__init__(request, client_address, server)
self.server: 'RestApiServer' = server
self.__start_time: float = 0.0
Expand All @@ -113,8 +114,8 @@ def _write_status_code_only(self, status_code: int) -> None:
self.wfile.write('{0} {1} {2}\r\n\r\n'.format(self.protocol_version, status_code, message).encode('utf-8'))
self.log_request(status_code)

def _write_response(self, status_code: int, body: str, content_type: str = 'text/html',
headers: Optional[Dict[str, str]] = None) -> None:
def write_response(self, status_code: int, body: str, content_type: str = 'text/html',
headers: Optional[Dict[str, str]] = None) -> None:
"""Write an HTTP response.
.. note::
Expand Down Expand Up @@ -143,12 +144,12 @@ def _write_response(self, status_code: int, body: str, content_type: str = 'text
def _write_json_response(self, status_code: int, response: Any) -> None:
"""Write an HTTP response with a JSON content type.
Call :func:`_write_response` with ``content_type`` as ``application/json``.
Call :func:`write_response` with ``content_type`` as ``application/json``.
:param status_code: response HTTP status code.
:param response: value to be dumped as a JSON string and to be used as the response body.
"""
self._write_response(status_code, json.dumps(response, default=str), content_type='application/json')
self.write_response(status_code, json.dumps(response, default=str), content_type='application/json')

def _write_status_response(self, status_code: int, response: Dict[str, Any]) -> None:
"""Write an HTTP response with Patroni/Postgres status in JSON format.
Expand Down Expand Up @@ -457,7 +458,7 @@ def do_GET_metrics(self) -> None:
* ``patroni_xlog_paused``: ``pg_is_wal_replay_paused()``;
* ``patroni_postgres_server_version``: Postgres version without periods, e.g. ``150002`` for Postgres ``15.2``;
* ``patroni_cluster_unlocked``: ``1`` if no one holds the leader lock, else ``0``;
* ``patroni_failsafe_mode_is_active``: ``1`` if ``failmode`` is currently active, else ``0``;
* ``patroni_failsafe_mode_is_active``: ``1`` if ``failsafe_mode`` is currently active, else ``0``;
* ``patroni_postgres_timeline``: PostgreSQL timeline based on current WAL file name;
* ``patroni_dcs_last_seen``: epoch timestamp when DCS was last contacted successfully;
* ``patroni_pending_restart``: ``1`` if this PostgreSQL node is pending a restart, else ``0``;
Expand Down Expand Up @@ -565,7 +566,7 @@ def do_GET_metrics(self) -> None:
metrics.append("# TYPE patroni_is_paused gauge")
metrics.append("patroni_is_paused{0} {1}".format(scope_label, int(postgres.get('pause', 0))))

self._write_response(200, '\n'.join(metrics) + '\n', content_type='text/plain')
self.write_response(200, '\n'.join(metrics) + '\n', content_type='text/plain')

def _read_json_content(self, body_is_optional: bool = False) -> Optional[Dict[Any, Any]]:
"""Read JSON from HTTP request body.
Expand Down Expand Up @@ -615,12 +616,12 @@ def do_PATCH_config(self) -> None:
request = self._read_json_content()
if request:
cluster = self.server.patroni.dcs.get_cluster(True)
if not (cluster.config and cluster.config.modify_index):
if not (cluster.config and cluster.config.modify_version):
return self.send_error(503)
data = cluster.config.data.copy()
if patch_config(data, request):
value = json.dumps(data, separators=(',', ':'))
if not self.server.patroni.dcs.set_config_value(value, cluster.config.index):
if not self.server.patroni.dcs.set_config_value(value, cluster.config.version):
return self.send_error(409)
self.server.patroni.ha.wakeup()
self._write_json_response(200, data)
Expand Down Expand Up @@ -651,7 +652,7 @@ def do_POST_reload(self) -> None:
Schedules a reload to Patroni and writes a response with HTTP status `202`.
"""
self.server.patroni.sighup_handler()
self._write_response(202, 'reload scheduled')
self.write_response(202, 'reload scheduled')

def do_GET_failsafe(self) -> None:
"""Handle a ``GET`` request to ``/failsafe`` path.
Expand Down Expand Up @@ -684,7 +685,7 @@ def do_POST_failsafe(self) -> None:
if request:
message = self.server.patroni.ha.update_failsafe(request) or 'Accepted'
code = 200 if message == 'Accepted' else 500
self._write_response(code, message)
self.write_response(code, message)
else:
self.send_error(502)

Expand All @@ -699,7 +700,7 @@ def do_POST_sigterm(self) -> None:
"""
if os.name == 'nt' and os.getenv('BEHAVE_DEBUG'):
self.server.patroni.api_sigterm()
self._write_response(202, 'shutdown scheduled')
self.write_response(202, 'shutdown scheduled')

@staticmethod
def parse_schedule(schedule: str,
Expand Down Expand Up @@ -779,7 +780,7 @@ def do_POST_restart(self) -> None:
logger.debug("received restart request: {0}".format(request))

if self.server.patroni.config.get_global_config(cluster).is_paused and 'schedule' in request:
self._write_response(status_code, "Can't schedule restart in the paused state")
self.write_response(status_code, "Can't schedule restart in the paused state")
return

for k in request:
Expand Down Expand Up @@ -827,8 +828,9 @@ def do_POST_restart(self) -> None:
status_code = 409
# pyright thinks ``data`` can be ``None`` because ``parse_schedule`` call may return ``None``. However, if
# that's the case, ``data`` will be overwritten when the ``for`` loop ends
assert isinstance(data, str)
self._write_response(status_code, data)
if TYPE_CHECKING: # pragma: no cover
assert isinstance(data, str)
self.write_response(status_code, data)

@check_access
def do_DELETE_restart(self) -> None:
Expand All @@ -846,7 +848,7 @@ def do_DELETE_restart(self) -> None:
else:
data = "no restarts are scheduled"
code = 404
self._write_response(code, data)
self.write_response(code, data)

@check_access
def do_DELETE_switchover(self) -> None:
Expand All @@ -861,15 +863,15 @@ def do_DELETE_switchover(self) -> None:
"""
failover = self.server.patroni.dcs.get_cluster().failover
if failover and failover.scheduled_at:
if not self.server.patroni.dcs.manual_failover('', '', index=failover.index):
if not self.server.patroni.dcs.manual_failover('', '', version=failover.version):
return self.send_error(409)
else:
data = "scheduled switchover deleted"
code = 200
else:
data = "no switchover is scheduled"
code = 404
self._write_response(code, data)
self.write_response(code, data)

@check_access
def do_POST_reinitialize(self) -> None:
Expand All @@ -895,7 +897,7 @@ def do_POST_reinitialize(self) -> None:
data = 'reinitialize started'
else:
status_code = 503
self._write_response(status_code, data)
self.write_response(status_code, data)

def poll_failover_result(self, leader: Optional[str], candidate: Optional[str], action: str) -> Tuple[int, str]:
"""Poll failover/switchover operation until it finishes or times out.
Expand Down Expand Up @@ -1033,9 +1035,10 @@ def do_POST_failover(self, action: str = 'failover') -> None:
status_code = 503
# pyright thinks ``status_code`` can be ``None`` because ``parse_schedule`` call may return ``None``. However,
# if that's the case, ``status_code`` will be overwritten somewhere between ``parse_schedule`` and
# ``_write_response`` calls.
assert isinstance(status_code, int)
self._write_response(status_code, data)
# ``write_response`` calls.
if TYPE_CHECKING: # pragma: no cover
assert isinstance(status_code, int)
self.write_response(status_code, data)

def do_POST_switchover(self) -> None:
"""Handle a ``POST`` request to ``/switchover`` path.
Expand All @@ -1062,7 +1065,7 @@ def do_POST_citus(self) -> None:
if patroni.postgresql.citus_handler.is_coordinator() and patroni.ha.is_leader():
cluster = patroni.dcs.get_cluster(True)
patroni.postgresql.citus_handler.handle_event(cluster, request)
self._write_response(200, 'OK')
self.write_response(200, 'OK')

def parse_request(self) -> bool:
"""Override :func:`parse_request` method to enrich basic functionality of :class:`BaseHTTPRequestHandler`.
Expand Down Expand Up @@ -1242,7 +1245,7 @@ def __init__(self, patroni: Patroni, config: Dict[str, Any]) -> None:
:param patroni: Patroni daemon process.
:param config: ``restapi`` section of Patroni configuration.
"""
self.connection_string = None
self.connection_string: str
self.__auth_key = None
self.__allowlist_include_members: Optional[bool] = None
self.__allowlist: Tuple[Union[IPv4Network, IPv6Network], ...] = ()
Expand Down Expand Up @@ -1300,7 +1303,8 @@ def check_basic_auth_key(self, key: str) -> bool:
:returns: ``True`` if *key* matches the password configured for the REST API.
"""
# pyright -- ``__auth_key`` was already checked through the caller method (:func:`check_auth_header`).
assert self.__auth_key is not None
if TYPE_CHECKING: # pragma: no cover
assert self.__auth_key is not None
return hmac.compare_digest(self.__auth_key, key.encode('utf-8'))

def check_auth_header(self, auth_header: Optional[str]) -> Optional[str]:
Expand Down Expand Up @@ -1371,16 +1375,16 @@ def check_access(self, rh: RestApiHandler) -> Optional[bool]:
if self.__allowlist or self.__allowlist_include_members:
incoming_ip = ip_address(rh.client_address[0])
if not any(incoming_ip in net for net in self.__allowlist + tuple(self.__members_ips())):
return rh._write_response(403, 'Access is denied')
return rh.write_response(403, 'Access is denied')

if not hasattr(rh.request, 'getpeercert') or not rh.request.getpeercert(): # valid client cert isn't present
if self.__protocol == 'https' and self.__ssl_options.get('verify_client') in ('required', 'optional'):
return rh._write_response(403, 'client certificate required')
return rh.write_response(403, 'client certificate required')

reason = self.check_auth_header(rh.headers.get('Authorization'))
if reason:
headers = {'WWW-Authenticate': 'Basic realm="' + self.patroni.__class__.__name__ + '"'}
return rh._write_response(401, reason, headers=headers)
return rh.write_response(401, reason, headers=headers)
return True

@staticmethod
Expand Down Expand Up @@ -1603,7 +1607,8 @@ def reload_config(self, config: Dict[str, Any]) -> None:
self.__auth_key = base64.b64encode(config['auth'].encode('utf-8')) if 'auth' in config else None
# pyright -- ``__listen`` is initially created as ``None``, but right after that it is replaced with a string
# through :func:`__initialize`.
assert isinstance(self.__listen, str)
if TYPE_CHECKING: # pragma: no cover
assert isinstance(self.__listen, str)
self.connection_string = uri(self.__protocol, config.get('connect_address') or self.__listen, 'patroni')

def handle_error(self, request: Union[socket.socket, Tuple[bytes, socket.socket]],
Expand Down
15 changes: 8 additions & 7 deletions patroni/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from collections import defaultdict
from copy import deepcopy
from typing import Any, Callable, Collection, Dict, List, Optional, Union
from typing import Any, Callable, Collection, Dict, List, Optional, Union, TYPE_CHECKING

from . import PATRONI_ENV_PREFIX
from .collections import CaseInsensitiveDict
Expand Down Expand Up @@ -151,7 +151,7 @@ def get_global_config(cluster: Union[Cluster, None], default: Optional[Dict[str,
:returns: :class:`GlobalConfig` object
"""
# Try to protect from the case when DCS was wiped out
if cluster and cluster.config and cluster.config.modify_index:
if cluster and cluster.config and cluster.config.modify_version:
config = cluster.config.data
else:
config = default or {}
Expand Down Expand Up @@ -202,7 +202,7 @@ class Config(object):

def __init__(self, configfile: str,
validator: Optional[Callable[[Dict[str, Any]], List[str]]] = default_validator) -> None:
self._modify_index = -1
self._modify_version = -1
self._dynamic_configuration = {}

self.__environment_configuration = self._build_environment_configuration()
Expand Down Expand Up @@ -257,7 +257,8 @@ def _load_config_path(self, path: str) -> Dict[str, Any]:

def _load_config_file(self) -> Dict[str, Any]:
"""Loads config.yaml from filesystem and applies some values which were set via ENV"""
assert self._config_file is not None
if TYPE_CHECKING: # pragma: no cover
assert self._config_file is not None
config = self._load_config_path(self._config_file)
patch_config(config, self.__environment_configuration)
return config
Expand Down Expand Up @@ -296,9 +297,9 @@ def save_cache(self) -> None:
# configuration could be either ClusterConfig or dict
def set_dynamic_configuration(self, configuration: Union[ClusterConfig, Dict[str, Any]]) -> bool:
if isinstance(configuration, ClusterConfig):
if self._modify_index == configuration.modify_index:
return False # If the index didn't changed there is nothing to do
self._modify_index = configuration.modify_index
if self._modify_version == configuration.modify_version:
return False # If the version didn't changed there is nothing to do
self._modify_version = configuration.modify_version
configuration = configuration.data

if not deep_compare(self._dynamic_configuration, configuration):
Expand Down
20 changes: 11 additions & 9 deletions patroni/ctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@
from psycopg2 import cursor

try:
from ydiff import markup_to_pager, PatchStream
from ydiff import markup_to_pager, PatchStream # pyright: ignore [reportMissingModuleSource]
except ImportError: # pragma: no cover
from cdiff import markup_to_pager, PatchStream
from cdiff import markup_to_pager, PatchStream # pyright: ignore [reportMissingModuleSource]

from .dcs import get_dcs as _get_dcs, AbstractDCS, Cluster, Member
from .exceptions import PatroniException
Expand Down Expand Up @@ -812,7 +812,8 @@ def _do_failover_or_switchover(obj: Dict[str, Any], action: str, cluster_name: s
r = None
try:
member = cluster.leader.member if cluster.leader else candidate and cluster.get_member(candidate, False)
assert isinstance(member, Member)
if TYPE_CHECKING: # pragma: no cover
assert isinstance(member, Member)
r = request_patroni(member, 'post', action, failover_value)

# probably old patroni, which doesn't support switchover yet
Expand Down Expand Up @@ -1052,15 +1053,15 @@ def flush(obj: Dict[str, Any], cluster_name: str, group: Optional[int],

logging.warning('Failing over to DCS')
click.echo('{0} Could not find any accessible member of cluster {1}'.format(timestamp(), cluster_name))
dcs.manual_failover('', '', index=failover.index)
dcs.manual_failover('', '', version=failover.version)


def wait_until_pause_is_applied(dcs: AbstractDCS, paused: bool, old_cluster: Cluster) -> None:
from patroni.config import get_global_config
config = get_global_config(old_cluster)

click.echo("'{0}' request sent, waiting until it is recognized by all nodes".format(paused and 'pause' or 'resume'))
old = {m.name: m.index for m in old_cluster.members if m.api_url}
old = {m.name: m.version for m in old_cluster.members if m.api_url}
loop_wait = config.get('loop_wait') or dcs.loop_wait

cluster = None
Expand All @@ -1072,7 +1073,7 @@ def wait_until_pause_is_applied(dcs: AbstractDCS, paused: bool, old_cluster: Clu
if TYPE_CHECKING: # pragma: no cover
assert cluster is not None
remaining = [m.name for m in cluster.members if m.data.get('pause', False) != paused
and m.name in old and old[m.name] != m.index]
and m.name in old and old[m.name] != m.version]
if remaining:
return click.echo("{0} members didn't recognized pause state after {1} seconds"
.format(', '.join(remaining), loop_wait))
Expand Down Expand Up @@ -1169,7 +1170,7 @@ class opts:
(
os.path.basename(p)
for p in (os.environ.get('PAGER'), "less", "more")
if p is not None and shutil.which(p)
if p is not None and bool(shutil.which(p))
),
None,
)
Expand Down Expand Up @@ -1347,7 +1348,7 @@ def edit_config(obj: Dict[str, Any], cluster_name: str, group: Optional[int],
return

if force or click.confirm('Apply these changes?'):
if not dcs.set_config_value(json.dumps(changed_data), cluster.config.index):
if not dcs.set_config_value(json.dumps(changed_data), cluster.config.version):
raise PatroniCtlException("Config modification aborted due to concurrent changes")
click.echo("Configuration changed")

Expand Down Expand Up @@ -1396,7 +1397,8 @@ def version(obj: Dict[str, Any], cluster_name: str, group: Optional[int], member
@click.pass_obj
def history(obj: Dict[str, Any], cluster_name: str, group: Optional[int], fmt: str) -> None:
cluster = get_dcs(obj, cluster_name, group).get_cluster()
history: List[List[Any]] = list(map(list, cluster.history and cluster.history.lines or []))
cluster_history = cluster.history.lines if cluster.history else []
history: List[List[Any]] = list(map(list, cluster_history))
table_header_row = ['TL', 'LSN', 'Reason', 'Timestamp', 'New Leader']
for line in history:
if len(line) < len(table_header_row):
Expand Down
Loading

0 comments on commit 66a0e44

Please sign in to comment.