diff --git a/snakebids/core/_querying.py b/snakebids/core/_querying.py index 4561a464..da5daf64 100644 --- a/snakebids/core/_querying.py +++ b/snakebids/core/_querying.py @@ -57,12 +57,6 @@ def add_filter( ValueError Raised if both include and exclude values are stipulated. """ - if inclusions is not None and exclusions is not None: - msg = ( - "Cannot define both participant_label and exclude_participant_label at " - "the same time" - ) - raise ValueError(msg) if inclusions is not None: self.inclusions[key] = list(itx.always_iterable(inclusions)) if exclusions is not None: diff --git a/snakebids/tests/test_generate_inputs.py b/snakebids/tests/test_generate_inputs.py index d09c3b3c..6d1faf7d 100644 --- a/snakebids/tests/test_generate_inputs.py +++ b/snakebids/tests/test_generate_inputs.py @@ -899,17 +899,6 @@ class TestPostfilter: valid_chars = st.characters(blacklist_characters=["\n"]) st_lists_or_text = st.lists(st.text(valid_chars)) | st.text(valid_chars) - @given(st.tuples(st_lists_or_text, st_lists_or_text)) - def test_throws_error_if_labels_and_excludes_are_given( - self, args: tuple[list[str] | str, list[str] | str] - ): - filters = PostFilter() - with pytest.raises( - ValueError, - match="Cannot define both participant_label and exclude_participant_label ", - ): - filters.add_filter("foo", *args) - @given(st.text(), st_lists_or_text) def test_returns_participant_label_as_dict(self, key: str, label: list[str] | str): filters = PostFilter() @@ -1278,20 +1267,6 @@ def test_t1w(): } } - # Can't define particpant_label and exclude_participant_label - with pytest.raises( - ValueError, - match="Cannot define both participant_label and " - "exclude_participant_label at the same time", - ): - result = generate_inputs( - pybids_inputs=pybids_inputs, - bids_dir=real_bids_dir, - derivatives=derivatives, - participant_label="001", - exclude_participant_label="002", - ) - # Simplest case -- one input type, using pybids result = generate_inputs( pybids_inputs=pybids_inputs, @@ -1707,7 +1682,7 @@ class FiltParams(TypedDict, total=False): @settings( deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture] ) - def test_participant_label_filters_comps_with_subject( + def test_exclude_and_participant_label_filter_correctly( self, data: st.DataObject, dataset: BidsDataset, tmpdir: Path ): root = tempfile.mkdtemp(dir=tmpdir) @@ -1716,33 +1691,19 @@ def test_participant_label_filters_comps_with_subject( for comp in dataset.values() ) sampler = st.sampled_from(itx.first(rooted.values()).entities["subject"]) - label = data.draw(st.lists(sampler, unique=True) | sampler) - reindexed = reindex_dataset(root, rooted, participant_label=label) - assert set(itx.first(reindexed.values()).entities["subject"]) == set( - itx.always_iterable(label) - ) - - @given( - data=st.data(), - dataset=dataset_with_subject(), - ) - @settings( - deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture] - ) - def test_exclude_participant_label_filters_comp_with_subject( - self, data: st.DataObject, dataset: BidsDataset, tmpdir: Path - ): - root = tempfile.mkdtemp(dir=tmpdir) - rooted = BidsDataset.from_iterable( - attrs.evolve(comp, path=os.path.join(root, comp.path)) - for comp in dataset.values() + excluded = data.draw(st.lists(sampler, unique=True) | sampler | st.none()) + included = data.draw(st.lists(sampler, unique=True) | sampler | st.none()) + reindexed = reindex_dataset( + root, rooted, exclude_participant_label=excluded, participant_label=included ) - sampler = st.sampled_from(itx.first(rooted.values()).entities["subject"]) - label = data.draw(st.lists(sampler, unique=True) | sampler) - reindexed = reindex_dataset(root, rooted, exclude_participant_label=label) reindexed_subjects = set(itx.first(reindexed.values()).entities["subject"]) - original_subjects = set(itx.first(rooted.values()).entities["subject"]) - assert reindexed_subjects == original_subjects - set(itx.always_iterable(label)) + expected_subjects = set(itx.first(rooted.values()).entities["subject"]) + if included is not None: + expected_subjects &= set(itx.always_iterable(included)) + if excluded is not None: + expected_subjects -= set(itx.always_iterable(excluded)) + + assert reindexed_subjects == expected_subjects @pytest.mark.parametrize("mode", ["include", "exclude"]) @given(