From 63e97c16cb4207f46d4b00f29ec608d4d051366e Mon Sep 17 00:00:00 2001 From: Evan Parker Date: Thu, 16 Nov 2017 10:53:47 -0800 Subject: [PATCH] Update client to accommodate oauth2client>=4.0 (#184) This changes has been needed for a while now. The main blocker seems to be the use of locked_file for caching GCE credentials. I've added a simple multiprocess lockable file cache that uses a similar approach to that used in ouath2client's multiprocess file storage. Submission of this should close issue #162. --- .travis.yml | 5 +- apitools/base/py/credentials_lib.py | 213 +++++++++++++++-------- apitools/base/py/credentials_lib_test.py | 100 ++++++++--- apitools/base/py/http_wrapper.py | 10 +- setup.py | 3 +- tox.ini | 49 +++++- 6 files changed, 280 insertions(+), 100 deletions(-) diff --git a/.travis.yml b/.travis.yml index 1e80e143..07b83aae 100644 --- a/.travis.yml +++ b/.travis.yml @@ -3,8 +3,11 @@ sudo: false env: - TOX_ENV=py27 - TOX_ENV=py27oldoauth2client + - TOX_ENV=py27newoauth2client - TOX_ENV=py34 - - TOX_ENV=py35 + - TOX_ENV=py35oauth2client15 + - TOX_ENV=py35oauth2client30 + - TOX_ENV=py35oauth2client41 - TOX_ENV=lint install: - pip install tox diff --git a/apitools/base/py/credentials_lib.py b/apitools/base/py/credentials_lib.py index 913c1440..ddf46b5a 100644 --- a/apitools/base/py/credentials_lib.py +++ b/apitools/base/py/credentials_lib.py @@ -17,17 +17,20 @@ """Common credentials classes and constructors.""" from __future__ import print_function +import contextlib import datetime import json import os import threading import warnings +import fasteners import httplib2 import oauth2client import oauth2client.client from oauth2client import service_account from oauth2client import tools # for gflags declarations +import six from six.moves import http_client from six.moves import urllib @@ -45,14 +48,14 @@ from oauth2client import gce try: - from oauth2client.contrib import locked_file + from oauth2client.contrib import multiprocess_file_storage + _NEW_FILESTORE = True except ImportError: - from oauth2client import locked_file - -try: - from oauth2client.contrib import multistore_file -except ImportError: - from oauth2client import multistore_file + _NEW_FILESTORE = False + try: + from oauth2client.contrib import multistore_file + except ImportError: + from oauth2client import multistore_file try: import gflags @@ -193,19 +196,6 @@ def ServiceAccountCredentialsFromP12File( user_agent=user_agent) -def _EnsureFileExists(filename): - """Touches a file; returns False on error, True on success.""" - if not os.path.exists(filename): - old_umask = os.umask(0o177) - try: - open(filename, 'a+b').close() - except OSError: - return False - finally: - os.umask(old_umask) - return True - - def _GceMetadataRequest(relative_url, use_metadata_ip=False): """Request the given url from the GCE metadata service.""" if use_metadata_ip: @@ -288,29 +278,20 @@ def _CheckCacheFileForMatch(self, cache_filename, scopes): 'scopes': sorted(list(scopes)) if scopes else None, 'svc_acct_name': self.__service_account_name, } - with cache_file_lock: - if _EnsureFileExists(cache_filename): - cache_file = locked_file.LockedFile( - cache_filename, 'r+b', 'rb') - try: - cache_file.open_and_lock() - cached_creds_str = cache_file.file_handle().read() - if cached_creds_str: - # Cached credentials metadata dict. - cached_creds = json.loads(cached_creds_str) - if (creds['svc_acct_name'] == - cached_creds['svc_acct_name']): - if (creds['scopes'] in - (None, cached_creds['scopes'])): - scopes = cached_creds['scopes'] - except KeyboardInterrupt: - raise - except: # pylint: disable=bare-except - # Treat exceptions as a cache miss. - pass - finally: - cache_file.unlock_and_close() - return scopes + cache_file = _MultiProcessCacheFile(cache_filename) + try: + cached_creds_str = cache_file.LockedRead() + if not cached_creds_str: + return None + cached_creds = json.loads(cached_creds_str) + if creds['svc_acct_name'] == cached_creds['svc_acct_name']: + if creds['scopes'] in (None, cached_creds['scopes']): + return cached_creds['scopes'] + except KeyboardInterrupt: + raise + except: # pylint: disable=bare-except + # Treat exceptions as a cache miss. + pass def _WriteCacheFile(self, cache_filename, scopes): """Writes the credential metadata to the cache file. @@ -322,28 +303,18 @@ def _WriteCacheFile(self, cache_filename, scopes): cache_filename: Cache filename to check. scopes: Scopes for the desired credentials. """ - with cache_file_lock: - if _EnsureFileExists(cache_filename): - cache_file = locked_file.LockedFile( - cache_filename, 'r+b', 'rb') - try: - cache_file.open_and_lock() - if cache_file.is_locked(): - creds = { # Credentials metadata dict. - 'scopes': sorted(list(scopes)), - 'svc_acct_name': self.__service_account_name} - cache_file.file_handle().write( - json.dumps(creds, encoding='ascii')) - # If it's not locked, the locking process will - # write the same data to the file, so just - # continue. - except KeyboardInterrupt: - raise - except: # pylint: disable=bare-except - # Treat exceptions as a cache miss. - pass - finally: - cache_file.unlock_and_close() + # Credentials metadata dict. + creds = {'scopes': sorted(list(scopes)), + 'svc_acct_name': self.__service_account_name} + creds_str = json.dumps(creds) + cache_file = _MultiProcessCacheFile(cache_filename) + try: + cache_file.LockedWrite(creds_str) + except KeyboardInterrupt: + raise + except: # pylint: disable=bare-except + # Treat exceptions as a cache miss. + pass def _ScopesFromMetadataServer(self, scopes): """Returns instance scopes based on GCE metadata server.""" @@ -537,11 +508,18 @@ def _GetRunFlowFlags(args=None): # TODO(craigcitro): Switch this from taking a path to taking a stream. def CredentialsFromFile(path, client_info, oauth2client_args=None): """Read credentials from a file.""" - credential_store = multistore_file.get_credential_storage( - path, - client_info['client_id'], - client_info['user_agent'], - client_info['scope']) + user_agent = client_info['user_agent'] + scope_key = client_info['scope'] + if not isinstance(scope_key, six.string_types): + scope_key = ':'.join(scope_key) + storage_key = client_info['client_id'] + user_agent + scope_key + + if _NEW_FILESTORE: + credential_store = multiprocess_file_storage.MultiprocessFileStorage( + path, storage_key) + else: + credential_store = multistore_file.get_credential_storage_custom_string_key( # noqa + path, storage_key) if hasattr(FLAGS, 'auth_local_webserver'): FLAGS.auth_local_webserver = False credentials = credential_store.get() @@ -568,6 +546,101 @@ def CredentialsFromFile(path, client_info, oauth2client_args=None): return credentials +class _MultiProcessCacheFile(object): + """Simple multithreading and multiprocessing safe cache file. + + Notes on behavior: + * the fasteners.InterProcessLock object cannot reliably prevent threads + from double-acquiring a lock. A threading lock is used in addition to + the InterProcessLock. The threading lock is always acquired first and + released last. + * The interprocess lock will not deadlock. If a process can not acquire + the interprocess lock within `_lock_timeout` the call will return as + a cache miss or unsuccessful cache write. + """ + + _lock_timeout = 1 + _encoding = 'utf-8' + _thread_lock = threading.Lock() + + def __init__(self, filename): + self._file = None + self._filename = filename + self._process_lock = fasteners.InterProcessLock( + '{0}.lock'.format(filename)) + + @contextlib.contextmanager + def _ProcessLockAcquired(self): + """Context manager for process locks with timeout.""" + try: + is_locked = self._process_lock.acquire(timeout=self._lock_timeout) + yield is_locked + finally: + if is_locked: + self._process_lock.release() + + def LockedRead(self): + """Acquire an interprocess lock and dump cache contents. + + This method safely acquires the locks then reads a string + from the cache file. If the file does not exist and cannot + be created, it will return None. If the locks cannot be + acquired, this will also return None. + + Returns: + cache data - string if present, None on failure. + """ + file_contents = None + with self._thread_lock: + if not self._EnsureFileExists(): + return None + with self._ProcessLockAcquired() as acquired_plock: + if not acquired_plock: + return None + with open(self._filename, 'rb') as f: + file_contents = f.read().decode(encoding=self._encoding) + return file_contents + + def LockedWrite(self, cache_data): + """Acquire an interprocess lock and write a string. + + This method safely acquires the locks then writes a string + to the cache file. If the string is written successfully + the function will return True, if the write fails for any + reason it will return False. + + Args: + cache_data: string or bytes to write. + + Returns: + bool: success + """ + if isinstance(cache_data, six.text_type): + cache_data = cache_data.encode(encoding=self._encoding) + + with self._thread_lock: + if not self._EnsureFileExists(): + return False + with self._ProcessLockAcquired() as acquired_plock: + if not acquired_plock: + return False + with open(self._filename, 'wb') as f: + f.write(cache_data) + return True + + def _EnsureFileExists(self): + """Touches a file; returns False on error, True on success.""" + if not os.path.exists(self._filename): + old_umask = os.umask(0o177) + try: + open(self._filename, 'a+b').close() + except OSError: + return False + finally: + os.umask(old_umask) + return True + + # TODO(craigcitro): Push this into oauth2client. def GetUserinfo(credentials, http=None): # pylint: disable=invalid-name """Get the userinfo associated with the given credentials. diff --git a/apitools/base/py/credentials_lib_test.py b/apitools/base/py/credentials_lib_test.py index 1bf5aa7d..d628d71b 100644 --- a/apitools/base/py/credentials_lib_test.py +++ b/apitools/base/py/credentials_lib_test.py @@ -13,6 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os.path +import shutil +import tempfile + import mock import six import unittest2 @@ -21,43 +25,91 @@ from apitools.base.py import util +class MetadataMock(object): + + def __init__(self, scopes=None, service_account_name=None): + self._scopes = scopes or ['scope1'] + self._sa = service_account_name or 'default' + + def __call__(self, request_url): + if request_url.endswith('scopes'): + return six.StringIO(''.join(self._scopes)) + elif request_url.endswith('service-accounts'): + return six.StringIO(self._sa) + elif request_url.endswith( + '/service-accounts/%s/token' % self._sa): + return six.StringIO('{"access_token": "token"}') + self.fail('Unexpected HTTP request to %s' % request_url) + + class CredentialsLibTest(unittest2.TestCase): - def _GetServiceCreds(self, service_account_name=None, scopes=None): + def _RunGceAssertionCredentials( + self, service_account_name=None, scopes=None, cache_filename=None): kwargs = {} if service_account_name is not None: kwargs['service_account_name'] = service_account_name + if cache_filename is not None: + kwargs['cache_filename'] = cache_filename service_account_name = service_account_name or 'default' + credentials = credentials_lib.GceAssertionCredentials( + scopes, **kwargs) + self.assertIsNone(credentials._refresh(None)) + return credentials - def MockMetadataCalls(request_url): - default_scopes = scopes or ['scope1'] - if request_url.endswith('scopes'): - return six.StringIO(''.join(default_scopes)) - elif request_url.endswith('service-accounts'): - return six.StringIO(service_account_name) - elif request_url.endswith( - '/service-accounts/%s/token' % service_account_name): - return six.StringIO('{"access_token": "token"}') - self.fail('Unexpected HTTP request to %s' % request_url) - - with mock.patch.object(credentials_lib, '_GceMetadataRequest', - side_effect=MockMetadataCalls, - autospec=True) as opener_mock: - with mock.patch.object(util, 'DetectGce', - autospec=True) as mock_detect: - mock_detect.return_value = True - credentials = credentials_lib.GceAssertionCredentials( - scopes, **kwargs) - self.assertIsNone(credentials._refresh(None)) + def _GetServiceCreds(self, service_account_name=None, scopes=None): + metadatamock = MetadataMock(scopes, service_account_name) + with mock.patch.object(util, 'DetectGce', autospec=True) as gce_detect: + gce_detect.return_value = True + with mock.patch.object(credentials_lib, + '_GceMetadataRequest', + side_effect=metadatamock, + autospec=True) as opener_mock: + credentials = self._RunGceAssertionCredentials( + service_account_name=service_account_name, + scopes=scopes) self.assertEqual(3, opener_mock.call_count) return credentials def testGceServiceAccounts(self): scopes = ['scope1'] - self._GetServiceCreds() - self._GetServiceCreds(scopes=scopes) - self._GetServiceCreds(service_account_name='my_service_account', + self._GetServiceCreds(service_account_name=None, + scopes=None) + self._GetServiceCreds(service_account_name=None, scopes=scopes) + self._GetServiceCreds( + service_account_name='my_service_account', + scopes=scopes) + + @mock.patch.object(util, 'DetectGce', autospec=True) + def testGceServiceAccountsCached(self, mock_detect): + mock_detect.return_value = True + tempd = tempfile.mkdtemp() + tempname = os.path.join(tempd, 'creds') + scopes = ['scope1'] + service_account_name = 'some_service_account_name' + metadatamock = MetadataMock(scopes, service_account_name) + with mock.patch.object(credentials_lib, + '_GceMetadataRequest', + side_effect=metadatamock, + autospec=True) as opener_mock: + try: + creds1 = self._RunGceAssertionCredentials( + service_account_name=service_account_name, + cache_filename=tempname, + scopes=scopes) + pre_cache_call_count = opener_mock.call_count + creds2 = self._RunGceAssertionCredentials( + service_account_name=service_account_name, + cache_filename=tempname, + scopes=None) + finally: + shutil.rmtree(tempd) + self.assertEqual(creds1.client_id, creds2.client_id) + self.assertEqual(pre_cache_call_count, 3) + # Caching obviates the need for extra metadata server requests. + # Only one metadata request is made if the cache is hit. + self.assertEqual(opener_mock.call_count, 4) def testGetServiceAccount(self): # We'd also like to test the metadata calls, which requires diff --git a/apitools/base/py/http_wrapper.py b/apitools/base/py/http_wrapper.py index 7baf09f2..c5fe225a 100644 --- a/apitools/base/py/http_wrapper.py +++ b/apitools/base/py/http_wrapper.py @@ -27,7 +27,6 @@ import time import httplib2 -import oauth2client import six from six.moves import http_client from six.moves.urllib import parse @@ -35,6 +34,12 @@ from apitools.base.py import exceptions from apitools.base.py import util +# pylint: disable=ungrouped-imports +try: + from oauth2client.client import HttpAccessTokenRefreshError as TokenRefreshError # noqa +except ImportError: + from oauth2client.client import AccessTokenRefreshError as TokenRefreshError # noqa + __all__ = [ 'CheckResponse', 'GetHttp', @@ -279,8 +284,7 @@ def HandleExceptionsAndRebuildHttpConnections(retry_args): # oauth2client, need to handle it here. logging.debug('Response content was invalid (%s), retrying', retry_args.exc) - elif (isinstance(retry_args.exc, - oauth2client.client.HttpAccessTokenRefreshError) and + elif (isinstance(retry_args.exc, TokenRefreshError) and (retry_args.exc.status == TOO_MANY_REQUESTS or retry_args.exc.status >= 500)): logging.debug( diff --git a/setup.py b/setup.py index 2b4cfef4..ed90948f 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,8 @@ # Python version and OS. REQUIRED_PACKAGES = [ 'httplib2>=0.8', - 'oauth2client>=1.5.2,<4.0.0dev', + 'fasteners>=0.14', + 'oauth2client>=1.4.12', 'six>=1.9.0', ] diff --git a/tox.ini b/tox.ini index e2d1f233..7a822355 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,16 @@ [tox] -envlist = py26,py27,pypy,py34,py35,lint,cover,py27oldoauth2client +envlist = + py26 + py27 + py27oldoauth2client + py27newoauth2client + pypy + py34 + py35oauth2client15 + py35oauth2client30 + py35oauth2client41 + lint + cover [testenv] deps = @@ -10,6 +21,12 @@ commands = nosetests [] passenv = TRAVIS* +[testenv:py27newoauth2client] +commands = + pip install oauth2client==4.1.0 + {[testenv]commands} +deps = {[testenv]deps} + [testenv:py27oldoauth2client] commands = pip install oauth2client==1.5.2 @@ -24,6 +41,36 @@ deps = unittest2 commands = nosetests [] +[testenv:py35oauth2client15] +basepython = python3.5 +deps = + mock + nose + unittest2 +commands = + pip install oauth2client==1.5.2 + nosetests [] + +[testenv:py35oauth2client30] +basepython = python3.5 +deps = + mock + nose + unittest2 +commands = + pip install oauth2client==3.0.0 + nosetests [] + +[testenv:py35oauth2client41] +basepython = python3.5 +deps = + mock + nose + unittest2 +commands = + pip install oauth2client==4.1.0 + nosetests [] + [testenv:py35] basepython = python3.5 deps =