From 80ed72d64feff20e073c60164912b46944c23b7c Mon Sep 17 00:00:00 2001 From: Peter Van Dyken Date: Tue, 17 Oct 2023 18:11:44 -0400 Subject: [PATCH] Overhaul global filtering of participants --participant-filter and --exclude-participant filter have had a couple of bugs over the past few versions: 1. Both terms would apply to every single component, even components without the subject entity. These components would have all of their entries filtered out 2. Using --exclude-participant-filter would turn on regex matching mode for every single filter. This changed the meaning of the filters (e.g. allowing partial matches), leading to workflow disruption The overhaul fixes both bugs, while improving the organization of the input generation code. Additionally, the documentation states the magic filter `regex_search: True` can be used to enable regex searching for a block of filters. This hasn't worked for the past few versions. Rather than fix it, the behaviour has been silently disabled in preparation for an overhaul of the regex filtering api. Mention of this feature has been removed from documentation Resolves #303 Resolves #216 --- docs/bids_app/config.md | 18 +- snakebids/core/datasets.py | 2 + snakebids/core/filtering.py | 2 +- snakebids/core/input_generation.py | 340 ++++++++++++++++-------- snakebids/tests/helpers.py | 18 +- snakebids/tests/test_generate_inputs.py | 312 ++++++++++++++++++---- snakebids/types.py | 2 +- 7 files changed, 507 insertions(+), 187 deletions(-) diff --git a/docs/bids_app/config.md b/docs/bids_app/config.md index de387717..d095ef36 100644 --- a/docs/bids_app/config.md +++ b/docs/bids_app/config.md @@ -9,11 +9,11 @@ Config Variables ### `pybids_inputs` -A dictionary that describes each type of input you want to grab from an input BIDS dataset. Snakebids will parse your dataset with {func}`generate_inputs() `, converting each input type into a {class}`BidsComponent `. The value of each item should be a dictionary with keys ``filters`` and ``wildcards``. +A dictionary that describes each type of input you want to grab from an input BIDS dataset. Snakebids will parse your dataset with {func}`generate_inputs() `, converting each input type into a {class}`BidsComponent `. The value of each item should be a dictionary with keys `filters` and `wildcards`. -The value of ``filters`` should be a dictionary where each key corresponds to a BIDS entity, and the value specifies which values of that entity should be grabbed. The dictionary for each input is sent to the [PyBIDS' get() function ](#bids.layout.BIDSLayout). `filters` can be set according to a few different formats: +The value of `filters` should be a dictionary where each key corresponds to a BIDS entity, and the value specifies which values of that entity should be grabbed. The dictionary for each input is sent to the [PyBIDS' `get()` function ](#bids.layout.BIDSLayout). `filters` can be set according to a few different formats: -* [string](#str): specifies an exact value for the entity. In the following example: +* [`string`](#str): specifies an exact value for the entity. In the following example: ```yaml pybids_inputs: bold: @@ -29,22 +29,20 @@ The value of ``filters`` should be a dictionary where each key corresponds to a sub-xxx/.../func/ent1-xxx_ent2-xxx_..._bold.nii.gz ``` -* [boolean](#bool): constrains presence or absence of the entity without restricting its value. `False` requires that the entity be **absent**, while `True` requires that the entity be **present**, regardless of value. +* [`boolean`](#bool): constrains presence or absence of the entity without restricting its value. `False` requires that the entity be **absent**, while `True` requires that the entity be **present**, regardless of value. ```yaml pybids_inputs: derivs: filters: datatype: 'func' - desc: True # or true, or yes - acquisition: False # or false, or no + desc: True + acquisition: False ``` The above example maps all paths in the `func/` datatype folder that have a `_desc-` entity but do not have the `_acq-` entity. -In addition, the special filter `regex_search` can be set to `true`, which causes all other filters in the component to use regex matching instead of exact matching. +The value of `wildcards` should be a list of BIDS entities. Snakebids collects the values of any entities specified and saves them in the {attr}`entities ` and {attr}`~snakebids.BidsComponent.zip_lists` entries of the corresponding {class}`BidsComponent `. In other words, these are the entities to be preserved in output paths derived from the input being described. Placing an entity in `wildcards` does not require the entity be present. If an entity is not found, it will be left out of {attr}`entities `. To require the presence of an entity, place it under `filters` set to `true`. -The value of ``wildcards`` should be a list of BIDS entities. Snakebids collects the values of any entities specified and saves them in the {attr}`entities ` and {attr}`~snakebids.BidsComponent.zip_lists` entries of the corresponding {class}`BidsComponent `. In other words, these are the entities to be preserved in output paths derived from the input being described. Placing an entity in `wildcards` does not require the entity be present. If an entity is not found, it will be left out of {attr}`entities `. To require the presence of an entity, place it under `filters` set to `true`. - -In the following (YAML-formatted) example, the ``bold`` input type is specified. BIDS files with the datatype ``func``, suffix ``bold``, and extension ``.nii.gz`` will be grabbed, and the ``subject``, ``session``, ``acquisition``, ``task``, and ``run`` entities of those files will be left as wildcards. The `task` entity must be present, but there must not be any `desc`. +In the following (YAML-formatted) example, the `bold` input type is specified. BIDS files with the datatype `func`, suffix `bold`, and extension `.nii.gz` will be grabbed, and the `subject`, `session`, `acquisition`, `task`, and `run` entities of those files will be left as wildcards. The `task` entity must be present, but there must not be any `desc`. ```yaml pybids_inputs: diff --git a/snakebids/core/datasets.py b/snakebids/core/datasets.py index 5acf232d..e6f389e0 100644 --- a/snakebids/core/datasets.py +++ b/snakebids/core/datasets.py @@ -458,6 +458,8 @@ def filter( if not isinstance(regex_search, bool): msg = "regex_search must be a boolean" raise TypeError(msg) + if not filters: + return self return attr.evolve( self, zip_lists=filter_list(self.zip_lists, filters, regex_search=regex_search), diff --git a/snakebids/core/filtering.py b/snakebids/core/filtering.py index b8517bfc..31cdbeab 100644 --- a/snakebids/core/filtering.py +++ b/snakebids/core/filtering.py @@ -25,7 +25,7 @@ def filter_list( def filter_list( zip_list: ZipListLike, filters: Mapping[str, Iterable[str] | str], - return_indices_only: Literal[True] = ..., + return_indices_only: Literal[True], regex_search: bool = ..., ) -> list[int]: ... diff --git a/snakebids/core/input_generation.py b/snakebids/core/input_generation.py index 03ba996a..7ce51ab7 100644 --- a/snakebids/core/input_generation.py +++ b/snakebids/core/input_generation.py @@ -1,6 +1,7 @@ """Utilities for converting Snakemake apps to BIDS apps.""" from __future__ import annotations +import functools as ft import json import logging import os @@ -15,7 +16,7 @@ from bids import BIDSLayout, BIDSLayoutIndexer from bids.layout import BIDSFile, Query from snakemake.script import Snakemake -from typing_extensions import TypeAlias +from typing_extensions import Self, TypeAlias from snakebids.core.datasets import BidsComponent, BidsDataset, BidsDatasetDict from snakebids.core.filtering import filter_list @@ -25,7 +26,7 @@ PybidsError, RunError, ) -from snakebids.types import InputsConfig, ZipList +from snakebids.types import InputConfig, InputsConfig, ZipList from snakebids.utils.snakemake_io import glob_wildcards from snakebids.utils.utils import ( DEPRECATION_FLAG, @@ -36,6 +37,9 @@ _logger = logging.getLogger(__name__) +FilterType: TypeAlias = "Mapping[str, str | bool | Sequence[str | bool]]" +CompiledFilter: TypeAlias = "Mapping[str, str | Query | Sequence[str | Query]]" + @overload def generate_inputs( @@ -255,9 +259,8 @@ def generate_inputs( ), }) """ - subject_filter, regex_search = _generate_filters( - participant_label, exclude_participant_label - ) + postfilters = _Postfilter() + postfilters.add_filter("subject", participant_label, exclude_participant_label) pybidsdb_dir, pybidsdb_reset = _normalize_database_args( pybidsdb_dir, pybidsdb_reset, pybids_database_dir, pybids_reset_database @@ -277,13 +280,11 @@ def generate_inputs( else None ) - filters = {"subject": subject_filter} if subject_filter else {} bids_inputs = _get_lists_from_bids( bids_layout=layout, pybids_inputs=pybids_inputs, limit_to=limit_to, - regex_search=regex_search, - **(filters), + postfilters=postfilters, ) if use_bids_inputs is True: @@ -465,65 +466,195 @@ def write_derivative_json(snakemake: Snakemake, **kwargs: dict[str, Any]) -> Non json.dump(sidecar, outfile, indent=4) -def _generate_filters( - include: Iterable[str] | str | None = None, - exclude: Iterable[str] | str | None = None, -) -> tuple[list[str], bool]: - """Generate Pybids filter based on inclusion or exclusion criteria. +class _Postfilter: + """Filters to apply after indexing, typically derived from the CLI. - Converts either a list of values to include or exclude in a list of Pybids - compatible filters. Unlike inclusion values, exclusion requires regex filtering. The - necessity for regex will be indicated by the boolean value of the second returned - item: True if regex is needed, False otherwise. Raises an exception if both include - and exclude are stipulated + Currently used for supporting ``--[exclude-]participant-label`` + """ - Parameters - ---------- - include : list of str or str, optional - Values to include, values not found in this list will be excluded, by default - None - exclude : list of str or str, optional - Values to exclude, only values not found in this list will be included, by - default None + def __init__(self): + self.inclusions: dict[str, Sequence[str] | str] = {} + self.exclusions: dict[str, Sequence[str] | str] = {} - Returns - ------- - list of str, bool - Two values: the first, a list of pybids compatible filters; the second, a - boolean indicating whether regex_search must be enabled in pybids + def add_filter( + self, + key: str, + inclusions: Iterable[str] | str | None, + exclusions: Iterable[str] | str | None, + ): + """Add entity filter based on inclusion or exclusion criteria. + + Converts a list of values to include or exclude into Pybids compatible filters. + Exclusion filters are appropriately formatted as regex. Raises an exception if + both include and exclude are stipulated + + _Postfilter is modified in-place + + Parameters + ---------- + key + Name of entity to be filtered + inclusions + Values to include, values not found in this list will be excluded, by + default None + exclusions + Values to exclude, only values not found in this list will be included, by + default None + + Raises + ------ + ValueError + Raised if both include and exclude values are stipulated. + """ + if inclusions is not None and exclusions is not None: + msg = ( + "Cannot define both participant_label and exclude_participant_label at " + "the same time" + ) + raise ValueError(msg) + if inclusions is not None: + self.inclusions[key] = list(itx.always_iterable(inclusions)) + if exclusions is not None: + self.exclusions[key] = self._format_exclusions(exclusions) - Raises - ------ - ValueError Raised of both include and exclude values are stipulated. + def _format_exclusions(self, exclusions: Iterable[str] | str): + # if multiple items to exclude, combine with with item1|item2|... + exclude_string = "|".join( + re.escape(label) for label in itx.always_iterable(exclusions) + ) + # regex to exclude subjects + return [f"^((?!({exclude_string})$).*)$"] + + +def _compile_filters(filters: FilterType) -> CompiledFilter: + return { + key: [ + Query.ANY if f is True else Query.NONE if f is False else f + for f in itx.always_iterable(filts) + ] + for key, filts in filters.items() + } + + +@attrs.define(slots=False) +class _UnifiedFilter: + """Manages component level and post filters.""" + + component: InputConfig + """The Component configuration defining the filters""" + + postfilters: _Postfilter + """Filters to be applied after collecting and parsing the data + + Currently only used to implement --[exclude-]participant-label, but in the future, + may implement other such CLI args. Unlike configuration-defined filters, these + filters apply after the dataset is indexed and queried. Thus, if a filter is set + to an empty list, a complete, albeit empty, component may be found. This is akin to + calling ``BidsComponent.filter`` after running ``generate_inputs``. + + For performance purposes, non-empty post-filters are applied via ``pybids.get()`` """ - if include is not None and exclude is not None: - msg = ( - "Cannot define both participant_label and " - "exclude_participant_label at the same time" + + @classmethod + def from_filter_dict( + cls, + filters: Mapping[str, str | bool | Sequence[str | bool]], + postfilter: _Postfilter | None = None, + ) -> Self: + """Patch together a UnifiedFilter based on a basic filter dict. + + Intended primarily for use in testing + """ + wildcards: list[str] = [] + if postfilter is not None: + wildcards.extend(postfilter.inclusions) + wildcards.extend(postfilter.exclusions) + return cls( + {"filters": filters, "wildcards": wildcards}, postfilter or _Postfilter() ) - raise ValueError(msg) - # add participant_label or exclude_participant_label to search terms (if - # defined) - # we make the item key in search_terms a list so we can have both - # include and exclude defined - if include is not None: - return [*itx.always_iterable(include)], False + def _has_empty_list(self, items: Iterable[Any]): + """Check if any of the lists within iterable are empty.""" + return any(itx.ilen(itx.always_iterable(item)) == 0 for item in items) + + def _has_overlap(self, key: str): + """Check if filter key is a wildcard and not already a prefilter.""" + return key not in self.prefilters and key in self.component.get("wildcards", []) + + @ft.cached_property + def prefilters(self) -> FilterType: + """Filters defined in the component configuration and applied via pybids. + + Unlike postfilters, a prefilter set to an empty list will result in no valid + paths found, resulting in a blank (missing) component. + """ + filters = dict(self.component.get("filters", {})) + # Silently remove "regex_search". This value has been blocked by a bug for the + # since version 0.6, and even before, never fully worked properly (e.g. would + # break if combined with --exclude-participant-label) + if "regex_search" in filters: + del filters["regex_search"] + return filters + + @ft.cached_property + def filters(self) -> CompiledFilter: + """The combination pre- and post- filters to be applied to pybids indexing. + + Includes all pre-filters, and all inclusion post-filters. Empty post-filters + are replaced with Query.ANY. This allows valid paths to be found and processed + later. Post-filters are not applied when an equivalent prefilter is present + """ + result = dict(_compile_filters(self.prefilters)) + postfilters = self.postfilters.inclusions + for key in self.postfilters.inclusions: + if self._has_overlap(key): + # if empty list filter, ensure the entity filtered is present + result[key] = ( + postfilters[key] + if itx.ilen(itx.always_iterable(postfilters[key])) + else Query.ANY + ) + return result - if exclude is not None: - # if multiple items to exclude, combine with with item1|item2|... - exclude_string = "|".join( - re.escape(label) for label in itx.always_iterable(exclude) + @property + def post_exclusions(self) -> dict[str, Sequence[str] | str]: + """Dictionary of all post-exclusion filters.""" + return { + key: val + for key, val in self.postfilters.exclusions.items() + if self._has_overlap(key) + } + + @property + def without_bools(self) -> Mapping[str, str | Sequence[str]]: + """Check and typeguard to ensure filters do not contain booleans.""" + for key, val in self.filters.items(): + if any(isinstance(v, Query) for v in itx.always_iterable(val)): + msg = ( + "Boolean filters in items with custom paths are not supported; in " + f"component='{key}'" + ) + raise ValueError(msg) + return cast("Mapping[str, str | Sequence[str]]", self.filters) + + @property + def has_empty_prefilter(self) -> bool: + """Returns True if even one prefilter is empty.""" + return self._has_empty_list(self.prefilters.values()) + + @property + def has_empty_postfilter(self) -> bool: + """Returns True if even one postfilter is empty.""" + return self._has_empty_list( + filt + for name, filt in self.postfilters.inclusions.items() + if self._has_overlap(name) ) - # regex to exclude subjects - return [f"^((?!({exclude_string})$).*)$"], True - return [], False def _parse_custom_path( input_path: Path | str, - regex_search: bool = False, - **filters: Sequence[str | bool] | str | bool, + filters: _UnifiedFilter, ) -> ZipList: """Glob wildcards from a custom path and apply filters. @@ -558,16 +689,15 @@ def _parse_custom_path( return wildcards # Return the output values, running filtering on the zip_lists - if any( - isinstance(v, bool) for f in filters.values() for v in itx.always_iterable(f) - ): - msg = "boolean filters are not currently supported in custom path filtering" - raise TypeError(msg) - return filter_list( + result = filter_list( wildcards, - cast("Mapping[str, str | Sequence[str]]", filters), - regex_search=regex_search, + filters.without_bools, + regex_search=False, ) + if not filters.post_exclusions: + return result + + return filter_list(result, filters.post_exclusions, regex_search=True) def _parse_bids_path(path: str, entities: Iterable[str]) -> tuple[str, dict[str, str]]: @@ -630,46 +760,24 @@ def _parse_bids_path(path: str, entities: Iterable[str]) -> tuple[str, dict[str, return "".join(new_path), wildcard_values -@attrs.define -class _GetListsFromBidsSteps: - input_name: str - - FilterType: TypeAlias = "Mapping[str, str | bool | Sequence[str | bool]]" - - def _get_invalid_filters(self, filters: FilterType): - for key, filts in filters.items(): - try: - if True in filts or False in filts: # type: ignore - yield key, filts - except TypeError: - pass - - def prepare_bids_filters( - self, filters: Mapping[str, str | bool | Sequence[str | bool]] - ) -> dict[str, str | Query | list[str | Query]]: - return { - key: [ - Query.ANY if f is True else Query.NONE if f is False else f - for f in itx.always_iterable(filts) - ] - for key, filts in filters.items() - } - - def get_matching_files( - self, - bids_layout: BIDSLayout, - regex_search: bool, - filters: Mapping[str, str | Query | Sequence[str | Query]], - ) -> Iterable[BIDSFile]: - try: - return bids_layout.get(regex_search=regex_search, **filters) - except AttributeError as err: - msg = ( - "Pybids has encountered a problem that Snakebids cannot handle. This " - "may indicate a missing or invalid dataset_description.json for this " - "dataset." - ) - raise PybidsError(msg) from err +def _get_matching_files( + bids_layout: BIDSLayout, + filters: _UnifiedFilter, +) -> Iterable[BIDSFile]: + if filters.has_empty_prefilter: + return [] + try: + return bids_layout.get( + regex_search=False, + **filters.filters, + ) + except AttributeError as err: + msg = ( + "Pybids has encountered a problem that Snakebids cannot handle. This " + "may indicate a missing or invalid dataset_description.json for this " + "dataset." + ) + raise PybidsError(msg) from err def _get_lists_from_bids( @@ -677,8 +785,7 @@ def _get_lists_from_bids( pybids_inputs: InputsConfig, *, limit_to: Iterable[str] | None = None, - regex_search: bool = False, - **filters: str | Sequence[str], + postfilters: _Postfilter, ) -> Generator[BidsComponent, None, None]: """Grabs files using pybids and creates snakemake-friendly lists. @@ -704,10 +811,11 @@ def _get_lists_from_bids( One BidsComponent is yielded for each modality described by ``pybids_inputs``. """ for input_name in limit_to or list(pybids_inputs): - steps = _GetListsFromBidsSteps(input_name) _logger.debug("Grabbing inputs for %s...", input_name) component = pybids_inputs[input_name] + filters = _UnifiedFilter(component, postfilters or {}) + if "custom_path" in component: # a custom path was specified for this input, skip pybids: # get input_wildcards by parsing path for {} entries (using a set @@ -716,12 +824,7 @@ def _get_lists_from_bids( # to deal with multiple wildcards path = component["custom_path"] - zip_lists = _parse_custom_path( - path, - regex_search=regex_search, - **pybids_inputs[input_name].get("filters", {}), - **filters, - ) + zip_lists = _parse_custom_path(path, filters=filters) yield BidsComponent(name=input_name, path=path, zip_lists=zip_lists) continue @@ -734,10 +837,7 @@ def _get_lists_from_bids( zip_lists: dict[str, list[str]] = defaultdict(list) paths: set[str] = set() - pybids_filters = steps.prepare_bids_filters(component.get("filters", {})) - matching_files = steps.get_matching_files( - bids_layout, regex_search, {**pybids_filters, **filters} - ) + matching_files = _get_matching_files(bids_layout, filters) for img in matching_files: wildcards: list[str] = [ @@ -805,7 +905,15 @@ def _get_lists_from_bids( ) raise ConfigError(msg) from err - yield BidsComponent(name=input_name, path=path, zip_lists=zip_lists) + if filters.has_empty_postfilter: + yield BidsComponent( + name=input_name, path=path, zip_lists={key: [] for key in zip_lists} + ) + continue + + yield BidsComponent(name=input_name, path=path, zip_lists=zip_lists).filter( + regex_search=True, **filters.post_exclusions + ) def get_wildcard_constraints(image_types: InputsConfig) -> dict[str, str]: diff --git a/snakebids/tests/helpers.py b/snakebids/tests/helpers.py index 294142fb..c7996464 100644 --- a/snakebids/tests/helpers.py +++ b/snakebids/tests/helpers.py @@ -209,7 +209,11 @@ def create_snakebids_config(dataset: BidsDataset) -> InputsConfig: def reindex_dataset( - root: str, dataset: BidsDataset, use_custom_paths: bool = False + root: str, + dataset: BidsDataset, + use_custom_paths: bool = False, + participant_label: str | Sequence[str] | None = None, + exclude_participant_label: str | Sequence[str] | None = None, ) -> BidsDataset: """Create BidsDataset on the filesystem and reindex @@ -217,10 +221,20 @@ def reindex_dataset( """ create_dataset(Path("/"), dataset) config = create_snakebids_config(dataset) + if participant_label is not None or exclude_participant_label is not None: + for comp in config.values(): + if "subject" in comp.get("filters", {}): + del comp["filters"]["subject"] # type: ignore + if use_custom_paths: for comp in config: config[comp]["custom_path"] = dataset[comp].path - return generate_inputs(root, config) + return generate_inputs( + root, + config, + participant_label=participant_label, + exclude_participant_label=exclude_participant_label, + ) def allow_function_scoped(func: _T, /) -> _T: diff --git a/snakebids/tests/test_generate_inputs.py b/snakebids/tests/test_generate_inputs.py index 2008dd65..2fd31375 100644 --- a/snakebids/tests/test_generate_inputs.py +++ b/snakebids/tests/test_generate_inputs.py @@ -13,7 +13,7 @@ import warnings from collections import defaultdict from pathlib import Path -from typing import Any, Callable, Iterable, NamedTuple, TypeVar, cast +from typing import Iterable, Literal, NamedTuple, TypedDict, TypeVar, cast import attrs import more_itertools as itx @@ -23,25 +23,29 @@ from hypothesis import strategies as st from pyfakefs.fake_filesystem import FakeFilesystem from pytest_mock import MockerFixture +from snakemake.io import expand as sb_expand from snakebids.core.datasets import BidsComponent, BidsDataset from snakebids.core.input_generation import ( _all_custom_paths, _gen_bids_layout, - _generate_filters, _get_lists_from_bids, _normalize_database_args, _parse_bids_path, _parse_custom_path, + _Postfilter, + _UnifiedFilter, generate_inputs, ) -from snakebids.exceptions import ConfigError, PybidsError, RunError +from snakebids.exceptions import ConfigError, RunError from snakebids.paths.presets import bids from snakebids.tests import strategies as sb_st from snakebids.tests.helpers import ( + Benchmark, BidsListCompare, allow_function_scoped, create_dataset, + create_snakebids_config, example_if, get_bids_path, get_zip_list, @@ -535,7 +539,7 @@ def test_missing_wildcards(self, tmpdir: Path): assert config.subj_wildcards == {"subject": "{subject}"} -class TestGenerateFilter: +class TestPostfilter: valid_chars = st.characters(blacklist_characters=["\n"]) st_lists_or_text = st.lists(st.text(valid_chars)) | st.text(valid_chars) @@ -543,52 +547,56 @@ class TestGenerateFilter: def test_throws_error_if_labels_and_excludes_are_given( self, args: tuple[list[str] | str, list[str] | str] ): + filters = _Postfilter() with pytest.raises( ValueError, - match="Cannot define both participant_label and " - "exclude_participant_label at the same time", + match="Cannot define both participant_label and exclude_participant_label ", ): - _generate_filters(*args) + filters.add_filter("foo", *args) - @given(st_lists_or_text) - def test_returns_participant_label_as_list(self, label: list[str] | str): - result = _generate_filters(label)[0] + @given(st.text(), st_lists_or_text) + def test_returns_participant_label_as_dict(self, key: str, label: list[str] | str): + filters = _Postfilter() + filters.add_filter(key, label, None) if isinstance(label, str): - assert result == [label] + assert filters.inclusions == {key: [label]} else: - assert result == label + assert filters.inclusions == {key: label} + assert filters.exclusions == {} @given( + st.text(), st_lists_or_text, st.lists(st.text(valid_chars, min_size=1), min_size=1), st.text(valid_chars, min_size=1, max_size=3), ) def test_exclude_gives_regex_that_matches_anything_except_exclude( - self, excluded: list[str] | str, dummy_values: list[str], padding: str + self, key: str, excluded: list[str] | str, dummy_values: list[str], padding: str ): + filters = _Postfilter() # Make sure the dummy_values and padding we'll be testing against are different # from our test values for value in dummy_values: assume(value not in itx.always_iterable(excluded)) assume(padding not in itx.always_iterable(excluded)) - result = _generate_filters(exclude=excluded) - assert result[1] is True - assert isinstance(result[0], list) - assert len(result[0]) == 1 + filters.add_filter(key, None, excluded) + assert isinstance(filters.exclusions[key], list) + assert len(filters.exclusions[key]) == 1 # We match any value that isn't the exclude string for value in dummy_values: - assert re.match(result[0][0], value) + assert re.match(filters.exclusions[key][0], value) for exclude in itx.always_iterable(excluded): # We don't match the exclude string - assert re.match(result[0][0], exclude) is None + assert re.match(filters.exclusions[key][0], exclude) is None # Addition of random strings before and/or after lets the match occur again - assert re.match(result[0][0], padding + exclude) - assert re.match(result[0][0], exclude + padding) - assert re.match(result[0][0], padding + exclude + padding) + assert re.match(filters.exclusions[key][0], padding + exclude) + assert re.match(filters.exclusions[key][0], exclude + padding) + assert re.match(filters.exclusions[key][0], padding + exclude + padding) + assert filters.inclusions == {} class PathEntities(NamedTuple): @@ -685,13 +693,11 @@ def generate_test_directory( T = TypeVar("T") - def test_benchmark_test_custom_paths( - self, benchmark: Callable[[Callable[..., Any], Any], Any], tmp_path: Path - ): + def test_benchmark_test_custom_paths(self, benchmark: Benchmark, tmp_path: Path): entities = {"A": ["A", "B", "C"], "B": ["1", "2", "3"]} template = Path("{A}/A-{A}_B-{B}") test_path = self.generate_test_directory(entities, template, tmp_path) - benchmark(_parse_custom_path, test_path) + benchmark(_parse_custom_path, test_path, _UnifiedFilter.from_filter_dict({})) @allow_function_scoped @given(path_entities=path_entities()) @@ -704,7 +710,7 @@ def test_collects_all_paths_when_no_filters( test_path = self.generate_test_directory(entities, template, temp_dir) # Test without any filters - result = _parse_custom_path(test_path) + result = _parse_custom_path(test_path, _UnifiedFilter.from_filter_dict({})) zip_lists = get_zip_list(entities, it.product(*entities.values())) assert BidsComponent( name="foo", path=get_bids_path(zip_lists), zip_lists=zip_lists @@ -724,7 +730,7 @@ def test_collects_only_filtered_entities( # Test with filters result_filtered = MultiSelectDict( - _parse_custom_path(test_path, regex_search=False, **filters) + _parse_custom_path(test_path, _UnifiedFilter.from_filter_dict(filters)) ) zip_lists = MultiSelectDict( { @@ -750,14 +756,13 @@ def test_collect_all_but_filters_when_exclusion_filters_used( entities, template, filters = path_entities test_path = self.generate_test_directory(entities, template, temp_dir) # Test with exclusion filters - exclude_filters = { - # We use _generate_filter to get our exclusion regex. This function was - # tested previously - key: _generate_filters(exclude=values)[0] - for key, values in filters.items() - } + exclude_filters = _Postfilter() + for key, values in filters.items(): + exclude_filters.add_filter(key, None, values) result_excluded = MultiSelectDict( - _parse_custom_path(test_path, regex_search=True, **exclude_filters) + _parse_custom_path( + test_path, _UnifiedFilter.from_filter_dict({}, exclude_filters) + ) ) entities_excluded = { @@ -1083,26 +1088,6 @@ def test_t1w_with_dict(): assert config["subj_wildcards"] == {"subject": "{subject}"} -@pytest.mark.skipif( - sys.version_info >= (3, 8), - reason=""" - Bug only surfaces on python 3.7 because higher python versions have access to the - latest pybids version - """, -) -def test_get_lists_from_bids_raises_pybids_error(): - """Test that we wrap a cryptic AttributeError from pybids with PybidsError. - - Pybids raises an AttributeError when a BIDSLayout.get is called with scope not - equal to 'all' on a layout that indexes a dataset without a - dataset_description.json. We wrap this error with something a bit less cryptic, so - this test ensures that that behaviour is still present. - """ - layout = BIDSLayout("snakebids/tests/data/bids_t1w", validate=False) - with pytest.raises(PybidsError): - next(_get_lists_from_bids(layout, {"t1": {"filters": {"scope": "raw"}}})) - - def test_get_lists_from_bids_raises_run_error(): bids_layout = None pybids_inputs: InputsConfig = { @@ -1112,7 +1097,9 @@ def test_get_lists_from_bids_raises_run_error(): } } with pytest.raises(RunError): - next(_get_lists_from_bids(bids_layout, pybids_inputs)) + next( + _get_lists_from_bids(bids_layout, pybids_inputs, postfilters=_Postfilter()) + ) def test_get_lists_from_bids(): @@ -1151,7 +1138,7 @@ def test_get_lists_from_bids(): pybids_inputs["t1"]["custom_path"] = wildcard_path_t1 pybids_inputs["t2"]["custom_path"] = wildcard_path_t2 - result = _get_lists_from_bids(layout, pybids_inputs) + result = _get_lists_from_bids(layout, pybids_inputs, postfilters=_Postfilter()) for bids_lists in result: if bids_lists.input_name == "t1": template = BidsComponent( @@ -1267,6 +1254,217 @@ def test_generate_inputs(dataset: BidsDataset, bids_fs: Path, fakefs_tmpdir: Pat assert reindexed.layout is not None +@st.composite +def dataset_with_subject(draw: st.DrawFn): + entities = draw(sb_st.bids_entity_lists(blacklist_entities=["subject"])) + entities += ["subject"] + return BidsDataset.from_iterable( + [ + draw( + sb_st.bids_components( + whitelist_entities=entities, + min_entities=len(entities), + max_entities=len(entities), + restrict_patterns=True, + unique=True, + ) + ) + ] + ) + + +class TestParticipantFiltering: + MODE = Literal["include", "exclude"] + + @pytest.fixture + def tmpdir(self, bids_fs: Path, fakefs_tmpdir: Path): + return fakefs_tmpdir + + def get_filter_params(self, mode: MODE, filters: list[str] | str): + class FiltParams(TypedDict, total=False): + participant_label: list[str] | str + exclude_participant_label: list[str] | str + + if mode == "include": + return FiltParams({"participant_label": filters}) + if mode == "exclude": + return FiltParams({"exclude_participant_label": filters}) + msg = f"Invalid mode specification: {mode}" + raise ValueError(msg) + + @given( + data=st.data(), + dataset=dataset_with_subject(), + ) + @settings( + deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture] + ) + def test_participant_label_filters_comps_with_subject( + self, data: st.DataObject, dataset: BidsDataset, tmpdir: Path + ): + root = tempfile.mkdtemp(dir=tmpdir) + rooted = BidsDataset.from_iterable( + attrs.evolve(comp, path=os.path.join(root, comp.path)) + for comp in dataset.values() + ) + sampler = st.sampled_from(itx.first(rooted.values()).entities["subject"]) + label = data.draw(st.lists(sampler, unique=True) | sampler) + reindexed = reindex_dataset(root, rooted, participant_label=label) + assert set(itx.first(reindexed.values()).entities["subject"]) == set( + itx.always_iterable(label) + ) + + @given( + data=st.data(), + dataset=dataset_with_subject(), + ) + @settings( + deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture] + ) + def test_exclude_participant_label_filters_comp_with_subject( + self, data: st.DataObject, dataset: BidsDataset, tmpdir: Path + ): + root = tempfile.mkdtemp(dir=tmpdir) + rooted = BidsDataset.from_iterable( + attrs.evolve(comp, path=os.path.join(root, comp.path)) + for comp in dataset.values() + ) + sampler = st.sampled_from(itx.first(rooted.values()).entities["subject"]) + label = data.draw(st.lists(sampler, unique=True) | sampler) + reindexed = reindex_dataset(root, rooted, exclude_participant_label=label) + reindexed_subjects = set(itx.first(reindexed.values()).entities["subject"]) + original_subjects = set(itx.first(rooted.values()).entities["subject"]) + assert reindexed_subjects == original_subjects - set(itx.always_iterable(label)) + + @pytest.mark.parametrize("mode", ["include", "exclude"]) + @given( + dataset=sb_st.datasets_one_comp(blacklist_entities=["subject"], unique=True), + participant_filter=st.lists(st.text(min_size=1)) | st.text(min_size=1), + ) + @settings( + deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture] + ) + def test_participant_label_doesnt_filter_comps_without_subject( + self, + mode: MODE, + dataset: BidsDataset, + participant_filter: list[str] | str, + tmpdir: Path, + ): + root = tempfile.mkdtemp(dir=tmpdir) + rooted = BidsDataset.from_iterable( + attrs.evolve(comp, path=os.path.join(root, comp.path)) + for comp in dataset.values() + ) + reindexed = reindex_dataset( + root, rooted, **self.get_filter_params(mode, participant_filter) + ) + assert reindexed == rooted + + @pytest.mark.parametrize("mode", ["include", "exclude"]) + @given( + dataset=dataset_with_subject(), + participant_filter=st.lists(st.text(min_size=1)) | st.text(min_size=1), + ) + @settings( + deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture] + ) + def test_participant_label_doesnt_filter_comps_when_subject_has_filter( + self, + mode: MODE, + dataset: BidsDataset, + participant_filter: list[str] | str, + tmpdir: Path, + ): + root = tempfile.mkdtemp(dir=tmpdir) + rooted = BidsDataset.from_iterable( + attrs.evolve(comp, path=os.path.join(root, comp.path)) + for comp in dataset.values() + ) + create_dataset(Path("/"), rooted) + reindexed = generate_inputs( + root, + create_snakebids_config(rooted), + **self.get_filter_params(mode, participant_filter), + ) + assert reindexed == rooted + + @pytest.mark.parametrize("mode", ["include", "exclude"]) + @given( + dataset=dataset_with_subject(), + participant_filter=st.lists(st.text(min_size=1)) | st.text(min_size=1), + data=st.data(), + ) + @settings( + deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture] + ) + def test_participant_label_doesnt_filter_comps_when_subject_has_filter_no_wcard( + self, + mode: MODE, + dataset: BidsDataset, + participant_filter: list[str] | str, + data: st.DataObject, + tmpdir: Path, + ): + root = tempfile.mkdtemp(dir=tmpdir) + rooted = BidsDataset.from_iterable( + attrs.evolve(comp, path=os.path.join(root, comp.path)) + for comp in dataset.values() + ) + subject = data.draw( + st.sampled_from(itx.first(rooted.values()).entities["subject"]) + ) + create_dataset(Path("/"), rooted) + config = create_snakebids_config(rooted) + for comp in config.values(): + comp["filters"] = dict(comp.get("filters", {})) + comp["filters"]["subject"] = subject + reindexed = generate_inputs( + root, + create_snakebids_config(rooted), + **self.get_filter_params(mode, participant_filter), + ) + assert reindexed == rooted + + @given( + data=st.data(), + dataset=dataset_with_subject().filter( + lambda ds: set(itx.first(ds.values()).wildcards) != {"subject", "extension"} + ), + ) + @settings( + deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture] + ) + def test_exclude_participant_does_not_make_all_other_filters_regex( + self, data: st.DataObject, dataset: BidsDataset, tmpdir: Path + ): + root = tempfile.mkdtemp(dir=tmpdir) + rooted = BidsDataset.from_iterable( + attrs.evolve(comp, path=os.path.join(root, comp.path)) + for comp in dataset.values() + ) + + # Create an extra set of paths by modifing one of the existing components to put + # foo after a set of entity values. If that filter gets changed to a regex, all + # of the suffixed decoys will get picked up by pybids + ziplist = dict(itx.first(rooted.values()).zip_lists) + mut_entity = itx.first( + filter(lambda e: e not in {"subject", "extension"}, ziplist) + ) + ziplist[mut_entity] = ["foo" + v for v in ziplist[mut_entity]] + for path in sb_expand(itx.first(rooted.values()).path, zip, **ziplist): + p = Path(path) + p.parent.mkdir(parents=True, exist_ok=True) + p.touch() + + sampler = st.sampled_from(itx.first(rooted.values()).entities["subject"]) + label = data.draw(st.lists(sampler, unique=True) | sampler) + reindexed = reindex_dataset(root, rooted, exclude_participant_label=label) + reindexed_subjects = set(itx.first(reindexed.values()).entities["subject"]) + original_subjects = set(itx.first(rooted.values()).entities["subject"]) + assert reindexed_subjects == original_subjects - set(itx.always_iterable(label)) + + # The content of the dataset is irrelevant to this test, so one example suffices # but can't use extension because the custom path won't glob properly @settings(max_examples=1, suppress_health_check=[HealthCheck.function_scoped_fixture]) diff --git a/snakebids/types.py b/snakebids/types.py index 72a0a7be..342f4fe1 100644 --- a/snakebids/types.py +++ b/snakebids/types.py @@ -23,7 +23,7 @@ class InputConfig(TypedDict, total=False): """Configuration for a single bids component.""" - filters: dict[str, str | bool | list[str | bool]] + filters: Mapping[str, str | bool | Sequence[str | bool]] """Filters to pass on to :class:`BIDSLayout.get() ` Each key refers to the name of an entity. Values may take the following forms: