Skip to content

Commit

Permalink
Merge branch 'globusToken' into v2.4.0
Browse files Browse the repository at this point in the history
  • Loading branch information
k1o0 committed Oct 16, 2023
2 parents 75a14c4 + 22c0b34 commit 6f3c6fc
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 12 deletions.
51 changes: 40 additions & 11 deletions one/remote/globus.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,15 +517,20 @@ def is_logged_in(self):

@property
def _token_expired(self):
"""bool or None: True if token expired; False if still valid; None if token not present.
"""bool: True if token absent or expired; False if valid
Note the 'expires_at_seconds' may be greater than `Globus.client.authorizer.expires_at` if
using refresh tokens. The `login` method will always refresh the token if still valid.
"""
try:
return getattr(self._pars, 'expires_at_seconds') - datetime.utcnow().timestamp() < 60
except AttributeError:
return
authorizer = getattr(self.client, 'authorizer', None)
has_refresh_token = self._pars.as_dict().get('refresh_token') is not None
if has_refresh_token and isinstance(authorizer, globus_sdk.RefreshTokenAuthorizer):
self.client.authorizer.ensure_valid_token() # Fetch new refresh token if needed
except Exception as ex:
_logger.debug('Failed to refresh token: %s', ex)
expires_at_seconds = getattr(self._pars, 'expires_at_seconds', 0)
return expires_at_seconds - datetime.utcnow().timestamp() < 60

def login(self, stay_logged_in=None):
"""
Expand All @@ -540,15 +545,19 @@ def login(self, stay_logged_in=None):
"""
if self.is_logged_in:
_logger.debug('Already logged in')
self.client.authorizer.ensure_valid_token() # refresh token if necessary
return

# Default depends on refresh token
stay_logged_in = True if stay_logged_in is None else stay_logged_in
expired = bool(
self._pars.as_dict().get('refresh_token') is None
if stay_logged_in else self._token_expired
)
# If no tokens in parameters, Globus must be authenticated
required_fields = {'refresh_token', 'access_token', 'expires_at_seconds'}
if not required_fields.issubset(iopar.as_dict(self._pars)) or self._token_expired:
if not required_fields.issubset(iopar.as_dict(self._pars)) or expired:
if self.headless:
raise RuntimeError(f'Globus not authenticated for client "{self.client_name}"')
stay_logged_in = True if stay_logged_in is None else stay_logged_in
token = get_token(self._pars.GLOBUS_CLIENT_ID, refresh_tokens=stay_logged_in)
if not any(token.values()):
_logger.debug('Login cancelled by user')
Expand All @@ -572,18 +581,38 @@ def logout(self):

def _authenticate(self, stay_logged_in=None):
"""Authenticate and instantiate Globus SDK client."""
if self._token_expired is not False:
raise RuntimeError(f'token no longer valid for client "{self.client_name}"')
if self._pars.as_dict().get('refresh_token') and stay_logged_in is not False:
client = globus_sdk.NativeAppAuthClient(self._pars.GLOBUS_CLIENT_ID)
client.oauth2_start_flow(refresh_tokens=True)
authorizer = globus_sdk.RefreshTokenAuthorizer(self._pars.refresh_token, client)
authorizer = globus_sdk.RefreshTokenAuthorizer(
self._pars.refresh_token, client, on_refresh=self._save_refresh_token_callback)
else:
if stay_logged_in is True:
warnings.warn('No refresh token. Please log out and back in to remain logged in.')
if self._token_expired is not False:
raise RuntimeError(f'token no longer valid for client "{self.client_name}"')
authorizer = globus_sdk.AccessTokenAuthorizer(self._pars.access_token)
self.client = globus_sdk.TransferClient(authorizer=authorizer)

def _save_refresh_token_callback(self, res):
"""
Save a token fetched by the refresh token authorizer.
This is a callback for the globus_sdk.RefreshTokenAuthorizer to update the parameters.
Parameters
----------
res : globus_sdk.services.auth.OAuthTokenResponse
An Open Authorization response object.
"""
if not res or not (token := next(iter(res.by_resource_server.values()), None)):
return
token_fields = {'refresh_token', 'access_token', 'expires_at_seconds'}
self._pars = iopar.from_dict(
{**self._pars.as_dict(), **{k: v for k, v in token.items() if k in token_fields}})
_save_globus_params(self._pars, self.client_name)

def fetch_endpoints_from_alyx(self, alyx=None, overwrite=False):
"""
Update endpoints property with Alyx Globus data repositories.
Expand Down Expand Up @@ -685,7 +714,7 @@ def download_file(self, file_address, source_endpoint, recursive=False, **kwargs
return_single = isinstance(file_address, str) and recursive is False
kwargs['label'] = kwargs.get('label', 'ONE download')
task = partial(self.transfer_data, file_address, source_endpoint, 'local',
sync_level='mtime', recursive=recursive, **kwargs)
recursive=recursive, **kwargs)
task_id = self.run_task(task)
files = []
root = Path(self.endpoints['local']['root_path'])
Expand Down
35 changes: 34 additions & 1 deletion one/tests/remote/test_globus.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

try:
import globus_sdk
import globus_sdk.services.auth
except ModuleNotFoundError:
raise unittest.skip('globus_sdk module not installed')
from iblutil.io import params as iopar
Expand Down Expand Up @@ -572,7 +573,11 @@ def test_globus_headless(self):
def test_login_logout(self):
"""Test for Globus.login and Globus.logout methods."""
assert self.globus.is_logged_in
sdk_mock, _ = self.globus_sdk_mock.get_original()
with self.assertLogs('one.remote.globus', 10):
# Token validator checks token auth class, which is mocked, so here we set the
# RefreshTokenAuthorizer to a MagicMock so that the types match
sdk_mock.RefreshTokenAuthorizer = mock.MagicMock
self.globus.login()
self.globus.client.authorizer.ensure_valid_token.assert_called()

Expand All @@ -590,7 +595,7 @@ def test_login_logout(self):
self.assertFalse(hasattr(self.globus.client.authorizer, 'access_token'))
self.assertFalse(hasattr(self.globus._pars, 'access_token'))
self.assertFalse(self.globus.is_logged_in)
self.assertIsNone(self.globus._token_expired)
self.assertTrue(self.globus._token_expired)

# Test what happens when authenticate called with invalid token
self.assertRaises(RuntimeError, self.globus._authenticate)
Expand Down Expand Up @@ -621,6 +626,34 @@ def test_login_logout(self):
self.assertWarns(UserWarning, self.globus.login, stay_logged_in=True)
self.assertTrue(self.globus.is_logged_in)

def test_save_refresh_token_callback(self):
"""Test for Globus._save_refresh_token_callback method."""
assert hasattr(self.globus._pars, 'refresh_token')
token = {'refresh_token': '567', 'access_token': 'abc', 'expires_at_seconds': 100000000}
res = mock.MagicMock(spec=globus_sdk.services.auth.OAuthTokenResponse)
res.by_resource_server = dict(server=token)

# Check behaviour when called with Globus auth response
with mock.patch('one.remote.globus.save_client_params') as client_params_mock:
self.globus._save_refresh_token_callback(res)
client_params_mock.assert_called_once()
(pars, name), _ = client_params_mock.call_args
self.assertEqual(name, 'globus')
self.assertIn(self.globus.client_name, pars)
self.assertTrue(set(pars[self.globus.client_name]) >= set(globus.DEFAULT_PAR))
for k, v in token.items():
with self.subTest(k):
self.assertEqual(pars[self.globus.client_name].get(k), v)
# Obj params should be modified
par_vals = map(partial(getattr, self.globus._pars), token.keys())
self.assertCountEqual(par_vals, token.values())

# Check behaviour when called with empty auth response
res.by_resource_server = dict()
with mock.patch('one.remote.globus.save_client_params') as client_params_mock:
self.globus._save_refresh_token_callback(res)
client_params_mock.assert_not_called()


class TestGlobusAsync(unittest.IsolatedAsyncioTestCase, _GlobusClientTest):
"""Asynchronous Globus method tests."""
Expand Down

0 comments on commit 6f3c6fc

Please sign in to comment.