Skip to content

Commit

Permalink
Merge pull request #98 from int-brain-lab/v2.4.0
Browse files Browse the repository at this point in the history
V2.4.0
  • Loading branch information
k1o0 authored Oct 17, 2023
2 parents 180aa5e + 63d6482 commit 31bfe65
Show file tree
Hide file tree
Showing 8 changed files with 175 additions and 17 deletions.
14 changes: 13 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
# Changelog
## [Latest](https://github.com/int-brain-lab/ONE/commits/main) [2.3.0]
## [Latest](https://github.com/int-brain-lab/ONE/commits/main) [2.4.0]

### Added

- one.remote.aws.url2uri function converts generic an Amazon virtual host URL to an S3 URI
- Globus._remove_token_fields method removes Globus token fields before saving parameters
- one.alf.files.padded_sequence function ensures file paths contain zero-padded experiment sequence folder

### Modified

- one.remote.globus.get_local_endpoint_id no longer prints ID to std out; uses debug log instead

## [2.3.0]

### Added

Expand Down
2 changes: 1 addition & 1 deletion one/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
"""The Open Neurophysiology Environment (ONE) API."""
__version__ = '2.3.0'
__version__ = '2.4.0'
36 changes: 36 additions & 0 deletions one/alf/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,3 +446,39 @@ def remove_uuid_string(file_path):
if spec.is_uuid_string(name_parts[-1]):
file_path = file_path.with_name('.'.join(name_parts[:-1]) + file_path.suffix)
return file_path


def padded_sequence(filepath):
"""
Ensures a file path contains a zero-padded experiment sequence folder.
Parameters
----------
filepath : str, pathlib.Path, pathlib.PurePath
A session or file path to convert.
Returns
-------
pathlib.Path, pathlib.PurePath
The same path but with the experiment sequence folder zero-padded. If a PurePath was
passed, a PurePath will be returned, otherwise a Path object is returned.
Examples
--------
>>> filepath = '/iblrigdata/subject/2023-01-01/1/_ibl_experiment.description.yaml'
>>> padded_sequence(filepath)
pathlib.Path('/iblrigdata/subject/2023-01-01/001/_ibl_experiment.description.yaml')
Supports folders and will not affect already padded paths
>>> session_path = pathlib.PurePosixPath('subject/2023-01-01/001')
>>> padded_sequence(filepath)
pathlib.PurePosixPath('subject/2023-01-01/001')
"""
if isinstance(filepath, str):
filepath = Path(filepath)
if (session_path := get_session_path(filepath)) is None:
raise ValueError('path must include a valid ALF session path, e.g. subject/YYYY-MM-DD/N')
idx = len(filepath.parts) - len(session_path.parts)
sequence = str(int(session_path.parts[-1])).zfill(3) # zero-pad if necessary
return filepath.parents[idx].joinpath(sequence, filepath.relative_to(session_path))
25 changes: 25 additions & 0 deletions one/remote/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,31 @@ def get_s3_virtual_host(uri, region) -> str:
return 'https://' + '/'.join((hostname, *key))


def url2uri(data_path, return_location=False):
"""
Convert a generic Amazon virtual host URL to an S3 URI.
Parameters
----------
data_path : str
An Amazon virtual host URL to convert.
return_location : bool
If true, additionally returns the location string.
Returns
-------
str
An S3 URI with scheme 's3://'.
str
If return_location is true, returns the bucket location, e.g. 'eu-east-1'.
"""
parsed = urllib.parse.urlparse(data_path)
assert parsed.netloc and parsed.scheme and parsed.path
bucket_name, _, loc, *_ = parsed.netloc.split('.')
uri = f's3://{bucket_name}{parsed.path}'
return (uri, loc) if return_location else uri


def is_folder(obj_summery) -> bool:
"""
Given an S3 ObjectSummery instance, returns true if the associated object is a directory.
Expand Down
53 changes: 41 additions & 12 deletions one/remote/globus.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def get_local_endpoint_id():
assert id_file.exists(), msg.format(id_file)
local_id = id_file.read_text().strip()
assert isinstance(local_id, str), msg.format(id_file)
print(f'Found local endpoint ID in Globus Connect settings {local_id}')
_logger.debug(f'Found local endpoint ID in Globus Connect settings {local_id}')
return UUID(local_id)


Expand Down Expand Up @@ -517,15 +517,20 @@ def is_logged_in(self):

@property
def _token_expired(self):
"""bool or None: True if token expired; False if still valid; None if token not present.
"""bool: True if token absent or expired; False if valid
Note the 'expires_at_seconds' may be greater than `Globus.client.authorizer.expires_at` if
using refresh tokens. The `login` method will always refresh the token if still valid.
"""
try:
return getattr(self._pars, 'expires_at_seconds') - datetime.utcnow().timestamp() < 60
except AttributeError:
return
authorizer = getattr(self.client, 'authorizer', None)
has_refresh_token = self._pars.as_dict().get('refresh_token') is not None
if has_refresh_token and isinstance(authorizer, globus_sdk.RefreshTokenAuthorizer):
self.client.authorizer.ensure_valid_token() # Fetch new refresh token if needed
except Exception as ex:
_logger.debug('Failed to refresh token: %s', ex)
expires_at_seconds = getattr(self._pars, 'expires_at_seconds', 0)
return expires_at_seconds - datetime.utcnow().timestamp() < 60

def login(self, stay_logged_in=None):
"""
Expand All @@ -540,15 +545,19 @@ def login(self, stay_logged_in=None):
"""
if self.is_logged_in:
_logger.debug('Already logged in')
self.client.authorizer.ensure_valid_token() # refresh token if necessary
return

# Default depends on refresh token
stay_logged_in = True if stay_logged_in is None else stay_logged_in
expired = bool(
self._pars.as_dict().get('refresh_token') is None
if stay_logged_in else self._token_expired
)
# If no tokens in parameters, Globus must be authenticated
required_fields = {'refresh_token', 'access_token', 'expires_at_seconds'}
if not required_fields.issubset(iopar.as_dict(self._pars)) or self._token_expired:
if not required_fields.issubset(iopar.as_dict(self._pars)) or expired:
if self.headless:
raise RuntimeError(f'Globus not authenticated for client "{self.client_name}"')
stay_logged_in = True if stay_logged_in is None else stay_logged_in
token = get_token(self._pars.GLOBUS_CLIENT_ID, refresh_tokens=stay_logged_in)
if not any(token.values()):
_logger.debug('Login cancelled by user')
Expand All @@ -572,18 +581,38 @@ def logout(self):

def _authenticate(self, stay_logged_in=None):
"""Authenticate and instantiate Globus SDK client."""
if self._token_expired is not False:
raise RuntimeError(f'token no longer valid for client "{self.client_name}"')
if self._pars.as_dict().get('refresh_token') and stay_logged_in is not False:
client = globus_sdk.NativeAppAuthClient(self._pars.GLOBUS_CLIENT_ID)
client.oauth2_start_flow(refresh_tokens=True)
authorizer = globus_sdk.RefreshTokenAuthorizer(self._pars.refresh_token, client)
authorizer = globus_sdk.RefreshTokenAuthorizer(
self._pars.refresh_token, client, on_refresh=self._save_refresh_token_callback)
else:
if stay_logged_in is True:
warnings.warn('No refresh token. Please log out and back in to remain logged in.')
if self._token_expired is not False:
raise RuntimeError(f'token no longer valid for client "{self.client_name}"')
authorizer = globus_sdk.AccessTokenAuthorizer(self._pars.access_token)
self.client = globus_sdk.TransferClient(authorizer=authorizer)

def _save_refresh_token_callback(self, res):
"""
Save a token fetched by the refresh token authorizer.
This is a callback for the globus_sdk.RefreshTokenAuthorizer to update the parameters.
Parameters
----------
res : globus_sdk.services.auth.OAuthTokenResponse
An Open Authorization response object.
"""
if not res or not (token := next(iter(res.by_resource_server.values()), None)):
return
token_fields = {'refresh_token', 'access_token', 'expires_at_seconds'}
self._pars = iopar.from_dict(
{**self._pars.as_dict(), **{k: v for k, v in token.items() if k in token_fields}})
_save_globus_params(self._pars, self.client_name)

def fetch_endpoints_from_alyx(self, alyx=None, overwrite=False):
"""
Update endpoints property with Alyx Globus data repositories.
Expand Down Expand Up @@ -685,7 +714,7 @@ def download_file(self, file_address, source_endpoint, recursive=False, **kwargs
return_single = isinstance(file_address, str) and recursive is False
kwargs['label'] = kwargs.get('label', 'ONE download')
task = partial(self.transfer_data, file_address, source_endpoint, 'local',
sync_level='mtime', recursive=recursive, **kwargs)
recursive=recursive, **kwargs)
task_id = self.run_task(task)
files = []
root = Path(self.endpoints['local']['root_path'])
Expand Down
17 changes: 16 additions & 1 deletion one/tests/alf/test_alf_files.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Unit tests for the one.alf.files module."""
import unittest
from pathlib import Path
from pathlib import Path, PureWindowsPath
import uuid

import one.alf.files as files
Expand Down Expand Up @@ -178,6 +178,21 @@ def test_remove_uuid(self):
desired_output = Path('toto.npy')
self.assertEqual(desired_output, files.remove_uuid_string(file_path))

def test_padded_sequence(self):
"""Test for one.alf.files.padded_sequence."""
# Test with pure path file input
filepath = PureWindowsPath(r'F:\ScanImageAcquisitions\subject\2023-01-01\1\foo\bar.baz')
expected = PureWindowsPath(r'F:\ScanImageAcquisitions\subject\2023-01-01\001\foo\bar.baz')
self.assertEqual(files.padded_sequence(filepath), expected)

# Test with str input session path
session_path = '/mnt/s0/Data/Subjects/subject/2023-01-01/001'
expected = Path('/mnt/s0/Data/Subjects/subject/2023-01-01/001')
self.assertEqual(files.padded_sequence(session_path), expected)

# Test invalid ALF session path
self.assertRaises(ValueError, files.padded_sequence, '/foo/bar/baz')


class TestALFGet(unittest.TestCase):
"""Tests for path extraction functions"""
Expand Down
10 changes: 9 additions & 1 deletion one/tests/remote/test_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ class TestUtils(unittest.TestCase):
"""Tests for one.remote.aws utility functions"""

def test_get_s3_virtual_host(self):
"""Tests for one.remote.aws.get_s3_virtual_host function"""
"""Tests for one.remote.aws.get_s3_virtual_host function."""
expected = 'https://my-s3-bucket.s3.eu-east-1.amazonaws.com/'
url = aws.get_s3_virtual_host('s3://my-s3-bucket', 'eu-east-1')
self.assertEqual(expected, url)
Expand All @@ -135,6 +135,14 @@ def test_get_s3_virtual_host(self):
with self.assertRaises(AssertionError):
aws.get_s3_virtual_host('s3://my-s3-bucket/path/to/file', 'wrong-foo-4')

def test_url2uri(self):
"""Tests for one.remote.aws.url2uri function."""
url = 'https://my-s3-bucket.s3.eu-east-1.amazonaws.com/path/to/file'
expected = 's3://my-s3-bucket/path/to/file'
self.assertEqual(aws.url2uri(url), expected)
uri, loc = aws.url2uri(url, return_location=True)
self.assertEqual(loc, 'eu-east-1')


if __name__ == '__main__':
unittest.main(exit=False)
35 changes: 34 additions & 1 deletion one/tests/remote/test_globus.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

try:
import globus_sdk
import globus_sdk.services.auth
except ModuleNotFoundError:
raise unittest.skip('globus_sdk module not installed')
from iblutil.io import params as iopar
Expand Down Expand Up @@ -572,7 +573,11 @@ def test_globus_headless(self):
def test_login_logout(self):
"""Test for Globus.login and Globus.logout methods."""
assert self.globus.is_logged_in
sdk_mock, _ = self.globus_sdk_mock.get_original()
with self.assertLogs('one.remote.globus', 10):
# Token validator checks token auth class, which is mocked, so here we set the
# RefreshTokenAuthorizer to a MagicMock so that the types match
sdk_mock.RefreshTokenAuthorizer = mock.MagicMock
self.globus.login()
self.globus.client.authorizer.ensure_valid_token.assert_called()

Expand All @@ -590,7 +595,7 @@ def test_login_logout(self):
self.assertFalse(hasattr(self.globus.client.authorizer, 'access_token'))
self.assertFalse(hasattr(self.globus._pars, 'access_token'))
self.assertFalse(self.globus.is_logged_in)
self.assertIsNone(self.globus._token_expired)
self.assertTrue(self.globus._token_expired)

# Test what happens when authenticate called with invalid token
self.assertRaises(RuntimeError, self.globus._authenticate)
Expand Down Expand Up @@ -621,6 +626,34 @@ def test_login_logout(self):
self.assertWarns(UserWarning, self.globus.login, stay_logged_in=True)
self.assertTrue(self.globus.is_logged_in)

def test_save_refresh_token_callback(self):
"""Test for Globus._save_refresh_token_callback method."""
assert hasattr(self.globus._pars, 'refresh_token')
token = {'refresh_token': '567', 'access_token': 'abc', 'expires_at_seconds': 100000000}
res = mock.MagicMock(spec=globus_sdk.services.auth.OAuthTokenResponse)
res.by_resource_server = dict(server=token)

# Check behaviour when called with Globus auth response
with mock.patch('one.remote.globus.save_client_params') as client_params_mock:
self.globus._save_refresh_token_callback(res)
client_params_mock.assert_called_once()
(pars, name), _ = client_params_mock.call_args
self.assertEqual(name, 'globus')
self.assertIn(self.globus.client_name, pars)
self.assertTrue(set(pars[self.globus.client_name]) >= set(globus.DEFAULT_PAR))
for k, v in token.items():
with self.subTest(k):
self.assertEqual(pars[self.globus.client_name].get(k), v)
# Obj params should be modified
par_vals = map(partial(getattr, self.globus._pars), token.keys())
self.assertCountEqual(par_vals, token.values())

# Check behaviour when called with empty auth response
res.by_resource_server = dict()
with mock.patch('one.remote.globus.save_client_params') as client_params_mock:
self.globus._save_refresh_token_callback(res)
client_params_mock.assert_not_called()


class TestGlobusAsync(unittest.IsolatedAsyncioTestCase, _GlobusClientTest):
"""Asynchronous Globus method tests."""
Expand Down

0 comments on commit 31bfe65

Please sign in to comment.