From 1e00e3b2b98adaf6301628f8b0c3afbf9ffd96c8 Mon Sep 17 00:00:00 2001 From: Peter Van Dyken Date: Fri, 26 Jul 2024 10:22:51 -0400 Subject: [PATCH 1/2] Transition from zip-list to entry core The entry table uses entry-first strides. This will be a much more efficient storage method for future API that makes individual entries more accessible. As much of the core as is currently feasible is rewritten to use entries over zip-lists The public API is completely unchanged as of now. Tests for filtering are largely unchanged, although some may be redundant at this point. --- snakebids/core/_table.py | 122 +++++++++++++++ snakebids/core/datasets.py | 94 ++++++------ snakebids/core/input_generation.py | 131 ++++++++-------- snakebids/exceptions.py | 11 ++ snakebids/tests/helpers.py | 24 ++- snakebids/tests/strategies.py | 66 ++++++-- snakebids/tests/test_datasets.py | 132 ++++++---------- snakebids/tests/test_generate_inputs.py | 191 +++++++++++++----------- snakebids/tests/test_printing.py | 18 +-- snakebids/tests/test_tables.py | 79 ++++++++++ snakebids/utils/snakemake_io.py | 62 ++++++++ snakebids/utils/utils.py | 9 -- 12 files changed, 605 insertions(+), 334 deletions(-) create mode 100644 snakebids/core/_table.py create mode 100644 snakebids/tests/test_tables.py diff --git a/snakebids/core/_table.py b/snakebids/core/_table.py new file mode 100644 index 00000000..95677227 --- /dev/null +++ b/snakebids/core/_table.py @@ -0,0 +1,122 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Iterable, Mapping + +import attrs +import more_itertools as itx + +from snakebids.io.printing import format_zip_lists +from snakebids.types import ZipList, ZipListLike +from snakebids.utils.containers import ContainerBag, MultiSelectDict, RegexContainer + +if TYPE_CHECKING: + + def wcard_tuple(x: Iterable[str]) -> tuple[str, ...]: + return tuple(x) + + def entries_list(x: Iterable[tuple[str, ...]]) -> list[tuple[str, ...]]: + return list(x) + + def liststr() -> list[str]: + return [] +else: + wcard_tuple = tuple + entries_list = list + liststr = list + + +@attrs.frozen(kw_only=True) +class BidsTable: + """Container holding the entries of a BidsComponent.""" + + wildcards: tuple[str, ...] = attrs.field(converter=wcard_tuple) + entries: list[tuple[str, ...]] = attrs.field(converter=entries_list) + + def __bool__(self): + """Return True if one or more entries, otherwise False.""" + return bool(self.entries) + + def __eq__(self, other: BidsTable | object): + if not isinstance(other, self.__class__): + return False + if set(self.wildcards) != set(other.wildcards): + return False + if len(self.entries) != len(other.entries): + return False + if self.wildcards == other.wildcards: + return sorted(self.entries) == sorted(other.entries) + ixs = [other.wildcards.index(w) for w in self.wildcards] + entries = self.entries.copy() + try: + for entry in other.entries: + sorted_entry = tuple(entry[i] for i in ixs) + entries.remove(sorted_entry) + except ValueError: + return False + return True + + @classmethod + def from_dict(cls, d: ZipListLike): + """Construct BidsTable from a mapping of entities to value lists.""" + lengths = {len(val) for val in d.values()} + if len(lengths) > 1: + msg = "each entity must have the same number of values" + raise ValueError(msg) + return cls(wildcards=d.keys(), entries=zip(*d.values())) + + def to_dict(self) -> ZipList: + """Convert into a zip_list.""" + if not self.entries: + return MultiSelectDict(zip(self.wildcards, itx.repeatfunc(liststr))) + return MultiSelectDict(zip(self.wildcards, map(list, zip(*self.entries)))) + + def pformat(self, max_width: int | float | None = None, tabstop: int = 4) -> str: + """Pretty-format.""" + return format_zip_lists(self.to_dict(), max_width=max_width, tabstop=tabstop) + + def get(self, wildcard: str): + """Get values for a single wildcard.""" + index = self.wildcards.index(wildcard) + return [entry[index] for entry in self.entries] + + def pick(self, wildcards: Iterable[str]): + """Select wildcards without deduplication.""" + # Use dict.fromkeys for de-duplication to preserve order + indices = [self.wildcards.index(w) for w in dict.fromkeys(wildcards)] + + entries = [tuple(entry[i] for i in indices) for entry in self.entries] + + return self.__class__(wildcards=wildcards, entries=entries) + + def filter( + self, + filters: Mapping[str, Iterable[str] | str], + regex_search: bool = False, + ): + """Apply filtering operation.""" + valid_filters = set(self.wildcards) + if regex_search: + filter_sets = { + self.wildcards.index(key): ContainerBag( + *(RegexContainer(r) for r in itx.always_iterable(vals)) + ) + for key, vals in filters.items() + if key in valid_filters + } + else: + filter_sets = { + self.wildcards.index(key): set(itx.always_iterable(vals)) + for key, vals in filters.items() + if key in valid_filters + } + + keep = [ + entry + for entry in self.entries + if all( + i not in filter_sets or val in filter_sets[i] + for i, val in enumerate(entry) + ) + ] + + return self.__class__(wildcards=self.wildcards, entries=keep) diff --git a/snakebids/core/datasets.py b/snakebids/core/datasets.py index 42458641..9f386a31 100644 --- a/snakebids/core/datasets.py +++ b/snakebids/core/datasets.py @@ -7,7 +7,7 @@ from math import inf from pathlib import Path from string import Formatter -from typing import Any, Iterable, NoReturn, cast, overload +from typing import Any, Iterable, Mapping, NoReturn, cast, overload import attr import more_itertools as itx @@ -16,14 +16,14 @@ from typing_extensions import Self, TypedDict import snakebids.utils.sb_itertools as sb_it -from snakebids.core.filtering import filter_list +from snakebids.core._table import BidsTable from snakebids.exceptions import DuplicateComponentError from snakebids.io.console import get_console_size -from snakebids.io.printing import format_zip_lists, quote_wrap +from snakebids.io.printing import quote_wrap from snakebids.snakemake_compat import expand as sn_expand -from snakebids.types import ZipList +from snakebids.types import ZipList, ZipListLike from snakebids.utils.containers import ImmutableList, MultiSelectDict, UserDictPy38 -from snakebids.utils.utils import get_wildcard_dict, property_alias, zip_list_eq +from snakebids.utils.utils import get_wildcard_dict, property_alias class BidsDatasetDict(TypedDict): @@ -176,12 +176,21 @@ def filter( msg = "Both __spec and filters cannot be used simultaneously" raise ValueError(msg) filters = {self.entity: spec} - entity, data = itx.first( - filter_list( - {self.entity: self._data}, filters, regex_search=regex_search - ).items() + data = it.chain.from_iterable( + BidsTable(wildcards=[self.entity], entries=[(el,) for el in self._data]) + .filter(filters, regex_search=regex_search) + .entries ) - return self.__class__(data, entity=entity) + return self.__class__(data, entity=self.entity) + + +def _to_bids_table(tbl: BidsTable | ZipListLike) -> BidsTable: + if isinstance(tbl, BidsTable): + return tbl + if isinstance(tbl, Mapping): # type: ignore + return BidsTable.from_dict(tbl) + msg = f"Cannot convert '{tbl}' to BidsTable" + raise TypeError(msg) @attr.define(kw_only=True) @@ -208,17 +217,12 @@ class BidsPartialComponent: ``BidsPartialComponents`` are immutable: their values cannot be altered. """ - _zip_lists: ZipList = attr.field( - on_setattr=attr.setters.frozen, converter=MultiSelectDict, alias="zip_lists" + _table: BidsTable = attr.field( + converter=_to_bids_table, + on_setattr=attr.setters.frozen, + alias="table", ) - @_zip_lists.validator # type: ignore - def _validate_zip_lists(self, __attr: str, value: dict[str, list[str]]) -> None: - lengths = {len(val) for val in value.values()} - if len(lengths) > 1: - msg = "zip_lists must all be of equal length" - raise ValueError(msg) - def __repr__(self) -> str: return self.pformat() @@ -232,11 +236,8 @@ def __getitem__( self, key: str | tuple[str, ...], / ) -> BidsComponentRow | BidsPartialComponent: if isinstance(key, tuple): - # Use dict.fromkeys for de-duplication - return BidsPartialComponent( - zip_lists={key: self.zip_lists[key] for key in dict.fromkeys(key)} - ) - return BidsComponentRow(self.zip_lists[key], entity=key) + return BidsPartialComponent(table=self._table.pick(key)) + return BidsComponentRow(self._table.get(key), entity=key) def __bool__(self) -> bool: """Truth of a BidsComponent is based on whether it has values. @@ -247,7 +248,7 @@ def __bool__(self) -> bool: consistent with :class:`BidsComponentRow`, which always has an entity name stored, but may or may not have values. """ - return bool(itx.first(self.zip_lists)) + return bool(self._table.entries) def _pformat_body(self) -> None | str | list[str]: """Extra properties to be printed within pformat. @@ -271,8 +272,7 @@ def pformat(self, max_width: int | float | None = None, tabstop: int = 4) -> str body = it.chain( itx.always_iterable(self._pformat_body() or []), [ - "zip_lists=" - f"{format_zip_lists(self.zip_lists, width - tabstop, tabstop)},", + "table=" f"{self._table.pformat(width - tabstop, tabstop)},", ], ) output = [ @@ -292,6 +292,9 @@ def pformat(self, max_width: int | float | None = None, tabstop: int = 4) -> str _entities: list[str] | None = attr.field( default=None, init=False, eq=False, repr=False ) + _zip_lists: ZipList | None = attr.field( + default=None, init=False, eq=False, repr=False + ) @property def zip_lists(self): @@ -302,6 +305,9 @@ def zip_lists(self): of images matched for this modality, so they can be zipped together to get a list of the wildcard values for each file. """ + if self._zip_lists is not None: + return self._zip_lists + self._zip_lists = self._table.to_dict() return self._zip_lists @property @@ -328,15 +334,8 @@ def wildcards(self) -> MultiSelectDict[str, str]: self._input_wildcards = MultiSelectDict(get_wildcard_dict(self.zip_lists)) return self._input_wildcards - @property + @property_alias(zip_lists, "zip_lists", "snakebids.BidsPartialComponent.zip_lists") def input_zip_lists(self) -> ZipList: - """Alias of :attr:`zip_lists `. - - Dictionary where each key is a wildcard entity and each value is a list of the - values found for that entity. Each of these lists has length equal to the number - of images matched for this modality, so they can be zipped together to get a - list of the wildcard values for each file. - """ return self.zip_lists @property_alias(entities, "entities", "snakebids.BidsComponent.entities") @@ -351,7 +350,7 @@ def __eq__(self, other: BidsComponent | object) -> bool: if not isinstance(other, self.__class__): return False - return zip_list_eq(self.zip_lists, other.zip_lists) + return self._table == other._table def expand( self, @@ -458,7 +457,7 @@ def filter( return self return attr.evolve( self, - zip_lists=filter_list(self.zip_lists, filters, regex_search=regex_search), + table=self._table.filter(filters, regex_search=regex_search), ) @@ -497,27 +496,28 @@ class BidsComponent(BidsPartialComponent): BidsComponents are immutable: their values cannot be altered. """ + _table: BidsTable = attr.field( + converter=_to_bids_table, + on_setattr=attr.setters.frozen, + alias="table", + ) + name: str = attr.field(on_setattr=attr.setters.frozen) """Name of the component""" path: str = attr.field(on_setattr=attr.setters.frozen) """Wildcard-filled path that matches the files for this component.""" - _zip_lists: ZipList = attr.field( - on_setattr=attr.setters.frozen, converter=MultiSelectDict, alias="zip_lists" - ) - - @_zip_lists.validator # type: ignore - def _validate_zip_lists(self, __attr: str, value: dict[str, list[str]]) -> None: - super()._validate_zip_lists(__attr, value) + @_table.validator # type: ignore + def _validate_zip_lists(self, __attr: str, value: BidsTable) -> None: _, raw_fields, *_ = sb_it.unpack( zip(*Formatter().parse(self.path)), [[], [], []] ) raw_fields = cast("Iterable[str]", raw_fields) - if (fields := set(filter(None, raw_fields))) != set(value): + if (fields := set(filter(None, raw_fields))) != set(value.wildcards): msg = ( - "zip_lists entries must match the wildcards in input_path: " - f"{self.path}: {fields} != zip_lists: {set(value)}" + "entries have the same wildcards as the input path: " + f"{self.path}: {fields} != entries: {set(value.wildcards)}" ) raise ValueError(msg) diff --git a/snakebids/core/input_generation.py b/snakebids/core/input_generation.py index c3f8f60f..3652220d 100644 --- a/snakebids/core/input_generation.py +++ b/snakebids/core/input_generation.py @@ -7,12 +7,12 @@ import os import re import warnings -from collections import defaultdict from pathlib import Path from typing import ( Any, Iterable, Literal, + cast, overload, ) @@ -25,20 +25,20 @@ UnifiedFilter, get_matching_files, ) +from snakebids.core._table import BidsTable from snakebids.core.datasets import BidsComponent, BidsDataset, BidsDatasetDict -from snakebids.core.filtering import filter_list from snakebids.exceptions import ( + BidsParseError, ConfigError, DuplicateComponentError, RunError, ) from snakebids.snakemake_compat import Snakemake -from snakebids.types import InputConfig, InputsConfig, ZipList -from snakebids.utils.snakemake_io import glob_wildcards +from snakebids.types import InputConfig, InputsConfig +from snakebids.utils.snakemake_io import glob_wildcards_to_entries from snakebids.utils.utils import ( DEPRECATION_FLAG, BidsEntity, - BidsParseError, get_first_dir, ) @@ -579,8 +579,11 @@ def _get_component( if "custom_path" in component: path = component["custom_path"] - zip_lists = _parse_custom_path(path, filters=filters) - return BidsComponent(name=input_name, path=path, zip_lists=zip_lists) + return BidsComponent( + name=input_name, + path=path, + table=_parse_custom_path(path, filters=filters), + ) if bids_layout is None: msg = ( @@ -589,7 +592,8 @@ def _get_component( ) raise RunError(msg) - zip_lists: dict[str, list[str]] = defaultdict(list) + entries: list[tuple[str, ...]] = [] + wildcards: tuple[str, ...] = () paths: set[str] = set() try: matching_files = get_matching_files(bids_layout, filters) @@ -597,15 +601,15 @@ def _get_component( raise err.get_config_error(input_name) from err for img in matching_files: - wildcards: list[str] = [ - wildcard + entities = [ + BidsEntity.normalize(wildcard) for wildcard in set(component.get("wildcards", [])) if wildcard in img.entities ] - _logger.debug("Wildcards %s found entities for %s", wildcards, img.path) + _logger.debug("Wildcards %s found entities for %s", entities, img.path) try: - path, parsed_wildcards = _parse_bids_path(img.path, wildcards) + path, parsed_wildcards = _parse_bids_path(img.path, entities) except BidsParseError as err: msg = ( "Parsing failed:\n" @@ -626,13 +630,22 @@ def _get_component( ) raise ConfigError(msg) from err - for wildcard_name, value in parsed_wildcards.items(): - zip_lists[wildcard_name].append(value) - + entries.append(parsed_wildcards) paths.add(path) + wildcards = tuple(entity.wildcard for entity in entities) - # now, check to see if unique - if len(paths) == 0: + try: + path = itx.one(paths, too_short=TypeError) + except ValueError: + msg = ( + f"Multiple path templates for one component. Use --filter_{input_name} to " + f"narrow your search or --wildcards_{input_name} to make the template more " + "generic.\n" + f"\tcomponent = {input_name!r}\n" + f"\tpath_templates = [\n\t\t" + ",\n\t\t".join(map(repr, paths)) + "\n\t]\n" + ).expandtabs(4) + raise ConfigError(msg) from None + except TypeError: _logger.warning( "No input files found for snakebids component %s:\n" " filters:\n%s\n" @@ -649,32 +662,23 @@ def _get_component( ), ) return None - try: - path = itx.one(paths) - except ValueError as err: - msg = ( - f"Multiple path templates for one component. Use --filter_{input_name} to " - f"narrow your search or --wildcards_{input_name} to make the template more " - "generic.\n" - f"\tcomponent = {input_name!r}\n" - f"\tpath_templates = [\n\t\t" + ",\n\t\t".join(map(repr, paths)) + "\n\t]\n" - ).expandtabs(4) - raise ConfigError(msg) from err if filters.has_empty_postfilter: return BidsComponent( - name=input_name, path=path, zip_lists={key: [] for key in zip_lists} + name=input_name, path=path, table=BidsTable(wildcards=wildcards, entries=[]) ) - return BidsComponent(name=input_name, path=path, zip_lists=zip_lists).filter( - regex_search=True, **filters.post_exclusions - ) + return BidsComponent( + name=input_name, + path=path, + table=BidsTable(wildcards=wildcards, entries=entries), + ).filter(regex_search=True, **filters.post_exclusions) def _parse_custom_path( input_path: Path | str, filters: UnifiedFilter, -) -> ZipList: +) -> BidsTable: """Glob wildcards from a custom path and apply filters. This replicates pybids path globbing for any custom path. Input path should have @@ -699,27 +703,29 @@ def _parse_custom_path( ------- input_zip_list, input_list, input_wildcards """ - if not (wildcards := glob_wildcards(input_path)): + if (table := glob_wildcards_to_entries(input_path)) is None: _logger.warning("No wildcards defined in %s", input_path) + return BidsTable(wildcards=(), entries=[]) # Log an error if no matches found - if len(itx.first(wildcards.values())) == 0: + if not table: _logger.error("No matching files for %s", input_path) - return wildcards + return table # Return the output values, running filtering on the zip_lists - result = filter_list( - wildcards, + result = table.filter( filters.without_bools, regex_search=False, ) if not filters.post_exclusions: return result - return filter_list(result, filters.post_exclusions, regex_search=True) + return result.filter(filters.post_exclusions, regex_search=True) -def _parse_bids_path(path: str, entities: Iterable[str]) -> tuple[str, dict[str, str]]: +def _parse_bids_path( + path: str, entities: Iterable[BidsEntity] +) -> tuple[str, tuple[str, ...]]: """Replace parameters in an bids path with the given wildcard {tags}. Parameters @@ -732,40 +738,31 @@ def _parse_bids_path(path: str, entities: Iterable[str]) -> tuple[str, dict[str, Returns ------- - path : str + path Original path with the original entities replaced with wildcards. (e.g. "root/sub-{subject}/ses-{session}/sub-{subject}_ses-{session}_{suffix}") - matches : iterable of (wildcard, value) + matches The values matched with each wildcard """ # If path is relative, we need to get a slash in front of it to ensure parsing works # correctly. So prepend "./" or ".\" and run function again, then strip before # returning if _is_local_relative(path) and get_first_dir(path) != ".": - path_, wildcard_values = _parse_bids_path(os.path.join(".", path), entities) - return str(Path(path_)), wildcard_values - - entities = list(entities) - - matches = sorted( - ( - (entity, match) - for entity in map(BidsEntity, entities) - for match in re.finditer(entity.regex, path) - ), - key=lambda match: match[1].start(2), - ) - - wildcard_values: dict[str, str] = { - entity.wildcard: match.group(2) for entity, match in matches - } - if len(wildcard_values) != len(entities): - unmatched = ( - set(map(BidsEntity, entities)) - .difference(match[0] for match in matches) - .pop() - ) - raise BidsParseError(path=path, entity=unmatched) + path_, vals = _parse_bids_path(os.path.join(".", path), entities) + return str(Path(path_)), vals + + matches: list[tuple[BidsEntity, re.Match[str]]] = [] + wildcard_values: list[str] = [] + for entity in entities: + values: list[str] = [] + for match in re.finditer(entity.regex, path): + matches.append((entity, match)) + values.append(match.group(2)) + if len(values) == 0: + raise BidsParseError(path=path, entity=entity) + wildcard_values.append(values[0]) + + matches = sorted(matches, key=lambda match: match[1].start(2)) num_matches = len(matches) new_path: list[str] = [] @@ -776,7 +773,7 @@ def _parse_bids_path(path: str, entities: Iterable[str]) -> tuple[str, dict[str, new_path.append(path[start:end].replace("{", "{{").replace("}", "}}")) if i < num_matches: new_path.append(f"{{{matches[i][0].wildcard}}}") - return "".join(new_path), wildcard_values + return "".join(new_path), tuple(wildcard_values) def get_wildcard_constraints(image_types: InputsConfig) -> dict[str, str]: diff --git a/snakebids/exceptions.py b/snakebids/exceptions.py index 0d532ffc..9f04f0c3 100644 --- a/snakebids/exceptions.py +++ b/snakebids/exceptions.py @@ -2,6 +2,8 @@ from collections.abc import Iterable +from snakebids.utils.utils import BidsEntity + class ConfigError(Exception): """Exception raised for errors with the Snakebids config.""" @@ -37,3 +39,12 @@ def __init__(self, duplicated_names: Iterable[str]): class SnakebidsPluginError(Exception): """Exception raised when a Snakebids plugin encounters a problem.""" + + +class BidsParseError(Exception): + """Exception raised for errors encountered in the parsing of Bids paths.""" + + def __init__(self, path: str, entity: BidsEntity) -> None: + self.path = path + self.entity = entity + super().__init__(path, entity) diff --git a/snakebids/tests/helpers.py b/snakebids/tests/helpers.py index b31a399a..9ad1f0dc 100644 --- a/snakebids/tests/helpers.py +++ b/snakebids/tests/helpers.py @@ -26,11 +26,12 @@ from typing_extensions import ParamSpec from snakebids import bids_factory +from snakebids.core._table import BidsTable from snakebids.core.datasets import BidsDataset from snakebids.core.input_generation import generate_inputs from snakebids.paths import specs -from snakebids.types import InputsConfig, ZipList, ZipListLike -from snakebids.utils.containers import MultiSelectDict, UserDictPy38 +from snakebids.types import InputsConfig, ZipListLike +from snakebids.utils.containers import UserDictPy38 from snakebids.utils.utils import BidsEntity _T = TypeVar("_T") @@ -38,10 +39,10 @@ _P = ParamSpec("_P") -def get_zip_list( +def get_bids_entries( entities: Iterable[BidsEntity | str], combinations: Iterable[tuple[str, ...]] -) -> ZipList: - """Return a zip list from iterables of entities and value combinations +) -> BidsTable: + """Return a BidsEntries from iterables of entities and value combinations Parameters ---------- @@ -56,16 +57,9 @@ def get_zip_list( dict[str, list[str]] zip_list representation of entity-value combinations """ - - def strlist() -> list[str]: - return [] - - lists: Iterable[Sequence[str]] = list(zip(*combinations)) or itx.repeatfunc(strlist) - return MultiSelectDict( - { - BidsEntity(str(entity)).wildcard: list(combs) - for entity, combs in zip(entities, lists) - } + return BidsTable( + wildcards=(BidsEntity(str(entity)).wildcard for entity in entities), + entries=combinations, ) diff --git a/snakebids/tests/strategies.py b/snakebids/tests/strategies.py index dbebb19c..42620285 100644 --- a/snakebids/tests/strategies.py +++ b/snakebids/tests/strategies.py @@ -19,6 +19,7 @@ from bids.layout import Config as BidsConfig from hypothesis import assume +from snakebids.core._table import BidsTable from snakebids.core.datasets import ( BidsComponent, BidsComponentRow, @@ -26,7 +27,7 @@ BidsPartialComponent, ) from snakebids.tests import helpers -from snakebids.types import Expandable, InputConfig, InputsConfig, ZipList +from snakebids.types import Expandable, InputConfig, InputsConfig from snakebids.utils.containers import ContainerBag, MultiSelectDict from snakebids.utils.utils import BidsEntity @@ -188,10 +189,9 @@ def bids_path( extras = ( { - k: v[0].replace("{", "{{").replace("}", "}}") + k: v.replace("{", "{{").replace("}", "}}") for k, v in draw( - zip_lists( - max_values=1, + bids_entries( max_entities=2, blacklist_entities=ContainerBag( blacklist_entities if blacklist_entities is not None else [], @@ -296,7 +296,48 @@ def inputs_configs( @st.composite -def zip_lists( +def bids_entries( + draw: st.DrawFn, + *, + min_entities: int = 1, + max_entities: int = 5, + entities: list[BidsEntity] | None = None, + blacklist_entities: Container[BidsEntity | str] | None = None, + whitelist_entities: Container[BidsEntity | str] | None = None, + restrict_patterns: bool = False, +) -> dict[str, str]: + # Generate entity value-pairs for different "file types" + + if entities is None: + entities = draw( + bids_entity_lists( + min_size=min_entities, + max_size=max_entities, + blacklist_entities=blacklist_entities, + whitelist_entities=whitelist_entities, + ) + ) + + def filter_ints(type_: str | None): + def inner(s: str): + if type_ == "int": + return int(s) > 0 and not s.startswith("0") + return True + + return inner + + return { + BidsEntity(entity).wildcard: draw( + bids_value(entity.match if restrict_patterns else ".*").filter( + filter_ints(entity.type if restrict_patterns else None) + ) + ) + for entity in entities + } + + +@st.composite +def bids_tables( draw: st.DrawFn, *, min_entities: int = 1, @@ -309,7 +350,7 @@ def zip_lists( restrict_patterns: bool = False, unique: bool = False, cull: bool = True, -) -> ZipList: +) -> BidsTable: # Generate multiple entity sets for different "file types" if entities is None: @@ -357,7 +398,10 @@ def inner(s: str): if cull and len(combinations) else combinations ) - return helpers.get_zip_list(values, used_combinations) + return BidsTable( + wildcards=(BidsEntity(entity).wildcard for entity in values), + entries=used_combinations, + ) @st.composite @@ -437,8 +481,8 @@ def bids_partial_components( ), ) ) - zip_list = draw( - zip_lists( + table = draw( + bids_tables( min_entities=min_entities, max_entities=max_entities, min_values=min_values, @@ -452,7 +496,7 @@ def bids_partial_components( ) ) - return BidsPartialComponent(zip_lists=zip_list) + return BidsPartialComponent(table=table) @st.composite @@ -502,7 +546,7 @@ def bids_components( return BidsComponent( name=name or draw(bids_value()), path=str(path), - zip_lists=partial.zip_lists, + table=partial._table, ) diff --git a/snakebids/tests/test_datasets.py b/snakebids/tests/test_datasets.py index 59c9b5f2..05bfc1b7 100644 --- a/snakebids/tests/test_datasets.py +++ b/snakebids/tests/test_datasets.py @@ -13,6 +13,7 @@ from hypothesis import assume, given from hypothesis import strategies as st +from snakebids.core._table import BidsTable from snakebids.core.datasets import ( BidsComponent, BidsComponentRow, @@ -23,16 +24,20 @@ from snakebids.paths._presets import bids from snakebids.snakemake_compat import WildcardError from snakebids.tests import strategies as sb_st -from snakebids.tests.helpers import expand_zip_list, get_bids_path, get_zip_list, setify -from snakebids.types import Expandable, ZipList +from snakebids.tests.helpers import ( + expand_zip_list, + get_bids_path, + setify, +) +from snakebids.types import Expandable from snakebids.utils import sb_itertools as sb_it from snakebids.utils.snakemake_io import glob_wildcards from snakebids.utils.utils import BidsEntity, get_wildcard_dict, zip_list_eq def test_multiple_components_cannot_have_same_name(): - comp1 = BidsComponent(name="foo", path=".", zip_lists={}) - comp2 = BidsComponent(name="foo", path=".", zip_lists={}) + comp1 = BidsComponent(name="foo", path=".", table={}) + comp2 = BidsComponent(name="foo", path=".", table={}) with pytest.raises(DuplicateComponentError): BidsDataset.from_iterable([comp1, comp2]) @@ -57,37 +62,29 @@ def test_bids_dataset_aliases_are_correctly_set(self, component: BidsComponent): class TestBidsComponentValidation: - @given(sb_st.zip_lists().filter(lambda v: len(v) > 1)) - def test_zip_lists_must_be_same_length(self, zip_lists: ZipList): - itx.first(zip_lists.values()).append("foo") - with pytest.raises(ValueError, match="zip_lists must all be of equal length"): - BidsComponent( - name="foo", path=get_bids_path(zip_lists), zip_lists=zip_lists - ) - - @given(sb_st.zip_lists(), sb_st.bids_entity()) + @given(sb_st.bids_tables(), sb_st.bids_entity()) def test_path_cannot_have_extra_entities( - self, zip_lists: ZipList, entity: BidsEntity + self, table: BidsTable, entity: BidsEntity ): - assume(entity.wildcard not in zip_lists) - path = get_bids_path(it.chain(zip_lists, [entity.entity])) + assume(entity.wildcard not in table.wildcards) + path = get_bids_path(it.chain(table.wildcards, [entity.entity])) with pytest.raises( - ValueError, match="zip_lists entries must match the wildcards in input_path" + ValueError, match="entries have the same wildcards as the input path" ): - BidsComponent(name="foo", path=path, zip_lists=zip_lists) + BidsComponent(name="foo", path=path, table=table) - @given(sb_st.zip_lists().filter(lambda v: len(v) > 1)) - def test_path_cannot_have_missing_entities(self, zip_lists: ZipList): + @given(sb_st.bids_tables(min_entities=2)) + def test_path_cannot_have_missing_entities(self, table: BidsTable): # Snakebids strategies won't return a zip_list with just datatype, but now that # we've dropped an entity we need to check again - path_entities = sb_it.drop(1, zip_lists) + path_entities = sb_it.drop(1, table.wildcards) assume(set(path_entities) - {"datatype"}) path = get_bids_path(path_entities) with pytest.raises( - ValueError, match="zip_lists entries must match the wildcards in input_path" + ValueError, match="entries have the same wildcards as the input path" ): - BidsComponent(name="foo", path=path, zip_lists=zip_lists) + BidsComponent(name="foo", path=path, table=table) class TestBidsComponentEq: @@ -95,56 +92,26 @@ class TestBidsComponentEq: def test_other_types_are_unequal(self, comp: BidsComponent, other: Any): assert comp != other - def test_empty_bidsinput_are_equal(self): - assert BidsComponent(name="", path="", zip_lists={}) == BidsComponent( - name="", path="", zip_lists={} - ) - assert BidsComponent( - name="", - path="{foo}{bar}", - zip_lists={"foo": [], "bar": []}, - ) == BidsComponent( - name="", - path="{foo}{bar}", - zip_lists={"foo": [], "bar": []}, - ) - - @given(sb_st.bids_components()) + @given(sb_st.expandables()) def test_copies_are_equal(self, comp: BidsComponent): cp = copy.deepcopy(comp) assert cp == comp - @given(sb_st.bids_components()) - def test_mutation_makes_unequal(self, comp: BidsComponent): - cp = copy.deepcopy(comp) - itx.first(cp.zip_lists.values())[0] += "foo" - assert cp != comp - - @given(sb_st.bids_components(), st.data()) - def test_extra_entities_makes_unequal( - self, comp: BidsComponent, data: st.DataObject + @given( + first=sb_st.bids_partial_components(), second=sb_st.bids_partial_components() + ) + def test_unequal_tables_yield_unequal_components( + self, first: BidsPartialComponent, second: BidsPartialComponent ): - cp = copy.deepcopy(comp) - new_entity = data.draw( - sb_st.bids_value().filter(lambda s: s not in comp.zip_lists) - ) - cp.zip_lists[new_entity] = [] - itx.first(cp.zip_lists.values())[0] += "foo" - assert cp != comp - - @given(sb_st.bids_components()) - def test_order_doesnt_affect_equality(self, comp: BidsComponent): - cp = copy.deepcopy(comp) - for list_ in cp.zip_lists: - cp.zip_lists[list_].reverse() - assert cp == comp + assume(first._table != second._table) + assert first != second @given(sb_st.bids_components()) def test_paths_must_be_identical(self, comp: BidsComponent): cp = BidsComponent( name=comp.input_name, path=comp.input_path + "foo", - zip_lists=comp.zip_lists, + table=comp.zip_lists, ) assert cp != comp @@ -159,29 +126,24 @@ def test_input_lists_derives_from_zip_lists( # Due to the product, we can delete some of the combinations and still # regenerate our input_lists combs = list(it.product(*input_lists.values()))[min_size - 1 :] - zip_lists = get_zip_list(input_lists, combs) - path = get_bids_path(zip_lists) + table = BidsTable( + wildcards=input_lists, + entries=combs, + ) + path = get_bids_path(table.wildcards) assert setify( - BidsComponent(name="foo", path=path, zip_lists=zip_lists).entities + BidsComponent(name="foo", path=path, table=table).entities ) == setify(input_lists) - @given( - st.dictionaries( - sb_st.bids_entity().map(lambda e: e.wildcard), - sb_st.bids_value("[^.]*"), - min_size=1, - ).filter(lambda v: list(v) != ["datatype"]) - ) + @given(sb_st.bids_entries(restrict_patterns=True, blacklist_entities=["extension"])) def test_input_wildcards_derives_from_zip_lists( self, bids_entities: dict[str, str], ): - zip_lists = {entity: [val] for entity, val in bids_entities.items()} + table = BidsTable.from_dict({k: [v] for k, v in bids_entities.items()}) bids_input = BidsComponent( - name="foo", - path=get_bids_path(zip_lists), - zip_lists=zip_lists, + name="foo", path=get_bids_path(bids_entities.keys()), table=table ) wildstr = ".".join(bids_input.input_wildcards.values()) @@ -342,7 +304,7 @@ def test_expand_preserves_entry_order(self, component: Expandable): @given(path=st.text()) def test_expandable_with_no_wildcards_returns_path_unaltered(self, path: str): - component = BidsPartialComponent(zip_lists={}) + component = BidsPartialComponent(table={}) assert itx.one(component.expand(path)) == path @given(component=sb_st.expandables(min_values=0, max_values=0, path_safe=True)) @@ -422,17 +384,13 @@ def value_strat(filt: str): for filt in filters: filter_dict[filt] = data.draw( - st.one_of( - [ - st.lists( - value_strat(filt), - unique=True, - min_size=1, - max_size=5, - ), - value_strat(filt), - ] + st.lists( + value_strat(filt), + unique=True, + min_size=1, + max_size=5, ) + | value_strat(filt), ) return filter_dict diff --git a/snakebids/tests/test_generate_inputs.py b/snakebids/tests/test_generate_inputs.py index 2f989a71..6df97fe9 100644 --- a/snakebids/tests/test_generate_inputs.py +++ b/snakebids/tests/test_generate_inputs.py @@ -25,6 +25,7 @@ from pytest_mock import MockerFixture from snakebids.core._querying import PostFilter, UnifiedFilter, get_matching_files +from snakebids.core._table import BidsTable from snakebids.core.datasets import BidsComponent, BidsDataset from snakebids.core.input_generation import ( _all_custom_paths, @@ -36,7 +37,7 @@ _parse_custom_path, generate_inputs, ) -from snakebids.exceptions import ConfigError, PybidsError, RunError +from snakebids.exceptions import BidsParseError, ConfigError, PybidsError, RunError from snakebids.paths._presets import bids from snakebids.snakemake_compat import expand as sb_expand from snakebids.tests import strategies as sb_st @@ -47,14 +48,14 @@ create_dataset, create_snakebids_config, example_if, + get_bids_entries, get_bids_path, - get_zip_list, mock_data, reindex_dataset, ) from snakebids.types import InputsConfig from snakebids.utils.containers import MultiSelectDict -from snakebids.utils.utils import DEPRECATION_FLAG, BidsEntity, BidsParseError +from snakebids.utils.utils import DEPRECATION_FLAG, BidsEntity T = TypeVar("T") @@ -315,9 +316,9 @@ def test_entity_excluded_when_filter_false( "0": BidsComponent( name="0", path="ce-{ce}_space-{space}", - zip_lists={"ce": ["0"], "space": ["0"]}, + table={"ce": "0", "space": "0"}, ), - "1": BidsComponent(name="1", path="ce-{ce}", zip_lists={"ce": ["0"]}), + "1": BidsComponent(name="1", path="ce-{ce}", table={"ce": ["0"]}), } ) ) @@ -327,12 +328,12 @@ def test_entity_excluded_when_filter_false( "1": BidsComponent( name="1", path="sub-{subject}/{datatype}/sub-{subject}", - zip_lists={"subject": ["0"], "datatype": ["anat"]}, + table={"subject": ["0"], "datatype": ["anat"]}, ), "0": BidsComponent( name="0", path="sub-{subject}/sub-{subject}", - zip_lists={"subject": ["0"]}, + table={"subject": ["0"]}, ), } ) @@ -343,12 +344,12 @@ def test_entity_excluded_when_filter_false( "1": BidsComponent( name="1", path="sub-{subject}/sub-{subject}_{suffix}.foo", - zip_lists={"subject": ["0"], "suffix": ["bar"]}, + table={"subject": ["0"], "suffix": ["bar"]}, ), "0": BidsComponent( name="0", path="{suffix}.foo", - zip_lists={"suffix": ["bar"]}, + table=({"suffix": ["bar"]}), ), } ) @@ -387,7 +388,12 @@ def test_entity_excluded_when_filter_true(self, tmpdir: Path, dataset: BidsDatas data = generate_inputs(root, pybids_inputs) assert data == expected - @allow_function_scoped + @settings( + deadline=800, + suppress_health_check=[ + HealthCheck.function_scoped_fixture, + ], + ) @given( template=sb_st.bids_components( name="template", restrict_patterns=True, unique=True, extra_entities=False @@ -450,7 +456,12 @@ def add_entity(component: BidsComponent, entity: str, value: str): result = generate_inputs(root, pybids_inputs) assert result == BidsDataset({"target": dataset["target"]}) - @allow_function_scoped + @settings( + deadline=800, + suppress_health_check=[ + HealthCheck.function_scoped_fixture, + ], + ) @given( template=sb_st.bids_components( name="template", restrict_patterns=True, unique=True, extra_entities=False @@ -520,10 +531,12 @@ def tmpdir(self, fakefs_tmpdir: Path): component=BidsComponent( name="template", path="sub-{subject}/ses-{session}/sub-{subject}_ses-{session}", - zip_lists={ - "session": ["0A", "0a"], - "subject": ["0", "00"], - }, + table=( + { + "session": ["0A", "0a"], + "subject": ["0", "00"], + } + ), ), data=mock_data(["0a"]), ) @@ -531,10 +544,12 @@ def tmpdir(self, fakefs_tmpdir: Path): component=BidsComponent( name="template", path="sub-{subject}/sub-{subject}_mt-{mt}", - zip_lists={ - "subject": ["0", "00"], - "mt": ["on", "on"], - }, + table=( + { + "subject": ["0", "00"], + "mt": ["on", "on"], + } + ), ), data=mock_data(["0"]), ) @@ -636,22 +651,27 @@ def test_regex_match_selects_paths( ) def test_regex_search_selects_paths(self, tmpdir: Path, component: BidsComponent): root = tempfile.mkdtemp(dir=tmpdir) - entity = itx.first(component.entities) - assume(f"prefix{component[entity][0]}suffix" not in component.entities[entity]) - zip_lists = { - ent: ( - [*value, f"prefix{value[0]}suffix"] - if ent is entity - else [*value, value[0]] - ) - for ent, value in component.zip_lists.items() - } + entity = component._table.wildcards[0] + assume( + f"prefix{component._table.entries[0][0]}suffix" + not in component.entities[entity] + ) + entries = [ + *component._table.entries, + ( + f"prefix{component._table.entries[0][0]}suffix", + *component._table.entries[0][1:], + ), + ] dataset = BidsDataset.from_iterable( [ attrs.evolve( component, path=os.path.join(root, component.path), - zip_lists=zip_lists, + table=attrs.evolve( + component._table, + entries=entries, + ), ) ] ) @@ -678,7 +698,7 @@ def test_regex_search_selects_paths(self, tmpdir: Path, component: BidsComponent component=BidsComponent( name="template", path="sub-{subject}/sub-{subject}_mt-{mt}", - zip_lists={ + table={ "subject": ["0", "00"], "mt": ["on", "on"], }, @@ -789,7 +809,7 @@ def test_filter_with_multiple_methods_raises_error( self, tmpdir: Path, methods: list[str] ): dataset = BidsDataset.from_iterable( - [BidsComponent(zip_lists={}, name="template", path=str(tmpdir))] + [BidsComponent(table={}, name="template", path=str(tmpdir))] ) create_dataset("", dataset) pybids_inputs: InputsConfig = { @@ -805,7 +825,7 @@ def test_filter_with_multiple_methods_raises_error( @pytest.mark.disable_fakefs(True) def test_filter_with_no_methods_raises_error(self, tmpdir: Path): dataset = BidsDataset.from_iterable( - [BidsComponent(zip_lists={}, name="template", path=str(tmpdir))] + [BidsComponent(table={}, name="template", path=str(tmpdir))] ) create_dataset("", dataset) pybids_inputs: InputsConfig = { @@ -828,7 +848,7 @@ def test_filter_with_no_methods_raises_error(self, tmpdir: Path): ) def test_filter_with_invalid_method_raises_error(self, tmpdir: Path, method: str): dataset = BidsDataset.from_iterable( - [BidsComponent(zip_lists={}, name="template", path=str(tmpdir))] + [BidsComponent(table={}, name="template", path=str(tmpdir))] ) create_dataset("", dataset) pybids_inputs: InputsConfig = { @@ -902,7 +922,7 @@ def test_missing_filters(self, tmpdir: Path): pybids_config=str(Path(__file__).parent / "data" / "custom_config.json"), ) template = BidsDataset( - {"t1": BidsComponent(name="t1", path=config["t1"].path, zip_lists=zip_list)} + {"t1": BidsComponent(name="t1", path=config["t1"].path, table=zip_list)} ) # Order of the subjects is not deterministic assert template == config @@ -927,7 +947,13 @@ def test_missing_wildcards(self, tmpdir: Path): pybids_config=str(Path(__file__).parent / "data" / "custom_config.json"), ) template = BidsDataset( - {"t1": BidsComponent(name="t1", path=config["t1"].path, zip_lists={})} + { + "t1": BidsComponent( + name="t1", + path=config["t1"].path, + table=BidsTable(wildcards=[], entries=[()]), + ) + } ) assert template == config assert config.subj_wildcards == {"subject": "{subject}"} @@ -1094,11 +1120,11 @@ def test_collects_all_paths_when_no_filters( # Test without any filters result = _parse_custom_path(test_path, UnifiedFilter.from_filter_dict({})) - zip_lists = get_zip_list(entities, it.product(*entities.values())) + zip_lists = get_bids_entries(entities, it.product(*entities.values())) assert BidsComponent( - name="foo", path=get_bids_path(zip_lists), zip_lists=zip_lists + name="foo", path=get_bids_path(zip_lists.wildcards), table=zip_lists ) == BidsComponent( - name="foo", path=get_bids_path(result), zip_lists=MultiSelectDict(result) + name="foo", path=get_bids_path(result.wildcards), table=result ) @settings(deadline=400, suppress_health_check=[HealthCheck.function_scoped_fixture]) @@ -1112,21 +1138,19 @@ def test_collects_only_filtered_entities( test_path = self.generate_test_directory(entities, template, temp_dir) # Test with filters - result_filtered = MultiSelectDict( - _parse_custom_path(test_path, UnifiedFilter.from_filter_dict(filters)) + result_filtered = _parse_custom_path( + test_path, UnifiedFilter.from_filter_dict(filters) ) - zip_lists = MultiSelectDict( - { - # Start with empty lists for each key, otherwise keys will be missing - **{key: [] for key in entities}, - # Override entities with relevant filters before making zip lists - **get_zip_list(entities, it.product(*{**entities, **filters}.values())), - } + zip_lists = get_bids_entries( + entities, it.product(*{**entities, **filters}.values()) ) + assert BidsComponent( - name="foo", path=get_bids_path(zip_lists), zip_lists=zip_lists + name="foo", path=get_bids_path(zip_lists.wildcards), table=zip_lists ) == BidsComponent( - name="foo", path=get_bids_path(result_filtered), zip_lists=result_filtered + name="foo", + path=get_bids_path(result_filtered.wildcards), + table=result_filtered, ) @settings( @@ -1144,29 +1168,22 @@ def test_collect_all_but_filters_when_exclusion_filters_used( exclude_filters = PostFilter() for key, values in filters.items(): exclude_filters.add_filter(key, None, values) - result_excluded = MultiSelectDict( - _parse_custom_path( - test_path, UnifiedFilter.from_filter_dict({}, exclude_filters) - ) + result_excluded = _parse_custom_path( + test_path, UnifiedFilter.from_filter_dict({}, exclude_filters) ) entities_excluded = { entity: [value for value in values if value not in filters.get(entity, [])] for entity, values in entities.items() } - zip_lists = MultiSelectDict( - { - # Start with empty lists for each key, otherwise keys will be missing - **{key: [] for key in entities}, - # Override entities with relevant filters before making zip lists - **get_zip_list(entities, it.product(*entities_excluded.values())), - } - ) + zip_lists = get_bids_entries(entities, it.product(*entities_excluded.values())) assert BidsComponent( - name="foo", path=get_bids_path(zip_lists), zip_lists=zip_lists + name="foo", path=get_bids_path(zip_lists.wildcards), table=zip_lists ) == BidsComponent( - name="foo", path=get_bids_path(result_excluded), zip_lists=result_excluded + name="foo", + path=get_bids_path(result_excluded.wildcards), + table=result_excluded, ) @given( @@ -1231,7 +1248,7 @@ def test_custom_pybids_config(tmpdir: Path): foo="{foo}", suffix="T1w.nii.gz", ), - zip_lists={"foo": ["0", "1"], "subject": ["001", "001"]}, + table={"foo": ["0", "1"], "subject": ["001", "001"]}, ) } ) @@ -1316,7 +1333,7 @@ def test_t1w(): "t1": BidsComponent( name="t1", path=result["t1"].path, - zip_lists={"acq": ["mprage", "mprage"], "subject": ["001", "002"]}, + table={"acq": ["mprage", "mprage"], "subject": ["001", "002"]}, ) } ) @@ -1355,16 +1372,10 @@ def test_t1w(): "scan": BidsComponent( name="scan", path=result["scan"].path, - zip_lists={ - "acq": [ - "mprage", - ], - "subject": [ - "001", - ], - "suffix": [ - "T1w", - ], + table={ + "acq": ["mprage"], + "subject": ["001"], + "suffix": ["T1w"], }, ) } @@ -1420,7 +1431,7 @@ def test_t1w(): "t1": BidsComponent( name="t1", path=result["t1"].path, - zip_lists={ + table={ "acq": ["mprage", "mprage"], "subject": ["001", "002"], }, @@ -1428,7 +1439,7 @@ def test_t1w(): "t2": BidsComponent( name="t2", path=result["t2"].path, - zip_lists={"subject": ["002"]}, + table={"subject": ["002"]}, ), } ) @@ -1565,7 +1576,7 @@ def test_get_lists_from_bids(): template = BidsComponent( name="t1", path=wildcard_path_t1, - zip_lists={ + table={ "acq": ["mprage", "mprage"], "subject": ["001", "002"], }, @@ -1576,7 +1587,7 @@ def test_get_lists_from_bids(): template = BidsComponent( name="t2", path=wildcard_path_t2, - zip_lists={ + table={ "subject": ["002"], }, ) @@ -1682,7 +1693,7 @@ def __str__(self): # __fspath__ calls __str__ by default "1": BidsComponent( name="1", path="sub-{subject}/sub-{subject}_{suffix}{extension}", - zip_lists={ + table={ "subject": ["0"], "suffix": ["0"], "extension": [".0"], @@ -1691,7 +1702,7 @@ def __str__(self): # __fspath__ calls __str__ by default "0": BidsComponent( name="0", path="sub-{subject}/sub-{subject}{extension}", - zip_lists={ + table={ "subject": ["0"], "extension": [".0"], }, @@ -1949,9 +1960,10 @@ def test_splits_wildcards_from_path( path = component.expand()[0] if scheme is not None: path = f"{scheme}{path}" - entities = [BidsEntity.normalize(e).entity for e in component.zip_lists] + entities = [BidsEntity.normalize(e) for e in component.zip_lists] + wildcards = [BidsEntity.normalize(e).wildcard for e in component.zip_lists] tpl_path, matches = _parse_bids_path(path, entities) - assert tpl_path.format(**matches) == path + assert tpl_path.format(**dict(zip(wildcards, matches))) == path @given( component=sb_st.bids_components(max_values=1, restrict_patterns=True), @@ -1963,9 +1975,10 @@ def test_one_match_found_for_each_entity( path = component.expand()[0] if scheme is not None: path = f"{scheme}{path}" - entities = [BidsEntity.normalize(e).entity for e in component.zip_lists] + entities = [BidsEntity.normalize(e) for e in component.zip_lists] + wildcards = [BidsEntity.normalize(e).wildcard for e in component.zip_lists] _, matches = _parse_bids_path(path, entities) - assert set(matches.items()) == { + assert set(zip(wildcards, matches)) == { (key, val[0]) for key, val in component.zip_lists.items() } @@ -1982,10 +1995,10 @@ def test_missing_match_leads_to_error( path = component.expand()[0] if scheme is not None: path = f"{scheme}{path}" - entities = [BidsEntity.normalize(e).entity for e in component.zip_lists] - assume(entity.entity not in entities) + entities = [BidsEntity.normalize(e) for e in component.zip_lists] + assume(entity not in entities) with pytest.raises(BidsParseError) as err: - _parse_bids_path(path, it.chain(entities, [entity.entity])) + _parse_bids_path(path, it.chain(entities, [entity])) assert err.value.entity == entity diff --git a/snakebids/tests/test_printing.py b/snakebids/tests/test_printing.py index a068984c..ce4ddb5b 100644 --- a/snakebids/tests/test_printing.py +++ b/snakebids/tests/test_printing.py @@ -25,7 +25,7 @@ def zip_list_parser() -> pp.ParserElement: return pp.Suppress("{") + pp.Group(row)[1, ...] + pp.Suppress("}") -@given(zip_list=sb_st.zip_lists(max_entities=1, restrict_patterns=True)) +@given(zip_list=sb_st.bids_tables(max_entities=1, restrict_patterns=True)) def test_ellipses_appears_when_maxwidth_too_short(zip_list: ZipList): width = len(format_zip_lists(zip_list, tabstop=0).splitlines()[1]) parsed = zip_list_parser().parse_string( @@ -34,13 +34,13 @@ def test_ellipses_appears_when_maxwidth_too_short(zip_list: ZipList): assert "ellipse" in parsed[0] -@given(zip_list=sb_st.zip_lists(max_entities=1, restrict_patterns=True)) +@given(zip_list=sb_st.bids_tables(max_entities=1, restrict_patterns=True)) def test_no_ellipses_when_no_max_width(zip_list: ZipList): parsed = zip_list_parser().parse_string(format_zip_lists(zip_list, tabstop=0)) assert "ellipse" not in parsed[0] -@given(zip_list=sb_st.zip_lists(max_entities=1, restrict_patterns=True)) +@given(zip_list=sb_st.bids_tables(max_entities=1, restrict_patterns=True)) def test_no_ellipses_when_max_width_long_enouth(zip_list: ZipList): width = len(format_zip_lists(zip_list, tabstop=0).splitlines()[1]) parsed = zip_list_parser().parse_string( @@ -50,7 +50,7 @@ def test_no_ellipses_when_max_width_long_enouth(zip_list: ZipList): @given( - zip_list=sb_st.zip_lists( + zip_list=sb_st.bids_tables( max_entities=1, min_values=0, max_values=0, restrict_patterns=True ) ) @@ -63,7 +63,7 @@ def test_no_ellipses_appears_when_ziplist_empty(zip_list: ZipList): @given( - zip_list=sb_st.zip_lists( + zip_list=sb_st.bids_tables( min_values=1, max_values=4, max_entities=4, restrict_patterns=True ), width=st.integers(min_value=10, max_value=200), @@ -91,7 +91,7 @@ def test_values_balanced_around_elision_correctly(zip_list: ZipList, width: int) class TestCorrectNumberOfLinesCreated: @given( - zip_list=sb_st.zip_lists( + zip_list=sb_st.bids_tables( min_values=0, max_values=1, max_entities=6, restrict_patterns=True ), ) @@ -122,7 +122,7 @@ def test_in_dataset(self, dataset: BidsDataset): class TestIsValidPython: @given( - zip_list=sb_st.zip_lists(restrict_patterns=True, min_values=0, min_entities=0) + zip_list=sb_st.bids_tables(restrict_patterns=True, min_values=0, min_entities=0) ) def test_in_zip_list(self, zip_list: ZipList): assert eval(format_zip_lists(zip_list, inf)) == zip_list @@ -140,7 +140,7 @@ def test_in_dataset(self, dataset: BidsDataset): # path and name are allowed to be longer than the width, so finding the zip_list lines # would prove more challenging than it's worth @given( - zip_list=sb_st.zip_lists(max_entities=1, restrict_patterns=True), + zip_list=sb_st.bids_tables(max_entities=1, restrict_patterns=True), width=st.integers(10, 100), tab=st.integers(0, 10), ) @@ -158,7 +158,7 @@ def get_indent_length(line: str): class TestIndentLengthMultipleOfTabStop: @given( - zip_list=sb_st.zip_lists(restrict_patterns=True, min_values=0), + zip_list=sb_st.bids_tables(restrict_patterns=True, min_values=0), tabstop=st.integers(1, 10), ) def test_in_zip_list(self, zip_list: ZipList, tabstop: int): diff --git a/snakebids/tests/test_tables.py b/snakebids/tests/test_tables.py new file mode 100644 index 00000000..37fd6093 --- /dev/null +++ b/snakebids/tests/test_tables.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +import copy +from typing import Any + +import attrs +import more_itertools as itx +import pytest +from hypothesis import assume, given + +from snakebids.core._table import BidsTable +from snakebids.core.datasets import BidsComponent +from snakebids.tests import strategies as sb_st +from snakebids.tests.helpers import get_bids_path +from snakebids.utils.utils import BidsEntity + + +@given(sb_st.bids_tables(min_entities=2)) +def test_zip_lists_must_be_same_length(table: BidsTable): + zip_lists = table.to_dict() + itx.first(zip_lists.values()).append("foo") + with pytest.raises( + ValueError, match="each entity must have the same number of values" + ): + BidsComponent(name="foo", path=get_bids_path(zip_lists), table=zip_lists) + + +class TestEq: + @given(sb_st.bids_tables(), sb_st.everything_except(BidsTable)) + def test_other_types_are_unequal(self, table: BidsTable, other: Any): + assert table != other + + @given(sb_st.bids_tables(min_entities=0, min_values=0)) + def test_copied_object_is_equal(self, table: BidsTable): + other = copy.deepcopy(table) + assert table == other + + @given(sb_st.bids_tables(min_entities=2, min_values=0)) + def test_wildcard_order_is_irrelevant(self, table: BidsTable): + other = copy.deepcopy(table) + reordered = BidsTable( + wildcards=reversed(other.wildcards), + entries=[tuple(reversed(entry)) for entry in other.entries], + ) + assert table == reordered + + @given( + table=sb_st.bids_tables(min_entities=1, min_values=0), + wildcard=sb_st.bids_entity(), + ) + def test_wildcards_must_be_the_same(self, table: BidsTable, wildcard: BidsEntity): + assume(wildcard.wildcard not in table.wildcards) + other = copy.deepcopy(table) + reordered = attrs.evolve(other, wildcards=[wildcard, *other.wildcards[1:]]) + assert table != reordered + + @given(sb_st.bids_tables()) + def test_mutation_makes_unequal(self, table: BidsTable): + cp = copy.deepcopy(table) + other = attrs.evolve( + cp, + entries=[ + ("0" + "".join(cp.entries[0]), *cp.entries[0][1:]), + *cp.entries[1:], + ], + ) + assert table != other + + @given(sb_st.bids_tables()) + def test_extra_entry_makes_unequal(self, table: BidsTable): + cp = copy.deepcopy(table) + other = attrs.evolve(cp, entries=[cp.entries[0], *cp.entries]) + assert table != other + + @given(sb_st.bids_tables()) + def test_missing_entry_makes_unequal(self, table: BidsTable): + cp = copy.deepcopy(table) + other = attrs.evolve(cp, entries=cp.entries[1:]) + assert table != other diff --git a/snakebids/utils/snakemake_io.py b/snakebids/utils/snakemake_io.py index 36d3368f..faf9871d 100644 --- a/snakebids/utils/snakemake_io.py +++ b/snakebids/utils/snakemake_io.py @@ -9,6 +9,7 @@ from pathlib import Path from typing import Sequence +from snakebids.core._table import BidsTable from snakebids.types import ZipList from snakebids.utils.containers import MultiSelectDict @@ -118,6 +119,67 @@ def glob_wildcards( return MultiSelectDict(wildcards) +def glob_wildcards_to_entries( + pattern: str | Path, + files: Sequence[str | Path] | None = None, + followlinks: bool = False, +) -> BidsTable | None: + """Glob the values of wildcards by matching a pattern to the filesystem. + + Returns a zip_list of field names with matched wildcard values. + + Parameters + ---------- + pattern + Path including wildcards to glob on the filesystem. + files + Files from which to glob wildcards. If None (default), the directory + corresponding to the first wildcard in the pattern is walked, and + wildcards are globbed from all files. + followlinks + Whether to follow links when globbing wildcards. + """ + pattern = os.path.normpath(pattern) + first_wildcard = re.search("{[^{]", pattern) + dirname = ( + os.path.dirname(pattern[: first_wildcard.start()]) + if first_wildcard + else os.path.dirname(pattern) + ) + if not dirname: + dirname = Path(".") + + names = [match.group("name") for match in _wildcard_regex.finditer(pattern)] + + # remove duplicates while preserving ordering + names = tuple(dict.fromkeys(names)) + if not names: + return None + + entries: list[tuple[str, ...]] = [] + + re_pattern = re.compile(regex(pattern)) + + file_iter = ( + ( + Path(dirpath, f) + for dirpath, dirnames, filenames in os.walk( + dirname, followlinks=followlinks + ) + for f in chain(filenames, dirnames) + ) + if files is None + else iter(files) + ) + + for f in file_iter: + if match := re.match(re_pattern, str(f)): + values = tuple(match.group(name) for name in names) + entries.append(values) + + return BidsTable(wildcards=names, entries=entries) + + def update_wildcard_constraints( pattern: str, wildcard_constraints: dict[str, str], diff --git a/snakebids/utils/utils.py b/snakebids/utils/utils.py index cd0774ac..01af1f61 100644 --- a/snakebids/utils/utils.py +++ b/snakebids/utils/utils.py @@ -225,15 +225,6 @@ def matches_any( return any(match_func(match, item, *args) for match in match_list) -class BidsParseError(Exception): - """Exception raised for errors encountered in the parsing of Bids paths.""" - - def __init__(self, path: str, entity: BidsEntity) -> None: - self.path = path - self.entity = entity - super().__init__(path, entity) - - class _Documented(Protocol): __doc__: str From 15329e5b73364edaf2fe7c348ab892cef1759472 Mon Sep 17 00:00:00 2001 From: Peter Van Dyken Date: Wed, 2 Oct 2024 16:04:05 -0400 Subject: [PATCH 2/2] Fixes to tests --- snakebids/core/_table.py | 5 +-- snakebids/core/input_generation.py | 1 - snakebids/tests/test_printing.py | 56 ++++++++++-------------------- 3 files changed, 22 insertions(+), 40 deletions(-) diff --git a/snakebids/core/_table.py b/snakebids/core/_table.py index 95677227..acf23911 100644 --- a/snakebids/core/_table.py +++ b/snakebids/core/_table.py @@ -82,11 +82,12 @@ def get(self, wildcard: str): def pick(self, wildcards: Iterable[str]): """Select wildcards without deduplication.""" # Use dict.fromkeys for de-duplication to preserve order - indices = [self.wildcards.index(w) for w in dict.fromkeys(wildcards)] + unique_keys = list(dict.fromkeys(wildcards)) + indices = [self.wildcards.index(w) for w in unique_keys] entries = [tuple(entry[i] for i in indices) for entry in self.entries] - return self.__class__(wildcards=wildcards, entries=entries) + return self.__class__(wildcards=unique_keys, entries=entries) def filter( self, diff --git a/snakebids/core/input_generation.py b/snakebids/core/input_generation.py index 3652220d..5620b179 100644 --- a/snakebids/core/input_generation.py +++ b/snakebids/core/input_generation.py @@ -12,7 +12,6 @@ Any, Iterable, Literal, - cast, overload, ) diff --git a/snakebids/tests/test_printing.py b/snakebids/tests/test_printing.py index ce4ddb5b..37b433bb 100644 --- a/snakebids/tests/test_printing.py +++ b/snakebids/tests/test_printing.py @@ -8,9 +8,8 @@ from hypothesis import strategies as st import snakebids.tests.strategies as sb_st +from snakebids.core._table import BidsTable from snakebids.core.datasets import BidsComponent, BidsDataset -from snakebids.io.printing import format_zip_lists -from snakebids.types import ZipList def zip_list_parser() -> pp.ParserElement: @@ -26,39 +25,22 @@ def zip_list_parser() -> pp.ParserElement: @given(zip_list=sb_st.bids_tables(max_entities=1, restrict_patterns=True)) -def test_ellipses_appears_when_maxwidth_too_short(zip_list: ZipList): - width = len(format_zip_lists(zip_list, tabstop=0).splitlines()[1]) - parsed = zip_list_parser().parse_string( - format_zip_lists(zip_list, width - 1, tabstop=0) - ) +def test_ellipses_appears_when_maxwidth_too_short(zip_list: BidsTable): + width = len(zip_list.pformat(tabstop=0).splitlines()[1]) + parsed = zip_list_parser().parse_string(zip_list.pformat(width - 1, tabstop=0)) assert "ellipse" in parsed[0] @given(zip_list=sb_st.bids_tables(max_entities=1, restrict_patterns=True)) -def test_no_ellipses_when_no_max_width(zip_list: ZipList): - parsed = zip_list_parser().parse_string(format_zip_lists(zip_list, tabstop=0)) +def test_no_ellipses_when_no_max_width(zip_list: BidsTable): + parsed = zip_list_parser().parse_string(zip_list.pformat(tabstop=0)) assert "ellipse" not in parsed[0] @given(zip_list=sb_st.bids_tables(max_entities=1, restrict_patterns=True)) -def test_no_ellipses_when_max_width_long_enouth(zip_list: ZipList): - width = len(format_zip_lists(zip_list, tabstop=0).splitlines()[1]) - parsed = zip_list_parser().parse_string( - format_zip_lists(zip_list, width, tabstop=0) - ) - assert "ellipse" not in parsed[0] - - -@given( - zip_list=sb_st.bids_tables( - max_entities=1, min_values=0, max_values=0, restrict_patterns=True - ) -) -def test_no_ellipses_appears_when_ziplist_empty(zip_list: ZipList): - width = len(format_zip_lists(zip_list, tabstop=0).splitlines()[1]) - parsed = zip_list_parser().parse_string( - format_zip_lists(zip_list, width - 1, tabstop=0) - ) +def test_no_ellipses_when_max_width_long_enough(zip_list: BidsTable): + width = len(zip_list.pformat(tabstop=0).splitlines()[1]) + parsed = zip_list_parser().parse_string(zip_list.pformat(width, tabstop=0)) assert "ellipse" not in parsed[0] @@ -68,9 +50,9 @@ def test_no_ellipses_appears_when_ziplist_empty(zip_list: ZipList): ), width=st.integers(min_value=10, max_value=200), ) -def test_values_balanced_around_elision_correctly(zip_list: ZipList, width: int): +def test_values_balanced_around_elision_correctly(zip_list: BidsTable, width: int): parsed: pp.ParseResults = zip_list_parser().parse_string( - format_zip_lists(zip_list, max_width=width, tabstop=0) + zip_list.pformat(max_width=width, tabstop=0) ) assert parsed assert parsed[0] @@ -95,9 +77,9 @@ class TestCorrectNumberOfLinesCreated: min_values=0, max_values=1, max_entities=6, restrict_patterns=True ), ) - def test_in_zip_list(self, zip_list: ZipList): + def test_in_zip_list(self, zip_list: BidsTable): assert ( - len(format_zip_lists(zip_list, tabstop=0).splitlines()) == len(zip_list) + 2 + len(zip_list.pformat(tabstop=0).splitlines()) == len(zip_list.wildcards) + 2 ) @given( @@ -124,8 +106,8 @@ class TestIsValidPython: @given( zip_list=sb_st.bids_tables(restrict_patterns=True, min_values=0, min_entities=0) ) - def test_in_zip_list(self, zip_list: ZipList): - assert eval(format_zip_lists(zip_list, inf)) == zip_list + def test_in_zip_list(self, zip_list: BidsTable): + assert eval(zip_list.pformat(inf)) == zip_list.to_dict() @given(component=sb_st.bids_components(restrict_patterns=True, min_values=0)) def test_in_component(self, component: BidsComponent): @@ -144,9 +126,9 @@ def test_in_dataset(self, dataset: BidsDataset): width=st.integers(10, 100), tab=st.integers(0, 10), ) -def test_line_never_longer_than_max_width(zip_list: ZipList, width: int, tab: int): +def test_line_never_longer_than_max_width(zip_list: BidsTable, width: int, tab: int): assume(width > tab + 10) - formatted = format_zip_lists(zip_list, width, tab) + formatted = zip_list.pformat(width, tab) parsed = zip_list_parser().parse_string(formatted) assume("left" in parsed[0]) assert all(len(line) <= width for line in formatted.splitlines()) @@ -161,8 +143,8 @@ class TestIndentLengthMultipleOfTabStop: zip_list=sb_st.bids_tables(restrict_patterns=True, min_values=0), tabstop=st.integers(1, 10), ) - def test_in_zip_list(self, zip_list: ZipList, tabstop: int): - for line in format_zip_lists(zip_list, tabstop=tabstop).splitlines(): + def test_in_zip_list(self, zip_list: BidsTable, tabstop: int): + for line in zip_list.pformat(tabstop=tabstop).splitlines(): assert get_indent_length(line) / tabstop in {0, 1} @given(