Skip to content

Commit

Permalink
attr_conflict argument in spool.chunk (#246)
Browse files Browse the repository at this point in the history
* fix coord issue

* fix degenerate coord, sort directory spool contents

* remove rolling cache

* test index version warnings

* start fix_242

* add attr_conflict kwargs

* fix coverage

* better error message

* attr_conflict -> conflict

* fix typo
  • Loading branch information
d-chambers authored Sep 5, 2023
1 parent 2251a94 commit 453deee
Show file tree
Hide file tree
Showing 10 changed files with 167 additions and 34 deletions.
10 changes: 10 additions & 0 deletions dascore/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,16 @@ def map(self, func, iterables, **kwargs):
same units as the specified dimension, or have units attached.
"""

attr_conflict_description = """
Indicates how to handle conflicts in attributes other than those
indicated by dim (eg tag, history, station, etc). If "drop" simply
drop conflicting attributes, or attributes not shared by all models.
If "raise" raise an
[AttributeMergeError](`dascore.exceptions.AttributeMergeError`] when
issues are encountered. If "keep_first", just keep the first value
for each attribute.
"""


# Rich styles for various object displays.
dascore_styles = dict(
Expand Down
22 changes: 13 additions & 9 deletions dascore/core/attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
VALID_DATA_CATEGORIES,
VALID_DATA_TYPES,
PatchType,
attr_conflict_description,
basic_summary_attrs,
max_lens,
)
Expand All @@ -27,6 +28,7 @@
from dascore.utils.docs import compose_docstring
from dascore.utils.mapping import FrozenDict
from dascore.utils.misc import (
_dict_list_diffs,
all_diffs_close_enough,
get_middle_value,
iterate,
Expand Down Expand Up @@ -263,6 +265,7 @@ def flat_dump(self, dim_tuple=False, exclude=None) -> dict:
return out


@compose_docstring(conflict_desc=attr_conflict_description)
def combine_patch_attrs(
model_list: Sequence[PatchAttrs],
coord_name: str | None = None,
Expand All @@ -280,12 +283,7 @@ def combine_patch_attrs(
coord_name
The coordinate, usually a dimension coord, along which to merge.
conflicts
Indicates how to handle conflicts in attributes other than those
indicated by dim. If "drop" simply drop conflicting attributes,
or attributes not shared by all models. If "raise" raise an
[AttributeMergeError](`dascore.exceptions.AttributeMergeError`] when
issues are encountered. If "keep_first", just keep the first value
for each attribute.
{conflict_desc}
drop_attrs
If provided, attributes which should be dropped.
coord
Expand Down Expand Up @@ -393,12 +391,18 @@ def _handle_other_attrs(mod_dict_list):
if conflicts == "keep_first":
return [dict(ChainMap(*mod_dict_list))]
no_null_ = _replace_null_with_None(mod_dict_list)
all_eq = all(no_null_[0] == x for x in no_null_)
all_eq = all(no_null_[0] == x for x in no_null_[1:])
if all_eq:
return mod_dict_list
# now the fun part.
if conflicts == "raise":
msg = "Cannot merge models, not all of their non-dim attrs are equal."
# determine which keys are not equal to help debug.
uneq_keys = _dict_list_diffs(mod_dict_list)
msg = (
"Cannot merge models, the following non-dim attrs are not "
f"equal: {uneq_keys}. Consider setting the `conflict` or "
f"`attr_conflict` arguments for more flexibility in merging "
f"unequal coordinates."
)
raise AttributeMergeError(msg)
final_dict = reduce(_keep_eq, mod_dict_list)
return [final_dict]
Expand Down
54 changes: 42 additions & 12 deletions dascore/core/spool.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections.abc import Callable, Generator, Mapping, Sequence
from functools import singledispatch
from pathlib import Path
from typing import TypeVar
from typing import Literal, TypeVar

import numpy as np
import pandas as pd
Expand All @@ -14,7 +14,13 @@

import dascore as dc
import dascore.io
from dascore.constants import ExecutorType, PatchType, numeric_types, timeable_types
from dascore.constants import (
ExecutorType,
PatchType,
attr_conflict_description,
numeric_types,
timeable_types,
)
from dascore.exceptions import InvalidSpoolError, ParameterError
from dascore.utils.chunk import ChunkManager
from dascore.utils.display import get_dascore_text, get_nice_text
Expand Down Expand Up @@ -70,13 +76,14 @@ def __eq__(self, other) -> bool:
"""Simple equality checks on spools."""

def _vals_equal(dict1, dict2):
if set(dict1) != set(dict2):
if (set1 := set(dict1)) != set(dict2):
return False
for key in set(dict1):
for key in set1:
val1, val2 = dict1[key], dict2[key]
if isinstance(val1, dict):
if not _vals_equal(val1, val2):
return False
return
# this is primarily for dataframes which have equals method.
elif hasattr(val1, "equals"):
if not val1.equals(val2):
return False
Expand All @@ -89,12 +96,14 @@ def _vals_equal(dict1, dict2):
return _vals_equal(my_dict, other_dict)

@abc.abstractmethod
@compose_docstring(conflict_desc=attr_conflict_description)
def chunk(
self,
overlap: numeric_types | timeable_types | None = None,
keep_partial: bool = False,
snap_coords: bool = True,
tolerance: float = 1.5,
conflict: Literal["drop", "raise", "keep_first"] = "raise",
**kwargs,
) -> Self:
"""
Expand All @@ -114,6 +123,8 @@ def chunk(
tolerance
The number of samples a block of data can be spaced and still be
considered contiguous.
conflict
{conflict_desc}
kwargs
kwargs are used to specify the dimension along which to chunk, eg:
`time=10` chunks along the time axis in 10 second increments.
Expand Down Expand Up @@ -295,8 +306,10 @@ class DataFrameSpool(BaseSpool):
_instruction_df: pd.DataFrame = CacheDescriptor("_cache", "_get_instruction_df")
# kwargs for filtering contents
_select_kwargs: Mapping | None = FrozenDict()
# kwargs for merging patches
_merge_kwargs: Mapping | None = FrozenDict()
# attributes which effect merge groups for internal patches
_group_columns = ("network", "station", "dims", "data_type", "tag", "history")
_group_columns = ("network", "station", "dims", "data_type", "tag")
_drop_columns = ("patch",)

def _get_df(self):
Expand All @@ -308,9 +321,12 @@ def _get_source_df(self):
def _get_instruction_df(self):
"""Function to get the current df."""

def __init__(self, select_kwargs: dict | None = None):
def __init__(
self, select_kwargs: dict | None = None, merge_kwargs: dict | None = None
):
self._cache = {}
self._select_kwargs = {} if select_kwargs is None else select_kwargs
self._merge_kwargs = {} if merge_kwargs is None else merge_kwargs

def __getitem__(self, item):
out = self._get_patches_from_index(item)
Expand Down Expand Up @@ -375,7 +391,7 @@ def _patch_from_instruction_df(self, joined):
info["patch"] = trimmed_patch
out.append(info)
if len(out) > expected_len:
out = _force_patch_merge(out)
out = _force_patch_merge(out, merge_kwargs=self._merge_kwargs)
return [x["patch"] for x in out]

@staticmethod
Expand Down Expand Up @@ -417,6 +433,7 @@ def chunk(
keep_partial: bool = False,
snap_coords: bool = True,
tolerance: float = 1.5,
conflict: Literal["drop", "raise", "keep_first"] = "raise",
**kwargs,
) -> Self:
"""{doc}."""
Expand All @@ -426,16 +443,29 @@ def chunk(
keep_partial=keep_partial,
group_columns=self._group_columns,
tolerance=tolerance,
conflict=conflict,
**kwargs,
)
in_df, out_df = chunker.chunk(df)
if df.empty:
instructions = None
else:
instructions = chunker.get_instruction_df(in_df, out_df)
return self.new_from_df(out_df, source_df=self._df, instruction_df=instructions)
return self.new_from_df(
out_df,
source_df=self._df,
instruction_df=instructions,
merge_kwargs={"conflicts": conflict},
)

def new_from_df(self, df, source_df=None, instruction_df=None, select_kwargs=None):
def new_from_df(
self,
df,
source_df=None,
instruction_df=None,
select_kwargs=None,
merge_kwargs=None,
):
"""Create a new instance from dataframes."""
new = self.__class__(self)
df_, source_, inst_ = self._get_dummy_dataframes(df)
Expand All @@ -444,6 +474,8 @@ def new_from_df(self, df, source_df=None, instruction_df=None, select_kwargs=Non
new._instruction_df = instruction_df if instruction_df is not None else inst_
new._select_kwargs = dict(self._select_kwargs)
new._select_kwargs.update(select_kwargs or {})
new._merge_kwargs = dict(self._merge_kwargs)
new._merge_kwargs.update(merge_kwargs or {})
return new

@compose_docstring(doc=BaseSpool.select.__doc__)
Expand Down Expand Up @@ -487,11 +519,9 @@ def sort(self, attribute) -> Self:
old_indices = df.index
new_indices = np.arange(len(df))
mapper = pd.Series(new_indices, index=old_indices)

# swap out all the old values with new ones
new_current_index = inst_df["current_index"].map(mapper)
new_instruction_df = inst_df.assign(current_index=new_current_index)

# create new spool from new dataframes
return self.new_from_df(df=sorted_df, instruction_df=new_instruction_df)

Expand Down
21 changes: 18 additions & 3 deletions dascore/utils/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
import numpy as np
import pandas as pd

from dascore.constants import numeric_types, timeable_types
from dascore.constants import attr_conflict_description, numeric_types, timeable_types
from dascore.exceptions import CoordMergeError, ParameterError
from dascore.utils.docs import compose_docstring
from dascore.utils.misc import get_middle_value
from dascore.utils.pd import (
_remove_overlaps,
Expand Down Expand Up @@ -94,6 +95,7 @@ def get_intervals(
return np.stack([starts, ends]).T


@compose_docstring(attr_conflict=attr_conflict_description)
class ChunkManager:
"""
A class for managing the chunking of data defined in a dataframe.
Expand All @@ -115,6 +117,8 @@ class ChunkManager:
The upper limit of a gap to tolerate in terms of the sampling
along the desired dimension. E.G., the default value means entities
with gaps <= 1.5 * {name}_step will be merged.
conflict
{attr_conflict}
**kawrgs
kwargs specify the column along which to chunk. The key specifies the
column along which to chunk, typically, `time` or `distance`, and the
Expand All @@ -132,21 +136,23 @@ def __init__(
group_columns: Collection[str] | None = None,
keep_partial=False,
tolerance=1.5,
conflict="raise",
**kwargs,
):
self._overlap = overlap
self._group_columns = group_columns
self._keep_partials = keep_partial
self._tolerance = tolerance
self._name, self._value = self._validate_kwargs(kwargs)
self._attr_conflict = conflict
self._validate_chunker()

def _validate_kwargs(self, kwargs):
"""Ensure kwargs is len one and has a valid."""
if not len(kwargs) == 1:
msg = (
f"Chunking only supported along one dimension. You passed "
f"You passed kwargs: {kwargs}"
f"kwargs: {kwargs}"
)
raise ParameterError(msg)
((key, value),) = kwargs.items()
Expand Down Expand Up @@ -218,12 +224,21 @@ def _create_df(self, df, name, start_stop, gnum):
out = pd.DataFrame(start_stop, columns=list(cols))
out[f"{name}_step"] = get_middle_value(df[f"{name}_step"].values)
merger = df.drop(columns=out.columns)
# get dims to determine which columns are still compared. Some test
# dfs don't have dims though, so it should still work without dims col.
dims = set(df.iloc[0].get("dims", "").split(","))
for col in set(merger.columns):
prefix = col.split("_")[0]
# If we have specified to ignore or remove conflicting attrs
# we don't need to check them here, but we do still check dims.
if self._attr_conflict != "raise" and prefix not in dims:
continue
vals = merger[col].unique()
if len(vals) > 1:
msg = (
f"Cannot merge on dim {self._name} because all values for "
f"{col} are not equal."
f"{col} are not equal. Consider using the `attr_conflict` "
f"argument to loosen this restriction."
)
raise CoordMergeError(msg)

Expand Down
25 changes: 23 additions & 2 deletions dascore/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,13 +438,17 @@ def separate_coord_info(
If provided, the required attributes (e.g., min, max, step).
cant_be_alone
names which cannot be on their own.
Returns
-------
coord_dict and attrs_dict.
"""

def _meets_required(coord_dict):
"""Return True coord dict does not meet the minimum required keys."""
"""Return True coord dict meets the minimum required keys."""
if not coord_dict:
return False
if required is None and (set(coord_dict) - cant_be_alone):
if not required and (set(coord_dict) - cant_be_alone):
return True
return set(coord_dict).issuperset(required)

Expand Down Expand Up @@ -590,3 +594,20 @@ def _spool_map(spool, func, size=None, client=None, progress=True, **kwargs):
spools = spool.split(size=size)
new_func = _MapFuncWrapper(func, kwargs, progress=progress)
return [x for y in client.map(new_func, spools) for x in y]


def _dict_list_diffs(dict_list):
"""Return the keys which are not equal dicts in a list."""
out = set()
first = dict_list[0]
first_keys = set(first)
for other in dict_list[1:]:
if other == first:
continue
other_keys = set(other)
out |= (other_keys - first_keys) | (first_keys - other_keys)
common_keys = other_keys & first_keys
for key in common_keys:
if first[key] != other[key]:
out.add(key)
return sorted(out)
5 changes: 3 additions & 2 deletions dascore/utils/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def merge_patches(
return dc.spool(patches).chunk(**{dim: None}, tolerance=tolerance)


def _force_patch_merge(patch_dict_list):
def _force_patch_merge(patch_dict_list, merge_kwargs, **kwargs):
"""
Force a merge of the patches along a dimension.
Expand Down Expand Up @@ -325,6 +325,7 @@ def _get_new_coord(df, merge_dim, coords):

df = pd.DataFrame(patch_dict_list)
merge_dim = _get_merge_col(df)
merge_kwargs = merge_kwargs if merge_kwargs is not None else {}
if merge_dim is None: # nothing to merge, complete overlap
return [patch_dict_list[0]]
dims = df["dims"].iloc[0].split(",")
Expand All @@ -338,7 +339,7 @@ def _get_new_coord(df, merge_dim, coords):
new_data = np.concatenate(datas, axis=axis)
new_coord = _get_new_coord(df, merge_dim, coords)
coord = new_coord.coord_map[merge_dim] if merge_dim in dims else None
new_attrs = combine_patch_attrs(attrs, merge_dim, coord=coord)
new_attrs = combine_patch_attrs(attrs, merge_dim, coord=coord, **merge_kwargs)
patch = dc.Patch(data=new_data, coords=new_coord, attrs=new_attrs, dims=dims)
new_dict = {"patch": patch}
return [new_dict]
Expand Down
Loading

0 comments on commit 453deee

Please sign in to comment.