diff --git a/CHANGELOG.md b/CHANGELOG.md index 778bc9b5..af4e2530 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,17 @@ # Changelog -## [Latest](https://github.com/int-brain-lab/ONE/commits/main) [2.3.0] +## [Latest](https://github.com/int-brain-lab/ONE/commits/main) [2.4.0] + +### Added + +- one.remote.aws.url2uri function converts generic an Amazon virtual host URL to an S3 URI +- Globus._remove_token_fields method removes Globus token fields before saving parameters +- one.alf.files.padded_sequence function ensures file paths contain zero-padded experiment sequence folder + +### Modified + +- one.remote.globus.get_local_endpoint_id no longer prints ID to std out; uses debug log instead + +## [2.3.0] ### Added diff --git a/one/__init__.py b/one/__init__.py index ee5172fb..1810db5d 100644 --- a/one/__init__.py +++ b/one/__init__.py @@ -1,2 +1,2 @@ """The Open Neurophysiology Environment (ONE) API.""" -__version__ = '2.3.0' +__version__ = '2.4.0' diff --git a/one/alf/files.py b/one/alf/files.py index 7e4edb2e..73c040db 100644 --- a/one/alf/files.py +++ b/one/alf/files.py @@ -446,3 +446,39 @@ def remove_uuid_string(file_path): if spec.is_uuid_string(name_parts[-1]): file_path = file_path.with_name('.'.join(name_parts[:-1]) + file_path.suffix) return file_path + + +def padded_sequence(filepath): + """ + Ensures a file path contains a zero-padded experiment sequence folder. + + Parameters + ---------- + filepath : str, pathlib.Path, pathlib.PurePath + A session or file path to convert. + + Returns + ------- + pathlib.Path, pathlib.PurePath + The same path but with the experiment sequence folder zero-padded. If a PurePath was + passed, a PurePath will be returned, otherwise a Path object is returned. + + Examples + -------- + >>> filepath = '/iblrigdata/subject/2023-01-01/1/_ibl_experiment.description.yaml' + >>> padded_sequence(filepath) + pathlib.Path('/iblrigdata/subject/2023-01-01/001/_ibl_experiment.description.yaml') + + Supports folders and will not affect already padded paths + + >>> session_path = pathlib.PurePosixPath('subject/2023-01-01/001') + >>> padded_sequence(filepath) + pathlib.PurePosixPath('subject/2023-01-01/001') + """ + if isinstance(filepath, str): + filepath = Path(filepath) + if (session_path := get_session_path(filepath)) is None: + raise ValueError('path must include a valid ALF session path, e.g. subject/YYYY-MM-DD/N') + idx = len(filepath.parts) - len(session_path.parts) + sequence = str(int(session_path.parts[-1])).zfill(3) # zero-pad if necessary + return filepath.parents[idx].joinpath(sequence, filepath.relative_to(session_path)) diff --git a/one/remote/aws.py b/one/remote/aws.py index d08654fa..0c9b41dc 100644 --- a/one/remote/aws.py +++ b/one/remote/aws.py @@ -82,6 +82,31 @@ def get_s3_virtual_host(uri, region) -> str: return 'https://' + '/'.join((hostname, *key)) +def url2uri(data_path, return_location=False): + """ + Convert a generic Amazon virtual host URL to an S3 URI. + + Parameters + ---------- + data_path : str + An Amazon virtual host URL to convert. + return_location : bool + If true, additionally returns the location string. + + Returns + ------- + str + An S3 URI with scheme 's3://'. + str + If return_location is true, returns the bucket location, e.g. 'eu-east-1'. + """ + parsed = urllib.parse.urlparse(data_path) + assert parsed.netloc and parsed.scheme and parsed.path + bucket_name, _, loc, *_ = parsed.netloc.split('.') + uri = f's3://{bucket_name}{parsed.path}' + return (uri, loc) if return_location else uri + + def is_folder(obj_summery) -> bool: """ Given an S3 ObjectSummery instance, returns true if the associated object is a directory. diff --git a/one/remote/globus.py b/one/remote/globus.py index d12169a5..161ef310 100644 --- a/one/remote/globus.py +++ b/one/remote/globus.py @@ -344,7 +344,7 @@ def get_local_endpoint_id(): assert id_file.exists(), msg.format(id_file) local_id = id_file.read_text().strip() assert isinstance(local_id, str), msg.format(id_file) - print(f'Found local endpoint ID in Globus Connect settings {local_id}') + _logger.debug(f'Found local endpoint ID in Globus Connect settings {local_id}') return UUID(local_id) @@ -517,15 +517,20 @@ def is_logged_in(self): @property def _token_expired(self): - """bool or None: True if token expired; False if still valid; None if token not present. + """bool: True if token absent or expired; False if valid Note the 'expires_at_seconds' may be greater than `Globus.client.authorizer.expires_at` if using refresh tokens. The `login` method will always refresh the token if still valid. """ try: - return getattr(self._pars, 'expires_at_seconds') - datetime.utcnow().timestamp() < 60 - except AttributeError: - return + authorizer = getattr(self.client, 'authorizer', None) + has_refresh_token = self._pars.as_dict().get('refresh_token') is not None + if has_refresh_token and isinstance(authorizer, globus_sdk.RefreshTokenAuthorizer): + self.client.authorizer.ensure_valid_token() # Fetch new refresh token if needed + except Exception as ex: + _logger.debug('Failed to refresh token: %s', ex) + expires_at_seconds = getattr(self._pars, 'expires_at_seconds', 0) + return expires_at_seconds - datetime.utcnow().timestamp() < 60 def login(self, stay_logged_in=None): """ @@ -540,15 +545,19 @@ def login(self, stay_logged_in=None): """ if self.is_logged_in: _logger.debug('Already logged in') - self.client.authorizer.ensure_valid_token() # refresh token if necessary return + # Default depends on refresh token + stay_logged_in = True if stay_logged_in is None else stay_logged_in + expired = bool( + self._pars.as_dict().get('refresh_token') is None + if stay_logged_in else self._token_expired + ) # If no tokens in parameters, Globus must be authenticated required_fields = {'refresh_token', 'access_token', 'expires_at_seconds'} - if not required_fields.issubset(iopar.as_dict(self._pars)) or self._token_expired: + if not required_fields.issubset(iopar.as_dict(self._pars)) or expired: if self.headless: raise RuntimeError(f'Globus not authenticated for client "{self.client_name}"') - stay_logged_in = True if stay_logged_in is None else stay_logged_in token = get_token(self._pars.GLOBUS_CLIENT_ID, refresh_tokens=stay_logged_in) if not any(token.values()): _logger.debug('Login cancelled by user') @@ -572,18 +581,38 @@ def logout(self): def _authenticate(self, stay_logged_in=None): """Authenticate and instantiate Globus SDK client.""" - if self._token_expired is not False: - raise RuntimeError(f'token no longer valid for client "{self.client_name}"') if self._pars.as_dict().get('refresh_token') and stay_logged_in is not False: client = globus_sdk.NativeAppAuthClient(self._pars.GLOBUS_CLIENT_ID) client.oauth2_start_flow(refresh_tokens=True) - authorizer = globus_sdk.RefreshTokenAuthorizer(self._pars.refresh_token, client) + authorizer = globus_sdk.RefreshTokenAuthorizer( + self._pars.refresh_token, client, on_refresh=self._save_refresh_token_callback) else: if stay_logged_in is True: warnings.warn('No refresh token. Please log out and back in to remain logged in.') + if self._token_expired is not False: + raise RuntimeError(f'token no longer valid for client "{self.client_name}"') authorizer = globus_sdk.AccessTokenAuthorizer(self._pars.access_token) self.client = globus_sdk.TransferClient(authorizer=authorizer) + def _save_refresh_token_callback(self, res): + """ + Save a token fetched by the refresh token authorizer. + + This is a callback for the globus_sdk.RefreshTokenAuthorizer to update the parameters. + + Parameters + ---------- + res : globus_sdk.services.auth.OAuthTokenResponse + An Open Authorization response object. + + """ + if not res or not (token := next(iter(res.by_resource_server.values()), None)): + return + token_fields = {'refresh_token', 'access_token', 'expires_at_seconds'} + self._pars = iopar.from_dict( + {**self._pars.as_dict(), **{k: v for k, v in token.items() if k in token_fields}}) + _save_globus_params(self._pars, self.client_name) + def fetch_endpoints_from_alyx(self, alyx=None, overwrite=False): """ Update endpoints property with Alyx Globus data repositories. @@ -685,7 +714,7 @@ def download_file(self, file_address, source_endpoint, recursive=False, **kwargs return_single = isinstance(file_address, str) and recursive is False kwargs['label'] = kwargs.get('label', 'ONE download') task = partial(self.transfer_data, file_address, source_endpoint, 'local', - sync_level='mtime', recursive=recursive, **kwargs) + recursive=recursive, **kwargs) task_id = self.run_task(task) files = [] root = Path(self.endpoints['local']['root_path']) diff --git a/one/tests/alf/test_alf_files.py b/one/tests/alf/test_alf_files.py index 1345ddcb..87eee4f6 100644 --- a/one/tests/alf/test_alf_files.py +++ b/one/tests/alf/test_alf_files.py @@ -1,6 +1,6 @@ """Unit tests for the one.alf.files module.""" import unittest -from pathlib import Path +from pathlib import Path, PureWindowsPath import uuid import one.alf.files as files @@ -178,6 +178,21 @@ def test_remove_uuid(self): desired_output = Path('toto.npy') self.assertEqual(desired_output, files.remove_uuid_string(file_path)) + def test_padded_sequence(self): + """Test for one.alf.files.padded_sequence.""" + # Test with pure path file input + filepath = PureWindowsPath(r'F:\ScanImageAcquisitions\subject\2023-01-01\1\foo\bar.baz') + expected = PureWindowsPath(r'F:\ScanImageAcquisitions\subject\2023-01-01\001\foo\bar.baz') + self.assertEqual(files.padded_sequence(filepath), expected) + + # Test with str input session path + session_path = '/mnt/s0/Data/Subjects/subject/2023-01-01/001' + expected = Path('/mnt/s0/Data/Subjects/subject/2023-01-01/001') + self.assertEqual(files.padded_sequence(session_path), expected) + + # Test invalid ALF session path + self.assertRaises(ValueError, files.padded_sequence, '/foo/bar/baz') + class TestALFGet(unittest.TestCase): """Tests for path extraction functions""" diff --git a/one/tests/remote/test_aws.py b/one/tests/remote/test_aws.py index 542226e5..fc820c89 100644 --- a/one/tests/remote/test_aws.py +++ b/one/tests/remote/test_aws.py @@ -120,7 +120,7 @@ class TestUtils(unittest.TestCase): """Tests for one.remote.aws utility functions""" def test_get_s3_virtual_host(self): - """Tests for one.remote.aws.get_s3_virtual_host function""" + """Tests for one.remote.aws.get_s3_virtual_host function.""" expected = 'https://my-s3-bucket.s3.eu-east-1.amazonaws.com/' url = aws.get_s3_virtual_host('s3://my-s3-bucket', 'eu-east-1') self.assertEqual(expected, url) @@ -135,6 +135,14 @@ def test_get_s3_virtual_host(self): with self.assertRaises(AssertionError): aws.get_s3_virtual_host('s3://my-s3-bucket/path/to/file', 'wrong-foo-4') + def test_url2uri(self): + """Tests for one.remote.aws.url2uri function.""" + url = 'https://my-s3-bucket.s3.eu-east-1.amazonaws.com/path/to/file' + expected = 's3://my-s3-bucket/path/to/file' + self.assertEqual(aws.url2uri(url), expected) + uri, loc = aws.url2uri(url, return_location=True) + self.assertEqual(loc, 'eu-east-1') + if __name__ == '__main__': unittest.main(exit=False) diff --git a/one/tests/remote/test_globus.py b/one/tests/remote/test_globus.py index db49075f..47943d9b 100644 --- a/one/tests/remote/test_globus.py +++ b/one/tests/remote/test_globus.py @@ -13,6 +13,7 @@ try: import globus_sdk + import globus_sdk.services.auth except ModuleNotFoundError: raise unittest.skip('globus_sdk module not installed') from iblutil.io import params as iopar @@ -572,7 +573,11 @@ def test_globus_headless(self): def test_login_logout(self): """Test for Globus.login and Globus.logout methods.""" assert self.globus.is_logged_in + sdk_mock, _ = self.globus_sdk_mock.get_original() with self.assertLogs('one.remote.globus', 10): + # Token validator checks token auth class, which is mocked, so here we set the + # RefreshTokenAuthorizer to a MagicMock so that the types match + sdk_mock.RefreshTokenAuthorizer = mock.MagicMock self.globus.login() self.globus.client.authorizer.ensure_valid_token.assert_called() @@ -590,7 +595,7 @@ def test_login_logout(self): self.assertFalse(hasattr(self.globus.client.authorizer, 'access_token')) self.assertFalse(hasattr(self.globus._pars, 'access_token')) self.assertFalse(self.globus.is_logged_in) - self.assertIsNone(self.globus._token_expired) + self.assertTrue(self.globus._token_expired) # Test what happens when authenticate called with invalid token self.assertRaises(RuntimeError, self.globus._authenticate) @@ -621,6 +626,34 @@ def test_login_logout(self): self.assertWarns(UserWarning, self.globus.login, stay_logged_in=True) self.assertTrue(self.globus.is_logged_in) + def test_save_refresh_token_callback(self): + """Test for Globus._save_refresh_token_callback method.""" + assert hasattr(self.globus._pars, 'refresh_token') + token = {'refresh_token': '567', 'access_token': 'abc', 'expires_at_seconds': 100000000} + res = mock.MagicMock(spec=globus_sdk.services.auth.OAuthTokenResponse) + res.by_resource_server = dict(server=token) + + # Check behaviour when called with Globus auth response + with mock.patch('one.remote.globus.save_client_params') as client_params_mock: + self.globus._save_refresh_token_callback(res) + client_params_mock.assert_called_once() + (pars, name), _ = client_params_mock.call_args + self.assertEqual(name, 'globus') + self.assertIn(self.globus.client_name, pars) + self.assertTrue(set(pars[self.globus.client_name]) >= set(globus.DEFAULT_PAR)) + for k, v in token.items(): + with self.subTest(k): + self.assertEqual(pars[self.globus.client_name].get(k), v) + # Obj params should be modified + par_vals = map(partial(getattr, self.globus._pars), token.keys()) + self.assertCountEqual(par_vals, token.values()) + + # Check behaviour when called with empty auth response + res.by_resource_server = dict() + with mock.patch('one.remote.globus.save_client_params') as client_params_mock: + self.globus._save_refresh_token_callback(res) + client_params_mock.assert_not_called() + class TestGlobusAsync(unittest.IsolatedAsyncioTestCase, _GlobusClientTest): """Asynchronous Globus method tests."""