Skip to content

Commit

Permalink
Transition from zip-list to entry core
Browse files Browse the repository at this point in the history
The entry table uses entry-first strides. This will be a much more
efficient storage method for future API that makes individual entries
more accessible. As much of the core as is currently feasible is
rewritten to use entries over zip-lists

The public API is completely unchanged as of now.

Tests for filtering are largely unchanged, although some may be
redundant at this point.
  • Loading branch information
pvandyken committed Oct 2, 2024
1 parent 5aab121 commit 1e00e3b
Show file tree
Hide file tree
Showing 12 changed files with 605 additions and 334 deletions.
122 changes: 122 additions & 0 deletions snakebids/core/_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Iterable, Mapping

import attrs
import more_itertools as itx

from snakebids.io.printing import format_zip_lists
from snakebids.types import ZipList, ZipListLike
from snakebids.utils.containers import ContainerBag, MultiSelectDict, RegexContainer

if TYPE_CHECKING:

def wcard_tuple(x: Iterable[str]) -> tuple[str, ...]:
return tuple(x)

def entries_list(x: Iterable[tuple[str, ...]]) -> list[tuple[str, ...]]:
return list(x)

def liststr() -> list[str]:
return []
else:
wcard_tuple = tuple
entries_list = list
liststr = list


@attrs.frozen(kw_only=True)
class BidsTable:
"""Container holding the entries of a BidsComponent."""

wildcards: tuple[str, ...] = attrs.field(converter=wcard_tuple)
entries: list[tuple[str, ...]] = attrs.field(converter=entries_list)

def __bool__(self):
"""Return True if one or more entries, otherwise False."""
return bool(self.entries)

def __eq__(self, other: BidsTable | object):
if not isinstance(other, self.__class__):
return False
if set(self.wildcards) != set(other.wildcards):
return False
if len(self.entries) != len(other.entries):
return False
if self.wildcards == other.wildcards:
return sorted(self.entries) == sorted(other.entries)
ixs = [other.wildcards.index(w) for w in self.wildcards]
entries = self.entries.copy()
try:
for entry in other.entries:
sorted_entry = tuple(entry[i] for i in ixs)
entries.remove(sorted_entry)
except ValueError:
return False
return True

@classmethod
def from_dict(cls, d: ZipListLike):
"""Construct BidsTable from a mapping of entities to value lists."""
lengths = {len(val) for val in d.values()}
if len(lengths) > 1:
msg = "each entity must have the same number of values"
raise ValueError(msg)
return cls(wildcards=d.keys(), entries=zip(*d.values()))

def to_dict(self) -> ZipList:
"""Convert into a zip_list."""
if not self.entries:
return MultiSelectDict(zip(self.wildcards, itx.repeatfunc(liststr)))
return MultiSelectDict(zip(self.wildcards, map(list, zip(*self.entries))))

def pformat(self, max_width: int | float | None = None, tabstop: int = 4) -> str:
"""Pretty-format."""
return format_zip_lists(self.to_dict(), max_width=max_width, tabstop=tabstop)

def get(self, wildcard: str):
"""Get values for a single wildcard."""
index = self.wildcards.index(wildcard)
return [entry[index] for entry in self.entries]

def pick(self, wildcards: Iterable[str]):
"""Select wildcards without deduplication."""
# Use dict.fromkeys for de-duplication to preserve order
indices = [self.wildcards.index(w) for w in dict.fromkeys(wildcards)]

entries = [tuple(entry[i] for i in indices) for entry in self.entries]

return self.__class__(wildcards=wildcards, entries=entries)

def filter(
self,
filters: Mapping[str, Iterable[str] | str],
regex_search: bool = False,
):
"""Apply filtering operation."""
valid_filters = set(self.wildcards)
if regex_search:
filter_sets = {
self.wildcards.index(key): ContainerBag(
*(RegexContainer(r) for r in itx.always_iterable(vals))
)
for key, vals in filters.items()
if key in valid_filters
}
else:
filter_sets = {
self.wildcards.index(key): set(itx.always_iterable(vals))
for key, vals in filters.items()
if key in valid_filters
}

keep = [
entry
for entry in self.entries
if all(
i not in filter_sets or val in filter_sets[i]
for i, val in enumerate(entry)
)
]

return self.__class__(wildcards=self.wildcards, entries=keep)
94 changes: 47 additions & 47 deletions snakebids/core/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from math import inf
from pathlib import Path
from string import Formatter
from typing import Any, Iterable, NoReturn, cast, overload
from typing import Any, Iterable, Mapping, NoReturn, cast, overload

import attr
import more_itertools as itx
Expand All @@ -16,14 +16,14 @@
from typing_extensions import Self, TypedDict

import snakebids.utils.sb_itertools as sb_it
from snakebids.core.filtering import filter_list
from snakebids.core._table import BidsTable
from snakebids.exceptions import DuplicateComponentError
from snakebids.io.console import get_console_size
from snakebids.io.printing import format_zip_lists, quote_wrap
from snakebids.io.printing import quote_wrap
from snakebids.snakemake_compat import expand as sn_expand
from snakebids.types import ZipList
from snakebids.types import ZipList, ZipListLike
from snakebids.utils.containers import ImmutableList, MultiSelectDict, UserDictPy38
from snakebids.utils.utils import get_wildcard_dict, property_alias, zip_list_eq
from snakebids.utils.utils import get_wildcard_dict, property_alias


class BidsDatasetDict(TypedDict):
Expand Down Expand Up @@ -176,12 +176,21 @@ def filter(
msg = "Both __spec and filters cannot be used simultaneously"
raise ValueError(msg)
filters = {self.entity: spec}
entity, data = itx.first(
filter_list(
{self.entity: self._data}, filters, regex_search=regex_search
).items()
data = it.chain.from_iterable(
BidsTable(wildcards=[self.entity], entries=[(el,) for el in self._data])
.filter(filters, regex_search=regex_search)
.entries
)
return self.__class__(data, entity=entity)
return self.__class__(data, entity=self.entity)


def _to_bids_table(tbl: BidsTable | ZipListLike) -> BidsTable:
if isinstance(tbl, BidsTable):
return tbl
if isinstance(tbl, Mapping): # type: ignore
return BidsTable.from_dict(tbl)
msg = f"Cannot convert '{tbl}' to BidsTable"
raise TypeError(msg)


@attr.define(kw_only=True)
Expand All @@ -208,17 +217,12 @@ class BidsPartialComponent:
``BidsPartialComponents`` are immutable: their values cannot be altered.
"""

_zip_lists: ZipList = attr.field(
on_setattr=attr.setters.frozen, converter=MultiSelectDict, alias="zip_lists"
_table: BidsTable = attr.field(
converter=_to_bids_table,
on_setattr=attr.setters.frozen,
alias="table",
)

@_zip_lists.validator # type: ignore
def _validate_zip_lists(self, __attr: str, value: dict[str, list[str]]) -> None:
lengths = {len(val) for val in value.values()}
if len(lengths) > 1:
msg = "zip_lists must all be of equal length"
raise ValueError(msg)

def __repr__(self) -> str:
return self.pformat()

Expand All @@ -232,11 +236,8 @@ def __getitem__(
self, key: str | tuple[str, ...], /
) -> BidsComponentRow | BidsPartialComponent:
if isinstance(key, tuple):
# Use dict.fromkeys for de-duplication
return BidsPartialComponent(
zip_lists={key: self.zip_lists[key] for key in dict.fromkeys(key)}
)
return BidsComponentRow(self.zip_lists[key], entity=key)
return BidsPartialComponent(table=self._table.pick(key))
return BidsComponentRow(self._table.get(key), entity=key)

def __bool__(self) -> bool:
"""Truth of a BidsComponent is based on whether it has values.
Expand All @@ -247,7 +248,7 @@ def __bool__(self) -> bool:
consistent with :class:`BidsComponentRow`, which always has an entity name
stored, but may or may not have values.
"""
return bool(itx.first(self.zip_lists))
return bool(self._table.entries)

def _pformat_body(self) -> None | str | list[str]:
"""Extra properties to be printed within pformat.
Expand All @@ -271,8 +272,7 @@ def pformat(self, max_width: int | float | None = None, tabstop: int = 4) -> str
body = it.chain(
itx.always_iterable(self._pformat_body() or []),
[
"zip_lists="
f"{format_zip_lists(self.zip_lists, width - tabstop, tabstop)},",
"table=" f"{self._table.pformat(width - tabstop, tabstop)},",
],
)
output = [
Expand All @@ -292,6 +292,9 @@ def pformat(self, max_width: int | float | None = None, tabstop: int = 4) -> str
_entities: list[str] | None = attr.field(
default=None, init=False, eq=False, repr=False
)
_zip_lists: ZipList | None = attr.field(
default=None, init=False, eq=False, repr=False
)

@property
def zip_lists(self):
Expand All @@ -302,6 +305,9 @@ def zip_lists(self):
of images matched for this modality, so they can be zipped together to get a
list of the wildcard values for each file.
"""
if self._zip_lists is not None:
return self._zip_lists
self._zip_lists = self._table.to_dict()
return self._zip_lists

@property
Expand All @@ -328,15 +334,8 @@ def wildcards(self) -> MultiSelectDict[str, str]:
self._input_wildcards = MultiSelectDict(get_wildcard_dict(self.zip_lists))
return self._input_wildcards

@property
@property_alias(zip_lists, "zip_lists", "snakebids.BidsPartialComponent.zip_lists")
def input_zip_lists(self) -> ZipList:
"""Alias of :attr:`zip_lists <snakebids.BidsComponent.zip_lists>`.
Dictionary where each key is a wildcard entity and each value is a list of the
values found for that entity. Each of these lists has length equal to the number
of images matched for this modality, so they can be zipped together to get a
list of the wildcard values for each file.
"""
return self.zip_lists

@property_alias(entities, "entities", "snakebids.BidsComponent.entities")
Expand All @@ -351,7 +350,7 @@ def __eq__(self, other: BidsComponent | object) -> bool:
if not isinstance(other, self.__class__):
return False

return zip_list_eq(self.zip_lists, other.zip_lists)
return self._table == other._table

def expand(
self,
Expand Down Expand Up @@ -458,7 +457,7 @@ def filter(
return self
return attr.evolve(
self,
zip_lists=filter_list(self.zip_lists, filters, regex_search=regex_search),
table=self._table.filter(filters, regex_search=regex_search),
)


Expand Down Expand Up @@ -497,27 +496,28 @@ class BidsComponent(BidsPartialComponent):
BidsComponents are immutable: their values cannot be altered.
"""

_table: BidsTable = attr.field(
converter=_to_bids_table,
on_setattr=attr.setters.frozen,
alias="table",
)

name: str = attr.field(on_setattr=attr.setters.frozen)
"""Name of the component"""

path: str = attr.field(on_setattr=attr.setters.frozen)
"""Wildcard-filled path that matches the files for this component."""

_zip_lists: ZipList = attr.field(
on_setattr=attr.setters.frozen, converter=MultiSelectDict, alias="zip_lists"
)

@_zip_lists.validator # type: ignore
def _validate_zip_lists(self, __attr: str, value: dict[str, list[str]]) -> None:
super()._validate_zip_lists(__attr, value)
@_table.validator # type: ignore
def _validate_zip_lists(self, __attr: str, value: BidsTable) -> None:
_, raw_fields, *_ = sb_it.unpack(
zip(*Formatter().parse(self.path)), [[], [], []]
)
raw_fields = cast("Iterable[str]", raw_fields)
if (fields := set(filter(None, raw_fields))) != set(value):
if (fields := set(filter(None, raw_fields))) != set(value.wildcards):
msg = (
"zip_lists entries must match the wildcards in input_path: "
f"{self.path}: {fields} != zip_lists: {set(value)}"
"entries have the same wildcards as the input path: "
f"{self.path}: {fields} != entries: {set(value.wildcards)}"
)
raise ValueError(msg)

Expand Down
Loading

0 comments on commit 1e00e3b

Please sign in to comment.