diff --git a/patroni/dcs/etcd3.py b/patroni/dcs/etcd3.py index 5cbd813f6..ea7e52f24 100644 --- a/patroni/dcs/etcd3.py +++ b/patroni/dcs/etcd3.py @@ -124,6 +124,10 @@ class AuthFailed(InvalidArgument): error = "etcdserver: authentication failed, invalid user ID or password" +class AuthOldRevision(InvalidArgument): + error = "etcdserver: revision of auth store is old" + + class PermissionDenied(Etcd3ClientError): code = GRPCCode.PermissionDenied error = "etcdserver: permission denied" @@ -193,6 +197,12 @@ def build_range_request(key: str, range_end: Union[bytes, str, None] = None) -> return fields +class ReAuthenticateMode(IntEnum): + NOT_REQUIRED = 0 + REQUIRED = 1 + WITHOUT_WATCHER_RESTART = 2 + + def _handle_auth_errors(func: Callable[..., Any]) -> Any: def wrapper(self: 'Etcd3Client', *args: Any, **kwargs: Any) -> Any: return self.handle_auth_errors(func, *args, **kwargs) @@ -204,6 +214,7 @@ class Etcd3Client(AbstractEtcdClientWithFailover): ERROR_CLS = Etcd3Error def __init__(self, config: Dict[str, Any], dns_resolver: DnsCachingResolver, cache_ttl: int = 300) -> None: + self._reauthenticate_reason = ReAuthenticateMode.NOT_REQUIRED self._token = None self._cluster_version: Tuple[int, ...] = tuple() super(Etcd3Client, self).__init__({**config, 'version_prefix': '/v3beta'}, dns_resolver, cache_ttl) @@ -282,7 +293,7 @@ def call_rpc(self, method: str, fields: Dict[str, Any], retry: Optional[Retry] = fields['retry'] = retry return self.api_execute(self.version_prefix + method, self._MPOST, fields) - def authenticate(self) -> bool: + def authenticate(self, *, restart_watcher: bool = True, retry: Optional[Retry] = None) -> bool: if self._use_proxies and not self._cluster_version: kwargs = self._prepare_common_parameters(1) self._ensure_version_prefix(self._base_uri, **kwargs) @@ -291,7 +302,7 @@ def authenticate(self) -> bool: logger.info('Trying to authenticate on Etcd...') old_token, self._token = self._token, None try: - response = self.call_rpc('/auth/authenticate', {'name': self.username, 'password': self.password}) + response = self.call_rpc('/auth/authenticate', {'name': self.username, 'password': self.password}, retry) except AuthNotEnabled: logger.info('Etcd authentication is not enabled') self._token = None @@ -302,48 +313,65 @@ def authenticate(self) -> bool: self._token = response.get('token') return old_token != self._token - def handle_auth_errors(self: 'Etcd3Client', func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: - def retry(ex: Exception) -> Any: - if self.username and self.password: - self.authenticate() - return func(self, *args, **kwargs) - else: - logger.fatal('Username or password not set, authentication is not possible') - raise ex + def handle_auth_errors(self: 'Etcd3Client', func: Callable[..., Any], *args: Any, + retry: Optional[Retry] = None, **kwargs: Any) -> Any: + exc = None + while True: + if self._reauthenticate_reason: + if self.username and self.password: + self.authenticate( + restart_watcher=self._reauthenticate_reason != ReAuthenticateMode.WITHOUT_WATCHER_RESTART, + retry=retry) + self._reauthenticate_reason = ReAuthenticateMode.NOT_REQUIRED + if retry: + retry.ensure_deadline(0) + else: + msg = 'Username or password not set, authentication is not possible' + logger.fatal(msg) + raise exc or Etcd3Exception(msg) - try: - return func(self, *args, **kwargs) - except (UserEmpty, PermissionDenied) as e: # no token provided - # PermissionDenied is raised on 3.0 and 3.1 - if self._cluster_version < (3, 3) and (not isinstance(e, PermissionDenied) - or self._cluster_version < (3, 2)): - raise UnsupportedEtcdVersion('Authentication is required by Etcd cluster but not ' - 'supported on version lower than 3.3.0. Cluster version: ' - '{0}'.format('.'.join(map(str, self._cluster_version)))) - return retry(e) - except InvalidAuthToken as e: - logger.error('Invalid auth token: %s', self._token) - return retry(e) + try: + return func(self, *args, retry=retry, **kwargs) + except (UserEmpty, PermissionDenied) as e: # no token provided + # PermissionDenied is raised on 3.0 and 3.1 + if self._cluster_version < (3, 3) and (not isinstance(e, PermissionDenied) + or self._cluster_version < (3, 2)): + raise UnsupportedEtcdVersion('Authentication is required by Etcd cluster but not ' + 'supported on version lower than 3.3.0. Cluster version: ' + '{0}'.format('.'.join(map(str, self._cluster_version)))) + exc = e + except InvalidAuthToken as e: + logger.error('Invalid auth token: %s', self._token) + exc = e + except AuthOldRevision as e: + logger.error('Auth token is for old revision of auth store') + exc = e + self._reauthenticate_reason = ReAuthenticateMode.WITHOUT_WATCHER_RESTART \ + if isinstance(exc, AuthOldRevision) else ReAuthenticateMode.REQUIRED + if not retry: + raise exc + retry.ensure_deadline(0.5, exc) @_handle_auth_errors def range(self, key: str, range_end: Union[bytes, str, None] = None, serializable: bool = True, - retry: Optional[Retry] = None) -> Dict[str, Any]: + *, retry: Optional[Retry] = None) -> Dict[str, Any]: params = build_range_request(key, range_end) params['serializable'] = serializable # For better performance. We can tolerate stale reads return self.call_rpc('/kv/range', params, retry) - def prefix(self, key: str, serializable: bool = True, retry: Optional[Retry] = None) -> Dict[str, Any]: - return self.range(key, prefix_range_end(key), serializable, retry) + def prefix(self, key: str, serializable: bool = True, *, retry: Optional[Retry] = None) -> Dict[str, Any]: + return self.range(key, prefix_range_end(key), serializable, retry=retry) @_handle_auth_errors - def lease_grant(self, ttl: int, retry: Optional[Retry] = None) -> str: + def lease_grant(self, ttl: int, *, retry: Optional[Retry] = None) -> str: return self.call_rpc('/lease/grant', {'TTL': ttl}, retry)['ID'] - def lease_keepalive(self, ID: str, retry: Optional[Retry] = None) -> Optional[str]: + def lease_keepalive(self, ID: str, *, retry: Optional[Retry] = None) -> Optional[str]: return self.call_rpc('/lease/keepalive', {'ID': ID}, retry).get('result', {}).get('TTL') + @_handle_auth_errors def txn(self, compare: Dict[str, Any], success: Dict[str, Any], - failure: Optional[Dict[str, Any]] = None, retry: Optional[Retry] = None) -> Dict[str, Any]: + failure: Optional[Dict[str, Any]] = None, *, retry: Optional[Retry] = None) -> Dict[str, Any]: fields = {'compare': [compare], 'success': [success]} if failure: fields['failure'] = [failure] @@ -352,7 +380,7 @@ def txn(self, compare: Dict[str, Any], success: Dict[str, Any], @_handle_auth_errors def put(self, key: str, value: str, lease: Optional[str] = None, create_revision: Optional[str] = None, - mod_revision: Optional[str] = None, retry: Optional[Retry] = None) -> Dict[str, Any]: + mod_revision: Optional[str] = None, *, retry: Optional[Retry] = None) -> Dict[str, Any]: fields = {'key': base64_encode(key), 'value': base64_encode(value)} if lease: fields['lease'] = lease @@ -367,14 +395,14 @@ def put(self, key: str, value: str, lease: Optional[str] = None, create_revision @_handle_auth_errors def deleterange(self, key: str, range_end: Union[bytes, str, None] = None, - mod_revision: Optional[str] = None, retry: Optional[Retry] = None) -> Dict[str, Any]: + mod_revision: Optional[str] = None, *, retry: Optional[Retry] = None) -> Dict[str, Any]: fields = build_range_request(key, range_end) if mod_revision is None: return self.call_rpc('/kv/deleterange', fields, retry) compare = {'target': 'MOD', 'mod_revision': mod_revision, 'key': fields['key']} return self.txn(compare, {'request_delete_range': fields}, retry=retry) - def deleteprefix(self, key: str, retry: Optional[Retry] = None) -> Dict[str, Any]: + def deleteprefix(self, key: str, *, retry: Optional[Retry] = None) -> Dict[str, Any]: return self.deleterange(key, prefix_range_end(key), retry=retry) def watchrange(self, key: str, range_end: Union[bytes, str, None] = None, @@ -574,9 +602,9 @@ def set_base_uri(self, value: str) -> None: super(PatroniEtcd3Client, self).set_base_uri(value) self._restart_watcher() - def authenticate(self) -> bool: - ret = super(PatroniEtcd3Client, self).authenticate() - if ret: + def authenticate(self, *, restart_watcher: bool = True, retry: Optional[Retry] = None) -> bool: + ret = super(PatroniEtcd3Client, self).authenticate(restart_watcher=restart_watcher, retry=retry) + if ret and restart_watcher: self._restart_watcher() return ret @@ -631,8 +659,8 @@ def call_rpc(self, method: str, fields: Dict[str, Any], retry: Optional[Retry] = return ret def txn(self, compare: Dict[str, Any], success: Dict[str, Any], - failure: Optional[Dict[str, Any]] = None, retry: Optional[Retry] = None) -> Dict[str, Any]: - ret = super(PatroniEtcd3Client, self).txn(compare, success, failure, retry) + failure: Optional[Dict[str, Any]] = None, *, retry: Optional[Retry] = None) -> Dict[str, Any]: + ret = super(PatroniEtcd3Client, self).txn(compare, success, failure, retry=retry) # Here we abuse the fact that the `failure` is only set in the call from update_leader(). # In all other cases the txn() call failure may be an indicator of a stale cache, # and therefore we want to restart watcher. @@ -676,12 +704,12 @@ def _do_refresh_lease(self, force: bool = False, retry: Optional[Retry] = None) if not force and self._lease and self._last_lease_refresh + self._loop_wait > time.time(): return False - if self._lease and not self._client.lease_keepalive(self._lease, retry): + if self._lease and not self._client.lease_keepalive(self._lease, retry=retry): self._lease = None ret = not self._lease if ret: - self._lease = self._client.lease_grant(self._ttl, retry) + self._lease = self._client.lease_grant(self._ttl, retry=retry) self._last_lease_refresh = time.time() return ret diff --git a/tests/test_etcd3.py b/tests/test_etcd3.py index 9aed7eb19..10ab1ea50 100644 --- a/tests/test_etcd3.py +++ b/tests/test_etcd3.py @@ -6,8 +6,8 @@ from mock import Mock, PropertyMock, patch from patroni.dcs.etcd import DnsCachingResolver from patroni.dcs.etcd3 import PatroniEtcd3Client, Cluster, Etcd3, Etcd3Client, \ - Etcd3Error, Etcd3ClientError, RetryFailedError, InvalidAuthToken, Unavailable, \ - Unknown, UnsupportedEtcdVersion, UserEmpty, AuthFailed, base64_encode + Etcd3Error, Etcd3ClientError, ReAuthenticateMode, RetryFailedError, InvalidAuthToken, Unavailable, \ + Unknown, UnsupportedEtcdVersion, UserEmpty, AuthFailed, AuthOldRevision, base64_encode from threading import Thread from . import SleepException, MockResponse @@ -161,9 +161,16 @@ def test__handle_auth_errors(self, mock_urlopen): mock_urlopen.return_value.content = '{"code":16,"error":"etcdserver: invalid auth token"}' self.assertRaises(InvalidAuthToken, self.client.deleteprefix, 'foo') with patch.object(PatroniEtcd3Client, 'authenticate', Mock(return_value=True)): - self.assertRaises(InvalidAuthToken, self.client.deleteprefix, 'foo') + retry = self.etcd3._retry.copy() + with patch('time.time', Mock(side_effect=[0, 10, 20, 30, 40])): + self.assertRaises(InvalidAuthToken, retry, self.client.deleteprefix, 'foo', retry=retry) self.client.username = None - self.assertRaises(InvalidAuthToken, self.client.deleteprefix, 'foo') + self.client._reauthenticate_reason = ReAuthenticateMode.NOT_REQUIRED + retry = self.etcd3._retry.copy() + self.assertRaises(InvalidAuthToken, retry, self.client.deleteprefix, 'foo', retry=retry) + mock_urlopen.return_value.content = '{"code":3,"error":"etcdserver: revision of auth store is old"}' + self.client._reauthenticate_reason = ReAuthenticateMode.NOT_REQUIRED + self.assertRaises(AuthOldRevision, retry, self.client.deleteprefix, 'foo', retry=retry) def test__handle_server_response(self): response = MockResponse()