diff --git a/dascore/constants.py b/dascore/constants.py index c6a07c1d..a9e5755f 100644 --- a/dascore/constants.py +++ b/dascore/constants.py @@ -174,6 +174,16 @@ def map(self, func, iterables, **kwargs): same units as the specified dimension, or have units attached. """ +attr_conflict_description = """ +Indicates how to handle conflicts in attributes other than those +indicated by dim (eg tag, history, station, etc). If "drop" simply +drop conflicting attributes, or attributes not shared by all models. +If "raise" raise an +[AttributeMergeError](`dascore.exceptions.AttributeMergeError`] when +issues are encountered. If "keep_first", just keep the first value +for each attribute. +""" + # Rich styles for various object displays. dascore_styles = dict( diff --git a/dascore/core/attrs.py b/dascore/core/attrs.py index 28ebf2f4..2f937a3d 100644 --- a/dascore/core/attrs.py +++ b/dascore/core/attrs.py @@ -18,6 +18,7 @@ VALID_DATA_CATEGORIES, VALID_DATA_TYPES, PatchType, + attr_conflict_description, basic_summary_attrs, max_lens, ) @@ -27,6 +28,7 @@ from dascore.utils.docs import compose_docstring from dascore.utils.mapping import FrozenDict from dascore.utils.misc import ( + _dict_list_diffs, all_diffs_close_enough, get_middle_value, iterate, @@ -263,6 +265,7 @@ def flat_dump(self, dim_tuple=False, exclude=None) -> dict: return out +@compose_docstring(conflict_desc=attr_conflict_description) def combine_patch_attrs( model_list: Sequence[PatchAttrs], coord_name: str | None = None, @@ -280,12 +283,7 @@ def combine_patch_attrs( coord_name The coordinate, usually a dimension coord, along which to merge. conflicts - Indicates how to handle conflicts in attributes other than those - indicated by dim. If "drop" simply drop conflicting attributes, - or attributes not shared by all models. If "raise" raise an - [AttributeMergeError](`dascore.exceptions.AttributeMergeError`] when - issues are encountered. If "keep_first", just keep the first value - for each attribute. + {conflict_desc} drop_attrs If provided, attributes which should be dropped. coord @@ -393,12 +391,18 @@ def _handle_other_attrs(mod_dict_list): if conflicts == "keep_first": return [dict(ChainMap(*mod_dict_list))] no_null_ = _replace_null_with_None(mod_dict_list) - all_eq = all(no_null_[0] == x for x in no_null_) + all_eq = all(no_null_[0] == x for x in no_null_[1:]) if all_eq: return mod_dict_list - # now the fun part. if conflicts == "raise": - msg = "Cannot merge models, not all of their non-dim attrs are equal." + # determine which keys are not equal to help debug. + uneq_keys = _dict_list_diffs(mod_dict_list) + msg = ( + "Cannot merge models, the following non-dim attrs are not " + f"equal: {uneq_keys}. Consider setting the `conflict` or " + f"`attr_conflict` arguments for more flexibility in merging " + f"unequal coordinates." + ) raise AttributeMergeError(msg) final_dict = reduce(_keep_eq, mod_dict_list) return [final_dict] diff --git a/dascore/core/spool.py b/dascore/core/spool.py index ec0904db..e74f6f68 100644 --- a/dascore/core/spool.py +++ b/dascore/core/spool.py @@ -5,7 +5,7 @@ from collections.abc import Callable, Generator, Mapping, Sequence from functools import singledispatch from pathlib import Path -from typing import TypeVar +from typing import Literal, TypeVar import numpy as np import pandas as pd @@ -14,7 +14,13 @@ import dascore as dc import dascore.io -from dascore.constants import ExecutorType, PatchType, numeric_types, timeable_types +from dascore.constants import ( + ExecutorType, + PatchType, + attr_conflict_description, + numeric_types, + timeable_types, +) from dascore.exceptions import InvalidSpoolError, ParameterError from dascore.utils.chunk import ChunkManager from dascore.utils.display import get_dascore_text, get_nice_text @@ -70,13 +76,14 @@ def __eq__(self, other) -> bool: """Simple equality checks on spools.""" def _vals_equal(dict1, dict2): - if set(dict1) != set(dict2): + if (set1 := set(dict1)) != set(dict2): return False - for key in set(dict1): + for key in set1: val1, val2 = dict1[key], dict2[key] if isinstance(val1, dict): if not _vals_equal(val1, val2): - return False + return + # this is primarily for dataframes which have equals method. elif hasattr(val1, "equals"): if not val1.equals(val2): return False @@ -89,12 +96,14 @@ def _vals_equal(dict1, dict2): return _vals_equal(my_dict, other_dict) @abc.abstractmethod + @compose_docstring(conflict_desc=attr_conflict_description) def chunk( self, overlap: numeric_types | timeable_types | None = None, keep_partial: bool = False, snap_coords: bool = True, tolerance: float = 1.5, + conflict: Literal["drop", "raise", "keep_first"] = "raise", **kwargs, ) -> Self: """ @@ -114,6 +123,8 @@ def chunk( tolerance The number of samples a block of data can be spaced and still be considered contiguous. + conflict + {conflict_desc} kwargs kwargs are used to specify the dimension along which to chunk, eg: `time=10` chunks along the time axis in 10 second increments. @@ -295,8 +306,10 @@ class DataFrameSpool(BaseSpool): _instruction_df: pd.DataFrame = CacheDescriptor("_cache", "_get_instruction_df") # kwargs for filtering contents _select_kwargs: Mapping | None = FrozenDict() + # kwargs for merging patches + _merge_kwargs: Mapping | None = FrozenDict() # attributes which effect merge groups for internal patches - _group_columns = ("network", "station", "dims", "data_type", "tag", "history") + _group_columns = ("network", "station", "dims", "data_type", "tag") _drop_columns = ("patch",) def _get_df(self): @@ -308,9 +321,12 @@ def _get_source_df(self): def _get_instruction_df(self): """Function to get the current df.""" - def __init__(self, select_kwargs: dict | None = None): + def __init__( + self, select_kwargs: dict | None = None, merge_kwargs: dict | None = None + ): self._cache = {} self._select_kwargs = {} if select_kwargs is None else select_kwargs + self._merge_kwargs = {} if merge_kwargs is None else merge_kwargs def __getitem__(self, item): out = self._get_patches_from_index(item) @@ -375,7 +391,7 @@ def _patch_from_instruction_df(self, joined): info["patch"] = trimmed_patch out.append(info) if len(out) > expected_len: - out = _force_patch_merge(out) + out = _force_patch_merge(out, merge_kwargs=self._merge_kwargs) return [x["patch"] for x in out] @staticmethod @@ -417,6 +433,7 @@ def chunk( keep_partial: bool = False, snap_coords: bool = True, tolerance: float = 1.5, + conflict: Literal["drop", "raise", "keep_first"] = "raise", **kwargs, ) -> Self: """{doc}.""" @@ -426,6 +443,7 @@ def chunk( keep_partial=keep_partial, group_columns=self._group_columns, tolerance=tolerance, + conflict=conflict, **kwargs, ) in_df, out_df = chunker.chunk(df) @@ -433,9 +451,21 @@ def chunk( instructions = None else: instructions = chunker.get_instruction_df(in_df, out_df) - return self.new_from_df(out_df, source_df=self._df, instruction_df=instructions) + return self.new_from_df( + out_df, + source_df=self._df, + instruction_df=instructions, + merge_kwargs={"conflicts": conflict}, + ) - def new_from_df(self, df, source_df=None, instruction_df=None, select_kwargs=None): + def new_from_df( + self, + df, + source_df=None, + instruction_df=None, + select_kwargs=None, + merge_kwargs=None, + ): """Create a new instance from dataframes.""" new = self.__class__(self) df_, source_, inst_ = self._get_dummy_dataframes(df) @@ -444,6 +474,8 @@ def new_from_df(self, df, source_df=None, instruction_df=None, select_kwargs=Non new._instruction_df = instruction_df if instruction_df is not None else inst_ new._select_kwargs = dict(self._select_kwargs) new._select_kwargs.update(select_kwargs or {}) + new._merge_kwargs = dict(self._merge_kwargs) + new._merge_kwargs.update(merge_kwargs or {}) return new @compose_docstring(doc=BaseSpool.select.__doc__) @@ -487,11 +519,9 @@ def sort(self, attribute) -> Self: old_indices = df.index new_indices = np.arange(len(df)) mapper = pd.Series(new_indices, index=old_indices) - # swap out all the old values with new ones new_current_index = inst_df["current_index"].map(mapper) new_instruction_df = inst_df.assign(current_index=new_current_index) - # create new spool from new dataframes return self.new_from_df(df=sorted_df, instruction_df=new_instruction_df) diff --git a/dascore/utils/chunk.py b/dascore/utils/chunk.py index 80be0337..51dca0e0 100644 --- a/dascore/utils/chunk.py +++ b/dascore/utils/chunk.py @@ -8,8 +8,9 @@ import numpy as np import pandas as pd -from dascore.constants import numeric_types, timeable_types +from dascore.constants import attr_conflict_description, numeric_types, timeable_types from dascore.exceptions import CoordMergeError, ParameterError +from dascore.utils.docs import compose_docstring from dascore.utils.misc import get_middle_value from dascore.utils.pd import ( _remove_overlaps, @@ -94,6 +95,7 @@ def get_intervals( return np.stack([starts, ends]).T +@compose_docstring(attr_conflict=attr_conflict_description) class ChunkManager: """ A class for managing the chunking of data defined in a dataframe. @@ -115,6 +117,8 @@ class ChunkManager: The upper limit of a gap to tolerate in terms of the sampling along the desired dimension. E.G., the default value means entities with gaps <= 1.5 * {name}_step will be merged. + conflict + {attr_conflict} **kawrgs kwargs specify the column along which to chunk. The key specifies the column along which to chunk, typically, `time` or `distance`, and the @@ -132,6 +136,7 @@ def __init__( group_columns: Collection[str] | None = None, keep_partial=False, tolerance=1.5, + conflict="raise", **kwargs, ): self._overlap = overlap @@ -139,6 +144,7 @@ def __init__( self._keep_partials = keep_partial self._tolerance = tolerance self._name, self._value = self._validate_kwargs(kwargs) + self._attr_conflict = conflict self._validate_chunker() def _validate_kwargs(self, kwargs): @@ -146,7 +152,7 @@ def _validate_kwargs(self, kwargs): if not len(kwargs) == 1: msg = ( f"Chunking only supported along one dimension. You passed " - f"You passed kwargs: {kwargs}" + f"kwargs: {kwargs}" ) raise ParameterError(msg) ((key, value),) = kwargs.items() @@ -218,12 +224,21 @@ def _create_df(self, df, name, start_stop, gnum): out = pd.DataFrame(start_stop, columns=list(cols)) out[f"{name}_step"] = get_middle_value(df[f"{name}_step"].values) merger = df.drop(columns=out.columns) + # get dims to determine which columns are still compared. Some test + # dfs don't have dims though, so it should still work without dims col. + dims = set(df.iloc[0].get("dims", "").split(",")) for col in set(merger.columns): + prefix = col.split("_")[0] + # If we have specified to ignore or remove conflicting attrs + # we don't need to check them here, but we do still check dims. + if self._attr_conflict != "raise" and prefix not in dims: + continue vals = merger[col].unique() if len(vals) > 1: msg = ( f"Cannot merge on dim {self._name} because all values for " - f"{col} are not equal." + f"{col} are not equal. Consider using the `attr_conflict` " + f"argument to loosen this restriction." ) raise CoordMergeError(msg) diff --git a/dascore/utils/misc.py b/dascore/utils/misc.py index fb5d477e..b3f049fc 100644 --- a/dascore/utils/misc.py +++ b/dascore/utils/misc.py @@ -438,13 +438,17 @@ def separate_coord_info( If provided, the required attributes (e.g., min, max, step). cant_be_alone names which cannot be on their own. + + Returns + ------- + coord_dict and attrs_dict. """ def _meets_required(coord_dict): - """Return True coord dict does not meet the minimum required keys.""" + """Return True coord dict meets the minimum required keys.""" if not coord_dict: return False - if required is None and (set(coord_dict) - cant_be_alone): + if not required and (set(coord_dict) - cant_be_alone): return True return set(coord_dict).issuperset(required) @@ -590,3 +594,20 @@ def _spool_map(spool, func, size=None, client=None, progress=True, **kwargs): spools = spool.split(size=size) new_func = _MapFuncWrapper(func, kwargs, progress=progress) return [x for y in client.map(new_func, spools) for x in y] + + +def _dict_list_diffs(dict_list): + """Return the keys which are not equal dicts in a list.""" + out = set() + first = dict_list[0] + first_keys = set(first) + for other in dict_list[1:]: + if other == first: + continue + other_keys = set(other) + out |= (other_keys - first_keys) | (first_keys - other_keys) + common_keys = other_keys & first_keys + for key in common_keys: + if first[key] != other[key]: + out.add(key) + return sorted(out) diff --git a/dascore/utils/patch.py b/dascore/utils/patch.py index 414fc722..7bc0bc77 100644 --- a/dascore/utils/patch.py +++ b/dascore/utils/patch.py @@ -281,7 +281,7 @@ def merge_patches( return dc.spool(patches).chunk(**{dim: None}, tolerance=tolerance) -def _force_patch_merge(patch_dict_list): +def _force_patch_merge(patch_dict_list, merge_kwargs, **kwargs): """ Force a merge of the patches along a dimension. @@ -325,6 +325,7 @@ def _get_new_coord(df, merge_dim, coords): df = pd.DataFrame(patch_dict_list) merge_dim = _get_merge_col(df) + merge_kwargs = merge_kwargs if merge_kwargs is not None else {} if merge_dim is None: # nothing to merge, complete overlap return [patch_dict_list[0]] dims = df["dims"].iloc[0].split(",") @@ -338,7 +339,7 @@ def _get_new_coord(df, merge_dim, coords): new_data = np.concatenate(datas, axis=axis) new_coord = _get_new_coord(df, merge_dim, coords) coord = new_coord.coord_map[merge_dim] if merge_dim in dims else None - new_attrs = combine_patch_attrs(attrs, merge_dim, coord=coord) + new_attrs = combine_patch_attrs(attrs, merge_dim, coord=coord, **merge_kwargs) patch = dc.Patch(data=new_data, coords=new_coord, attrs=new_attrs, dims=dims) new_dict = {"patch": patch} return [new_dict] diff --git a/tests/test_core/test_attrs.py b/tests/test_core/test_attrs.py index 8fbed801..b2524e2d 100644 --- a/tests/test_core/test_attrs.py +++ b/tests/test_core/test_attrs.py @@ -317,18 +317,20 @@ def test_drop(self): """Ensure drop_attrs does its job.""" pa1 = PatchAttrs(history=["a", "b"]) pa2 = PatchAttrs() - with pytest.raises(AttributeMergeError, match="not all of their non-dim"): + msg = "the following non-dim attrs are not equal" + with pytest.raises(AttributeMergeError, match=msg): combine_patch_attrs([pa1, pa2]) out = combine_patch_attrs([pa1, pa2], drop_attrs="history") assert isinstance(out, PatchAttrs) def test_conflicts(self): """Ensure when non-dim fields aren't equal merge raises.""" - pa1 = PatchAttrs(tag="bob") - pa2 = PatchAttrs() - match = "not all of their non-dim" - with pytest.raises(AttributeMergeError, match=match): - combine_patch_attrs([pa1, pa2]) + pa1 = PatchAttrs(tag="bob", another=2, same=42) + pa2 = PatchAttrs(tag="bob", another=2, same=42) + pa3 = PatchAttrs(another=1, same=42, different=10) + msg = "the following non-dim attrs are not equal" + with pytest.raises(AttributeMergeError, match=msg): + combine_patch_attrs([pa1, pa2, pa3]) def test_missing_coordinate(self): """When one model has a missing coord it should just get dropped.""" diff --git a/tests/test_core/test_patch_chunk.py b/tests/test_core/test_patch_chunk.py index 702fbe89..187de704 100644 --- a/tests/test_core/test_patch_chunk.py +++ b/tests/test_core/test_patch_chunk.py @@ -204,6 +204,15 @@ def distance_adjacent_no_order(self, wacky_dim_patch): pa2 = pa1.update_attrs(distance_min=new_dist) return dc.spool([pa1, pa2]) + @pytest.fixture(scope="class") + def adjacent_spool_different_attrs(self, adjacent_spool_no_overlap): + """An adjacent spool with on attribute that is different on each patch.""" + out = [] + for num, patch in enumerate(adjacent_spool_no_overlap): + out.append(patch.update_attrs(my_attr=num)) + # since + return dc.spool(out) + def test_merge_unequal_other(self, distance_adjacent): """When distance values are not equal time shouldn't be merge-able.""" with pytest.raises(CoordMergeError): @@ -373,3 +382,18 @@ def test_merge_select(self, adjacent_spool_no_overlap): assert (time_min - time_step) < time_tup[0] assert time_max <= time_tup[1] assert (time_max + time_step) > time_tup[1] + + def test_attrs_conflict(self, adjacent_spool_different_attrs): + """Test various cases for specifying what to do when attrs conflict.""" + spool = adjacent_spool_different_attrs + # when we don't specify to ignore or drop attrs this should raise. + match = "all values for my_attr" + with pytest.raises(CoordMergeError, match=match): + spool.chunk(time=...) + # however, when we specify drop attrs this shouldn't. + out = spool.chunk(time=..., conflict="keep_first") + assert isinstance(out, dc.BaseSpool) + assert len(out) == 1 + # make sure we can read the patch + patch = out[0] + assert isinstance(patch, dc.Patch) diff --git a/tests/test_core/test_spool.py b/tests/test_core/test_spool.py index ddd8260a..716a319f 100644 --- a/tests/test_core/test_spool.py +++ b/tests/test_core/test_spool.py @@ -67,6 +67,8 @@ def test_chunked_differently(self, random_spool): """Spools with different chunking should !=.""" sp1 = random_spool.chunk(time=1.12) assert sp1 != random_spool + sp2 = random_spool.chunk(time=1.00) + assert sp2 != sp1 def test_eq_self(self, random_spool): """A spool should always eq itself.""" @@ -80,6 +82,14 @@ def test_unequal_attr(self, random_spool): new2.__dict__["bad_attr"] = 2 assert new1 != new2 + def test_unequal_dicts(self, random_spool): + """Simulate some dicts which don't have the same values.""" + new1 = copy.deepcopy(random_spool) + new1.__dict__["bad_attr"] = {1: 2} + new2 = copy.deepcopy(random_spool) + new2.__dict__["bad_attr"] = {2: 3} + assert new1 != new2 + class TestIndexing: """Tests for indexing spools to retrieve patches.""" diff --git a/tests/test_utils/test_misc.py b/tests/test_utils/test_misc.py index 8f766a96..5642b0de 100644 --- a/tests/test_utils/test_misc.py +++ b/tests/test_utils/test_misc.py @@ -17,6 +17,7 @@ iter_files, iterate, optional_import, + separate_coord_info, ) @@ -304,3 +305,18 @@ def test_kwargs(self): john = self._JohnnyCached() assert john.multiargs(1, b=1) == 2 assert john.multiargs(a=2, b=3) == 5 + + +class TestSeparateCoordInfo: + """Tests for separating coord info from attr dict.""" + + def test_empty(self): + """Empty args should return emtpy dicts.""" + out1, out2 = separate_coord_info(None) + assert out1 == out2 == {} + + def test_meets_reqs(self): + """Simple case for filtering out required attrs.""" + input_dict = {"coords": {"time": {"min": 10}}} + coords, attrs = separate_coord_info(input_dict) + assert coords == input_dict["coords"]