diff --git a/docs/bids_app/config.md b/docs/bids_app/config.md index de387717..d095ef36 100644 --- a/docs/bids_app/config.md +++ b/docs/bids_app/config.md @@ -9,11 +9,11 @@ Config Variables ### `pybids_inputs` -A dictionary that describes each type of input you want to grab from an input BIDS dataset. Snakebids will parse your dataset with {func}`generate_inputs() `, converting each input type into a {class}`BidsComponent `. The value of each item should be a dictionary with keys ``filters`` and ``wildcards``. +A dictionary that describes each type of input you want to grab from an input BIDS dataset. Snakebids will parse your dataset with {func}`generate_inputs() `, converting each input type into a {class}`BidsComponent `. The value of each item should be a dictionary with keys `filters` and `wildcards`. -The value of ``filters`` should be a dictionary where each key corresponds to a BIDS entity, and the value specifies which values of that entity should be grabbed. The dictionary for each input is sent to the [PyBIDS' get() function ](#bids.layout.BIDSLayout). `filters` can be set according to a few different formats: +The value of `filters` should be a dictionary where each key corresponds to a BIDS entity, and the value specifies which values of that entity should be grabbed. The dictionary for each input is sent to the [PyBIDS' `get()` function ](#bids.layout.BIDSLayout). `filters` can be set according to a few different formats: -* [string](#str): specifies an exact value for the entity. In the following example: +* [`string`](#str): specifies an exact value for the entity. In the following example: ```yaml pybids_inputs: bold: @@ -29,22 +29,20 @@ The value of ``filters`` should be a dictionary where each key corresponds to a sub-xxx/.../func/ent1-xxx_ent2-xxx_..._bold.nii.gz ``` -* [boolean](#bool): constrains presence or absence of the entity without restricting its value. `False` requires that the entity be **absent**, while `True` requires that the entity be **present**, regardless of value. +* [`boolean`](#bool): constrains presence or absence of the entity without restricting its value. `False` requires that the entity be **absent**, while `True` requires that the entity be **present**, regardless of value. ```yaml pybids_inputs: derivs: filters: datatype: 'func' - desc: True # or true, or yes - acquisition: False # or false, or no + desc: True + acquisition: False ``` The above example maps all paths in the `func/` datatype folder that have a `_desc-` entity but do not have the `_acq-` entity. -In addition, the special filter `regex_search` can be set to `true`, which causes all other filters in the component to use regex matching instead of exact matching. +The value of `wildcards` should be a list of BIDS entities. Snakebids collects the values of any entities specified and saves them in the {attr}`entities ` and {attr}`~snakebids.BidsComponent.zip_lists` entries of the corresponding {class}`BidsComponent `. In other words, these are the entities to be preserved in output paths derived from the input being described. Placing an entity in `wildcards` does not require the entity be present. If an entity is not found, it will be left out of {attr}`entities `. To require the presence of an entity, place it under `filters` set to `true`. -The value of ``wildcards`` should be a list of BIDS entities. Snakebids collects the values of any entities specified and saves them in the {attr}`entities ` and {attr}`~snakebids.BidsComponent.zip_lists` entries of the corresponding {class}`BidsComponent `. In other words, these are the entities to be preserved in output paths derived from the input being described. Placing an entity in `wildcards` does not require the entity be present. If an entity is not found, it will be left out of {attr}`entities `. To require the presence of an entity, place it under `filters` set to `true`. - -In the following (YAML-formatted) example, the ``bold`` input type is specified. BIDS files with the datatype ``func``, suffix ``bold``, and extension ``.nii.gz`` will be grabbed, and the ``subject``, ``session``, ``acquisition``, ``task``, and ``run`` entities of those files will be left as wildcards. The `task` entity must be present, but there must not be any `desc`. +In the following (YAML-formatted) example, the `bold` input type is specified. BIDS files with the datatype `func`, suffix `bold`, and extension `.nii.gz` will be grabbed, and the `subject`, `session`, `acquisition`, `task`, and `run` entities of those files will be left as wildcards. The `task` entity must be present, but there must not be any `desc`. ```yaml pybids_inputs: diff --git a/snakebids/core/datasets.py b/snakebids/core/datasets.py index 5acf232d..e6f389e0 100644 --- a/snakebids/core/datasets.py +++ b/snakebids/core/datasets.py @@ -458,6 +458,8 @@ def filter( if not isinstance(regex_search, bool): msg = "regex_search must be a boolean" raise TypeError(msg) + if not filters: + return self return attr.evolve( self, zip_lists=filter_list(self.zip_lists, filters, regex_search=regex_search), diff --git a/snakebids/core/filtering.py b/snakebids/core/filtering.py index b8517bfc..31cdbeab 100644 --- a/snakebids/core/filtering.py +++ b/snakebids/core/filtering.py @@ -25,7 +25,7 @@ def filter_list( def filter_list( zip_list: ZipListLike, filters: Mapping[str, Iterable[str] | str], - return_indices_only: Literal[True] = ..., + return_indices_only: Literal[True], regex_search: bool = ..., ) -> list[int]: ... diff --git a/snakebids/core/input_generation.py b/snakebids/core/input_generation.py index 03ba996a..7ce51ab7 100644 --- a/snakebids/core/input_generation.py +++ b/snakebids/core/input_generation.py @@ -1,6 +1,7 @@ """Utilities for converting Snakemake apps to BIDS apps.""" from __future__ import annotations +import functools as ft import json import logging import os @@ -15,7 +16,7 @@ from bids import BIDSLayout, BIDSLayoutIndexer from bids.layout import BIDSFile, Query from snakemake.script import Snakemake -from typing_extensions import TypeAlias +from typing_extensions import Self, TypeAlias from snakebids.core.datasets import BidsComponent, BidsDataset, BidsDatasetDict from snakebids.core.filtering import filter_list @@ -25,7 +26,7 @@ PybidsError, RunError, ) -from snakebids.types import InputsConfig, ZipList +from snakebids.types import InputConfig, InputsConfig, ZipList from snakebids.utils.snakemake_io import glob_wildcards from snakebids.utils.utils import ( DEPRECATION_FLAG, @@ -36,6 +37,9 @@ _logger = logging.getLogger(__name__) +FilterType: TypeAlias = "Mapping[str, str | bool | Sequence[str | bool]]" +CompiledFilter: TypeAlias = "Mapping[str, str | Query | Sequence[str | Query]]" + @overload def generate_inputs( @@ -255,9 +259,8 @@ def generate_inputs( ), }) """ - subject_filter, regex_search = _generate_filters( - participant_label, exclude_participant_label - ) + postfilters = _Postfilter() + postfilters.add_filter("subject", participant_label, exclude_participant_label) pybidsdb_dir, pybidsdb_reset = _normalize_database_args( pybidsdb_dir, pybidsdb_reset, pybids_database_dir, pybids_reset_database @@ -277,13 +280,11 @@ def generate_inputs( else None ) - filters = {"subject": subject_filter} if subject_filter else {} bids_inputs = _get_lists_from_bids( bids_layout=layout, pybids_inputs=pybids_inputs, limit_to=limit_to, - regex_search=regex_search, - **(filters), + postfilters=postfilters, ) if use_bids_inputs is True: @@ -465,65 +466,195 @@ def write_derivative_json(snakemake: Snakemake, **kwargs: dict[str, Any]) -> Non json.dump(sidecar, outfile, indent=4) -def _generate_filters( - include: Iterable[str] | str | None = None, - exclude: Iterable[str] | str | None = None, -) -> tuple[list[str], bool]: - """Generate Pybids filter based on inclusion or exclusion criteria. +class _Postfilter: + """Filters to apply after indexing, typically derived from the CLI. - Converts either a list of values to include or exclude in a list of Pybids - compatible filters. Unlike inclusion values, exclusion requires regex filtering. The - necessity for regex will be indicated by the boolean value of the second returned - item: True if regex is needed, False otherwise. Raises an exception if both include - and exclude are stipulated + Currently used for supporting ``--[exclude-]participant-label`` + """ - Parameters - ---------- - include : list of str or str, optional - Values to include, values not found in this list will be excluded, by default - None - exclude : list of str or str, optional - Values to exclude, only values not found in this list will be included, by - default None + def __init__(self): + self.inclusions: dict[str, Sequence[str] | str] = {} + self.exclusions: dict[str, Sequence[str] | str] = {} - Returns - ------- - list of str, bool - Two values: the first, a list of pybids compatible filters; the second, a - boolean indicating whether regex_search must be enabled in pybids + def add_filter( + self, + key: str, + inclusions: Iterable[str] | str | None, + exclusions: Iterable[str] | str | None, + ): + """Add entity filter based on inclusion or exclusion criteria. + + Converts a list of values to include or exclude into Pybids compatible filters. + Exclusion filters are appropriately formatted as regex. Raises an exception if + both include and exclude are stipulated + + _Postfilter is modified in-place + + Parameters + ---------- + key + Name of entity to be filtered + inclusions + Values to include, values not found in this list will be excluded, by + default None + exclusions + Values to exclude, only values not found in this list will be included, by + default None + + Raises + ------ + ValueError + Raised if both include and exclude values are stipulated. + """ + if inclusions is not None and exclusions is not None: + msg = ( + "Cannot define both participant_label and exclude_participant_label at " + "the same time" + ) + raise ValueError(msg) + if inclusions is not None: + self.inclusions[key] = list(itx.always_iterable(inclusions)) + if exclusions is not None: + self.exclusions[key] = self._format_exclusions(exclusions) - Raises - ------ - ValueError Raised of both include and exclude values are stipulated. + def _format_exclusions(self, exclusions: Iterable[str] | str): + # if multiple items to exclude, combine with with item1|item2|... + exclude_string = "|".join( + re.escape(label) for label in itx.always_iterable(exclusions) + ) + # regex to exclude subjects + return [f"^((?!({exclude_string})$).*)$"] + + +def _compile_filters(filters: FilterType) -> CompiledFilter: + return { + key: [ + Query.ANY if f is True else Query.NONE if f is False else f + for f in itx.always_iterable(filts) + ] + for key, filts in filters.items() + } + + +@attrs.define(slots=False) +class _UnifiedFilter: + """Manages component level and post filters.""" + + component: InputConfig + """The Component configuration defining the filters""" + + postfilters: _Postfilter + """Filters to be applied after collecting and parsing the data + + Currently only used to implement --[exclude-]participant-label, but in the future, + may implement other such CLI args. Unlike configuration-defined filters, these + filters apply after the dataset is indexed and queried. Thus, if a filter is set + to an empty list, a complete, albeit empty, component may be found. This is akin to + calling ``BidsComponent.filter`` after running ``generate_inputs``. + + For performance purposes, non-empty post-filters are applied via ``pybids.get()`` """ - if include is not None and exclude is not None: - msg = ( - "Cannot define both participant_label and " - "exclude_participant_label at the same time" + + @classmethod + def from_filter_dict( + cls, + filters: Mapping[str, str | bool | Sequence[str | bool]], + postfilter: _Postfilter | None = None, + ) -> Self: + """Patch together a UnifiedFilter based on a basic filter dict. + + Intended primarily for use in testing + """ + wildcards: list[str] = [] + if postfilter is not None: + wildcards.extend(postfilter.inclusions) + wildcards.extend(postfilter.exclusions) + return cls( + {"filters": filters, "wildcards": wildcards}, postfilter or _Postfilter() ) - raise ValueError(msg) - # add participant_label or exclude_participant_label to search terms (if - # defined) - # we make the item key in search_terms a list so we can have both - # include and exclude defined - if include is not None: - return [*itx.always_iterable(include)], False + def _has_empty_list(self, items: Iterable[Any]): + """Check if any of the lists within iterable are empty.""" + return any(itx.ilen(itx.always_iterable(item)) == 0 for item in items) + + def _has_overlap(self, key: str): + """Check if filter key is a wildcard and not already a prefilter.""" + return key not in self.prefilters and key in self.component.get("wildcards", []) + + @ft.cached_property + def prefilters(self) -> FilterType: + """Filters defined in the component configuration and applied via pybids. + + Unlike postfilters, a prefilter set to an empty list will result in no valid + paths found, resulting in a blank (missing) component. + """ + filters = dict(self.component.get("filters", {})) + # Silently remove "regex_search". This value has been blocked by a bug for the + # since version 0.6, and even before, never fully worked properly (e.g. would + # break if combined with --exclude-participant-label) + if "regex_search" in filters: + del filters["regex_search"] + return filters + + @ft.cached_property + def filters(self) -> CompiledFilter: + """The combination pre- and post- filters to be applied to pybids indexing. + + Includes all pre-filters, and all inclusion post-filters. Empty post-filters + are replaced with Query.ANY. This allows valid paths to be found and processed + later. Post-filters are not applied when an equivalent prefilter is present + """ + result = dict(_compile_filters(self.prefilters)) + postfilters = self.postfilters.inclusions + for key in self.postfilters.inclusions: + if self._has_overlap(key): + # if empty list filter, ensure the entity filtered is present + result[key] = ( + postfilters[key] + if itx.ilen(itx.always_iterable(postfilters[key])) + else Query.ANY + ) + return result - if exclude is not None: - # if multiple items to exclude, combine with with item1|item2|... - exclude_string = "|".join( - re.escape(label) for label in itx.always_iterable(exclude) + @property + def post_exclusions(self) -> dict[str, Sequence[str] | str]: + """Dictionary of all post-exclusion filters.""" + return { + key: val + for key, val in self.postfilters.exclusions.items() + if self._has_overlap(key) + } + + @property + def without_bools(self) -> Mapping[str, str | Sequence[str]]: + """Check and typeguard to ensure filters do not contain booleans.""" + for key, val in self.filters.items(): + if any(isinstance(v, Query) for v in itx.always_iterable(val)): + msg = ( + "Boolean filters in items with custom paths are not supported; in " + f"component='{key}'" + ) + raise ValueError(msg) + return cast("Mapping[str, str | Sequence[str]]", self.filters) + + @property + def has_empty_prefilter(self) -> bool: + """Returns True if even one prefilter is empty.""" + return self._has_empty_list(self.prefilters.values()) + + @property + def has_empty_postfilter(self) -> bool: + """Returns True if even one postfilter is empty.""" + return self._has_empty_list( + filt + for name, filt in self.postfilters.inclusions.items() + if self._has_overlap(name) ) - # regex to exclude subjects - return [f"^((?!({exclude_string})$).*)$"], True - return [], False def _parse_custom_path( input_path: Path | str, - regex_search: bool = False, - **filters: Sequence[str | bool] | str | bool, + filters: _UnifiedFilter, ) -> ZipList: """Glob wildcards from a custom path and apply filters. @@ -558,16 +689,15 @@ def _parse_custom_path( return wildcards # Return the output values, running filtering on the zip_lists - if any( - isinstance(v, bool) for f in filters.values() for v in itx.always_iterable(f) - ): - msg = "boolean filters are not currently supported in custom path filtering" - raise TypeError(msg) - return filter_list( + result = filter_list( wildcards, - cast("Mapping[str, str | Sequence[str]]", filters), - regex_search=regex_search, + filters.without_bools, + regex_search=False, ) + if not filters.post_exclusions: + return result + + return filter_list(result, filters.post_exclusions, regex_search=True) def _parse_bids_path(path: str, entities: Iterable[str]) -> tuple[str, dict[str, str]]: @@ -630,46 +760,24 @@ def _parse_bids_path(path: str, entities: Iterable[str]) -> tuple[str, dict[str, return "".join(new_path), wildcard_values -@attrs.define -class _GetListsFromBidsSteps: - input_name: str - - FilterType: TypeAlias = "Mapping[str, str | bool | Sequence[str | bool]]" - - def _get_invalid_filters(self, filters: FilterType): - for key, filts in filters.items(): - try: - if True in filts or False in filts: # type: ignore - yield key, filts - except TypeError: - pass - - def prepare_bids_filters( - self, filters: Mapping[str, str | bool | Sequence[str | bool]] - ) -> dict[str, str | Query | list[str | Query]]: - return { - key: [ - Query.ANY if f is True else Query.NONE if f is False else f - for f in itx.always_iterable(filts) - ] - for key, filts in filters.items() - } - - def get_matching_files( - self, - bids_layout: BIDSLayout, - regex_search: bool, - filters: Mapping[str, str | Query | Sequence[str | Query]], - ) -> Iterable[BIDSFile]: - try: - return bids_layout.get(regex_search=regex_search, **filters) - except AttributeError as err: - msg = ( - "Pybids has encountered a problem that Snakebids cannot handle. This " - "may indicate a missing or invalid dataset_description.json for this " - "dataset." - ) - raise PybidsError(msg) from err +def _get_matching_files( + bids_layout: BIDSLayout, + filters: _UnifiedFilter, +) -> Iterable[BIDSFile]: + if filters.has_empty_prefilter: + return [] + try: + return bids_layout.get( + regex_search=False, + **filters.filters, + ) + except AttributeError as err: + msg = ( + "Pybids has encountered a problem that Snakebids cannot handle. This " + "may indicate a missing or invalid dataset_description.json for this " + "dataset." + ) + raise PybidsError(msg) from err def _get_lists_from_bids( @@ -677,8 +785,7 @@ def _get_lists_from_bids( pybids_inputs: InputsConfig, *, limit_to: Iterable[str] | None = None, - regex_search: bool = False, - **filters: str | Sequence[str], + postfilters: _Postfilter, ) -> Generator[BidsComponent, None, None]: """Grabs files using pybids and creates snakemake-friendly lists. @@ -704,10 +811,11 @@ def _get_lists_from_bids( One BidsComponent is yielded for each modality described by ``pybids_inputs``. """ for input_name in limit_to or list(pybids_inputs): - steps = _GetListsFromBidsSteps(input_name) _logger.debug("Grabbing inputs for %s...", input_name) component = pybids_inputs[input_name] + filters = _UnifiedFilter(component, postfilters or {}) + if "custom_path" in component: # a custom path was specified for this input, skip pybids: # get input_wildcards by parsing path for {} entries (using a set @@ -716,12 +824,7 @@ def _get_lists_from_bids( # to deal with multiple wildcards path = component["custom_path"] - zip_lists = _parse_custom_path( - path, - regex_search=regex_search, - **pybids_inputs[input_name].get("filters", {}), - **filters, - ) + zip_lists = _parse_custom_path(path, filters=filters) yield BidsComponent(name=input_name, path=path, zip_lists=zip_lists) continue @@ -734,10 +837,7 @@ def _get_lists_from_bids( zip_lists: dict[str, list[str]] = defaultdict(list) paths: set[str] = set() - pybids_filters = steps.prepare_bids_filters(component.get("filters", {})) - matching_files = steps.get_matching_files( - bids_layout, regex_search, {**pybids_filters, **filters} - ) + matching_files = _get_matching_files(bids_layout, filters) for img in matching_files: wildcards: list[str] = [ @@ -805,7 +905,15 @@ def _get_lists_from_bids( ) raise ConfigError(msg) from err - yield BidsComponent(name=input_name, path=path, zip_lists=zip_lists) + if filters.has_empty_postfilter: + yield BidsComponent( + name=input_name, path=path, zip_lists={key: [] for key in zip_lists} + ) + continue + + yield BidsComponent(name=input_name, path=path, zip_lists=zip_lists).filter( + regex_search=True, **filters.post_exclusions + ) def get_wildcard_constraints(image_types: InputsConfig) -> dict[str, str]: diff --git a/snakebids/tests/helpers.py b/snakebids/tests/helpers.py index 294142fb..c7996464 100644 --- a/snakebids/tests/helpers.py +++ b/snakebids/tests/helpers.py @@ -209,7 +209,11 @@ def create_snakebids_config(dataset: BidsDataset) -> InputsConfig: def reindex_dataset( - root: str, dataset: BidsDataset, use_custom_paths: bool = False + root: str, + dataset: BidsDataset, + use_custom_paths: bool = False, + participant_label: str | Sequence[str] | None = None, + exclude_participant_label: str | Sequence[str] | None = None, ) -> BidsDataset: """Create BidsDataset on the filesystem and reindex @@ -217,10 +221,20 @@ def reindex_dataset( """ create_dataset(Path("/"), dataset) config = create_snakebids_config(dataset) + if participant_label is not None or exclude_participant_label is not None: + for comp in config.values(): + if "subject" in comp.get("filters", {}): + del comp["filters"]["subject"] # type: ignore + if use_custom_paths: for comp in config: config[comp]["custom_path"] = dataset[comp].path - return generate_inputs(root, config) + return generate_inputs( + root, + config, + participant_label=participant_label, + exclude_participant_label=exclude_participant_label, + ) def allow_function_scoped(func: _T, /) -> _T: diff --git a/snakebids/tests/test_generate_inputs.py b/snakebids/tests/test_generate_inputs.py index 2008dd65..2fd31375 100644 --- a/snakebids/tests/test_generate_inputs.py +++ b/snakebids/tests/test_generate_inputs.py @@ -13,7 +13,7 @@ import warnings from collections import defaultdict from pathlib import Path -from typing import Any, Callable, Iterable, NamedTuple, TypeVar, cast +from typing import Iterable, Literal, NamedTuple, TypedDict, TypeVar, cast import attrs import more_itertools as itx @@ -23,25 +23,29 @@ from hypothesis import strategies as st from pyfakefs.fake_filesystem import FakeFilesystem from pytest_mock import MockerFixture +from snakemake.io import expand as sb_expand from snakebids.core.datasets import BidsComponent, BidsDataset from snakebids.core.input_generation import ( _all_custom_paths, _gen_bids_layout, - _generate_filters, _get_lists_from_bids, _normalize_database_args, _parse_bids_path, _parse_custom_path, + _Postfilter, + _UnifiedFilter, generate_inputs, ) -from snakebids.exceptions import ConfigError, PybidsError, RunError +from snakebids.exceptions import ConfigError, RunError from snakebids.paths.presets import bids from snakebids.tests import strategies as sb_st from snakebids.tests.helpers import ( + Benchmark, BidsListCompare, allow_function_scoped, create_dataset, + create_snakebids_config, example_if, get_bids_path, get_zip_list, @@ -535,7 +539,7 @@ def test_missing_wildcards(self, tmpdir: Path): assert config.subj_wildcards == {"subject": "{subject}"} -class TestGenerateFilter: +class TestPostfilter: valid_chars = st.characters(blacklist_characters=["\n"]) st_lists_or_text = st.lists(st.text(valid_chars)) | st.text(valid_chars) @@ -543,52 +547,56 @@ class TestGenerateFilter: def test_throws_error_if_labels_and_excludes_are_given( self, args: tuple[list[str] | str, list[str] | str] ): + filters = _Postfilter() with pytest.raises( ValueError, - match="Cannot define both participant_label and " - "exclude_participant_label at the same time", + match="Cannot define both participant_label and exclude_participant_label ", ): - _generate_filters(*args) + filters.add_filter("foo", *args) - @given(st_lists_or_text) - def test_returns_participant_label_as_list(self, label: list[str] | str): - result = _generate_filters(label)[0] + @given(st.text(), st_lists_or_text) + def test_returns_participant_label_as_dict(self, key: str, label: list[str] | str): + filters = _Postfilter() + filters.add_filter(key, label, None) if isinstance(label, str): - assert result == [label] + assert filters.inclusions == {key: [label]} else: - assert result == label + assert filters.inclusions == {key: label} + assert filters.exclusions == {} @given( + st.text(), st_lists_or_text, st.lists(st.text(valid_chars, min_size=1), min_size=1), st.text(valid_chars, min_size=1, max_size=3), ) def test_exclude_gives_regex_that_matches_anything_except_exclude( - self, excluded: list[str] | str, dummy_values: list[str], padding: str + self, key: str, excluded: list[str] | str, dummy_values: list[str], padding: str ): + filters = _Postfilter() # Make sure the dummy_values and padding we'll be testing against are different # from our test values for value in dummy_values: assume(value not in itx.always_iterable(excluded)) assume(padding not in itx.always_iterable(excluded)) - result = _generate_filters(exclude=excluded) - assert result[1] is True - assert isinstance(result[0], list) - assert len(result[0]) == 1 + filters.add_filter(key, None, excluded) + assert isinstance(filters.exclusions[key], list) + assert len(filters.exclusions[key]) == 1 # We match any value that isn't the exclude string for value in dummy_values: - assert re.match(result[0][0], value) + assert re.match(filters.exclusions[key][0], value) for exclude in itx.always_iterable(excluded): # We don't match the exclude string - assert re.match(result[0][0], exclude) is None + assert re.match(filters.exclusions[key][0], exclude) is None # Addition of random strings before and/or after lets the match occur again - assert re.match(result[0][0], padding + exclude) - assert re.match(result[0][0], exclude + padding) - assert re.match(result[0][0], padding + exclude + padding) + assert re.match(filters.exclusions[key][0], padding + exclude) + assert re.match(filters.exclusions[key][0], exclude + padding) + assert re.match(filters.exclusions[key][0], padding + exclude + padding) + assert filters.inclusions == {} class PathEntities(NamedTuple): @@ -685,13 +693,11 @@ def generate_test_directory( T = TypeVar("T") - def test_benchmark_test_custom_paths( - self, benchmark: Callable[[Callable[..., Any], Any], Any], tmp_path: Path - ): + def test_benchmark_test_custom_paths(self, benchmark: Benchmark, tmp_path: Path): entities = {"A": ["A", "B", "C"], "B": ["1", "2", "3"]} template = Path("{A}/A-{A}_B-{B}") test_path = self.generate_test_directory(entities, template, tmp_path) - benchmark(_parse_custom_path, test_path) + benchmark(_parse_custom_path, test_path, _UnifiedFilter.from_filter_dict({})) @allow_function_scoped @given(path_entities=path_entities()) @@ -704,7 +710,7 @@ def test_collects_all_paths_when_no_filters( test_path = self.generate_test_directory(entities, template, temp_dir) # Test without any filters - result = _parse_custom_path(test_path) + result = _parse_custom_path(test_path, _UnifiedFilter.from_filter_dict({})) zip_lists = get_zip_list(entities, it.product(*entities.values())) assert BidsComponent( name="foo", path=get_bids_path(zip_lists), zip_lists=zip_lists @@ -724,7 +730,7 @@ def test_collects_only_filtered_entities( # Test with filters result_filtered = MultiSelectDict( - _parse_custom_path(test_path, regex_search=False, **filters) + _parse_custom_path(test_path, _UnifiedFilter.from_filter_dict(filters)) ) zip_lists = MultiSelectDict( { @@ -750,14 +756,13 @@ def test_collect_all_but_filters_when_exclusion_filters_used( entities, template, filters = path_entities test_path = self.generate_test_directory(entities, template, temp_dir) # Test with exclusion filters - exclude_filters = { - # We use _generate_filter to get our exclusion regex. This function was - # tested previously - key: _generate_filters(exclude=values)[0] - for key, values in filters.items() - } + exclude_filters = _Postfilter() + for key, values in filters.items(): + exclude_filters.add_filter(key, None, values) result_excluded = MultiSelectDict( - _parse_custom_path(test_path, regex_search=True, **exclude_filters) + _parse_custom_path( + test_path, _UnifiedFilter.from_filter_dict({}, exclude_filters) + ) ) entities_excluded = { @@ -1083,26 +1088,6 @@ def test_t1w_with_dict(): assert config["subj_wildcards"] == {"subject": "{subject}"} -@pytest.mark.skipif( - sys.version_info >= (3, 8), - reason=""" - Bug only surfaces on python 3.7 because higher python versions have access to the - latest pybids version - """, -) -def test_get_lists_from_bids_raises_pybids_error(): - """Test that we wrap a cryptic AttributeError from pybids with PybidsError. - - Pybids raises an AttributeError when a BIDSLayout.get is called with scope not - equal to 'all' on a layout that indexes a dataset without a - dataset_description.json. We wrap this error with something a bit less cryptic, so - this test ensures that that behaviour is still present. - """ - layout = BIDSLayout("snakebids/tests/data/bids_t1w", validate=False) - with pytest.raises(PybidsError): - next(_get_lists_from_bids(layout, {"t1": {"filters": {"scope": "raw"}}})) - - def test_get_lists_from_bids_raises_run_error(): bids_layout = None pybids_inputs: InputsConfig = { @@ -1112,7 +1097,9 @@ def test_get_lists_from_bids_raises_run_error(): } } with pytest.raises(RunError): - next(_get_lists_from_bids(bids_layout, pybids_inputs)) + next( + _get_lists_from_bids(bids_layout, pybids_inputs, postfilters=_Postfilter()) + ) def test_get_lists_from_bids(): @@ -1151,7 +1138,7 @@ def test_get_lists_from_bids(): pybids_inputs["t1"]["custom_path"] = wildcard_path_t1 pybids_inputs["t2"]["custom_path"] = wildcard_path_t2 - result = _get_lists_from_bids(layout, pybids_inputs) + result = _get_lists_from_bids(layout, pybids_inputs, postfilters=_Postfilter()) for bids_lists in result: if bids_lists.input_name == "t1": template = BidsComponent( @@ -1267,6 +1254,217 @@ def test_generate_inputs(dataset: BidsDataset, bids_fs: Path, fakefs_tmpdir: Pat assert reindexed.layout is not None +@st.composite +def dataset_with_subject(draw: st.DrawFn): + entities = draw(sb_st.bids_entity_lists(blacklist_entities=["subject"])) + entities += ["subject"] + return BidsDataset.from_iterable( + [ + draw( + sb_st.bids_components( + whitelist_entities=entities, + min_entities=len(entities), + max_entities=len(entities), + restrict_patterns=True, + unique=True, + ) + ) + ] + ) + + +class TestParticipantFiltering: + MODE = Literal["include", "exclude"] + + @pytest.fixture + def tmpdir(self, bids_fs: Path, fakefs_tmpdir: Path): + return fakefs_tmpdir + + def get_filter_params(self, mode: MODE, filters: list[str] | str): + class FiltParams(TypedDict, total=False): + participant_label: list[str] | str + exclude_participant_label: list[str] | str + + if mode == "include": + return FiltParams({"participant_label": filters}) + if mode == "exclude": + return FiltParams({"exclude_participant_label": filters}) + msg = f"Invalid mode specification: {mode}" + raise ValueError(msg) + + @given( + data=st.data(), + dataset=dataset_with_subject(), + ) + @settings( + deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture] + ) + def test_participant_label_filters_comps_with_subject( + self, data: st.DataObject, dataset: BidsDataset, tmpdir: Path + ): + root = tempfile.mkdtemp(dir=tmpdir) + rooted = BidsDataset.from_iterable( + attrs.evolve(comp, path=os.path.join(root, comp.path)) + for comp in dataset.values() + ) + sampler = st.sampled_from(itx.first(rooted.values()).entities["subject"]) + label = data.draw(st.lists(sampler, unique=True) | sampler) + reindexed = reindex_dataset(root, rooted, participant_label=label) + assert set(itx.first(reindexed.values()).entities["subject"]) == set( + itx.always_iterable(label) + ) + + @given( + data=st.data(), + dataset=dataset_with_subject(), + ) + @settings( + deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture] + ) + def test_exclude_participant_label_filters_comp_with_subject( + self, data: st.DataObject, dataset: BidsDataset, tmpdir: Path + ): + root = tempfile.mkdtemp(dir=tmpdir) + rooted = BidsDataset.from_iterable( + attrs.evolve(comp, path=os.path.join(root, comp.path)) + for comp in dataset.values() + ) + sampler = st.sampled_from(itx.first(rooted.values()).entities["subject"]) + label = data.draw(st.lists(sampler, unique=True) | sampler) + reindexed = reindex_dataset(root, rooted, exclude_participant_label=label) + reindexed_subjects = set(itx.first(reindexed.values()).entities["subject"]) + original_subjects = set(itx.first(rooted.values()).entities["subject"]) + assert reindexed_subjects == original_subjects - set(itx.always_iterable(label)) + + @pytest.mark.parametrize("mode", ["include", "exclude"]) + @given( + dataset=sb_st.datasets_one_comp(blacklist_entities=["subject"], unique=True), + participant_filter=st.lists(st.text(min_size=1)) | st.text(min_size=1), + ) + @settings( + deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture] + ) + def test_participant_label_doesnt_filter_comps_without_subject( + self, + mode: MODE, + dataset: BidsDataset, + participant_filter: list[str] | str, + tmpdir: Path, + ): + root = tempfile.mkdtemp(dir=tmpdir) + rooted = BidsDataset.from_iterable( + attrs.evolve(comp, path=os.path.join(root, comp.path)) + for comp in dataset.values() + ) + reindexed = reindex_dataset( + root, rooted, **self.get_filter_params(mode, participant_filter) + ) + assert reindexed == rooted + + @pytest.mark.parametrize("mode", ["include", "exclude"]) + @given( + dataset=dataset_with_subject(), + participant_filter=st.lists(st.text(min_size=1)) | st.text(min_size=1), + ) + @settings( + deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture] + ) + def test_participant_label_doesnt_filter_comps_when_subject_has_filter( + self, + mode: MODE, + dataset: BidsDataset, + participant_filter: list[str] | str, + tmpdir: Path, + ): + root = tempfile.mkdtemp(dir=tmpdir) + rooted = BidsDataset.from_iterable( + attrs.evolve(comp, path=os.path.join(root, comp.path)) + for comp in dataset.values() + ) + create_dataset(Path("/"), rooted) + reindexed = generate_inputs( + root, + create_snakebids_config(rooted), + **self.get_filter_params(mode, participant_filter), + ) + assert reindexed == rooted + + @pytest.mark.parametrize("mode", ["include", "exclude"]) + @given( + dataset=dataset_with_subject(), + participant_filter=st.lists(st.text(min_size=1)) | st.text(min_size=1), + data=st.data(), + ) + @settings( + deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture] + ) + def test_participant_label_doesnt_filter_comps_when_subject_has_filter_no_wcard( + self, + mode: MODE, + dataset: BidsDataset, + participant_filter: list[str] | str, + data: st.DataObject, + tmpdir: Path, + ): + root = tempfile.mkdtemp(dir=tmpdir) + rooted = BidsDataset.from_iterable( + attrs.evolve(comp, path=os.path.join(root, comp.path)) + for comp in dataset.values() + ) + subject = data.draw( + st.sampled_from(itx.first(rooted.values()).entities["subject"]) + ) + create_dataset(Path("/"), rooted) + config = create_snakebids_config(rooted) + for comp in config.values(): + comp["filters"] = dict(comp.get("filters", {})) + comp["filters"]["subject"] = subject + reindexed = generate_inputs( + root, + create_snakebids_config(rooted), + **self.get_filter_params(mode, participant_filter), + ) + assert reindexed == rooted + + @given( + data=st.data(), + dataset=dataset_with_subject().filter( + lambda ds: set(itx.first(ds.values()).wildcards) != {"subject", "extension"} + ), + ) + @settings( + deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture] + ) + def test_exclude_participant_does_not_make_all_other_filters_regex( + self, data: st.DataObject, dataset: BidsDataset, tmpdir: Path + ): + root = tempfile.mkdtemp(dir=tmpdir) + rooted = BidsDataset.from_iterable( + attrs.evolve(comp, path=os.path.join(root, comp.path)) + for comp in dataset.values() + ) + + # Create an extra set of paths by modifing one of the existing components to put + # foo after a set of entity values. If that filter gets changed to a regex, all + # of the suffixed decoys will get picked up by pybids + ziplist = dict(itx.first(rooted.values()).zip_lists) + mut_entity = itx.first( + filter(lambda e: e not in {"subject", "extension"}, ziplist) + ) + ziplist[mut_entity] = ["foo" + v for v in ziplist[mut_entity]] + for path in sb_expand(itx.first(rooted.values()).path, zip, **ziplist): + p = Path(path) + p.parent.mkdir(parents=True, exist_ok=True) + p.touch() + + sampler = st.sampled_from(itx.first(rooted.values()).entities["subject"]) + label = data.draw(st.lists(sampler, unique=True) | sampler) + reindexed = reindex_dataset(root, rooted, exclude_participant_label=label) + reindexed_subjects = set(itx.first(reindexed.values()).entities["subject"]) + original_subjects = set(itx.first(rooted.values()).entities["subject"]) + assert reindexed_subjects == original_subjects - set(itx.always_iterable(label)) + + # The content of the dataset is irrelevant to this test, so one example suffices # but can't use extension because the custom path won't glob properly @settings(max_examples=1, suppress_health_check=[HealthCheck.function_scoped_fixture]) diff --git a/snakebids/types.py b/snakebids/types.py index 72a0a7be..342f4fe1 100644 --- a/snakebids/types.py +++ b/snakebids/types.py @@ -23,7 +23,7 @@ class InputConfig(TypedDict, total=False): """Configuration for a single bids component.""" - filters: dict[str, str | bool | list[str | bool]] + filters: Mapping[str, str | bool | Sequence[str | bool]] """Filters to pass on to :class:`BIDSLayout.get() ` Each key refers to the name of an entity. Values may take the following forms: