diff --git a/snakebids/core/datasets.py b/snakebids/core/datasets.py index 1c910ed2..8b3eba7c 100644 --- a/snakebids/core/datasets.py +++ b/snakebids/core/datasets.py @@ -446,6 +446,8 @@ def filter( if not isinstance(regex_search, bool): msg = "regex_search must be a boolean" raise TypeError(msg) + if not filters: + return self return attr.evolve( self, zip_lists=filter_list(self.zip_lists, filters, regex_search=regex_search), diff --git a/snakebids/core/filtering.py b/snakebids/core/filtering.py index 87a6b570..b67be33d 100644 --- a/snakebids/core/filtering.py +++ b/snakebids/core/filtering.py @@ -25,7 +25,7 @@ def filter_list( def filter_list( zip_list: ZipListLike, filters: Mapping[str, Iterable[str] | str], - return_indices_only: Literal[True] = ..., + return_indices_only: Literal[True], regex_search: bool = ..., ) -> list[int]: ... diff --git a/snakebids/core/input_generation.py b/snakebids/core/input_generation.py index fef7512b..0710a9ec 100644 --- a/snakebids/core/input_generation.py +++ b/snakebids/core/input_generation.py @@ -1,6 +1,7 @@ """Utilities for converting Snakemake apps to BIDS apps.""" from __future__ import annotations +import functools as ft import json import logging import os @@ -24,12 +25,15 @@ PybidsError, RunError, ) -from snakebids.types import InputsConfig, ZipList +from snakebids.types import InputConfig, InputsConfig, ZipList from snakebids.utils.snakemake_io import glob_wildcards from snakebids.utils.utils import BidsEntity, BidsParseError, get_first_dir _logger = logging.getLogger(__name__) +FilterType: TypeAlias = "Mapping[str, str | bool | Sequence[str | bool]]" +CompiledFilter: TypeAlias = "Mapping[str, str | Query | Sequence[str | Query]]" + @overload def generate_inputs( @@ -250,9 +254,8 @@ def generate_inputs( }) """ - subject_filter, regex_search = _generate_filters( - participant_label, exclude_participant_label - ) + postfilters = _Postfilter() + postfilters.add_filter("subject", participant_label, exclude_participant_label) if pybids_database_dir: _logger.warning( @@ -281,13 +284,11 @@ def generate_inputs( else None ) - filters = {"subject": subject_filter} if subject_filter else {} bids_inputs = _get_lists_from_bids( bids_layout=layout, pybids_inputs=pybids_inputs, limit_to=limit_to, - regex_search=regex_search, - **(filters), + postfilters=postfilters, ) if use_bids_inputs is True: @@ -403,64 +404,69 @@ def write_derivative_json(snakemake: Snakemake, **kwargs: dict[str, Any]) -> Non json.dump(sidecar, outfile, indent=4) -def _generate_filters( - include: Iterable[str] | str | None = None, - exclude: Iterable[str] | str | None = None, -) -> tuple[list[str], bool]: - """Generate Pybids filter based on inclusion or exclusion criteria - - Converts either a list of values to include or exclude in a list of Pybids - compatible filters. Unlike inclusion values, exclusion requires regex filtering. The - necessity for regex will be indicated by the boolean value of the second returned - item: True if regex is needed, False otherwise. Raises an exception if both include - and exclude are stipulated - - Parameters - ---------- - include : list of str or str, optional - Values to include, values not found in this list will be excluded, by default - None - exclude : list of str or str, optional - Values to exclude, only values not found in this list will be included, by - default None +class _Postfilter: + """Filters to apply after indexing, typically derived from the CLI - Returns - ------- - list of str, bool - Two values: the first, a list of pybids compatible filters; the second, a - boolean indicating whether regex_search must be enabled in pybids - - Raises - ------ - ValueError Raised of both include and exclude values are stipulated. + Currently used for supporting ``--[exclude-]participant-label`` """ - if include is not None and exclude is not None: - raise ValueError( - "Cannot define both participant_label and " - "exclude_participant_label at the same time" - ) - # add participant_label or exclude_participant_label to search terms (if - # defined) - # we make the item key in search_terms a list so we can have both - # include and exclude defined - if include is not None: - return [*itx.always_iterable(include)], False + def __init__(self): + self.inclusions: dict[str, Sequence[str] | str] = {} + self.exclusions: dict[str, Sequence[str] | str] = {} - if exclude is not None: + def add_filter( + self, + key: str, + inclusions: Iterable[str] | str | None, + exclusions: Iterable[str] | str | None, + ): + """Add entity filter based on inclusion or exclusion criteria + + Converts either a list of values to include or exclude in a list of Pybids + compatible filters. Unlike inclusion values, exclusion requires regex filtering. + The necessity for regex will be indicated by the boolean value of the second + returned item: True if regex is needed, False otherwise. Raises an exception if + both include and exclude are stipulated + + _Postfilter is modified in-place + + Parameters + ---------- + key + Name of entity to be filtered + inclusions + Values to include, values not found in this list will be excluded, by + default None + exclusions + Values to exclude, only values not found in this list will be included, by + default None + + Raises + ------ + ValueError Raised of both include and exclude values are stipulated. + """ + if inclusions is not None and exclusions is not None: + raise ValueError( + "Cannot define both participant_label and " + "exclude_participant_label at the same time" + ) + if inclusions is not None: + self.inclusions[key] = list(itx.always_iterable(inclusions)) + if exclusions is not None: + self.exclusions[key] = self._format_exclusions(exclusions) + + def _format_exclusions(self, exclusions: Iterable[str] | str): # if multiple items to exclude, combine with with item1|item2|... exclude_string = "|".join( - re.escape(label) for label in itx.always_iterable(exclude) + re.escape(label) for label in itx.always_iterable(exclusions) ) # regex to exclude subjects - return [f"^((?!({exclude_string})$).*)$"], True - return [], False + return [f"^((?!({exclude_string})$).*)$"] def _parse_custom_path( input_path: Path | str, - regex_search: bool = False, - **filters: Sequence[str | bool] | str | bool, + filters: _UnifiedFilter, ) -> ZipList: """Glob wildcards from a custom path and apply filters @@ -495,17 +501,15 @@ def _parse_custom_path( return wildcards # Return the output values, running filtering on the zip_lists - if any( - isinstance(v, bool) for f in filters.values() for v in itx.always_iterable(f) - ): - raise TypeError( - "boolean filters are not currently supported in custom path filtering" - ) - return filter_list( + result = filter_list( wildcards, - cast("Mapping[str, str | Sequence[str]]", filters), - regex_search=regex_search, + filters.without_bools, + regex_search=False, ) + if not filters.post_exclusions: + return result + + return filter_list(result, filters.post_exclusions, regex_search=True) def _parse_bids_path(path: str, entities: Iterable[str]) -> tuple[str, dict[str, str]]: @@ -568,45 +572,142 @@ def _parse_bids_path(path: str, entities: Iterable[str]) -> tuple[str, dict[str, return "".join(new_path), wildcard_values -@attrs.define -class _GetListsFromBidsSteps: - input_name: str +def _get_matching_files( + bids_layout: BIDSLayout, + filters: _UnifiedFilter, +) -> Iterable[BIDSFile]: + if filters.has_empty_prefilter: + return [] + try: + return bids_layout.get( + regex_search=False, + **filters.filters, + ) + except AttributeError as err: + raise PybidsError( + "Pybids has encountered a problem that Snakebids cannot handle. This " + "may indicate a missing or invalid dataset_description.json for this " + "dataset." + ) from err - FilterType: TypeAlias = "Mapping[str, str | bool | Sequence[str | bool]]" - def _get_invalid_filters(self, filters: FilterType): - for key, filts in filters.items(): - try: - if True in filts or False in filts: # type: ignore - yield key, filts - except TypeError: - pass - - def prepare_bids_filters( - self, filters: Mapping[str, str | bool | Sequence[str | bool]] - ) -> dict[str, str | Query | list[str | Query]]: +def _compile_filters(filters: FilterType) -> CompiledFilter: + return { + key: [ + Query.ANY if f is True else Query.NONE if f is False else f + for f in itx.always_iterable(filts) + ] + for key, filts in filters.items() + } + + +@attrs.define(slots=False) +class _UnifiedFilter: + """Manages component level and post filters""" + + component: InputConfig + """The Component configuration defining the filters""" + + postfilters: _Postfilter + """Filters to be applied after collecting and parsing the data + + Currently only used to implement --[exclude-]participant-label, but in the future, + may implement other such CLI args. Unlike configuration-defined filters, these + filters apply after the dataset is indexed and queried. Thus, if a filter is set + to an empty list, a complete, albeit empty, component may be found. This is akin to + calling ``BidsComponent.filter`` after running ``generate_inputs``. + + For performance purposes, non-empty post-filters are applied via ``pybids.get()`` + """ + + @classmethod + def from_filter_dict( + cls, + filters: Mapping[str, str | bool | Sequence[str | bool]], + postfilter: _Postfilter | None = None, + ): + """Patch together a UnifiedFilter based on a basic filter dict + + Intended primarily for use in testing + """ + return cls({"filters": filters}, postfilter or _Postfilter()) + + def _has_empty_list(self, items: Iterable[Any]): + """Check if any of the lists within iterable are empty""" + return any(itx.ilen(itx.always_iterable(item)) == 0 for item in items) + + def _has_overlap(self, key: str): + """Check if filter key is a wildcard and not already a prefilter""" + return key not in self.prefilters and key in self.component.get("wildcards", []) + + @ft.cached_property + def prefilters(self): + """Filters defined in the component configuration and applied via pybids + + Unlike postfilters, a prefilter set to an empty list will result in no valid + paths found, resulting in a blank (missing) component. + """ + filters = dict(self.component.get("filters", {})) + # Silently remove "regex_search". This value has been blocked by a bug for the + # since version 0.6, and even before, never fully worked properly (e.g. would + # break if combined with --exclude-participant-label) + if "regex_search" in filters: + del filters["regex_search"] + return filters + + @ft.cached_property + def filters(self): + """The combination pre- and post- filters to be applied to pybids indexing + + Includes all pre-filters, and all inclusion post-filters. Empty post-filters + are replaced with Query.ANY. This allows valid paths to be found and processed + later. Post-filters are not applied when an equivalent prefilter is present + """ + result = dict(_compile_filters(self.prefilters)) + postfilters = self.postfilters.inclusions + for key in self.postfilters.inclusions: + if self._has_overlap(key): + # if empty list filter, ensure the entity filtered is present + result[key] = ( + postfilters[key] + if itx.ilen(itx.always_iterable(postfilters[key])) + else Query.ANY + ) + return result + + @property + def post_exclusions(self): + """Dictionary of all post-exclusion filters""" return { - key: [ - Query.ANY if f is True else Query.NONE if f is False else f - for f in itx.always_iterable(filts) - ] - for key, filts in filters.items() + key: val + for key, val in self.postfilters.exclusions.items() + if self._has_overlap(key) } - def get_matching_files( - self, - bids_layout: BIDSLayout, - regex_search: bool, - filters: Mapping[str, str | Query | Sequence[str | Query]], - ) -> Iterable[BIDSFile]: - try: - return bids_layout.get(regex_search=regex_search, **filters) - except AttributeError as err: - raise PybidsError( - "Pybids has encountered a problem that Snakebids cannot handle. This " - "may indicate a missing or invalid dataset_description.json for this " - "dataset." - ) from err + @property + def without_bools(self): + """Check and typeguard to ensure filters do not contain booleans""" + for key, val in self.filters.items(): + if any(isinstance(v, Query) for v in itx.always_iterable(val)): + raise ValueError( + "Boolean filters in items with custom paths are not supported; in " + f"component='{key}'" + ) + return cast("Mapping[str, str | Sequence[str]]", self.filters) + + @property + def has_empty_prefilter(self): + """Returns True if even one prefilter is empty""" + return self._has_empty_list(self.prefilters.values()) + + @property + def has_empty_postfilter(self): + """Returns True if even one postfilter is empty""" + return self._has_empty_list( + filt + for name, filt in self.postfilters.inclusions.items() + if self._has_overlap(name) + ) def _get_lists_from_bids( @@ -614,8 +715,7 @@ def _get_lists_from_bids( pybids_inputs: InputsConfig, *, limit_to: Iterable[str] | None = None, - regex_search: bool = False, - **filters: str | Sequence[str], + postfilters: _Postfilter, ) -> Generator[BidsComponent, None, None]: """Grabs files using pybids and creates snakemake-friendly lists @@ -641,10 +741,11 @@ def _get_lists_from_bids( One BidsComponent is yielded for each modality described by ``pybids_inputs``. """ for input_name in limit_to or list(pybids_inputs): - steps = _GetListsFromBidsSteps(input_name) _logger.debug("Grabbing inputs for %s...", input_name) component = pybids_inputs[input_name] + filters = _UnifiedFilter(component, postfilters or {}) + if "custom_path" in component: # a custom path was specified for this input, skip pybids: # get input_wildcards by parsing path for {} entries (using a set @@ -653,12 +754,7 @@ def _get_lists_from_bids( # to deal with multiple wildcards path = component["custom_path"] - zip_lists = _parse_custom_path( - path, - regex_search=regex_search, - **pybids_inputs[input_name].get("filters", {}), - **filters, - ) + zip_lists = _parse_custom_path(path, filters=filters) yield BidsComponent(name=input_name, path=path, zip_lists=zip_lists) continue @@ -670,10 +766,7 @@ def _get_lists_from_bids( zip_lists: dict[str, list[str]] = defaultdict(list) paths: set[str] = set() - pybids_filters = steps.prepare_bids_filters(component.get("filters", {})) - matching_files = steps.get_matching_files( - bids_layout, regex_search, {**pybids_filters, **filters} - ) + matching_files = _get_matching_files(bids_layout, filters) for img in matching_files: wildcards: list[str] = [ @@ -739,7 +832,15 @@ def _get_lists_from_bids( f"narrow the search. Found filenames: {paths}" ) - yield BidsComponent(name=input_name, path=path, zip_lists=zip_lists) + if filters.has_empty_postfilter: + yield BidsComponent( + name=input_name, path=path, zip_lists={key: [] for key in zip_lists} + ) + continue + + yield BidsComponent(name=input_name, path=path, zip_lists=zip_lists).filter( + regex_search=True, **filters.post_exclusions + ) def get_wildcard_constraints(image_types: InputsConfig) -> dict[str, str]: diff --git a/snakebids/tests/helpers.py b/snakebids/tests/helpers.py index c9c848e7..b279300b 100644 --- a/snakebids/tests/helpers.py +++ b/snakebids/tests/helpers.py @@ -206,7 +206,11 @@ def create_snakebids_config(dataset: BidsDataset) -> InputsConfig: def reindex_dataset( - root: str, dataset: BidsDataset, use_custom_paths: bool = False + root: str, + dataset: BidsDataset, + use_custom_paths: bool = False, + participant_label: str | Sequence[str] | None = None, + exclude_participant_label: str | Sequence[str] | None = None, ) -> BidsDataset: """Create BidsDataset on the filesystem and reindex @@ -214,10 +218,20 @@ def reindex_dataset( """ create_dataset(Path("/"), dataset) config = create_snakebids_config(dataset) + if participant_label is not None or exclude_participant_label is not None: + for comp in config.values(): + if "subject" in comp.get("filters", {}): + del comp["filters"]["subject"] # type: ignore + if use_custom_paths: for comp in config: config[comp]["custom_path"] = dataset[comp].path - return generate_inputs(root, config) + return generate_inputs( + root, + config, + participant_label=participant_label, + exclude_participant_label=exclude_participant_label, + ) def allow_function_scoped(callable: _T, /) -> _T: diff --git a/snakebids/tests/test_generate_inputs.py b/snakebids/tests/test_generate_inputs.py index 4f01ad5a..e1721e6d 100644 --- a/snakebids/tests/test_generate_inputs.py +++ b/snakebids/tests/test_generate_inputs.py @@ -12,7 +12,16 @@ import tempfile from collections import defaultdict from pathlib import Path -from typing import Any, Callable, Iterable, NamedTuple, TypeVar, cast +from typing import ( + Any, + Callable, + Iterable, + Literal, + NamedTuple, + TypedDict, + TypeVar, + cast, +) import attrs import more_itertools as itx @@ -22,24 +31,28 @@ from hypothesis import strategies as st from pyfakefs.fake_filesystem import FakeFilesystem from pytest_mock import MockerFixture +from snakemake.io import expand as sb_expand from snakebids.core.datasets import BidsComponent, BidsDataset from snakebids.core.input_generation import ( _all_custom_paths, _gen_bids_layout, - _generate_filters, _get_lists_from_bids, _parse_bids_path, _parse_custom_path, + _Postfilter, + _UnifiedFilter, generate_inputs, ) -from snakebids.exceptions import ConfigError, PybidsError, RunError +from snakebids.exceptions import ConfigError, RunError from snakebids.paths.presets import bids from snakebids.tests import strategies as sb_st from snakebids.tests.helpers import ( BidsListCompare, allow_function_scoped, create_dataset, + create_snakebids_config, + debug, example_if, get_bids_path, get_zip_list, @@ -427,7 +440,7 @@ def test_missing_wildcards(self, tmpdir: Path): assert config.subj_wildcards == {"subject": "{subject}"} -class TestGenerateFilter: +class TestPostfilter: valid_chars = st.characters(blacklist_characters=["\n"]) st_lists_or_text = st.lists(st.text(valid_chars)) | st.text(valid_chars) @@ -435,48 +448,53 @@ class TestGenerateFilter: def test_throws_error_if_labels_and_excludes_are_given( self, args: tuple[list[str] | str, list[str] | str] ): + filters = _Postfilter() with pytest.raises(ValueError): - _generate_filters(*args) + filters.add_filter("foo", *args) - @given(st_lists_or_text) - def test_returns_participant_label_as_list(self, label: list[str] | str): - result = _generate_filters(label)[0] + @given(st.text(), st_lists_or_text) + def test_returns_participant_label_as_dict(self, key: str, label: list[str] | str): + filters = _Postfilter() + filters.add_filter(key, label, None) if isinstance(label, str): - assert result == [label] + assert filters.inclusions == {key: [label]} else: - assert result == label + assert filters.inclusions == {key: label} + assert filters.exclusions == {} @given( + st.text(), st_lists_or_text, st.lists(st.text(valid_chars, min_size=1), min_size=1), st.text(valid_chars, min_size=1, max_size=3), ) def test_exclude_gives_regex_that_matches_anything_except_exclude( - self, excluded: list[str] | str, dummy_values: list[str], padding: str + self, key: str, excluded: list[str] | str, dummy_values: list[str], padding: str ): + filters = _Postfilter() # Make sure the dummy_values and padding we'll be testing against are different # from our test values for value in dummy_values: assume(value not in itx.always_iterable(excluded)) assume(padding not in itx.always_iterable(excluded)) - result = _generate_filters(exclude=excluded) - assert result[1] is True - assert isinstance(result[0], list) - assert len(result[0]) == 1 + filters.add_filter(key, None, excluded) + assert isinstance(filters.exclusions[key], list) + assert len(filters.exclusions[key]) == 1 # We match any value that isn't the exclude string for value in dummy_values: - assert re.match(result[0][0], value) + assert re.match(filters.exclusions[key][0], value) for exclude in itx.always_iterable(excluded): # We don't match the exclude string - assert re.match(result[0][0], exclude) is None + assert re.match(filters.exclusions[key][0], exclude) is None # Addition of random strings before and/or after lets the match occur again - assert re.match(result[0][0], padding + exclude) - assert re.match(result[0][0], exclude + padding) - assert re.match(result[0][0], padding + exclude + padding) + assert re.match(filters.exclusions[key][0], padding + exclude) + assert re.match(filters.exclusions[key][0], exclude + padding) + assert re.match(filters.exclusions[key][0], padding + exclude + padding) + assert filters.inclusions == {} class PathEntities(NamedTuple): @@ -592,7 +610,7 @@ def test_collects_all_paths_when_no_filters( test_path = self.generate_test_directory(entities, template, temp_dir) # Test without any filters - result = _parse_custom_path(test_path) + result = _parse_custom_path(test_path, _UnifiedFilter.from_filter_dict({})) zip_lists = get_zip_list(entities, it.product(*entities.values())) assert BidsComponent( name="foo", path=get_bids_path(zip_lists), zip_lists=zip_lists @@ -612,7 +630,7 @@ def test_collects_only_filtered_entities( # Test with filters result_filtered = MultiSelectDict( - _parse_custom_path(test_path, regex_search=False, **filters) + _parse_custom_path(test_path, _UnifiedFilter.from_filter_dict(filters)) ) zip_lists = MultiSelectDict( { @@ -628,6 +646,11 @@ def test_collects_only_filtered_entities( name="foo", path=get_bids_path(result_filtered), zip_lists=result_filtered ) + @debug( + path_entities=PathEntities( + entities={"A": ["A"]}, template=Path("A-{A}"), filters={"A": ["A"]} + ), + ) @settings(deadline=400, suppress_health_check=[HealthCheck.function_scoped_fixture]) @given(path_entities=path_entities()) def test_collect_all_but_filters_when_exclusion_filters_used( @@ -638,14 +661,13 @@ def test_collect_all_but_filters_when_exclusion_filters_used( entities, template, filters = path_entities test_path = self.generate_test_directory(entities, template, temp_dir) # Test with exclusion filters - exclude_filters = { - # We use _generate_filter to get our exclusion regex. This function was - # tested previously - key: _generate_filters(exclude=values)[0] - for key, values in filters.items() - } + exclude_filters = _Postfilter() + for key, values in filters.items(): + exclude_filters.add_filter(key, None, values) result_excluded = MultiSelectDict( - _parse_custom_path(test_path, regex_search=True, **exclude_filters) + _parse_custom_path( + test_path, _UnifiedFilter.from_filter_dict({}, exclude_filters) + ) ) entities_excluded = { @@ -971,26 +993,6 @@ def test_t1w_with_dict(): assert config["subj_wildcards"] == {"subject": "{subject}"} -@pytest.mark.skipif( - sys.version_info >= (3, 8), - reason=""" - Bug only surfaces on python 3.7 because higher python versions have access to the - latest pybids version - """, -) -def test_get_lists_from_bids_raises_pybids_error(): - """Test that we wrap a cryptic AttributeError from pybids with PybidsError. - - Pybids raises an AttributeError when a BIDSLayout.get is called with scope not - equal to 'all' on a layout that indexes a dataset without a - dataset_description.json. We wrap this error with something a bit less cryptic, so - this test ensures that that behaviour is still present. - """ - layout = BIDSLayout("snakebids/tests/data/bids_t1w", validate=False) - with pytest.raises(PybidsError): - next(_get_lists_from_bids(layout, {"t1": {"filters": {"scope": "raw"}}})) - - def test_get_lists_from_bids_raises_run_error(): bids_layout = None pybids_inputs: InputsConfig = { @@ -1000,7 +1002,9 @@ def test_get_lists_from_bids_raises_run_error(): } } with pytest.raises(RunError): - next(_get_lists_from_bids(bids_layout, pybids_inputs)) + next( + _get_lists_from_bids(bids_layout, pybids_inputs, postfilters=_Postfilter()) + ) def test_get_lists_from_bids(): @@ -1039,7 +1043,7 @@ def test_get_lists_from_bids(): pybids_inputs["t1"]["custom_path"] = wildcard_path_t1 pybids_inputs["t2"]["custom_path"] = wildcard_path_t2 - result = _get_lists_from_bids(layout, pybids_inputs) + result = _get_lists_from_bids(layout, pybids_inputs, postfilters=_Postfilter()) for bids_lists in result: if bids_lists.input_name == "t1": template = BidsComponent( @@ -1155,6 +1159,216 @@ def test_generate_inputs(dataset: BidsDataset, bids_fs: Path, fakefs_tmpdir: Pat assert reindexed.layout is not None +@st.composite +def dataset_with_subject(draw: st.DrawFn): + entities = draw(sb_st.bids_entity_lists(blacklist_entities=["subject"])) + entities += ["subject"] + return BidsDataset.from_iterable( + [ + draw( + sb_st.bids_components( + whitelist_entities=entities, + min_entities=len(entities), + max_entities=len(entities), + restrict_patterns=True, + unique=True, + ) + ) + ] + ) + + +class TestParticipantFiltering: + MODE = Literal["include", "exclude"] + + @pytest.fixture + def tmpdir(self, bids_fs: Path, fakefs_tmpdir: Path): + return fakefs_tmpdir + + def get_filter_params(self, mode: MODE, filters: list[str] | str): + class FiltParams(TypedDict, total=False): + participant_label: list[str] | str + exclude_participant_label: list[str] | str + + if mode == "include": + return FiltParams({"participant_label": filters}) + elif mode == "exclude": + return FiltParams({"exclude_participant_label": filters}) + raise ValueError(f"Invalid mode specification: {mode}") + + @given( + data=st.data(), + dataset=dataset_with_subject(), + ) + @settings( + deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture] + ) + def test_participant_label_filters_comps_with_subject( + self, data: st.DataObject, dataset: BidsDataset, tmpdir: Path + ): + root = tempfile.mkdtemp(dir=tmpdir) + rooted = BidsDataset.from_iterable( + attrs.evolve(comp, path=os.path.join(root, comp.path)) + for comp in dataset.values() + ) + sampler = st.sampled_from(itx.first(rooted.values()).entities["subject"]) + label = data.draw(st.lists(sampler, unique=True) | sampler) + reindexed = reindex_dataset(root, rooted, participant_label=label) + assert set(itx.first(reindexed.values()).entities["subject"]) == set( + itx.always_iterable(label) + ) + + @given( + data=st.data(), + dataset=dataset_with_subject(), + ) + @settings( + deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture] + ) + def test_exclude_participant_label_filters_comp_with_subject( + self, data: st.DataObject, dataset: BidsDataset, tmpdir: Path + ): + root = tempfile.mkdtemp(dir=tmpdir) + rooted = BidsDataset.from_iterable( + attrs.evolve(comp, path=os.path.join(root, comp.path)) + for comp in dataset.values() + ) + sampler = st.sampled_from(itx.first(rooted.values()).entities["subject"]) + label = data.draw(st.lists(sampler, unique=True) | sampler) + reindexed = reindex_dataset(root, rooted, exclude_participant_label=label) + reindexed_subjects = set(itx.first(reindexed.values()).entities["subject"]) + original_subjects = set(itx.first(rooted.values()).entities["subject"]) + assert reindexed_subjects == original_subjects - set(itx.always_iterable(label)) + + @pytest.mark.parametrize("mode", ("include", "exclude")) + @given( + dataset=sb_st.datasets_one_comp(blacklist_entities=["subject"], unique=True), + participant_filter=st.lists(st.text(min_size=1)) | st.text(min_size=1), + ) + @settings( + deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture] + ) + def test_participant_label_doesnt_filter_comps_without_subject( + self, + mode: MODE, + dataset: BidsDataset, + participant_filter: list[str] | str, + tmpdir: Path, + ): + root = tempfile.mkdtemp(dir=tmpdir) + rooted = BidsDataset.from_iterable( + attrs.evolve(comp, path=os.path.join(root, comp.path)) + for comp in dataset.values() + ) + reindexed = reindex_dataset( + root, rooted, **self.get_filter_params(mode, participant_filter) + ) + assert reindexed == rooted + + @pytest.mark.parametrize("mode", ("include", "exclude")) + @given( + dataset=dataset_with_subject(), + participant_filter=st.lists(st.text(min_size=1)) | st.text(min_size=1), + ) + @settings( + deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture] + ) + def test_participant_label_doesnt_filter_comps_when_subject_has_filter( + self, + mode: MODE, + dataset: BidsDataset, + participant_filter: list[str] | str, + tmpdir: Path, + ): + root = tempfile.mkdtemp(dir=tmpdir) + rooted = BidsDataset.from_iterable( + attrs.evolve(comp, path=os.path.join(root, comp.path)) + for comp in dataset.values() + ) + create_dataset(Path("/"), rooted) + reindexed = generate_inputs( + root, + create_snakebids_config(rooted), + **self.get_filter_params(mode, participant_filter), + ) + assert reindexed == rooted + + @pytest.mark.parametrize("mode", ("include", "exclude")) + @given( + dataset=dataset_with_subject(), + participant_filter=st.lists(st.text(min_size=1)) | st.text(min_size=1), + data=st.data(), + ) + @settings( + deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture] + ) + def test_participant_label_doesnt_filter_comps_when_subject_has_filter_no_wcard( + self, + mode: MODE, + dataset: BidsDataset, + participant_filter: list[str] | str, + data: st.DataObject, + tmpdir: Path, + ): + root = tempfile.mkdtemp(dir=tmpdir) + rooted = BidsDataset.from_iterable( + attrs.evolve(comp, path=os.path.join(root, comp.path)) + for comp in dataset.values() + ) + subject = data.draw( + st.sampled_from(itx.first(rooted.values()).entities["subject"]) + ) + create_dataset(Path("/"), rooted) + config = create_snakebids_config(rooted) + for comp in config.values(): + comp["filters"] = dict(comp.get("filters", {})) + comp["filters"]["subject"] = subject + reindexed = generate_inputs( + root, + create_snakebids_config(rooted), + **self.get_filter_params(mode, participant_filter), + ) + assert reindexed == rooted + + @given( + data=st.data(), + dataset=dataset_with_subject().filter( + lambda ds: set(itx.first(ds.values()).wildcards) != {"subject", "extension"} + ), + ) + @settings( + deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture] + ) + def test_exclude_participant_does_not_make_all_other_filters_regex( + self, data: st.DataObject, dataset: BidsDataset, tmpdir: Path + ): + root = tempfile.mkdtemp(dir=tmpdir) + rooted = BidsDataset.from_iterable( + attrs.evolve(comp, path=os.path.join(root, comp.path)) + for comp in dataset.values() + ) + + # Create an extra set of paths by modifing one of the existing components to put + # foo after a set of entity values. If that filter gets changed to a regex, all + # of the suffixed decoys will get picked up by pybids + ziplist = dict(itx.first(rooted.values()).zip_lists) + mut_entity = itx.first( + filter(lambda e: e not in {"subject", "extension"}, ziplist) + ) + ziplist[mut_entity] = ["foo" + v for v in ziplist[mut_entity]] + for path in sb_expand(itx.first(rooted.values()).path, zip, **ziplist): + p = Path(path) + p.parent.mkdir(parents=True, exist_ok=True) + p.touch() + + sampler = st.sampled_from(itx.first(rooted.values()).entities["subject"]) + label = data.draw(st.lists(sampler, unique=True) | sampler) + reindexed = reindex_dataset(root, rooted, exclude_participant_label=label) + reindexed_subjects = set(itx.first(reindexed.values()).entities["subject"]) + original_subjects = set(itx.first(rooted.values()).entities["subject"]) + assert reindexed_subjects == original_subjects - set(itx.always_iterable(label)) + + # The content of the dataset is irrelevant to this test, so one example suffices # but can't use extension because the custom path won't glob properly @settings(max_examples=1, suppress_health_check=[HealthCheck.function_scoped_fixture]) diff --git a/snakebids/types.py b/snakebids/types.py index 6726da63..ac5c5f97 100644 --- a/snakebids/types.py +++ b/snakebids/types.py @@ -23,7 +23,7 @@ class InputConfig(TypedDict, total=False): """Configuration for a single bids component""" - filters: dict[str, str | bool | list[str | bool]] + filters: Mapping[str, str | bool | Sequence[str | bool]] """Filters to pass on to :class:`BIDSLayout.get() ` Each key refers to the name of an entity. Values may take the following forms: