Skip to content

Commit

Permalink
Refactor bids() and add test suite
Browse files Browse the repository at this point in the history
Bids() is refactored to use the new bids specs, paving the way for
future customizable bids functions, plus speeds it up 2-3 fold. Also
adds a brand new test suite.
  • Loading branch information
pvandyken committed Aug 24, 2023
1 parent 98d9b28 commit 8d3519e
Show file tree
Hide file tree
Showing 9 changed files with 510 additions and 197 deletions.
219 changes: 120 additions & 99 deletions poetry.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ scipy = [
{ version = ">=1.10.0", python = ">=3.9" }
]


[tool.poetry.group.dev.dependencies]
black = "^23.1.0"
pytest = "^7.0.0"
Expand All @@ -82,6 +81,7 @@ pyparsing = "^3.0.9"
# calls
pyright = ">=1.1.307,<1.1.312"
ruff = "^0.0.280"
pathvalidate = "^3.0.0"

[tool.poetry.scripts]
snakebids = "snakebids.admin:main"
Expand All @@ -93,7 +93,7 @@ build-backend = "poetry_dynamic_versioning.backend"
[tool.poe.tasks]
setup = "pre-commit install"
quality = { shell = "isort snakebids && black snakebids && ruff snakebids && pyright" }
fix = { shell = "isort snakebids && black snakebids && ruff --fix snakebids"}
fix = { shell = "ruff --fix snakebids && isort snakebids && black snakebids"}
test = """
pytest --doctest-modules --ignore=docs \
--ignore=snakebids/project_template --benchmark-disable
Expand Down
153 changes: 80 additions & 73 deletions snakebids/paths/presets.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,58 @@
"""Utilities for converting Snakemake apps to BIDS apps."""
from __future__ import annotations

from collections import OrderedDict
import functools as ft
import itertools as it
import os
import sys
from pathlib import Path

import more_itertools as itx

from snakebids.paths.specs import v0_0_0


@ft.lru_cache
def _parse_spec(include_subject_dir: bool, include_session_dir: bool):
spec = v0_0_0(subject_dir=include_subject_dir, session_dir=include_session_dir)
order: list[str] = []
dirs: set[str] = set()
aliases: dict[str, str] = {}

for entry in spec:
tag = entry.get("tag", entry["entity"])
order.append(tag)
aliases[entry["entity"]] = tag
if entry.get("dir"):
dirs.add(tag)

def parse_entities(entities: dict[str, str | bool]) -> dict[str, str]:
result: dict[str, str] = {}
for entity, val in entities.items():
# strip underscores from keys (needed so that users can use reserved
# keywords by appending a _)
stripped = entity.rstrip("_")
unaliased = aliases.get(stripped, stripped)
if unaliased in result:
aliased = itx.nth(aliases, list(aliases.values()).index(unaliased))
err = (
"Long and short names of an entity cannot be used in the same "
f"call to bids(): got '{aliased}' and '{unaliased}'"
)
raise ValueError(err)
result[unaliased] = str(val)
return result

return order, dirs, parse_entities


def bids(
root: str | Path | None = None,
datatype: str | None = None,
prefix: str | None = None,
suffix: str | None = None,
extension: str | None = None,
subject: str | None = None,
session: str | None = None,
include_subject_dir: bool = True,
include_session_dir: bool = True,
**entities: str,
**entities: str | bool,
) -> str:
"""Helper function for generating bids paths for snakemake workflows.
Expand Down Expand Up @@ -126,11 +163,9 @@ def bids(
* Some code adapted from mne-bids, specifically
https://mne.tools/mne-bids/stable/_modules/mne_bids/utils.html
"""
if not any([entities, suffix, subject, session, extension]) and any(
[datatype, prefix]
):
if not any([entities, suffix, extension]) and any([datatype, prefix]):
raise ValueError(
"At least one of subject, session, suffix, extension, or an entity must be "
"At least one of suffix, extension, or an entity must be "
"supplied.\n\tGot only: "
+ " and ".join(
filter(
Expand All @@ -143,69 +178,41 @@ def bids(
)
)

# replace underscores in keys (needed so that users can use reserved
# keywords by appending a _)
entities = {k.replace("_", ""): v for k, v in entities.items()}

# strict ordering of bids entities is specified here:
order: OrderedDict[str, str | None] = OrderedDict(
[
("task", None),
("acq", None),
("ce", None),
("rec", None),
("dir", None),
("run", None),
("mod", None),
("echo", None),
("hemi", None),
("space", None),
("res", None),
("den", None),
("label", None),
("desc", None),
]
)
include_subject_dir = bool(entities.pop("include_subject_dir", True))
include_session_dir = bool(entities.pop("include_session_dir", True))

# Now add in entities (this preserves ordering above)
for key, val in entities.items():
order[key] = val

# Form all entities for filename as a list, and join with "_". Any undefined
# entities will be `None` and will be filtered out.
filename: str = "_".join(
filter(
None,
[
# Put the prefix before anything else
prefix,
# Add in subject and session
f"sub-{subject}" if subject else None,
f"ses-{session}" if session else None,
# Iterate through all other entities and add as "key-value"
*(f"{key}-{val}" for key, val in order.items() if val is not None),
# Put the suffix last
suffix,
],
)
) + (extension or "")

# If all entities were `None`, the list will be empty and filename == ""
if not filename:
return ""

# Form folder using list similar to filename, above. Filter out Nones, and convert
# to Path.
folder = Path(
*filter(
None,
[
str(root) if root else None,
f"sub-{subject}" if subject and include_subject_dir else None,
f"ses-{session}" if session and include_session_dir else None,
datatype,
],
)
order, dirs, parse_entities = _parse_spec(
include_subject_dir=include_subject_dir, include_session_dir=include_session_dir
)
parsed = parse_entities(entities)

spec_parts: list[str] = []
custom_parts: list[str] = []
split: int = sys.maxsize + 1
path_parts: list[str] = []

if root:
path_parts.append(str(root))
if prefix:
spec_parts.append(prefix)
for entity in order:
# Check for `*` first so that if user specifies an entity called `*` we don't
# skip setting the split
if entity == "*":
split = len(path_parts)
elif value := parsed.pop(entity, None):
spec_parts.append(f"{entity}-{value}")
if entity in dirs:
path_parts.append(f"{entity}-{value}")
for key, value in parsed.items():
custom_parts.append(f"{key}-{value}")

if datatype:
path_parts.append(datatype)
path_parts.append(
"_".join(it.chain(spec_parts[:split], custom_parts, spec_parts[split:]))
)

tail = f"_{suffix}{extension or ''}" if suffix else extension or ""

return str(folder / filename)
return os.path.join(*path_parts) + tail
10 changes: 5 additions & 5 deletions snakebids/paths/resources/spec.0.0.0.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
- entity: sub
name: subject
- entity: "subject"
tag: "sub"
dir: true
- entity: ses
name: session
- entity: "session"
tag: "ses"
dir: true
- entity: "task"
- entity: "acq"
Expand All @@ -15,6 +15,6 @@
- entity: "hemi"
- entity: "space"
- entity: "res"
- entity: den
- entity: "den"
- entity: "label"
- entity: "desc"
6 changes: 3 additions & 3 deletions snakebids/paths/specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

class BidsPathEntitySpec(TypedDict):
entity: str
name: NotRequired[str]
tag: NotRequired[str]
dir: NotRequired[bool]


Expand Down Expand Up @@ -38,9 +38,9 @@ def v0_0_0(subject_dir: bool = True, session_dir: bool = True) -> BidsPathSpec:
"""
spec = yaml.safe_load(impr.files(resources).joinpath("spec.0.0.0.yaml").read_text())
if not subject_dir:
_find_entity(spec, "sub")["dir"] = False
_find_entity(spec, "subject")["dir"] = False

if not session_dir:
_find_entity(spec, "ses")["dir"] = False
_find_entity(spec, "session")["dir"] = False

return spec
23 changes: 22 additions & 1 deletion snakebids/tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Iterable,
List,
Mapping,
Protocol,
Sequence,
TypeVar,
)
Expand All @@ -29,6 +30,7 @@
from snakebids.utils.utils import BidsEntity, MultiSelectDict

_T = TypeVar("_T")
_T_contra = TypeVar("_T_contra", contravariant=True)


def get_zip_list(
Expand Down Expand Up @@ -93,7 +95,6 @@ def get_tag(entity: BidsEntity) -> tuple[str, str]:
return entity.wildcard, f"{{{entity.wildcard}}}"

return bids(
root=".",
**dict(get_tag(BidsEntity(entity)) for entity in sorted(entities)),
**dict(
(BidsEntity(entity).wildcard, value) for entity, value in extras.items()
Expand Down Expand Up @@ -298,3 +299,23 @@ def __contains__(self, x: object, /) -> bool:
if x in entry:
return True
return False


class Benchmark(Protocol):
def __call__(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
...


"""Comparison Dunders copied from typeshed"""


class SupportsDunderLT(Protocol[_T_contra]):
def __lt__(self, __other: _T_contra) -> bool:
...


def is_strictly_increasing(items: Iterable[SupportsDunderLT[Any]]) -> bool:
# itx.pairwise properly aliases it.pairwise on py310+
return all(i < j for i, j in itx.pairwise(items))
4 changes: 4 additions & 0 deletions snakebids/tests/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
valid_entities: tuple[str] = tuple(BidsConfig.load("bids").entities.keys())


def nothing() -> Any:
return st.nothing() # type: ignore


def bids_entity(
*,
blacklist_entities: Container[BidsEntity | str] | None = None,
Expand Down
Loading

0 comments on commit 8d3519e

Please sign in to comment.