Skip to content

Commit

Permalink
Refactor function locations; empty datasets and sessions frames with …
Browse files Browse the repository at this point in the history
…correct dtypes
  • Loading branch information
k1o0 committed Nov 3, 2024
1 parent f6b48e2 commit 6ee3dac
Show file tree
Hide file tree
Showing 8 changed files with 270 additions and 261 deletions.
124 changes: 94 additions & 30 deletions one/alf/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
>>> one = One(cache_dir=cache_dir)
"""


# -------------------------------------------------------------------------------------------------
# Imports
# -------------------------------------------------------------------------------------------------
Expand All @@ -24,42 +23,57 @@
import logging

import pandas as pd
import numpy as np
from packaging import version
from iblutil.io import parquet
from iblutil.io.hashfile import md5

from one.alf.spec import QC
from one.alf.io import iter_sessions, iter_datasets
from one.alf.path import session_path_parts, get_alf_path
from one.converters import session_record2path
from one.util import QC_TYPE, patch_cache

__all__ = [
'make_parquet_db', 'remove_missing_datasets', 'remove_cache_table_files',
'DATASETS_COLUMNS', 'SESSIONS_COLUMNS']
__all__ = ['make_parquet_db', 'patch_cache', 'remove_missing_datasets',
'remove_cache_table_files', 'EMPTY_DATASETS_FRAME', 'EMPTY_SESSIONS_FRAME', 'QC_TYPE']
_logger = logging.getLogger(__name__)

# -------------------------------------------------------------------------------------------------
# Global variables
# -------------------------------------------------------------------------------------------------

SESSIONS_COLUMNS = (
'id', # int64
'lab', # str
'subject', # str
'date', # datetime.date
'number', # int
'task_protocol', # str
'projects', # str
)

DATASETS_COLUMNS = (
'id', # int64
'eid', # int64
'rel_path', # relative to the session path, includes the filename
'file_size', # file size in bytes
'hash', # sha1/md5, computed in load function
'exists', # bool
'qc', # one.util.QC_TYPE
)
QC_TYPE = pd.CategoricalDtype(categories=[e.name for e in sorted(QC)], ordered=True)
"""pandas.api.types.CategoricalDtype: The cache table QC column data type."""

SESSIONS_COLUMNS = {
'id': object, # str
'lab': object, # str
'subject': object, # str
'date': object, # datetime.date
'number': np.uint16, # int
'task_protocol': object, # str
'projects': object # str
}
"""dict: A map of sessions table fields and their data types."""

DATASETS_COLUMNS = {
'eid': object, # str
'id': object, # str
'rel_path': object, # relative to the session path, includes the filename
'file_size': 'UInt64', # file size in bytes (nullable)
'hash': object, # sha1/md5, computed in load function
'exists': bool, # bool
'qc': QC_TYPE # one.alf.spec.QC enumeration
}
"""dict: A map of datasets table fields and their data types."""

EMPTY_DATASETS_FRAME = (pd.DataFrame(columns=DATASETS_COLUMNS)
.astype(DATASETS_COLUMNS)
.set_index(['eid', 'id']))
"""pandas.DataFrame: An empty datasets dataframe with correct columns and dtypes."""

EMPTY_SESSIONS_FRAME = (pd.DataFrame(columns=SESSIONS_COLUMNS)
.astype(SESSIONS_COLUMNS)
.set_index('id'))
"""pandas.DataFrame: An empty sessions dataframe with correct columns and dtypes."""


# -------------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -103,8 +117,7 @@ def _rel_path_to_uuid(df, id_key='rel_path', base_id=None, keep_old=False):
toUUID = partial(uuid.uuid3, base_id) # MD5 hash from base uuid and rel session path string
if keep_old:
df[f'{id_key}_'] = df[id_key].copy()
df[id_key] = df[id_key].apply(lambda x: str(toUUID(x)))
assert len(df[id_key].unique()) == len(df[id_key]) # WARNING This fails :(
df.loc[:, id_key] = df.groupby(id_key)[id_key].transform(lambda x: str(toUUID(x.name)))
return df


Expand Down Expand Up @@ -173,7 +186,7 @@ def _make_sessions_df(root_dir) -> pd.DataFrame:
ses_info = _get_session_info(rel_path)
assert set(ses_info.keys()) <= set(SESSIONS_COLUMNS)
rows.append(ses_info)
df = pd.DataFrame(rows, columns=SESSIONS_COLUMNS)
df = pd.DataFrame(rows, columns=SESSIONS_COLUMNS).astype(SESSIONS_COLUMNS)
return df


Expand All @@ -193,15 +206,15 @@ def _make_datasets_df(root_dir, hash_files=False) -> pd.DataFrame:
pandas.DataFrame
A pandas DataFrame of dataset info.
"""
df = pd.DataFrame([], columns=DATASETS_COLUMNS).astype({'qc': QC_TYPE})
df = EMPTY_DATASETS_FRAME.copy()
# Go through sessions and append datasets
for session_path in iter_sessions(root_dir):
rows = []
for rel_dset_path in iter_datasets(session_path):
file_info = _get_dataset_info(session_path, rel_dset_path, compute_hash=hash_files)
assert set(file_info.keys()) <= set(DATASETS_COLUMNS)
rows.append(file_info)
df = pd.concat((df, pd.DataFrame(rows, columns=DATASETS_COLUMNS)),
df = pd.concat((df, pd.DataFrame(rows, columns=DATASETS_COLUMNS).astype(DATASETS_COLUMNS)),
ignore_index=True, verify_integrity=True)
return df.astype({'qc': QC_TYPE})

Expand Down Expand Up @@ -308,6 +321,7 @@ def remove_missing_datasets(cache_dir, tables=None, remove_empty_sessions=True,
tables[name].set_index(idx_columns, inplace=True)

to_delete = set()
from one.converters import session_record2path # imported here due to circular imports
gen_path = partial(session_record2path, root_dir=cache_dir)
# map of session path to eid
sessions = {gen_path(rec): eid for eid, rec in tables['sessions'].iterrows()}
Expand Down Expand Up @@ -367,3 +381,53 @@ def remove_cache_table_files(folder, tables=('sessions', 'datasets')):
else:
_logger.warning('%s not found', file)
return removed


def _cache_int2str(table: pd.DataFrame) -> pd.DataFrame:
"""Convert int ids to str ids for cache table.
Parameters
----------
table : pd.DataFrame
A cache table (from One._cache).
"""
# Convert integer uuids to str uuids
if table.index.nlevels < 2 or not any(x.endswith('_0') for x in table.index.names):
return table
table = table.reset_index()
int_cols = table.filter(regex=r'_\d{1}$').columns.sort_values()
assert not len(int_cols) % 2, 'expected even number of columns ending in _0 or _1'
names = sorted(set(c.rsplit('_', 1)[0] for c in int_cols.values))
for i, name in zip(range(0, len(int_cols), 2), names):
table[name] = parquet.np2str(table[int_cols[i:i + 2]])
table = table.drop(int_cols, axis=1).set_index(names)
return table


def patch_cache(table: pd.DataFrame, min_api_version=None, name=None) -> pd.DataFrame:
"""Reformat older cache tables to comply with this version of ONE.
Currently this function will 1. convert integer UUIDs to string UUIDs; 2. rename the 'project'
column to 'projects'; 3. add QC column; 4. drop session_path column.
Parameters
----------
table : pd.DataFrame
A cache table (from One._cache).
min_api_version : str
The minimum API version supported by this cache table.
name : {'dataset', 'session'} str
The name of the table.
"""
min_version = version.parse(min_api_version or '0.0.0')
table = _cache_int2str(table)
# Rename project column
if min_version < version.Version('1.13.0') and 'project' in table.columns:
table.rename(columns={'project': 'projects'}, inplace=True)
if name == 'datasets' and min_version < version.Version('2.7.0') and 'qc' not in table.columns:
qc = pd.Categorical.from_codes(np.zeros(len(table.index), dtype=int), dtype=QC_TYPE)
table = table.assign(qc=qc)
if name == 'datasets' and 'session_path' in table.columns:
table = table.drop('session_path', axis=1)
return table
33 changes: 19 additions & 14 deletions one/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,20 @@
import packaging.version

from iblutil.io import parquet, hashfile
from iblutil.util import Bunch, flatten, ensure_list
from iblutil.util import Bunch, flatten, ensure_list, Listable

import one.params
import one.webclient as wc
import one.alf.io as alfio
import one.alf.path as alfiles
import one.alf.exceptions as alferr
from .alf.cache import (
make_parquet_db, remove_cache_table_files, DATASETS_COLUMNS, SESSIONS_COLUMNS)
make_parquet_db, patch_cache, remove_cache_table_files,
EMPTY_DATASETS_FRAME, EMPTY_SESSIONS_FRAME
)
from .alf.spec import is_uuid_string, QC, to_alf
from . import __version__
from one.converters import ConversionMixin, session_record2path
from one.converters import ConversionMixin, session_record2path, ses2records, datasets2records
from one import util

_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -101,8 +103,8 @@ def search_terms(self, query_type=None) -> tuple:
def _reset_cache(self):
"""Replace the cache object with a Bunch that contains the right fields."""
self._cache = Bunch({
'datasets': pd.DataFrame(columns=DATASETS_COLUMNS).set_index(['eid', 'id']),
'sessions': pd.DataFrame(columns=SESSIONS_COLUMNS).set_index('id'),
'datasets': EMPTY_DATASETS_FRAME.copy(),
'sessions': EMPTY_SESSIONS_FRAME.copy(),
'_meta': {
'expired': False,
'created_time': None,
Expand Down Expand Up @@ -160,7 +162,7 @@ def load_cache(self, tables_dir=None, **kwargs):
cache.set_index(idx_columns, inplace=True)

# Patch older tables
cache = util.patch_cache(cache, meta['raw'][table].get('min_api_version'), table)
cache = patch_cache(cache, meta['raw'][table].get('min_api_version'), table)

# Check sorted
# Sorting makes MultiIndex indexing O(N) -> O(1)
Expand All @@ -173,6 +175,9 @@ def load_cache(self, tables_dir=None, **kwargs):
# No tables present
meta['expired'] = True
meta['raw'] = {}
self._cache.update({
'datasets': EMPTY_DATASETS_FRAME.copy(),
'sessions': EMPTY_SESSIONS_FRAME.copy()})
if self.offline: # In online mode, the cache tables should be downloaded later
warnings.warn(f'No cache tables found in {self._tables_dir}')
created = [datetime.fromisoformat(x['date_created'])
Expand Down Expand Up @@ -286,7 +291,7 @@ def _update_cache_from_records(self, strict=False, **kwargs):
Example
-------
>>> session, datasets = util.ses2records(self.get_details(eid, full=True))
>>> session, datasets = ses2records(self.get_details(eid, full=True))
... self._update_cache_from_records(sessions=session, datasets=datasets)
Raises
Expand Down Expand Up @@ -599,7 +604,7 @@ def _check_filesystem(self, datasets, offline=None, update_exists=True, check_ha
datasets.index.set_names(idx_names, inplace=True)
elif not isinstance(datasets, pd.DataFrame):
# Cast set of dicts (i.e. from REST datasets endpoint)
datasets = util.datasets2records(list(datasets))
datasets = datasets2records(list(datasets))
else:
datasets = datasets.copy()
indices_to_download = [] # indices of datasets that need (re)downloading
Expand Down Expand Up @@ -1868,7 +1873,7 @@ def list_datasets(
eid = self.to_eid(eid) # Ensure we have a UUID str list
if not eid:
return self._cache['datasets'].iloc[0:0] if details else [] # Return empty
session, datasets = util.ses2records(self.alyx.rest('sessions', 'read', id=eid))
session, datasets = ses2records(self.alyx.rest('sessions', 'read', id=eid))
# Add to cache tables
self._update_cache_from_records(sessions=session, datasets=datasets.copy())
if datasets is None or datasets.empty:
Expand Down Expand Up @@ -1913,7 +1918,7 @@ def list_aggregates(self, relation: str, identifier: str = None,
"""
query = 'session__isnull,True' # ',data_repository_name__endswith,aggregates'
all_aggregates = self.alyx.rest('datasets', 'list', django=query)
records = (util.datasets2records(all_aggregates)
records = (datasets2records(all_aggregates)
.reset_index(level=0)
.drop('eid', axis=1))
# Since rel_path for public FI file records starts with 'public/aggregates' instead of just
Expand Down Expand Up @@ -2328,7 +2333,7 @@ def _update_sessions_table(self, session_records):
datetime.datetime:
A timestamp of when the cache was updated.
"""
df = pd.DataFrame(next(zip(*map(util.ses2records, session_records))))
df = pd.DataFrame(next(zip(*map(ses2records, session_records))))
return self._update_cache_from_records(sessions=df)

def _download_datasets(self, dsets, **kwargs) -> List[Path]:
Expand Down Expand Up @@ -2661,7 +2666,7 @@ def setup(base_url=None, **kwargs):

@util.refresh
@util.parse_id
def eid2path(self, eid, query_type=None) -> util.Listable(Path):
def eid2path(self, eid, query_type=None) -> Listable(Path):
"""
From an experiment ID gets the local session path
Expand Down Expand Up @@ -2700,7 +2705,7 @@ def eid2path(self, eid, query_type=None) -> util.Listable(Path):
str(ses[0]['number']).zfill(3))

@util.refresh
def path2eid(self, path_obj: Union[str, Path], query_type=None) -> util.Listable(Path):
def path2eid(self, path_obj: Union[str, Path], query_type=None) -> Listable(Path):
"""
From a local path, gets the experiment ID
Expand Down Expand Up @@ -2804,7 +2809,7 @@ def type2datasets(self, eid, dataset_type, details=False):
restriction = f'session__id,{eid},dataset_type__name__in,{dataset_type}'
else:
raise TypeError('dataset_type must be a str or str list')
datasets = util.datasets2records(self.alyx.rest('datasets', 'list', django=restriction))
datasets = datasets2records(self.alyx.rest('datasets', 'list', django=restriction))
return datasets if details else datasets['rel_path'].sort_values().values

def dataset2type(self, dset) -> str:
Expand Down
Loading

0 comments on commit 6ee3dac

Please sign in to comment.