diff --git a/pyproject.toml b/pyproject.toml index 841bd389..97d239eb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ dependencies = [ "num2words >=0.5.5", "click >=8.0", "universal_pathlib >=0.2.2", + "frozendict >=2", ] dynamic = ["version"] diff --git a/src/bids/layout/layout.py b/src/bids/layout/layout.py index f37675ab..ced66fc9 100644 --- a/src/bids/layout/layout.py +++ b/src/bids/layout/layout.py @@ -16,7 +16,7 @@ from sqlalchemy.sql.expression import cast from bids_validator import BIDSValidator -from ..utils import listify, natural_sort +from ..utils import listify, natural_sort, hashablefy from ..external import inflect from ..exceptions import ( BIDSDerivativesValidationError, @@ -645,6 +645,13 @@ def get(self, return_type='object', target=None, scope='all', A list of BIDSFiles (default) or strings (see return_type). """ + if ( + not return_type.startswith(("obj", "file")) + and return_type not in ("id", "dir") + ): + raise ValueError(f"Invalid return_type <{return_type}> specified (must be one " + "of 'object', 'file', 'filename', 'id', or 'dir').") + if absolute_paths is False: absolute_path_deprecation_warning() @@ -691,6 +698,9 @@ def get(self, return_type='object', target=None, scope='all', message = "Valid targets are: {}".format(potential) raise TargetError(("Unknown target '{}'. " + message) .format(target)) + elif target is None and return_type in ['id', 'dir']: + raise TargetError('If return_type is "id" or "dir", a valid ' + 'target entity must also be specified.') results = [] for l in layouts: @@ -718,18 +728,22 @@ def get(self, return_type='object', target=None, scope='all', if return_type.startswith('file'): results = natural_sort([f.path for f in results]) - elif return_type in ['id', 'dir']: if target is None: raise TargetError('If return_type is "id" or "dir", a valid ' 'target entity must also be specified.') + metadata = target not in self.get_entities(metadata=False) + if return_type == 'id': + ent_iter = ( + hashablefy(res.get_entities(metadata=metadata)) + for res in results if target in res.entities + ) results = list(dict.fromkeys( - res.entities[target] for res in results - if target in res.entities and isinstance(res.entities[target], Hashable) + ents[target] for ents in ent_iter if target in ents )) - + results = natural_sort(list(set(results))) elif return_type == 'dir': template = entities[target].directory if template is None: @@ -752,12 +766,7 @@ def get(self, return_type='object', target=None, scope='all', for f in results if re.search(template, f._dirname.as_posix()) ] - results = natural_sort(list(set(matches))) - - else: - raise ValueError("Invalid return_type specified (must be one " - "of 'tuple', 'filename', 'id', or 'dir'.") else: results = natural_sort(results, 'path') diff --git a/src/bids/layout/tests/test_layout.py b/src/bids/layout/tests/test_layout.py index af933d5b..c0c7089a 100644 --- a/src/bids/layout/tests/test_layout.py +++ b/src/bids/layout/tests/test_layout.py @@ -776,6 +776,27 @@ def test_get_tr(layout_7t_trt): assert sum([t in tr for t in [3.0, 4.0]]) == 2 +def test_get_nonhashable_metadata(layout_ds117): + """Test nonhashable metadata values (#683).""" + assert layout_ds117.get_IntendedFor(subject=['01'])[0] == ( + "ses-mri/func/sub-01_ses-mri_task-facerecognition_run-01_bold.nii.gz", + "ses-mri/func/sub-01_ses-mri_task-facerecognition_run-02_bold.nii.gz", + "ses-mri/func/sub-01_ses-mri_task-facerecognition_run-03_bold.nii.gz", + "ses-mri/func/sub-01_ses-mri_task-facerecognition_run-04_bold.nii.gz", + "ses-mri/func/sub-01_ses-mri_task-facerecognition_run-05_bold.nii.gz", + "ses-mri/func/sub-01_ses-mri_task-facerecognition_run-06_bold.nii.gz", + "ses-mri/func/sub-01_ses-mri_task-facerecognition_run-07_bold.nii.gz", + "ses-mri/func/sub-01_ses-mri_task-facerecognition_run-08_bold.nii.gz", + "ses-mri/func/sub-01_ses-mri_task-facerecognition_run-09_bold.nii.gz", + ) + + landmarks = layout_ds117.get_AnatomicalLandmarkCoordinates(subject=['01'])[0] + assert landmarks["Nasion"] == (43, 111, 95) + assert landmarks["LPA"] == (140, 74, 16) + assert landmarks["RPA"] == (143, 74, 173) + + + def test_to_df(layout_ds117): # Only filename entities df = layout_ds117.to_df() diff --git a/src/bids/utils.py b/src/bids/utils.py index e2a57f9b..f6b9542d 100644 --- a/src/bids/utils.py +++ b/src/bids/utils.py @@ -2,15 +2,36 @@ import re import os +from pathlib import Path +from frozendict import frozendict as _frozendict from upath import UPath as Path +# Monkeypatch to print out frozendicts *as if* they were dictionaries. +class frozendict(_frozendict): + """A hashable dictionary type.""" + + def __repr__(self): + """Override frozendict representation.""" + return repr({k: v for k, v in self.items()}) + + def listify(obj): ''' Wraps all non-list or tuple objects in a list; provides a simple way to accept flexible arguments. ''' return obj if isinstance(obj, (list, tuple, type(None))) else [obj] +def hashablefy(obj): + ''' Make dictionaries and lists hashable or raise. ''' + if isinstance(obj, list): + return tuple([hashablefy(o) for o in obj]) + + if isinstance(obj, dict): + return frozendict({k: hashablefy(v) for k, v in obj.items()}) + return obj + + def matches_entities(obj, entities, strict=False): ''' Checks whether an object's entities match the input. ''' if strict and set(obj.entities.keys()) != set(entities.keys()):