Skip to content

Commit

Permalink
Handle AuthOldRevision error (patroni#2913)
Browse files Browse the repository at this point in the history
The error is raised if Etcd is configured to use JWT auth tokens and when the user database in Etcd is updated, because the update invalidates all tokens.

If retries are requested - try to get a new new token and repeat the request. Repeat it in a loop until request is successfully executed or until `retry_timeout` is exhausted. This is the only way of solving a race condition, because between authentication and executing the request yet another modification of the user database in Etcd might happen.

In case if the request doesn't have to be immediately retried - set a flag that the next API request should perform the authentication first and let Patroni to naturally repeat the request on the next heartbeat loop.

Co-authored-by: Kenny Do <[email protected]>
Ref: patroni#2911
  • Loading branch information
CyberDem0n and kennydo authored Oct 23, 2023
1 parent 6d98944 commit d471f11
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 43 deletions.
106 changes: 67 additions & 39 deletions patroni/dcs/etcd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ class AuthFailed(InvalidArgument):
error = "etcdserver: authentication failed, invalid user ID or password"


class AuthOldRevision(InvalidArgument):
error = "etcdserver: revision of auth store is old"


class PermissionDenied(Etcd3ClientError):
code = GRPCCode.PermissionDenied
error = "etcdserver: permission denied"
Expand Down Expand Up @@ -193,6 +197,12 @@ def build_range_request(key: str, range_end: Union[bytes, str, None] = None) ->
return fields


class ReAuthenticateMode(IntEnum):
NOT_REQUIRED = 0
REQUIRED = 1
WITHOUT_WATCHER_RESTART = 2


def _handle_auth_errors(func: Callable[..., Any]) -> Any:
def wrapper(self: 'Etcd3Client', *args: Any, **kwargs: Any) -> Any:
return self.handle_auth_errors(func, *args, **kwargs)
Expand All @@ -204,6 +214,7 @@ class Etcd3Client(AbstractEtcdClientWithFailover):
ERROR_CLS = Etcd3Error

def __init__(self, config: Dict[str, Any], dns_resolver: DnsCachingResolver, cache_ttl: int = 300) -> None:
self._reauthenticate_reason = ReAuthenticateMode.NOT_REQUIRED
self._token = None
self._cluster_version: Tuple[int, ...] = tuple()
super(Etcd3Client, self).__init__({**config, 'version_prefix': '/v3beta'}, dns_resolver, cache_ttl)
Expand Down Expand Up @@ -282,7 +293,7 @@ def call_rpc(self, method: str, fields: Dict[str, Any], retry: Optional[Retry] =
fields['retry'] = retry
return self.api_execute(self.version_prefix + method, self._MPOST, fields)

def authenticate(self) -> bool:
def authenticate(self, *, restart_watcher: bool = True, retry: Optional[Retry] = None) -> bool:
if self._use_proxies and not self._cluster_version:
kwargs = self._prepare_common_parameters(1)
self._ensure_version_prefix(self._base_uri, **kwargs)
Expand All @@ -291,7 +302,7 @@ def authenticate(self) -> bool:
logger.info('Trying to authenticate on Etcd...')
old_token, self._token = self._token, None
try:
response = self.call_rpc('/auth/authenticate', {'name': self.username, 'password': self.password})
response = self.call_rpc('/auth/authenticate', {'name': self.username, 'password': self.password}, retry)
except AuthNotEnabled:
logger.info('Etcd authentication is not enabled')
self._token = None
Expand All @@ -302,48 +313,65 @@ def authenticate(self) -> bool:
self._token = response.get('token')
return old_token != self._token

def handle_auth_errors(self: 'Etcd3Client', func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
def retry(ex: Exception) -> Any:
if self.username and self.password:
self.authenticate()
return func(self, *args, **kwargs)
else:
logger.fatal('Username or password not set, authentication is not possible')
raise ex
def handle_auth_errors(self: 'Etcd3Client', func: Callable[..., Any], *args: Any,
retry: Optional[Retry] = None, **kwargs: Any) -> Any:
exc = None
while True:
if self._reauthenticate_reason:
if self.username and self.password:
self.authenticate(
restart_watcher=self._reauthenticate_reason != ReAuthenticateMode.WITHOUT_WATCHER_RESTART,
retry=retry)
self._reauthenticate_reason = ReAuthenticateMode.NOT_REQUIRED
if retry:
retry.ensure_deadline(0)
else:
msg = 'Username or password not set, authentication is not possible'
logger.fatal(msg)
raise exc or Etcd3Exception(msg)

try:
return func(self, *args, **kwargs)
except (UserEmpty, PermissionDenied) as e: # no token provided
# PermissionDenied is raised on 3.0 and 3.1
if self._cluster_version < (3, 3) and (not isinstance(e, PermissionDenied)
or self._cluster_version < (3, 2)):
raise UnsupportedEtcdVersion('Authentication is required by Etcd cluster but not '
'supported on version lower than 3.3.0. Cluster version: '
'{0}'.format('.'.join(map(str, self._cluster_version))))
return retry(e)
except InvalidAuthToken as e:
logger.error('Invalid auth token: %s', self._token)
return retry(e)
try:
return func(self, *args, retry=retry, **kwargs)
except (UserEmpty, PermissionDenied) as e: # no token provided
# PermissionDenied is raised on 3.0 and 3.1
if self._cluster_version < (3, 3) and (not isinstance(e, PermissionDenied)
or self._cluster_version < (3, 2)):
raise UnsupportedEtcdVersion('Authentication is required by Etcd cluster but not '
'supported on version lower than 3.3.0. Cluster version: '
'{0}'.format('.'.join(map(str, self._cluster_version))))
exc = e
except InvalidAuthToken as e:
logger.error('Invalid auth token: %s', self._token)
exc = e
except AuthOldRevision as e:
logger.error('Auth token is for old revision of auth store')
exc = e
self._reauthenticate_reason = ReAuthenticateMode.WITHOUT_WATCHER_RESTART \
if isinstance(exc, AuthOldRevision) else ReAuthenticateMode.REQUIRED
if not retry:
raise exc
retry.ensure_deadline(0.5, exc)

@_handle_auth_errors
def range(self, key: str, range_end: Union[bytes, str, None] = None, serializable: bool = True,
retry: Optional[Retry] = None) -> Dict[str, Any]:
*, retry: Optional[Retry] = None) -> Dict[str, Any]:
params = build_range_request(key, range_end)
params['serializable'] = serializable # For better performance. We can tolerate stale reads
return self.call_rpc('/kv/range', params, retry)

def prefix(self, key: str, serializable: bool = True, retry: Optional[Retry] = None) -> Dict[str, Any]:
return self.range(key, prefix_range_end(key), serializable, retry)
def prefix(self, key: str, serializable: bool = True, *, retry: Optional[Retry] = None) -> Dict[str, Any]:
return self.range(key, prefix_range_end(key), serializable, retry=retry)

@_handle_auth_errors
def lease_grant(self, ttl: int, retry: Optional[Retry] = None) -> str:
def lease_grant(self, ttl: int, *, retry: Optional[Retry] = None) -> str:
return self.call_rpc('/lease/grant', {'TTL': ttl}, retry)['ID']

def lease_keepalive(self, ID: str, retry: Optional[Retry] = None) -> Optional[str]:
def lease_keepalive(self, ID: str, *, retry: Optional[Retry] = None) -> Optional[str]:
return self.call_rpc('/lease/keepalive', {'ID': ID}, retry).get('result', {}).get('TTL')

@_handle_auth_errors
def txn(self, compare: Dict[str, Any], success: Dict[str, Any],
failure: Optional[Dict[str, Any]] = None, retry: Optional[Retry] = None) -> Dict[str, Any]:
failure: Optional[Dict[str, Any]] = None, *, retry: Optional[Retry] = None) -> Dict[str, Any]:
fields = {'compare': [compare], 'success': [success]}
if failure:
fields['failure'] = [failure]
Expand All @@ -352,7 +380,7 @@ def txn(self, compare: Dict[str, Any], success: Dict[str, Any],

@_handle_auth_errors
def put(self, key: str, value: str, lease: Optional[str] = None, create_revision: Optional[str] = None,
mod_revision: Optional[str] = None, retry: Optional[Retry] = None) -> Dict[str, Any]:
mod_revision: Optional[str] = None, *, retry: Optional[Retry] = None) -> Dict[str, Any]:
fields = {'key': base64_encode(key), 'value': base64_encode(value)}
if lease:
fields['lease'] = lease
Expand All @@ -367,14 +395,14 @@ def put(self, key: str, value: str, lease: Optional[str] = None, create_revision

@_handle_auth_errors
def deleterange(self, key: str, range_end: Union[bytes, str, None] = None,
mod_revision: Optional[str] = None, retry: Optional[Retry] = None) -> Dict[str, Any]:
mod_revision: Optional[str] = None, *, retry: Optional[Retry] = None) -> Dict[str, Any]:
fields = build_range_request(key, range_end)
if mod_revision is None:
return self.call_rpc('/kv/deleterange', fields, retry)
compare = {'target': 'MOD', 'mod_revision': mod_revision, 'key': fields['key']}
return self.txn(compare, {'request_delete_range': fields}, retry=retry)

def deleteprefix(self, key: str, retry: Optional[Retry] = None) -> Dict[str, Any]:
def deleteprefix(self, key: str, *, retry: Optional[Retry] = None) -> Dict[str, Any]:
return self.deleterange(key, prefix_range_end(key), retry=retry)

def watchrange(self, key: str, range_end: Union[bytes, str, None] = None,
Expand Down Expand Up @@ -574,9 +602,9 @@ def set_base_uri(self, value: str) -> None:
super(PatroniEtcd3Client, self).set_base_uri(value)
self._restart_watcher()

def authenticate(self) -> bool:
ret = super(PatroniEtcd3Client, self).authenticate()
if ret:
def authenticate(self, *, restart_watcher: bool = True, retry: Optional[Retry] = None) -> bool:
ret = super(PatroniEtcd3Client, self).authenticate(restart_watcher=restart_watcher, retry=retry)
if ret and restart_watcher:
self._restart_watcher()
return ret

Expand Down Expand Up @@ -631,8 +659,8 @@ def call_rpc(self, method: str, fields: Dict[str, Any], retry: Optional[Retry] =
return ret

def txn(self, compare: Dict[str, Any], success: Dict[str, Any],
failure: Optional[Dict[str, Any]] = None, retry: Optional[Retry] = None) -> Dict[str, Any]:
ret = super(PatroniEtcd3Client, self).txn(compare, success, failure, retry)
failure: Optional[Dict[str, Any]] = None, *, retry: Optional[Retry] = None) -> Dict[str, Any]:
ret = super(PatroniEtcd3Client, self).txn(compare, success, failure, retry=retry)
# Here we abuse the fact that the `failure` is only set in the call from update_leader().
# In all other cases the txn() call failure may be an indicator of a stale cache,
# and therefore we want to restart watcher.
Expand Down Expand Up @@ -676,12 +704,12 @@ def _do_refresh_lease(self, force: bool = False, retry: Optional[Retry] = None)
if not force and self._lease and self._last_lease_refresh + self._loop_wait > time.time():
return False

if self._lease and not self._client.lease_keepalive(self._lease, retry):
if self._lease and not self._client.lease_keepalive(self._lease, retry=retry):
self._lease = None

ret = not self._lease
if ret:
self._lease = self._client.lease_grant(self._ttl, retry)
self._lease = self._client.lease_grant(self._ttl, retry=retry)

self._last_lease_refresh = time.time()
return ret
Expand Down
15 changes: 11 additions & 4 deletions tests/test_etcd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from mock import Mock, PropertyMock, patch
from patroni.dcs.etcd import DnsCachingResolver
from patroni.dcs.etcd3 import PatroniEtcd3Client, Cluster, Etcd3, Etcd3Client, \
Etcd3Error, Etcd3ClientError, RetryFailedError, InvalidAuthToken, Unavailable, \
Unknown, UnsupportedEtcdVersion, UserEmpty, AuthFailed, base64_encode
Etcd3Error, Etcd3ClientError, ReAuthenticateMode, RetryFailedError, InvalidAuthToken, Unavailable, \
Unknown, UnsupportedEtcdVersion, UserEmpty, AuthFailed, AuthOldRevision, base64_encode
from threading import Thread

from . import SleepException, MockResponse
Expand Down Expand Up @@ -161,9 +161,16 @@ def test__handle_auth_errors(self, mock_urlopen):
mock_urlopen.return_value.content = '{"code":16,"error":"etcdserver: invalid auth token"}'
self.assertRaises(InvalidAuthToken, self.client.deleteprefix, 'foo')
with patch.object(PatroniEtcd3Client, 'authenticate', Mock(return_value=True)):
self.assertRaises(InvalidAuthToken, self.client.deleteprefix, 'foo')
retry = self.etcd3._retry.copy()
with patch('time.time', Mock(side_effect=[0, 10, 20, 30, 40])):
self.assertRaises(InvalidAuthToken, retry, self.client.deleteprefix, 'foo', retry=retry)
self.client.username = None
self.assertRaises(InvalidAuthToken, self.client.deleteprefix, 'foo')
self.client._reauthenticate_reason = ReAuthenticateMode.NOT_REQUIRED
retry = self.etcd3._retry.copy()
self.assertRaises(InvalidAuthToken, retry, self.client.deleteprefix, 'foo', retry=retry)
mock_urlopen.return_value.content = '{"code":3,"error":"etcdserver: revision of auth store is old"}'
self.client._reauthenticate_reason = ReAuthenticateMode.NOT_REQUIRED
self.assertRaises(AuthOldRevision, retry, self.client.deleteprefix, 'foo', retry=retry)

def test__handle_server_response(self):
response = MockResponse()
Expand Down

0 comments on commit d471f11

Please sign in to comment.