From 18946e454d1df3b5e0ef6d404a40c66c530a24cb Mon Sep 17 00:00:00 2001 From: Peter Van Dyken Date: Tue, 17 Oct 2023 18:11:44 -0400 Subject: [PATCH] Overhaul global filtering of participants --participant-filter and --exclude-participant filter have had a couple of bugs over the past few versions: 1. Both terms would apply to every single component, even components without the subject entity. These components would have all of their entries filtered out 2. Using --exclude-participant-filter would turn on regex matching mode for every single filter. This changed the meaning of the filters (e.g. allowing partial matches), leading to workflow disruption The overhaul fixes both bugs, while improving the organization of the input generation code. Additionally, the documentation states the magic filter `regex_search: True` can be used to enable regex searching for a block of filters. This hasn't worked for the past few versions. Rather than fix it, the behaviour has been silently disabled in preparation for an overhaul of the regex filtering api Resolves #303 Resolves #216 --- snakebids/core/datasets.py | 2 + snakebids/core/filtering.py | 2 +- snakebids/core/input_generation.py | 321 ++++++++++++++++-------- snakebids/tests/helpers.py | 18 +- snakebids/tests/test_generate_inputs.py | 316 +++++++++++++++++++---- snakebids/types.py | 2 +- 6 files changed, 496 insertions(+), 165 deletions(-) diff --git a/snakebids/core/datasets.py b/snakebids/core/datasets.py index 1c910ed2..8b3eba7c 100644 --- a/snakebids/core/datasets.py +++ b/snakebids/core/datasets.py @@ -446,6 +446,8 @@ def filter( if not isinstance(regex_search, bool): msg = "regex_search must be a boolean" raise TypeError(msg) + if not filters: + return self return attr.evolve( self, zip_lists=filter_list(self.zip_lists, filters, regex_search=regex_search), diff --git a/snakebids/core/filtering.py b/snakebids/core/filtering.py index 87a6b570..b67be33d 100644 --- a/snakebids/core/filtering.py +++ b/snakebids/core/filtering.py @@ -25,7 +25,7 @@ def filter_list( def filter_list( zip_list: ZipListLike, filters: Mapping[str, Iterable[str] | str], - return_indices_only: Literal[True] = ..., + return_indices_only: Literal[True], regex_search: bool = ..., ) -> list[int]: ... diff --git a/snakebids/core/input_generation.py b/snakebids/core/input_generation.py index fef7512b..0710a9ec 100644 --- a/snakebids/core/input_generation.py +++ b/snakebids/core/input_generation.py @@ -1,6 +1,7 @@ """Utilities for converting Snakemake apps to BIDS apps.""" from __future__ import annotations +import functools as ft import json import logging import os @@ -24,12 +25,15 @@ PybidsError, RunError, ) -from snakebids.types import InputsConfig, ZipList +from snakebids.types import InputConfig, InputsConfig, ZipList from snakebids.utils.snakemake_io import glob_wildcards from snakebids.utils.utils import BidsEntity, BidsParseError, get_first_dir _logger = logging.getLogger(__name__) +FilterType: TypeAlias = "Mapping[str, str | bool | Sequence[str | bool]]" +CompiledFilter: TypeAlias = "Mapping[str, str | Query | Sequence[str | Query]]" + @overload def generate_inputs( @@ -250,9 +254,8 @@ def generate_inputs( }) """ - subject_filter, regex_search = _generate_filters( - participant_label, exclude_participant_label - ) + postfilters = _Postfilter() + postfilters.add_filter("subject", participant_label, exclude_participant_label) if pybids_database_dir: _logger.warning( @@ -281,13 +284,11 @@ def generate_inputs( else None ) - filters = {"subject": subject_filter} if subject_filter else {} bids_inputs = _get_lists_from_bids( bids_layout=layout, pybids_inputs=pybids_inputs, limit_to=limit_to, - regex_search=regex_search, - **(filters), + postfilters=postfilters, ) if use_bids_inputs is True: @@ -403,64 +404,69 @@ def write_derivative_json(snakemake: Snakemake, **kwargs: dict[str, Any]) -> Non json.dump(sidecar, outfile, indent=4) -def _generate_filters( - include: Iterable[str] | str | None = None, - exclude: Iterable[str] | str | None = None, -) -> tuple[list[str], bool]: - """Generate Pybids filter based on inclusion or exclusion criteria - - Converts either a list of values to include or exclude in a list of Pybids - compatible filters. Unlike inclusion values, exclusion requires regex filtering. The - necessity for regex will be indicated by the boolean value of the second returned - item: True if regex is needed, False otherwise. Raises an exception if both include - and exclude are stipulated - - Parameters - ---------- - include : list of str or str, optional - Values to include, values not found in this list will be excluded, by default - None - exclude : list of str or str, optional - Values to exclude, only values not found in this list will be included, by - default None +class _Postfilter: + """Filters to apply after indexing, typically derived from the CLI - Returns - ------- - list of str, bool - Two values: the first, a list of pybids compatible filters; the second, a - boolean indicating whether regex_search must be enabled in pybids - - Raises - ------ - ValueError Raised of both include and exclude values are stipulated. + Currently used for supporting ``--[exclude-]participant-label`` """ - if include is not None and exclude is not None: - raise ValueError( - "Cannot define both participant_label and " - "exclude_participant_label at the same time" - ) - # add participant_label or exclude_participant_label to search terms (if - # defined) - # we make the item key in search_terms a list so we can have both - # include and exclude defined - if include is not None: - return [*itx.always_iterable(include)], False + def __init__(self): + self.inclusions: dict[str, Sequence[str] | str] = {} + self.exclusions: dict[str, Sequence[str] | str] = {} - if exclude is not None: + def add_filter( + self, + key: str, + inclusions: Iterable[str] | str | None, + exclusions: Iterable[str] | str | None, + ): + """Add entity filter based on inclusion or exclusion criteria + + Converts either a list of values to include or exclude in a list of Pybids + compatible filters. Unlike inclusion values, exclusion requires regex filtering. + The necessity for regex will be indicated by the boolean value of the second + returned item: True if regex is needed, False otherwise. Raises an exception if + both include and exclude are stipulated + + _Postfilter is modified in-place + + Parameters + ---------- + key + Name of entity to be filtered + inclusions + Values to include, values not found in this list will be excluded, by + default None + exclusions + Values to exclude, only values not found in this list will be included, by + default None + + Raises + ------ + ValueError Raised of both include and exclude values are stipulated. + """ + if inclusions is not None and exclusions is not None: + raise ValueError( + "Cannot define both participant_label and " + "exclude_participant_label at the same time" + ) + if inclusions is not None: + self.inclusions[key] = list(itx.always_iterable(inclusions)) + if exclusions is not None: + self.exclusions[key] = self._format_exclusions(exclusions) + + def _format_exclusions(self, exclusions: Iterable[str] | str): # if multiple items to exclude, combine with with item1|item2|... exclude_string = "|".join( - re.escape(label) for label in itx.always_iterable(exclude) + re.escape(label) for label in itx.always_iterable(exclusions) ) # regex to exclude subjects - return [f"^((?!({exclude_string})$).*)$"], True - return [], False + return [f"^((?!({exclude_string})$).*)$"] def _parse_custom_path( input_path: Path | str, - regex_search: bool = False, - **filters: Sequence[str | bool] | str | bool, + filters: _UnifiedFilter, ) -> ZipList: """Glob wildcards from a custom path and apply filters @@ -495,17 +501,15 @@ def _parse_custom_path( return wildcards # Return the output values, running filtering on the zip_lists - if any( - isinstance(v, bool) for f in filters.values() for v in itx.always_iterable(f) - ): - raise TypeError( - "boolean filters are not currently supported in custom path filtering" - ) - return filter_list( + result = filter_list( wildcards, - cast("Mapping[str, str | Sequence[str]]", filters), - regex_search=regex_search, + filters.without_bools, + regex_search=False, ) + if not filters.post_exclusions: + return result + + return filter_list(result, filters.post_exclusions, regex_search=True) def _parse_bids_path(path: str, entities: Iterable[str]) -> tuple[str, dict[str, str]]: @@ -568,45 +572,142 @@ def _parse_bids_path(path: str, entities: Iterable[str]) -> tuple[str, dict[str, return "".join(new_path), wildcard_values -@attrs.define -class _GetListsFromBidsSteps: - input_name: str +def _get_matching_files( + bids_layout: BIDSLayout, + filters: _UnifiedFilter, +) -> Iterable[BIDSFile]: + if filters.has_empty_prefilter: + return [] + try: + return bids_layout.get( + regex_search=False, + **filters.filters, + ) + except AttributeError as err: + raise PybidsError( + "Pybids has encountered a problem that Snakebids cannot handle. This " + "may indicate a missing or invalid dataset_description.json for this " + "dataset." + ) from err - FilterType: TypeAlias = "Mapping[str, str | bool | Sequence[str | bool]]" - def _get_invalid_filters(self, filters: FilterType): - for key, filts in filters.items(): - try: - if True in filts or False in filts: # type: ignore - yield key, filts - except TypeError: - pass - - def prepare_bids_filters( - self, filters: Mapping[str, str | bool | Sequence[str | bool]] - ) -> dict[str, str | Query | list[str | Query]]: +def _compile_filters(filters: FilterType) -> CompiledFilter: + return { + key: [ + Query.ANY if f is True else Query.NONE if f is False else f + for f in itx.always_iterable(filts) + ] + for key, filts in filters.items() + } + + +@attrs.define(slots=False) +class _UnifiedFilter: + """Manages component level and post filters""" + + component: InputConfig + """The Component configuration defining the filters""" + + postfilters: _Postfilter + """Filters to be applied after collecting and parsing the data + + Currently only used to implement --[exclude-]participant-label, but in the future, + may implement other such CLI args. Unlike configuration-defined filters, these + filters apply after the dataset is indexed and queried. Thus, if a filter is set + to an empty list, a complete, albeit empty, component may be found. This is akin to + calling ``BidsComponent.filter`` after running ``generate_inputs``. + + For performance purposes, non-empty post-filters are applied via ``pybids.get()`` + """ + + @classmethod + def from_filter_dict( + cls, + filters: Mapping[str, str | bool | Sequence[str | bool]], + postfilter: _Postfilter | None = None, + ): + """Patch together a UnifiedFilter based on a basic filter dict + + Intended primarily for use in testing + """ + return cls({"filters": filters}, postfilter or _Postfilter()) + + def _has_empty_list(self, items: Iterable[Any]): + """Check if any of the lists within iterable are empty""" + return any(itx.ilen(itx.always_iterable(item)) == 0 for item in items) + + def _has_overlap(self, key: str): + """Check if filter key is a wildcard and not already a prefilter""" + return key not in self.prefilters and key in self.component.get("wildcards", []) + + @ft.cached_property + def prefilters(self): + """Filters defined in the component configuration and applied via pybids + + Unlike postfilters, a prefilter set to an empty list will result in no valid + paths found, resulting in a blank (missing) component. + """ + filters = dict(self.component.get("filters", {})) + # Silently remove "regex_search". This value has been blocked by a bug for the + # since version 0.6, and even before, never fully worked properly (e.g. would + # break if combined with --exclude-participant-label) + if "regex_search" in filters: + del filters["regex_search"] + return filters + + @ft.cached_property + def filters(self): + """The combination pre- and post- filters to be applied to pybids indexing + + Includes all pre-filters, and all inclusion post-filters. Empty post-filters + are replaced with Query.ANY. This allows valid paths to be found and processed + later. Post-filters are not applied when an equivalent prefilter is present + """ + result = dict(_compile_filters(self.prefilters)) + postfilters = self.postfilters.inclusions + for key in self.postfilters.inclusions: + if self._has_overlap(key): + # if empty list filter, ensure the entity filtered is present + result[key] = ( + postfilters[key] + if itx.ilen(itx.always_iterable(postfilters[key])) + else Query.ANY + ) + return result + + @property + def post_exclusions(self): + """Dictionary of all post-exclusion filters""" return { - key: [ - Query.ANY if f is True else Query.NONE if f is False else f - for f in itx.always_iterable(filts) - ] - for key, filts in filters.items() + key: val + for key, val in self.postfilters.exclusions.items() + if self._has_overlap(key) } - def get_matching_files( - self, - bids_layout: BIDSLayout, - regex_search: bool, - filters: Mapping[str, str | Query | Sequence[str | Query]], - ) -> Iterable[BIDSFile]: - try: - return bids_layout.get(regex_search=regex_search, **filters) - except AttributeError as err: - raise PybidsError( - "Pybids has encountered a problem that Snakebids cannot handle. This " - "may indicate a missing or invalid dataset_description.json for this " - "dataset." - ) from err + @property + def without_bools(self): + """Check and typeguard to ensure filters do not contain booleans""" + for key, val in self.filters.items(): + if any(isinstance(v, Query) for v in itx.always_iterable(val)): + raise ValueError( + "Boolean filters in items with custom paths are not supported; in " + f"component='{key}'" + ) + return cast("Mapping[str, str | Sequence[str]]", self.filters) + + @property + def has_empty_prefilter(self): + """Returns True if even one prefilter is empty""" + return self._has_empty_list(self.prefilters.values()) + + @property + def has_empty_postfilter(self): + """Returns True if even one postfilter is empty""" + return self._has_empty_list( + filt + for name, filt in self.postfilters.inclusions.items() + if self._has_overlap(name) + ) def _get_lists_from_bids( @@ -614,8 +715,7 @@ def _get_lists_from_bids( pybids_inputs: InputsConfig, *, limit_to: Iterable[str] | None = None, - regex_search: bool = False, - **filters: str | Sequence[str], + postfilters: _Postfilter, ) -> Generator[BidsComponent, None, None]: """Grabs files using pybids and creates snakemake-friendly lists @@ -641,10 +741,11 @@ def _get_lists_from_bids( One BidsComponent is yielded for each modality described by ``pybids_inputs``. """ for input_name in limit_to or list(pybids_inputs): - steps = _GetListsFromBidsSteps(input_name) _logger.debug("Grabbing inputs for %s...", input_name) component = pybids_inputs[input_name] + filters = _UnifiedFilter(component, postfilters or {}) + if "custom_path" in component: # a custom path was specified for this input, skip pybids: # get input_wildcards by parsing path for {} entries (using a set @@ -653,12 +754,7 @@ def _get_lists_from_bids( # to deal with multiple wildcards path = component["custom_path"] - zip_lists = _parse_custom_path( - path, - regex_search=regex_search, - **pybids_inputs[input_name].get("filters", {}), - **filters, - ) + zip_lists = _parse_custom_path(path, filters=filters) yield BidsComponent(name=input_name, path=path, zip_lists=zip_lists) continue @@ -670,10 +766,7 @@ def _get_lists_from_bids( zip_lists: dict[str, list[str]] = defaultdict(list) paths: set[str] = set() - pybids_filters = steps.prepare_bids_filters(component.get("filters", {})) - matching_files = steps.get_matching_files( - bids_layout, regex_search, {**pybids_filters, **filters} - ) + matching_files = _get_matching_files(bids_layout, filters) for img in matching_files: wildcards: list[str] = [ @@ -739,7 +832,15 @@ def _get_lists_from_bids( f"narrow the search. Found filenames: {paths}" ) - yield BidsComponent(name=input_name, path=path, zip_lists=zip_lists) + if filters.has_empty_postfilter: + yield BidsComponent( + name=input_name, path=path, zip_lists={key: [] for key in zip_lists} + ) + continue + + yield BidsComponent(name=input_name, path=path, zip_lists=zip_lists).filter( + regex_search=True, **filters.post_exclusions + ) def get_wildcard_constraints(image_types: InputsConfig) -> dict[str, str]: diff --git a/snakebids/tests/helpers.py b/snakebids/tests/helpers.py index c9c848e7..b279300b 100644 --- a/snakebids/tests/helpers.py +++ b/snakebids/tests/helpers.py @@ -206,7 +206,11 @@ def create_snakebids_config(dataset: BidsDataset) -> InputsConfig: def reindex_dataset( - root: str, dataset: BidsDataset, use_custom_paths: bool = False + root: str, + dataset: BidsDataset, + use_custom_paths: bool = False, + participant_label: str | Sequence[str] | None = None, + exclude_participant_label: str | Sequence[str] | None = None, ) -> BidsDataset: """Create BidsDataset on the filesystem and reindex @@ -214,10 +218,20 @@ def reindex_dataset( """ create_dataset(Path("/"), dataset) config = create_snakebids_config(dataset) + if participant_label is not None or exclude_participant_label is not None: + for comp in config.values(): + if "subject" in comp.get("filters", {}): + del comp["filters"]["subject"] # type: ignore + if use_custom_paths: for comp in config: config[comp]["custom_path"] = dataset[comp].path - return generate_inputs(root, config) + return generate_inputs( + root, + config, + participant_label=participant_label, + exclude_participant_label=exclude_participant_label, + ) def allow_function_scoped(callable: _T, /) -> _T: diff --git a/snakebids/tests/test_generate_inputs.py b/snakebids/tests/test_generate_inputs.py index 4f01ad5a..e1721e6d 100644 --- a/snakebids/tests/test_generate_inputs.py +++ b/snakebids/tests/test_generate_inputs.py @@ -12,7 +12,16 @@ import tempfile from collections import defaultdict from pathlib import Path -from typing import Any, Callable, Iterable, NamedTuple, TypeVar, cast +from typing import ( + Any, + Callable, + Iterable, + Literal, + NamedTuple, + TypedDict, + TypeVar, + cast, +) import attrs import more_itertools as itx @@ -22,24 +31,28 @@ from hypothesis import strategies as st from pyfakefs.fake_filesystem import FakeFilesystem from pytest_mock import MockerFixture +from snakemake.io import expand as sb_expand from snakebids.core.datasets import BidsComponent, BidsDataset from snakebids.core.input_generation import ( _all_custom_paths, _gen_bids_layout, - _generate_filters, _get_lists_from_bids, _parse_bids_path, _parse_custom_path, + _Postfilter, + _UnifiedFilter, generate_inputs, ) -from snakebids.exceptions import ConfigError, PybidsError, RunError +from snakebids.exceptions import ConfigError, RunError from snakebids.paths.presets import bids from snakebids.tests import strategies as sb_st from snakebids.tests.helpers import ( BidsListCompare, allow_function_scoped, create_dataset, + create_snakebids_config, + debug, example_if, get_bids_path, get_zip_list, @@ -427,7 +440,7 @@ def test_missing_wildcards(self, tmpdir: Path): assert config.subj_wildcards == {"subject": "{subject}"} -class TestGenerateFilter: +class TestPostfilter: valid_chars = st.characters(blacklist_characters=["\n"]) st_lists_or_text = st.lists(st.text(valid_chars)) | st.text(valid_chars) @@ -435,48 +448,53 @@ class TestGenerateFilter: def test_throws_error_if_labels_and_excludes_are_given( self, args: tuple[list[str] | str, list[str] | str] ): + filters = _Postfilter() with pytest.raises(ValueError): - _generate_filters(*args) + filters.add_filter("foo", *args) - @given(st_lists_or_text) - def test_returns_participant_label_as_list(self, label: list[str] | str): - result = _generate_filters(label)[0] + @given(st.text(), st_lists_or_text) + def test_returns_participant_label_as_dict(self, key: str, label: list[str] | str): + filters = _Postfilter() + filters.add_filter(key, label, None) if isinstance(label, str): - assert result == [label] + assert filters.inclusions == {key: [label]} else: - assert result == label + assert filters.inclusions == {key: label} + assert filters.exclusions == {} @given( + st.text(), st_lists_or_text, st.lists(st.text(valid_chars, min_size=1), min_size=1), st.text(valid_chars, min_size=1, max_size=3), ) def test_exclude_gives_regex_that_matches_anything_except_exclude( - self, excluded: list[str] | str, dummy_values: list[str], padding: str + self, key: str, excluded: list[str] | str, dummy_values: list[str], padding: str ): + filters = _Postfilter() # Make sure the dummy_values and padding we'll be testing against are different # from our test values for value in dummy_values: assume(value not in itx.always_iterable(excluded)) assume(padding not in itx.always_iterable(excluded)) - result = _generate_filters(exclude=excluded) - assert result[1] is True - assert isinstance(result[0], list) - assert len(result[0]) == 1 + filters.add_filter(key, None, excluded) + assert isinstance(filters.exclusions[key], list) + assert len(filters.exclusions[key]) == 1 # We match any value that isn't the exclude string for value in dummy_values: - assert re.match(result[0][0], value) + assert re.match(filters.exclusions[key][0], value) for exclude in itx.always_iterable(excluded): # We don't match the exclude string - assert re.match(result[0][0], exclude) is None + assert re.match(filters.exclusions[key][0], exclude) is None # Addition of random strings before and/or after lets the match occur again - assert re.match(result[0][0], padding + exclude) - assert re.match(result[0][0], exclude + padding) - assert re.match(result[0][0], padding + exclude + padding) + assert re.match(filters.exclusions[key][0], padding + exclude) + assert re.match(filters.exclusions[key][0], exclude + padding) + assert re.match(filters.exclusions[key][0], padding + exclude + padding) + assert filters.inclusions == {} class PathEntities(NamedTuple): @@ -592,7 +610,7 @@ def test_collects_all_paths_when_no_filters( test_path = self.generate_test_directory(entities, template, temp_dir) # Test without any filters - result = _parse_custom_path(test_path) + result = _parse_custom_path(test_path, _UnifiedFilter.from_filter_dict({})) zip_lists = get_zip_list(entities, it.product(*entities.values())) assert BidsComponent( name="foo", path=get_bids_path(zip_lists), zip_lists=zip_lists @@ -612,7 +630,7 @@ def test_collects_only_filtered_entities( # Test with filters result_filtered = MultiSelectDict( - _parse_custom_path(test_path, regex_search=False, **filters) + _parse_custom_path(test_path, _UnifiedFilter.from_filter_dict(filters)) ) zip_lists = MultiSelectDict( { @@ -628,6 +646,11 @@ def test_collects_only_filtered_entities( name="foo", path=get_bids_path(result_filtered), zip_lists=result_filtered ) + @debug( + path_entities=PathEntities( + entities={"A": ["A"]}, template=Path("A-{A}"), filters={"A": ["A"]} + ), + ) @settings(deadline=400, suppress_health_check=[HealthCheck.function_scoped_fixture]) @given(path_entities=path_entities()) def test_collect_all_but_filters_when_exclusion_filters_used( @@ -638,14 +661,13 @@ def test_collect_all_but_filters_when_exclusion_filters_used( entities, template, filters = path_entities test_path = self.generate_test_directory(entities, template, temp_dir) # Test with exclusion filters - exclude_filters = { - # We use _generate_filter to get our exclusion regex. This function was - # tested previously - key: _generate_filters(exclude=values)[0] - for key, values in filters.items() - } + exclude_filters = _Postfilter() + for key, values in filters.items(): + exclude_filters.add_filter(key, None, values) result_excluded = MultiSelectDict( - _parse_custom_path(test_path, regex_search=True, **exclude_filters) + _parse_custom_path( + test_path, _UnifiedFilter.from_filter_dict({}, exclude_filters) + ) ) entities_excluded = { @@ -971,26 +993,6 @@ def test_t1w_with_dict(): assert config["subj_wildcards"] == {"subject": "{subject}"} -@pytest.mark.skipif( - sys.version_info >= (3, 8), - reason=""" - Bug only surfaces on python 3.7 because higher python versions have access to the - latest pybids version - """, -) -def test_get_lists_from_bids_raises_pybids_error(): - """Test that we wrap a cryptic AttributeError from pybids with PybidsError. - - Pybids raises an AttributeError when a BIDSLayout.get is called with scope not - equal to 'all' on a layout that indexes a dataset without a - dataset_description.json. We wrap this error with something a bit less cryptic, so - this test ensures that that behaviour is still present. - """ - layout = BIDSLayout("snakebids/tests/data/bids_t1w", validate=False) - with pytest.raises(PybidsError): - next(_get_lists_from_bids(layout, {"t1": {"filters": {"scope": "raw"}}})) - - def test_get_lists_from_bids_raises_run_error(): bids_layout = None pybids_inputs: InputsConfig = { @@ -1000,7 +1002,9 @@ def test_get_lists_from_bids_raises_run_error(): } } with pytest.raises(RunError): - next(_get_lists_from_bids(bids_layout, pybids_inputs)) + next( + _get_lists_from_bids(bids_layout, pybids_inputs, postfilters=_Postfilter()) + ) def test_get_lists_from_bids(): @@ -1039,7 +1043,7 @@ def test_get_lists_from_bids(): pybids_inputs["t1"]["custom_path"] = wildcard_path_t1 pybids_inputs["t2"]["custom_path"] = wildcard_path_t2 - result = _get_lists_from_bids(layout, pybids_inputs) + result = _get_lists_from_bids(layout, pybids_inputs, postfilters=_Postfilter()) for bids_lists in result: if bids_lists.input_name == "t1": template = BidsComponent( @@ -1155,6 +1159,216 @@ def test_generate_inputs(dataset: BidsDataset, bids_fs: Path, fakefs_tmpdir: Pat assert reindexed.layout is not None +@st.composite +def dataset_with_subject(draw: st.DrawFn): + entities = draw(sb_st.bids_entity_lists(blacklist_entities=["subject"])) + entities += ["subject"] + return BidsDataset.from_iterable( + [ + draw( + sb_st.bids_components( + whitelist_entities=entities, + min_entities=len(entities), + max_entities=len(entities), + restrict_patterns=True, + unique=True, + ) + ) + ] + ) + + +class TestParticipantFiltering: + MODE = Literal["include", "exclude"] + + @pytest.fixture + def tmpdir(self, bids_fs: Path, fakefs_tmpdir: Path): + return fakefs_tmpdir + + def get_filter_params(self, mode: MODE, filters: list[str] | str): + class FiltParams(TypedDict, total=False): + participant_label: list[str] | str + exclude_participant_label: list[str] | str + + if mode == "include": + return FiltParams({"participant_label": filters}) + elif mode == "exclude": + return FiltParams({"exclude_participant_label": filters}) + raise ValueError(f"Invalid mode specification: {mode}") + + @given( + data=st.data(), + dataset=dataset_with_subject(), + ) + @settings( + deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture] + ) + def test_participant_label_filters_comps_with_subject( + self, data: st.DataObject, dataset: BidsDataset, tmpdir: Path + ): + root = tempfile.mkdtemp(dir=tmpdir) + rooted = BidsDataset.from_iterable( + attrs.evolve(comp, path=os.path.join(root, comp.path)) + for comp in dataset.values() + ) + sampler = st.sampled_from(itx.first(rooted.values()).entities["subject"]) + label = data.draw(st.lists(sampler, unique=True) | sampler) + reindexed = reindex_dataset(root, rooted, participant_label=label) + assert set(itx.first(reindexed.values()).entities["subject"]) == set( + itx.always_iterable(label) + ) + + @given( + data=st.data(), + dataset=dataset_with_subject(), + ) + @settings( + deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture] + ) + def test_exclude_participant_label_filters_comp_with_subject( + self, data: st.DataObject, dataset: BidsDataset, tmpdir: Path + ): + root = tempfile.mkdtemp(dir=tmpdir) + rooted = BidsDataset.from_iterable( + attrs.evolve(comp, path=os.path.join(root, comp.path)) + for comp in dataset.values() + ) + sampler = st.sampled_from(itx.first(rooted.values()).entities["subject"]) + label = data.draw(st.lists(sampler, unique=True) | sampler) + reindexed = reindex_dataset(root, rooted, exclude_participant_label=label) + reindexed_subjects = set(itx.first(reindexed.values()).entities["subject"]) + original_subjects = set(itx.first(rooted.values()).entities["subject"]) + assert reindexed_subjects == original_subjects - set(itx.always_iterable(label)) + + @pytest.mark.parametrize("mode", ("include", "exclude")) + @given( + dataset=sb_st.datasets_one_comp(blacklist_entities=["subject"], unique=True), + participant_filter=st.lists(st.text(min_size=1)) | st.text(min_size=1), + ) + @settings( + deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture] + ) + def test_participant_label_doesnt_filter_comps_without_subject( + self, + mode: MODE, + dataset: BidsDataset, + participant_filter: list[str] | str, + tmpdir: Path, + ): + root = tempfile.mkdtemp(dir=tmpdir) + rooted = BidsDataset.from_iterable( + attrs.evolve(comp, path=os.path.join(root, comp.path)) + for comp in dataset.values() + ) + reindexed = reindex_dataset( + root, rooted, **self.get_filter_params(mode, participant_filter) + ) + assert reindexed == rooted + + @pytest.mark.parametrize("mode", ("include", "exclude")) + @given( + dataset=dataset_with_subject(), + participant_filter=st.lists(st.text(min_size=1)) | st.text(min_size=1), + ) + @settings( + deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture] + ) + def test_participant_label_doesnt_filter_comps_when_subject_has_filter( + self, + mode: MODE, + dataset: BidsDataset, + participant_filter: list[str] | str, + tmpdir: Path, + ): + root = tempfile.mkdtemp(dir=tmpdir) + rooted = BidsDataset.from_iterable( + attrs.evolve(comp, path=os.path.join(root, comp.path)) + for comp in dataset.values() + ) + create_dataset(Path("/"), rooted) + reindexed = generate_inputs( + root, + create_snakebids_config(rooted), + **self.get_filter_params(mode, participant_filter), + ) + assert reindexed == rooted + + @pytest.mark.parametrize("mode", ("include", "exclude")) + @given( + dataset=dataset_with_subject(), + participant_filter=st.lists(st.text(min_size=1)) | st.text(min_size=1), + data=st.data(), + ) + @settings( + deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture] + ) + def test_participant_label_doesnt_filter_comps_when_subject_has_filter_no_wcard( + self, + mode: MODE, + dataset: BidsDataset, + participant_filter: list[str] | str, + data: st.DataObject, + tmpdir: Path, + ): + root = tempfile.mkdtemp(dir=tmpdir) + rooted = BidsDataset.from_iterable( + attrs.evolve(comp, path=os.path.join(root, comp.path)) + for comp in dataset.values() + ) + subject = data.draw( + st.sampled_from(itx.first(rooted.values()).entities["subject"]) + ) + create_dataset(Path("/"), rooted) + config = create_snakebids_config(rooted) + for comp in config.values(): + comp["filters"] = dict(comp.get("filters", {})) + comp["filters"]["subject"] = subject + reindexed = generate_inputs( + root, + create_snakebids_config(rooted), + **self.get_filter_params(mode, participant_filter), + ) + assert reindexed == rooted + + @given( + data=st.data(), + dataset=dataset_with_subject().filter( + lambda ds: set(itx.first(ds.values()).wildcards) != {"subject", "extension"} + ), + ) + @settings( + deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture] + ) + def test_exclude_participant_does_not_make_all_other_filters_regex( + self, data: st.DataObject, dataset: BidsDataset, tmpdir: Path + ): + root = tempfile.mkdtemp(dir=tmpdir) + rooted = BidsDataset.from_iterable( + attrs.evolve(comp, path=os.path.join(root, comp.path)) + for comp in dataset.values() + ) + + # Create an extra set of paths by modifing one of the existing components to put + # foo after a set of entity values. If that filter gets changed to a regex, all + # of the suffixed decoys will get picked up by pybids + ziplist = dict(itx.first(rooted.values()).zip_lists) + mut_entity = itx.first( + filter(lambda e: e not in {"subject", "extension"}, ziplist) + ) + ziplist[mut_entity] = ["foo" + v for v in ziplist[mut_entity]] + for path in sb_expand(itx.first(rooted.values()).path, zip, **ziplist): + p = Path(path) + p.parent.mkdir(parents=True, exist_ok=True) + p.touch() + + sampler = st.sampled_from(itx.first(rooted.values()).entities["subject"]) + label = data.draw(st.lists(sampler, unique=True) | sampler) + reindexed = reindex_dataset(root, rooted, exclude_participant_label=label) + reindexed_subjects = set(itx.first(reindexed.values()).entities["subject"]) + original_subjects = set(itx.first(rooted.values()).entities["subject"]) + assert reindexed_subjects == original_subjects - set(itx.always_iterable(label)) + + # The content of the dataset is irrelevant to this test, so one example suffices # but can't use extension because the custom path won't glob properly @settings(max_examples=1, suppress_health_check=[HealthCheck.function_scoped_fixture]) diff --git a/snakebids/types.py b/snakebids/types.py index 6726da63..ac5c5f97 100644 --- a/snakebids/types.py +++ b/snakebids/types.py @@ -23,7 +23,7 @@ class InputConfig(TypedDict, total=False): """Configuration for a single bids component""" - filters: dict[str, str | bool | list[str | bool]] + filters: Mapping[str, str | bool | Sequence[str | bool]] """Filters to pass on to :class:`BIDSLayout.get() ` Each key refers to the name of an entity. Values may take the following forms: