diff --git a/snakebids/plugins/base.py b/snakebids/plugins/base.py index fabd20ff..44495559 100644 --- a/snakebids/plugins/base.py +++ b/snakebids/plugins/base.py @@ -30,7 +30,7 @@ class AddArgumentArgs(TypedDict, Generic[_T], total=False): nargs: int | _NArgsStr | _SUPPRESS_T | None const: Any default: Any - type: Callable[[str], _T] | argparse.FileType + type: Callable[[Any], _T] | argparse.FileType choices: Iterable[_T] | None required: bool help: str | None @@ -49,7 +49,7 @@ class AddArgumentArgs(TypedDict, total=False): nargs: int | _NArgsStr | _SUPPRESS_T | None const: Any default: Any - type: Callable[[str], Any] | argparse.FileType + type: Callable[[Any], Any] | argparse.FileType choices: Iterable[Any] | None required: bool help: str | None diff --git a/snakebids/plugins/bidsargs.py b/snakebids/plugins/bidsargs.py index f54fa00c..275b99d8 100644 --- a/snakebids/plugins/bidsargs.py +++ b/snakebids/plugins/bidsargs.py @@ -2,7 +2,7 @@ import argparse from pathlib import Path -from typing import Any, Iterable +from typing import Any, Iterable, Sequence import attrs import more_itertools as itx @@ -17,6 +17,25 @@ def _list_or_none(arg: str | Iterable[str] | None) -> list[str] | None: return arg if arg is None else list(itx.always_iterable(arg)) +class _Derivative(argparse.Action): + def __call__( + self, + parser: argparse.ArgumentParser, + namespace: argparse.Namespace, + values: str | Sequence[Any] | None, + option_string: str | None = None, + ): + if values is None: # pragma: no cover + result = False + elif isinstance(values, str): + result = [Path(values)] + elif not len(values): + result = True + else: + result = [Path(p) for p in values] + setattr(namespace, self.dest, result) + + @attrs.define class BidsArgs(PluginBase): """Add basic BIDSApp arguments. @@ -182,7 +201,9 @@ def add_cli_arguments( "--derivatives", help="Path(s) to a derivatives dataset, for folder(s) that contains " "multiple derivatives datasets", - nargs="+", + nargs="*", dest="derivatives", metavar="PATH", + action=_Derivative, + default=False, ) diff --git a/snakebids/tests/strategies.py b/snakebids/tests/strategies.py index c75678b2..0588e742 100644 --- a/snakebids/tests/strategies.py +++ b/snakebids/tests/strategies.py @@ -36,6 +36,8 @@ alphanum = ascii_letters + digits valid_entities: tuple[str, ...] = tuple(BidsConfig.load("bids").entities.keys()) +path_characters = st.characters(blacklist_characters=["/", "\x00"], codec="UTF-8") + def nothing() -> Any: return st.nothing() # type: ignore @@ -48,9 +50,10 @@ def paths( absolute: bool | None = None, resolve: bool = False, ) -> st.SearchStrategy[Path]: - valid_chars = st.characters(blacklist_characters=["/", "\x00"], codec="UTF-8") paths = st.lists( - st.text(valid_chars, min_size=1), min_size=min_segments, max_size=max_segments + st.text(path_characters, min_size=1), + min_size=min_segments, + max_size=max_segments, ).map(lambda x: Path(*x)) relative_paths = paths.filter(lambda p: not p.is_absolute()) diff --git a/snakebids/tests/test_plugins/test_bidsargs.py b/snakebids/tests/test_plugins/test_bidsargs.py index 3c79f5c4..38e2aa85 100644 --- a/snakebids/tests/test_plugins/test_bidsargs.py +++ b/snakebids/tests/test_plugins/test_bidsargs.py @@ -2,6 +2,7 @@ import itertools as it from argparse import ArgumentParser +from pathlib import Path from typing import Iterable import pytest @@ -10,6 +11,7 @@ from snakebids.exceptions import ConfigError from snakebids.plugins.bidsargs import BidsArgs +from snakebids.tests import strategies as sb_st class _St: @@ -65,3 +67,32 @@ def test_analysis_levels_can_be_defined_in_config( for arg in choices: nspc = parser.parse_args(["...", "...", arg]) assert nspc.analysis_level == arg + + @given( + derivatives=st.lists( + st.text(sb_st.path_characters).filter(lambda s: not s.startswith("-")), + min_size=1, + ) + ) + def test_derivatives_converted_to_paths(self, derivatives: list[str]): + parser = ArgumentParser() + bidsargs = BidsArgs() + bidsargs.add_cli_arguments(parser, {}, {}) + nspc = parser.parse_args( + ["...", "...", "participant", "--derivatives", *derivatives] + ) + assert nspc.derivatives == [Path(p) for p in derivatives] + + def test_derivatives_true_if_no_paths_given(self): + parser = ArgumentParser() + bidsargs = BidsArgs() + bidsargs.add_cli_arguments(parser, {}, {}) + nspc = parser.parse_args(["...", "...", "participant", "--derivatives"]) + assert nspc.derivatives is True + + def test_derivatives_false_by_default(self): + parser = ArgumentParser() + bidsargs = BidsArgs() + bidsargs.add_cli_arguments(parser, {}, {}) + nspc = parser.parse_args(["...", "...", "participant"]) + assert nspc.derivatives is False