diff --git a/one/alf/cache.py b/one/alf/cache.py index 572ba210..9ca1fe26 100644 --- a/one/alf/cache.py +++ b/one/alf/cache.py @@ -11,7 +11,6 @@ >>> one = One(cache_dir=cache_dir) """ - # ------------------------------------------------------------------------------------------------- # Imports # ------------------------------------------------------------------------------------------------- @@ -24,42 +23,57 @@ import logging import pandas as pd +import numpy as np +from packaging import version from iblutil.io import parquet from iblutil.io.hashfile import md5 +from one.alf.spec import QC from one.alf.io import iter_sessions, iter_datasets from one.alf.files import session_path_parts, get_alf_path -from one.converters import session_record2path -from one.util import QC_TYPE, patch_cache -__all__ = [ - 'make_parquet_db', 'remove_missing_datasets', 'remove_cache_table_files', - 'DATASETS_COLUMNS', 'SESSIONS_COLUMNS'] +__all__ = ['make_parquet_db', 'patch_cache', 'remove_missing_datasets', + 'remove_cache_table_files', 'EMPTY_DATASETS_FRAME', 'EMPTY_SESSIONS_FRAME', 'QC_TYPE'] _logger = logging.getLogger(__name__) # ------------------------------------------------------------------------------------------------- # Global variables # ------------------------------------------------------------------------------------------------- -SESSIONS_COLUMNS = ( - 'id', # int64 - 'lab', # str - 'subject', # str - 'date', # datetime.date - 'number', # int - 'task_protocol', # str - 'projects', # str -) - -DATASETS_COLUMNS = ( - 'id', # int64 - 'eid', # int64 - 'rel_path', # relative to the session path, includes the filename - 'file_size', # file size in bytes - 'hash', # sha1/md5, computed in load function - 'exists', # bool - 'qc', # one.util.QC_TYPE -) +QC_TYPE = pd.CategoricalDtype(categories=[e.name for e in sorted(QC)], ordered=True) +"""pandas.api.types.CategoricalDtype: The cache table QC column data type.""" + +SESSIONS_COLUMNS = { + 'id': object, # str + 'lab': object, # str + 'subject': object, # str + 'date': object, # datetime.date + 'number': np.uint16, # int + 'task_protocol': object, # str + 'projects': object # str +} +"""dict: A map of sessions table fields and their data types.""" + +DATASETS_COLUMNS = { + 'eid': object, # str + 'id': object, # str + 'rel_path': object, # relative to the session path, includes the filename + 'file_size': 'UInt64', # file size in bytes (nullable) + 'hash': object, # sha1/md5, computed in load function + 'exists': bool, # bool + 'qc': QC_TYPE # one.alf.spec.QC enumeration +} +"""dict: A map of datasets table fields and their data types.""" + +EMPTY_DATASETS_FRAME = (pd.DataFrame(columns=DATASETS_COLUMNS) + .astype(DATASETS_COLUMNS) + .set_index(['eid', 'id'])) +"""pandas.DataFrame: An empty datasets dataframe with correct columns and dtypes.""" + +EMPTY_SESSIONS_FRAME = (pd.DataFrame(columns=SESSIONS_COLUMNS) + .astype(SESSIONS_COLUMNS) + .set_index('id')) +"""pandas.DataFrame: An empty sessions dataframe with correct columns and dtypes.""" # ------------------------------------------------------------------------------------------------- @@ -103,8 +117,7 @@ def _rel_path_to_uuid(df, id_key='rel_path', base_id=None, keep_old=False): toUUID = partial(uuid.uuid3, base_id) # MD5 hash from base uuid and rel session path string if keep_old: df[f'{id_key}_'] = df[id_key].copy() - df[id_key] = df[id_key].apply(lambda x: str(toUUID(x))) - assert len(df[id_key].unique()) == len(df[id_key]) # WARNING This fails :( + df.loc[:, id_key] = df.groupby(id_key)[id_key].transform(lambda x: str(toUUID(x.name))) return df @@ -173,7 +186,7 @@ def _make_sessions_df(root_dir) -> pd.DataFrame: ses_info = _get_session_info(rel_path) assert set(ses_info.keys()) <= set(SESSIONS_COLUMNS) rows.append(ses_info) - df = pd.DataFrame(rows, columns=SESSIONS_COLUMNS) + df = pd.DataFrame(rows, columns=SESSIONS_COLUMNS).astype(SESSIONS_COLUMNS) return df @@ -193,7 +206,7 @@ def _make_datasets_df(root_dir, hash_files=False) -> pd.DataFrame: pandas.DataFrame A pandas DataFrame of dataset info. """ - df = pd.DataFrame([], columns=DATASETS_COLUMNS).astype({'qc': QC_TYPE}) + df = EMPTY_DATASETS_FRAME.copy() # Go through sessions and append datasets for session_path in iter_sessions(root_dir): rows = [] @@ -201,7 +214,7 @@ def _make_datasets_df(root_dir, hash_files=False) -> pd.DataFrame: file_info = _get_dataset_info(session_path, rel_dset_path, compute_hash=hash_files) assert set(file_info.keys()) <= set(DATASETS_COLUMNS) rows.append(file_info) - df = pd.concat((df, pd.DataFrame(rows, columns=DATASETS_COLUMNS)), + df = pd.concat((df, pd.DataFrame(rows, columns=DATASETS_COLUMNS).astype(DATASETS_COLUMNS)), ignore_index=True, verify_integrity=True) return df.astype({'qc': QC_TYPE}) @@ -308,6 +321,7 @@ def remove_missing_datasets(cache_dir, tables=None, remove_empty_sessions=True, tables[name].set_index(idx_columns, inplace=True) to_delete = set() + from one.converters import session_record2path # imported here due to circular imports gen_path = partial(session_record2path, root_dir=cache_dir) # map of session path to eid sessions = {gen_path(rec): eid for eid, rec in tables['sessions'].iterrows()} @@ -367,3 +381,53 @@ def remove_cache_table_files(folder, tables=('sessions', 'datasets')): else: _logger.warning('%s not found', file) return removed + + +def _cache_int2str(table: pd.DataFrame) -> pd.DataFrame: + """Convert int ids to str ids for cache table. + + Parameters + ---------- + table : pd.DataFrame + A cache table (from One._cache). + + """ + # Convert integer uuids to str uuids + if table.index.nlevels < 2 or not any(x.endswith('_0') for x in table.index.names): + return table + table = table.reset_index() + int_cols = table.filter(regex=r'_\d{1}$').columns.sort_values() + assert not len(int_cols) % 2, 'expected even number of columns ending in _0 or _1' + names = sorted(set(c.rsplit('_', 1)[0] for c in int_cols.values)) + for i, name in zip(range(0, len(int_cols), 2), names): + table[name] = parquet.np2str(table[int_cols[i:i + 2]]) + table = table.drop(int_cols, axis=1).set_index(names) + return table + + +def patch_cache(table: pd.DataFrame, min_api_version=None, name=None) -> pd.DataFrame: + """Reformat older cache tables to comply with this version of ONE. + + Currently this function will 1. convert integer UUIDs to string UUIDs; 2. rename the 'project' + column to 'projects'; 3. add QC column; 4. drop session_path column. + + Parameters + ---------- + table : pd.DataFrame + A cache table (from One._cache). + min_api_version : str + The minimum API version supported by this cache table. + name : {'dataset', 'session'} str + The name of the table. + """ + min_version = version.parse(min_api_version or '0.0.0') + table = _cache_int2str(table) + # Rename project column + if min_version < version.Version('1.13.0') and 'project' in table.columns: + table.rename(columns={'project': 'projects'}, inplace=True) + if name == 'datasets' and min_version < version.Version('2.7.0') and 'qc' not in table.columns: + qc = pd.Categorical.from_codes(np.zeros(len(table.index), dtype=int), dtype=QC_TYPE) + table = table.assign(qc=qc) + if name == 'datasets' and 'session_path' in table.columns: + table = table.drop('session_path', axis=1) + return table diff --git a/one/api.py b/one/api.py index 1d6503fc..00da1bca 100644 --- a/one/api.py +++ b/one/api.py @@ -21,7 +21,7 @@ import packaging.version from iblutil.io import parquet, hashfile -from iblutil.util import Bunch, flatten, ensure_list +from iblutil.util import Bunch, flatten, ensure_list, Listable import one.params import one.webclient as wc @@ -29,10 +29,12 @@ import one.alf.files as alfiles import one.alf.exceptions as alferr from .alf.cache import ( - make_parquet_db, remove_cache_table_files, DATASETS_COLUMNS, SESSIONS_COLUMNS) + make_parquet_db, patch_cache, remove_cache_table_files, + EMPTY_DATASETS_FRAME, EMPTY_SESSIONS_FRAME +) from .alf.spec import is_uuid_string, QC, to_alf from . import __version__ -from one.converters import ConversionMixin, session_record2path +from one.converters import ConversionMixin, session_record2path, ses2records, datasets2records from one import util _logger = logging.getLogger(__name__) @@ -101,8 +103,8 @@ def search_terms(self, query_type=None) -> tuple: def _reset_cache(self): """Replace the cache object with a Bunch that contains the right fields.""" self._cache = Bunch({ - 'datasets': pd.DataFrame(columns=DATASETS_COLUMNS).set_index(['eid', 'id']), - 'sessions': pd.DataFrame(columns=SESSIONS_COLUMNS).set_index('id'), + 'datasets': EMPTY_DATASETS_FRAME.copy(), + 'sessions': EMPTY_SESSIONS_FRAME.copy(), '_meta': { 'expired': False, 'created_time': None, @@ -160,7 +162,7 @@ def load_cache(self, tables_dir=None, **kwargs): cache.set_index(idx_columns, inplace=True) # Patch older tables - cache = util.patch_cache(cache, meta['raw'][table].get('min_api_version'), table) + cache = patch_cache(cache, meta['raw'][table].get('min_api_version'), table) # Check sorted # Sorting makes MultiIndex indexing O(N) -> O(1) @@ -173,6 +175,9 @@ def load_cache(self, tables_dir=None, **kwargs): # No tables present meta['expired'] = True meta['raw'] = {} + self._cache.update({ + 'datasets': EMPTY_DATASETS_FRAME.copy(), + 'sessions': EMPTY_SESSIONS_FRAME.copy()}) if self.offline: # In online mode, the cache tables should be downloaded later warnings.warn(f'No cache tables found in {self._tables_dir}') created = [datetime.fromisoformat(x['date_created']) @@ -286,7 +291,7 @@ def _update_cache_from_records(self, strict=False, **kwargs): Example ------- - >>> session, datasets = util.ses2records(self.get_details(eid, full=True)) + >>> session, datasets = ses2records(self.get_details(eid, full=True)) ... self._update_cache_from_records(sessions=session, datasets=datasets) Raises @@ -599,7 +604,7 @@ def _check_filesystem(self, datasets, offline=None, update_exists=True, check_ha datasets.index.set_names(idx_names, inplace=True) elif not isinstance(datasets, pd.DataFrame): # Cast set of dicts (i.e. from REST datasets endpoint) - datasets = util.datasets2records(list(datasets)) + datasets = datasets2records(list(datasets)) else: datasets = datasets.copy() indices_to_download = [] # indices of datasets that need (re)downloading @@ -1862,7 +1867,7 @@ def list_datasets( eid = self.to_eid(eid) # Ensure we have a UUID str list if not eid: return self._cache['datasets'].iloc[0:0] if details else [] # Return empty - session, datasets = util.ses2records(self.alyx.rest('sessions', 'read', id=eid)) + session, datasets = ses2records(self.alyx.rest('sessions', 'read', id=eid)) # Add to cache tables self._update_cache_from_records(sessions=session, datasets=datasets.copy()) if datasets is None or datasets.empty: @@ -1907,7 +1912,7 @@ def list_aggregates(self, relation: str, identifier: str = None, """ query = 'session__isnull,True' # ',data_repository_name__endswith,aggregates' all_aggregates = self.alyx.rest('datasets', 'list', django=query) - records = (util.datasets2records(all_aggregates) + records = (datasets2records(all_aggregates) .reset_index(level=0) .drop('eid', axis=1)) # Since rel_path for public FI file records starts with 'public/aggregates' instead of just @@ -2654,7 +2659,7 @@ def setup(base_url=None, **kwargs): @util.refresh @util.parse_id - def eid2path(self, eid, query_type=None) -> util.Listable(Path): + def eid2path(self, eid, query_type=None) -> Listable(Path): """ From an experiment ID gets the local session path @@ -2693,7 +2698,7 @@ def eid2path(self, eid, query_type=None) -> util.Listable(Path): str(ses[0]['number']).zfill(3)) @util.refresh - def path2eid(self, path_obj: Union[str, Path], query_type=None) -> util.Listable(Path): + def path2eid(self, path_obj: Union[str, Path], query_type=None) -> Listable(Path): """ From a local path, gets the experiment ID @@ -2797,7 +2802,7 @@ def type2datasets(self, eid, dataset_type, details=False): restriction = f'session__id,{eid},dataset_type__name__in,{dataset_type}' else: raise TypeError('dataset_type must be a str or str list') - datasets = util.datasets2records(self.alyx.rest('datasets', 'list', django=restriction)) + datasets = datasets2records(self.alyx.rest('datasets', 'list', django=restriction)) return datasets if details else datasets['rel_path'].sort_values().values def dataset2type(self, dset) -> str: diff --git a/one/converters.py b/one/converters.py index 20b80ea7..e2259ab1 100644 --- a/one/converters.py +++ b/one/converters.py @@ -10,17 +10,19 @@ import re import functools import datetime +import urllib.parse from uuid import UUID from inspect import unwrap from pathlib import Path, PurePosixPath from typing import Optional, Union, Mapping, List, Iterable as Iter import pandas as pd -from iblutil.util import Bunch +from iblutil.util import Bunch, Listable, ensure_list from one.alf.spec import is_session_path, is_uuid_string -from one.alf.files import get_session_path, add_uuid_string, session_path_parts, get_alf_path -from .util import Listable +from one.alf.cache import QC_TYPE, EMPTY_DATASETS_FRAME +from one.alf.files import ( + get_session_path, add_uuid_string, session_path_parts, get_alf_path, remove_uuid_string) def recurse(func): @@ -738,3 +740,97 @@ def session_record2path(session, root_dir=None): elif isinstance(root_dir, str): root_dir = Path(root_dir) return Path(root_dir).joinpath(rel_path) + + +def ses2records(ses: dict): + """Extract session cache record and datasets cache from a remote session data record. + + Parameters + ---------- + ses : dict + Session dictionary from Alyx REST endpoint. + + Returns + ------- + pd.Series + Session record. + pd.DataFrame + Datasets frame. + """ + # Extract session record + eid = ses['url'][-36:] + session_keys = ('subject', 'start_time', 'lab', 'number', 'task_protocol', 'projects') + session_data = {k: v for k, v in ses.items() if k in session_keys} + session = ( + pd.Series(data=session_data, name=eid).rename({'start_time': 'date'}) + ) + session['projects'] = ','.join(session.pop('projects')) + session['date'] = datetime.datetime.fromisoformat(session['date']).date() + + # Extract datasets table + def _to_record(d): + rec = dict(file_size=d['file_size'], hash=d['hash'], exists=True, id=d['id']) + rec['eid'] = session.name + file_path = urllib.parse.urlsplit(d['data_url'], allow_fragments=False).path.strip('/') + file_path = get_alf_path(remove_uuid_string(file_path)) + session_path = get_session_path(file_path).as_posix() + rec['rel_path'] = file_path[len(session_path):].strip('/') + rec['default_revision'] = d['default_revision'] == 'True' + rec['qc'] = d.get('qc', 'NOT_SET') + return rec + + if not ses.get('data_dataset_session_related'): + return session, EMPTY_DATASETS_FRAME.copy() + records = map(_to_record, ses['data_dataset_session_related']) + index = ['eid', 'id'] + datasets = pd.DataFrame(records).set_index(index).sort_index().astype({'qc': QC_TYPE}) + return session, datasets + + +def datasets2records(datasets, additional=None) -> pd.DataFrame: + """Extract datasets DataFrame from one or more Alyx dataset records. + + Parameters + ---------- + datasets : dict, list + One or more records from the Alyx 'datasets' endpoint. + additional : list of str + A set of optional fields to extract from dataset records. + + Returns + ------- + pd.DataFrame + Datasets frame. + + Examples + -------- + >>> datasets = ONE().alyx.rest('datasets', 'list', subject='foobar') + >>> df = datasets2records(datasets) + """ + records = [] + + for d in ensure_list(datasets): + file_record = next((x for x in d['file_records'] if x['data_url'] and x['exists']), None) + if not file_record: + continue # Ignore files that are not accessible + rec = dict(file_size=d['file_size'], hash=d['hash'], exists=True) + rec['id'] = d['url'][-36:] + rec['eid'] = (d['session'] or '')[-36:] + data_url = urllib.parse.urlsplit(file_record['data_url'], allow_fragments=False) + file_path = get_alf_path(data_url.path.strip('/')) + file_path = remove_uuid_string(file_path).as_posix() + session_path = get_session_path(file_path) or '' + if session_path: + session_path = session_path.as_posix() + rec['rel_path'] = file_path[len(session_path):].strip('/') + rec['default_revision'] = d['default_dataset'] + rec['qc'] = d.get('qc') + for field in additional or []: + rec[field] = d.get(field) + records.append(rec) + + index = ['eid', 'id'] + if not records: + keys = (*index, 'file_size', 'hash', 'session_path', 'rel_path', 'default_revision', 'qc') + return pd.DataFrame(columns=keys).set_index(index) + return pd.DataFrame(records).set_index(index).sort_index().astype({'qc': QC_TYPE}) diff --git a/one/tests/test_converters.py b/one/tests/test_converters.py index 79aa0b78..5f639b46 100644 --- a/one/tests/test_converters.py +++ b/one/tests/test_converters.py @@ -10,6 +10,7 @@ from one.api import ONE from one import converters from one.alf.files import add_uuid_string +from one.alf.cache import EMPTY_DATASETS_FRAME from . import util, OFFLINE_ONLY, TEST_DB_2 @@ -344,6 +345,55 @@ def test_eid2pid(self): for d in det: self.assertTrue(set(d.keys()) >= expected_keys) + def test_ses2records(self): + """Test one.converters.ses2records function.""" + ses = self.one.alyx.rest('sessions', 'read', id=self.eid) + session, datasets = converters.ses2records(ses) + + # Verify returned tables are compatible with cache tables + self.assertIsInstance(session, pd.Series) + self.assertIsInstance(datasets, pd.DataFrame) + self.assertEqual(session.name, self.eid) + self.assertCountEqual(session.keys(), self.one._cache['sessions'].columns) + self.assertEqual(len(datasets), len(ses['data_dataset_session_related'])) + expected = list(EMPTY_DATASETS_FRAME.columns) + ['default_revision'] + self.assertCountEqual(expected, datasets.columns) + self.assertEqual(tuple(datasets.index.names), ('eid', 'id')) + self.assertIsInstance(datasets.qc.dtype, pd.CategoricalDtype) + + # Check behaviour when no datasets present + ses['data_dataset_session_related'] = [] + _, datasets = converters.ses2records(ses) + self.assertTrue(datasets.empty) + + def test_datasets2records(self): + """Test one.converters.datasets2records function.""" + dsets = self.one.alyx.rest('datasets', 'list', session=self.eid) + datasets = converters.datasets2records(dsets) + + # Verify returned tables are compatible with cache tables + self.assertIsInstance(datasets, pd.DataFrame) + self.assertTrue(len(datasets) >= len(dsets)) + expected = list(EMPTY_DATASETS_FRAME.columns) + ['default_revision'] + self.assertCountEqual(expected, datasets.columns) + self.assertEqual(tuple(datasets.index.names), ('eid', 'id')) + self.assertIsInstance(datasets.qc.dtype, pd.CategoricalDtype) + + # Test extracts additional fields + fields = ('url', 'auto_datetime') + datasets = converters.datasets2records(dsets, additional=fields) + self.assertTrue(set(datasets.columns) >= set(fields)) + self.assertTrue(all(datasets['url'].str.startswith('http'))) + + # Test single input + dataset = converters.datasets2records(dsets[0]) + self.assertTrue(len(dataset) == 1) + # Test records when data missing + for fr in dsets[0]['file_records']: + fr['exists'] = False + empty = converters.datasets2records(dsets[0]) + self.assertTrue(isinstance(empty, pd.DataFrame) and empty.empty) + class TestAlyx2Path(unittest.TestCase): dset = { diff --git a/one/tests/test_one.py b/one/tests/test_one.py index 26b3543b..fbf4e750 100644 --- a/one/tests/test_one.py +++ b/one/tests/test_one.py @@ -47,11 +47,12 @@ from one import __version__ from one.api import ONE, One, OneAlyx from one.util import ( - ses2records, validate_date_range, index_last_before, filter_datasets, _collection_spec, - filter_revision_last_before, parse_id, autocomplete, LazyId, datasets2records, ensure_list + validate_date_range, index_last_before, filter_datasets, _collection_spec, + filter_revision_last_before, parse_id, autocomplete, LazyId, ensure_list ) import one.params import one.alf.exceptions as alferr +from one.converters import datasets2records from one.alf import spec from one.alf.files import get_alf_path from . import util @@ -1105,57 +1106,6 @@ def test_dataset2type(self): with self.assertRaises(ValueError): self.one.dataset2type(bad_id) - def test_ses2records(self): - """Test one.util.ses2records""" - eid = '8dd0fcb0-1151-4c97-ae35-2e2421695ad7' - ses = self.one.alyx.rest('sessions', 'read', id=eid) - session, datasets = ses2records(ses) - - # Verify returned tables are compatible with cache tables - self.assertIsInstance(session, pd.Series) - self.assertIsInstance(datasets, pd.DataFrame) - self.assertEqual(session.name, eid) - self.assertCountEqual(session.keys(), self.one._cache['sessions'].columns) - self.assertEqual(len(datasets), len(ses['data_dataset_session_related'])) - expected = [x for x in self.one._cache['datasets'].columns] + ['default_revision'] - self.assertCountEqual(expected, datasets.columns) - self.assertEqual(tuple(datasets.index.names), ('eid', 'id')) - self.assertTrue(datasets.default_revision.all()) - self.assertIsInstance(datasets.qc.dtype, pd.CategoricalDtype) - - # Check behaviour when no datasets present - ses['data_dataset_session_related'] = [] - _, datasets = ses2records(ses) - self.assertTrue(datasets.empty) - - def test_datasets2records(self): - """Test one.util.datasets2records""" - eid = '8dd0fcb0-1151-4c97-ae35-2e2421695ad7' - dsets = self.one.alyx.rest('datasets', 'list', session=eid) - datasets = datasets2records(dsets) - - # Verify returned tables are compatible with cache tables - self.assertIsInstance(datasets, pd.DataFrame) - self.assertTrue(len(datasets) >= len(dsets)) - expected = self.one._cache['datasets'].columns - self.assertCountEqual(expected, (x for x in datasets.columns if x != 'default_revision')) - self.assertEqual(tuple(datasets.index.names), ('eid', 'id')) - self.assertIsInstance(datasets.qc.dtype, pd.CategoricalDtype) - - # Test extracts additional fields - fields = ('url', 'auto_datetime') - datasets = datasets2records(dsets, additional=fields) - self.assertTrue(set(datasets.columns) >= set(fields)) - self.assertTrue(all(datasets['url'].str.startswith('http'))) - - # Test single input - dataset = datasets2records(dsets[0]) - self.assertTrue(len(dataset) == 1) - # Test records when data missing - dsets[0]['file_records'][0]['exists'] = False - empty = datasets2records(dsets[0]) - self.assertTrue(isinstance(empty, pd.DataFrame) and len(empty) == 0) - def test_pid2eid(self): """Test OneAlyx.pid2eid""" pid = 'b529f2d8-cdae-4d59-aba2-cbd1b5572e36' @@ -1464,7 +1414,7 @@ def test_list_datasets(self): self.assertEqual(len(dsets), 0) # Test empty datasets - with mock.patch('one.util.ses2records', return_value=(pd.DataFrame(), pd.DataFrame())): + with mock.patch('one.api.ses2records', return_value=(pd.DataFrame(), pd.DataFrame())): dsets = self.one.list_datasets(self.eid, details=True, query_type='remote') self.assertIsInstance(dsets, pd.DataFrame) self.assertEqual(len(dsets), 0) diff --git a/one/tests/util.py b/one/tests/util.py index 98286c8d..bb176741 100644 --- a/one/tests/util.py +++ b/one/tests/util.py @@ -10,7 +10,7 @@ from iblutil.io.params import set_hidden import one.params -from one.util import QC_TYPE +from one.alf.cache import QC_TYPE from one.converters import session_record2path diff --git a/one/util.py b/one/util.py index 056da194..5c27a013 100644 --- a/one/util.py +++ b/one/util.py @@ -1,128 +1,22 @@ """Decorators and small standalone functions for api module.""" import re import logging -import urllib.parse import fnmatch import warnings from functools import wraps, partial -from typing import Sequence, Union, Iterable, Optional, List +from typing import Iterable, Optional, List from collections.abc import Mapping -from datetime import datetime import pandas as pd -from iblutil.io import parquet from iblutil.util import ensure_list as _ensure_list import numpy as np -from packaging import version import one.alf.exceptions as alferr -from one.alf.files import rel_path_parts, get_session_path, get_alf_path, remove_uuid_string +from one.alf.files import rel_path_parts from one.alf.spec import QC, FILE_SPEC, regex as alf_regex logger = logging.getLogger(__name__) -QC_TYPE = pd.CategoricalDtype(categories=[e.name for e in sorted(QC)], ordered=True) -"""pandas.api.types.CategoricalDtype: The cache table QC column data type.""" - - -def Listable(t): - """Return a typing.Union if the input and sequence of input.""" - return Union[t, Sequence[t]] - - -def ses2records(ses: dict): - """Extract session cache record and datasets cache from a remote session data record. - - Parameters - ---------- - ses : dict - Session dictionary from Alyx REST endpoint. - - Returns - ------- - pd.Series - Session record. - pd.DataFrame - Datasets frame. - """ - # Extract session record - eid = ses['url'][-36:] - session_keys = ('subject', 'start_time', 'lab', 'number', 'task_protocol', 'projects') - session_data = {k: v for k, v in ses.items() if k in session_keys} - session = ( - pd.Series(data=session_data, name=eid).rename({'start_time': 'date'}) - ) - session['projects'] = ','.join(session.pop('projects')) - session['date'] = datetime.fromisoformat(session['date']).date() - - # Extract datasets table - def _to_record(d): - rec = dict(file_size=d['file_size'], hash=d['hash'], exists=True, id=d['id']) - rec['eid'] = session.name - file_path = urllib.parse.urlsplit(d['data_url'], allow_fragments=False).path.strip('/') - file_path = get_alf_path(remove_uuid_string(file_path)) - session_path = get_session_path(file_path).as_posix() - rec['rel_path'] = file_path[len(session_path):].strip('/') - rec['default_revision'] = d['default_revision'] == 'True' - rec['qc'] = d.get('qc', 'NOT_SET') - return rec - - if not ses.get('data_dataset_session_related'): - return session, pd.DataFrame() - records = map(_to_record, ses['data_dataset_session_related']) - index = ['eid', 'id'] - datasets = pd.DataFrame(records).set_index(index).sort_index().astype({'qc': QC_TYPE}) - return session, datasets - - -def datasets2records(datasets, additional=None) -> pd.DataFrame: - """Extract datasets DataFrame from one or more Alyx dataset records. - - Parameters - ---------- - datasets : dict, list - One or more records from the Alyx 'datasets' endpoint. - additional : list of str - A set of optional fields to extract from dataset records. - - Returns - ------- - pd.DataFrame - Datasets frame. - - Examples - -------- - >>> datasets = ONE().alyx.rest('datasets', 'list', subject='foobar') - >>> df = datasets2records(datasets) - """ - records = [] - - for d in _ensure_list(datasets): - file_record = next((x for x in d['file_records'] if x['data_url'] and x['exists']), None) - if not file_record: - continue # Ignore files that are not accessible - rec = dict(file_size=d['file_size'], hash=d['hash'], exists=True) - rec['id'] = d['url'][-36:] - rec['eid'] = (d['session'] or '')[-36:] - data_url = urllib.parse.urlsplit(file_record['data_url'], allow_fragments=False) - file_path = get_alf_path(data_url.path.strip('/')) - file_path = remove_uuid_string(file_path).as_posix() - session_path = get_session_path(file_path) or '' - if session_path: - session_path = session_path.as_posix() - rec['rel_path'] = file_path[len(session_path):].strip('/') - rec['default_revision'] = d['default_dataset'] - rec['qc'] = d.get('qc') - for field in additional or []: - rec[field] = d.get(field) - records.append(rec) - - index = ['eid', 'id'] - if not records: - keys = (*index, 'file_size', 'hash', 'session_path', 'rel_path', 'default_revision', 'qc') - return pd.DataFrame(columns=keys).set_index(index) - return pd.DataFrame(records).set_index(index).sort_index().astype({'qc': QC_TYPE}) - def parse_id(method): """ @@ -634,53 +528,3 @@ def ses2eid(ses): return [LazyId.ses2eid(x) for x in ses] else: return ses.get('id', None) or ses['url'].split('/').pop() - - -def cache_int2str(table: pd.DataFrame) -> pd.DataFrame: - """Convert int ids to str ids for cache table. - - Parameters - ---------- - table : pd.DataFrame - A cache table (from One._cache). - - """ - # Convert integer uuids to str uuids - if table.index.nlevels < 2 or not any(x.endswith('_0') for x in table.index.names): - return table - table = table.reset_index() - int_cols = table.filter(regex=r'_\d{1}$').columns.sort_values() - assert not len(int_cols) % 2, 'expected even number of columns ending in _0 or _1' - names = sorted(set(c.rsplit('_', 1)[0] for c in int_cols.values)) - for i, name in zip(range(0, len(int_cols), 2), names): - table[name] = parquet.np2str(table[int_cols[i:i + 2]]) - table = table.drop(int_cols, axis=1).set_index(names) - return table - - -def patch_cache(table: pd.DataFrame, min_api_version=None, name=None) -> pd.DataFrame: - """Reformat older cache tables to comply with this version of ONE. - - Currently this function will 1. convert integer UUIDs to string UUIDs; 2. rename the 'project' - column to 'projects'. - - Parameters - ---------- - table : pd.DataFrame - A cache table (from One._cache). - min_api_version : str - The minimum API version supported by this cache table. - name : {'dataset', 'session'} str - The name of the table. - """ - min_version = version.parse(min_api_version or '0.0.0') - table = cache_int2str(table) - # Rename project column - if min_version < version.Version('1.13.0') and 'project' in table.columns: - table.rename(columns={'project': 'projects'}, inplace=True) - if name == 'datasets' and min_version < version.Version('2.7.0') and 'qc' not in table.columns: - qc = pd.Categorical.from_codes(np.zeros(len(table.index), dtype=int), dtype=QC_TYPE) - table = table.assign(qc=qc) - if name == 'datasets' and 'session_path' in table.columns: - table = table.drop('session_path', axis=1) - return table diff --git a/requirements.txt b/requirements.txt index 1c3f9b7f..ffb56f1e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ numpy>=1.18 pandas>=1.5.0 tqdm>=4.32.1 requests>=2.22.0 -iblutil>=1.13.0 +iblutil>=1.14.0 packaging boto3 pyyaml