Skip to content

Commit

Permalink
Merge pull request #1112 from bpinsard/pr/684
Browse files Browse the repository at this point in the history
FIX: Make lists and dicts hashable
  • Loading branch information
effigies authored Dec 13, 2024
2 parents 2c0469b + 234f624 commit c66a4d9
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 10 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ dependencies = [
"num2words >=0.5.5",
"click >=8.0",
"universal_pathlib >=0.2.2",
"frozendict >=2",
]
dynamic = ["version"]

Expand Down
29 changes: 19 additions & 10 deletions src/bids/layout/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from sqlalchemy.sql.expression import cast
from bids_validator import BIDSValidator

from ..utils import listify, natural_sort
from ..utils import listify, natural_sort, hashablefy
from ..external import inflect
from ..exceptions import (
BIDSDerivativesValidationError,
Expand Down Expand Up @@ -645,6 +645,13 @@ def get(self, return_type='object', target=None, scope='all',
A list of BIDSFiles (default) or strings (see return_type).
"""

if (
not return_type.startswith(("obj", "file"))
and return_type not in ("id", "dir")
):
raise ValueError(f"Invalid return_type <{return_type}> specified (must be one "
"of 'object', 'file', 'filename', 'id', or 'dir').")

if absolute_paths is False:
absolute_path_deprecation_warning()

Expand Down Expand Up @@ -691,6 +698,9 @@ def get(self, return_type='object', target=None, scope='all',
message = "Valid targets are: {}".format(potential)
raise TargetError(("Unknown target '{}'. " + message)
.format(target))
elif target is None and return_type in ['id', 'dir']:
raise TargetError('If return_type is "id" or "dir", a valid '
'target entity must also be specified.')

results = []
for l in layouts:
Expand Down Expand Up @@ -718,18 +728,22 @@ def get(self, return_type='object', target=None, scope='all',

if return_type.startswith('file'):
results = natural_sort([f.path for f in results])

elif return_type in ['id', 'dir']:
if target is None:
raise TargetError('If return_type is "id" or "dir", a valid '
'target entity must also be specified.')

metadata = target not in self.get_entities(metadata=False)

if return_type == 'id':
ent_iter = (
hashablefy(res.get_entities(metadata=metadata))
for res in results if target in res.entities
)
results = list(dict.fromkeys(
res.entities[target] for res in results
if target in res.entities and isinstance(res.entities[target], Hashable)
ents[target] for ents in ent_iter if target in ents
))

results = natural_sort(list(set(results)))
elif return_type == 'dir':
template = entities[target].directory
if template is None:
Expand All @@ -752,12 +766,7 @@ def get(self, return_type='object', target=None, scope='all',
for f in results
if re.search(template, f._dirname.as_posix())
]

results = natural_sort(list(set(matches)))

else:
raise ValueError("Invalid return_type specified (must be one "
"of 'tuple', 'filename', 'id', or 'dir'.")
else:
results = natural_sort(results, 'path')

Expand Down
21 changes: 21 additions & 0 deletions src/bids/layout/tests/test_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,6 +776,27 @@ def test_get_tr(layout_7t_trt):
assert sum([t in tr for t in [3.0, 4.0]]) == 2


def test_get_nonhashable_metadata(layout_ds117):
"""Test nonhashable metadata values (#683)."""
assert layout_ds117.get_IntendedFor(subject=['01'])[0] == (
"ses-mri/func/sub-01_ses-mri_task-facerecognition_run-01_bold.nii.gz",
"ses-mri/func/sub-01_ses-mri_task-facerecognition_run-02_bold.nii.gz",
"ses-mri/func/sub-01_ses-mri_task-facerecognition_run-03_bold.nii.gz",
"ses-mri/func/sub-01_ses-mri_task-facerecognition_run-04_bold.nii.gz",
"ses-mri/func/sub-01_ses-mri_task-facerecognition_run-05_bold.nii.gz",
"ses-mri/func/sub-01_ses-mri_task-facerecognition_run-06_bold.nii.gz",
"ses-mri/func/sub-01_ses-mri_task-facerecognition_run-07_bold.nii.gz",
"ses-mri/func/sub-01_ses-mri_task-facerecognition_run-08_bold.nii.gz",
"ses-mri/func/sub-01_ses-mri_task-facerecognition_run-09_bold.nii.gz",
)

landmarks = layout_ds117.get_AnatomicalLandmarkCoordinates(subject=['01'])[0]
assert landmarks["Nasion"] == (43, 111, 95)
assert landmarks["LPA"] == (140, 74, 16)
assert landmarks["RPA"] == (143, 74, 173)



def test_to_df(layout_ds117):
# Only filename entities
df = layout_ds117.to_df()
Expand Down
21 changes: 21 additions & 0 deletions src/bids/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,36 @@

import re
import os
from pathlib import Path
from frozendict import frozendict as _frozendict
from upath import UPath as Path


# Monkeypatch to print out frozendicts *as if* they were dictionaries.
class frozendict(_frozendict):
"""A hashable dictionary type."""

def __repr__(self):
"""Override frozendict representation."""
return repr({k: v for k, v in self.items()})


def listify(obj):
''' Wraps all non-list or tuple objects in a list; provides a simple way
to accept flexible arguments. '''
return obj if isinstance(obj, (list, tuple, type(None))) else [obj]


def hashablefy(obj):
''' Make dictionaries and lists hashable or raise. '''
if isinstance(obj, list):
return tuple([hashablefy(o) for o in obj])

if isinstance(obj, dict):
return frozendict({k: hashablefy(v) for k, v in obj.items()})
return obj


def matches_entities(obj, entities, strict=False):
''' Checks whether an object's entities match the input. '''
if strict and set(obj.entities.keys()) != set(entities.keys()):
Expand Down

0 comments on commit c66a4d9

Please sign in to comment.