Skip to content

Commit

Permalink
Update argument handling for --derivatives
Browse files Browse the repository at this point in the history
The initial implementation of bidsargs had --derivatives take the same
arguments as previously specified in the template app: action="+". This
made the argument alway look for one or more paths. This does not
reflect the pybids api, which additionally takes True and False

The fix adjusts this so that --derivatives defaults to False, if
provided by itself give True, and if provide with paths, gives a list of
Path().
  • Loading branch information
pvandyken committed May 9, 2024
1 parent 156f58f commit 1ead79d
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 6 deletions.
4 changes: 2 additions & 2 deletions snakebids/plugins/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class AddArgumentArgs(TypedDict, Generic[_T], total=False):
nargs: int | _NArgsStr | _SUPPRESS_T | None
const: Any
default: Any
type: Callable[[str], _T] | argparse.FileType
type: Callable[[Any], _T] | argparse.FileType
choices: Iterable[_T] | None
required: bool
help: str | None
Expand All @@ -49,7 +49,7 @@ class AddArgumentArgs(TypedDict, total=False):
nargs: int | _NArgsStr | _SUPPRESS_T | None
const: Any
default: Any
type: Callable[[str], Any] | argparse.FileType
type: Callable[[Any], Any] | argparse.FileType
choices: Iterable[Any] | None
required: bool
help: str | None
Expand Down
25 changes: 23 additions & 2 deletions snakebids/plugins/bidsargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import argparse
from pathlib import Path
from typing import Any, Iterable
from typing import Any, Iterable, Sequence

import attrs
import more_itertools as itx
Expand All @@ -17,6 +17,25 @@ def _list_or_none(arg: str | Iterable[str] | None) -> list[str] | None:
return arg if arg is None else list(itx.always_iterable(arg))


class _Derivative(argparse.Action):
def __call__(
self,
parser: argparse.ArgumentParser,
namespace: argparse.Namespace,
values: str | Sequence[Any] | None,
option_string: str | None = None,
):
if values is None: # pragma: no cover
result = False
elif isinstance(values, str):
result = [Path(values)]

Check warning on line 31 in snakebids/plugins/bidsargs.py

View check run for this annotation

Codecov / codecov/patch

snakebids/plugins/bidsargs.py#L31

Added line #L31 was not covered by tests
elif not len(values):
result = True
else:
result = [Path(p) for p in values]
setattr(namespace, self.dest, result)


@attrs.define
class BidsArgs(PluginBase):
"""Add basic BIDSApp arguments.
Expand Down Expand Up @@ -182,7 +201,9 @@ def add_cli_arguments(
"--derivatives",
help="Path(s) to a derivatives dataset, for folder(s) that contains "
"multiple derivatives datasets",
nargs="+",
nargs="*",
dest="derivatives",
metavar="PATH",
action=_Derivative,
default=False,
)
7 changes: 5 additions & 2 deletions snakebids/tests/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
alphanum = ascii_letters + digits
valid_entities: tuple[str, ...] = tuple(BidsConfig.load("bids").entities.keys())

path_characters = st.characters(blacklist_characters=["/", "\x00"], codec="UTF-8")


def nothing() -> Any:
return st.nothing() # type: ignore
Expand All @@ -48,9 +50,10 @@ def paths(
absolute: bool | None = None,
resolve: bool = False,
) -> st.SearchStrategy[Path]:
valid_chars = st.characters(blacklist_characters=["/", "\x00"], codec="UTF-8")
paths = st.lists(
st.text(valid_chars, min_size=1), min_size=min_segments, max_size=max_segments
st.text(path_characters, min_size=1),
min_size=min_segments,
max_size=max_segments,
).map(lambda x: Path(*x))

relative_paths = paths.filter(lambda p: not p.is_absolute())
Expand Down
31 changes: 31 additions & 0 deletions snakebids/tests/test_plugins/test_bidsargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import itertools as it
from argparse import ArgumentParser
from pathlib import Path
from typing import Iterable

import pytest
Expand All @@ -10,6 +11,7 @@

from snakebids.exceptions import ConfigError
from snakebids.plugins.bidsargs import BidsArgs
from snakebids.tests import strategies as sb_st


class _St:
Expand Down Expand Up @@ -65,3 +67,32 @@ def test_analysis_levels_can_be_defined_in_config(
for arg in choices:
nspc = parser.parse_args(["...", "...", arg])
assert nspc.analysis_level == arg

@given(
derivatives=st.lists(
st.text(sb_st.path_characters).filter(lambda s: not s.startswith("-")),
min_size=1,
)
)
def test_derivatives_converted_to_paths(self, derivatives: list[str]):
parser = ArgumentParser()
bidsargs = BidsArgs()
bidsargs.add_cli_arguments(parser, {}, {})
nspc = parser.parse_args(
["...", "...", "participant", "--derivatives", *derivatives]
)
assert nspc.derivatives == [Path(p) for p in derivatives]

def test_derivatives_true_if_no_paths_given(self):
parser = ArgumentParser()
bidsargs = BidsArgs()
bidsargs.add_cli_arguments(parser, {}, {})
nspc = parser.parse_args(["...", "...", "participant", "--derivatives"])
assert nspc.derivatives is True

def test_derivatives_false_by_default(self):
parser = ArgumentParser()
bidsargs = BidsArgs()
bidsargs.add_cli_arguments(parser, {}, {})
nspc = parser.parse_args(["...", "...", "participant"])
assert nspc.derivatives is False

0 comments on commit 1ead79d

Please sign in to comment.