From b9fbc95ff189248884b0e50e56d6f21fb8145b64 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 9 Sep 2024 18:53:04 -0700 Subject: [PATCH 1/3] Fix DataTree.equals and DataTree.identical --- xarray/core/datatree.py | 39 ++++++++++++++++++++++----------------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index bc818eb5671..3d3559acaeb 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -1238,6 +1238,20 @@ def isomorphic( except (TypeError, TreeIsomorphismError): return False + def _matching( + self, + other: DataTree, + from_root: bool, + nodes_match: Callable[[DataTree, DataTree], bool], + ) -> bool: + if not self.isomorphic(other, from_root=from_root, strict_names=True): + return False + + return all( + nodes_match(node, other_node) + for node, other_node in zip(self.subtree, other.subtree) + ) + def equals(self, other: DataTree, from_root: bool = True) -> bool: """ Two DataTrees are equal if they have isomorphic node structures, with matching node names, @@ -1259,20 +1273,15 @@ def equals(self, other: DataTree, from_root: bool = True) -> bool: DataTree.isomorphic DataTree.identical """ - if not self.isomorphic(other, from_root=from_root, strict_names=True): - return False - - return all( - [ - node.ds.equals(other_node.ds) - for node, other_node in zip(self.subtree, other.subtree) - ] - ) + to_ds = lambda x: x._to_dataset_view(inherited=True, rebuild_dims=False) + matcher = lambda x, y: to_ds(x).equals(to_ds(y)) + return self._matching(other, from_root, nodes_match=matcher) def identical(self, other: DataTree, from_root=True) -> bool: """ Like equals, but will also check all dataset attributes and the attributes on - all variables and coordinates. + all variables and coordinates, and requires the coordinates are defined at the + exact same levels of the DataTree hierarchy. By default this method will check the whole tree above the given node. @@ -1290,13 +1299,9 @@ def identical(self, other: DataTree, from_root=True) -> bool: DataTree.isomorphic DataTree.equals """ - if not self.isomorphic(other, from_root=from_root, strict_names=True): - return False - - return all( - node.ds.identical(other_node.ds) - for node, other_node in zip(self.subtree, other.subtree) - ) + to_ds = lambda x: x._to_dataset_view(inherited=False, rebuild_dims=False) + matcher = lambda x, y: to_ds(x).identical(to_ds(y)) + return self._matching(other, from_root, nodes_match=matcher) def filter(self: DataTree, filterfunc: Callable[[DataTree], bool]) -> DataTree: """ From 909593082014aaa810f8da6354e0948a93c47961 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 9 Sep 2024 19:06:03 -0700 Subject: [PATCH 2/3] Add unit tests for identical and equals --- xarray/tests/test_datatree.py | 56 +++++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index f1f74d240f0..98360c4a339 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -1091,6 +1091,62 @@ def f(x, tree, y): assert actual is dt and actual.attrs == attrs +class TestEqualsAndIdentical: + def test_basics(self, create_test_datatree): + dt = create_test_datatree() + assert dt.equals(dt) + assert dt.identical(dt) + + other = DataTree() + assert not dt.equals(other) + assert not dt.identical(other) + assert not other.equals(dt) + assert not other.identical(dt) + + def test_isomormorphic(self, create_test_datatree): + dt = create_test_datatree() + + dt_new_node = dt.copy() + dt_new_node["/something_new"] = DataTree() + assert not dt.equals(dt_new_node) + assert not dt.identical(dt_new_node) + assert not dt_new_node.equals(dt) + assert not dt_new_node.identical(dt) + + dt_new_array = dt.copy() + dt_new_array["/something_else"] = xr.DataArray(1234) + assert not dt.equals(dt_new_array) + assert not dt.identical(dt_new_array) + assert not dt_new_array.equals(dt) + assert not dt_new_array.identical(dt) + + def test_equal_but_not_identical_inheritance(self): + duplicated_coords = DataTree.from_dict( + { + "/": Dataset(coords={"x": [1]}), + "/sub": Dataset(coords={"x": [1]}), + } + ) + inherited_coords = DataTree.from_dict( + { + "/": Dataset(coords={"x": [1]}), + "/sub": Dataset(), + } + ) + assert duplicated_coords.equals(inherited_coords) + assert inherited_coords.equals(duplicated_coords) + assert not duplicated_coords.identical(inherited_coords) + assert not inherited_coords.identical(duplicated_coords) + + def test_equal_but_not_identical_attrs(self): + without_attrs = DataTree() + with_attrs = DataTree(Dataset(attrs={"foo": "bar"})) + assert without_attrs.equals(with_attrs) + assert with_attrs.equals(without_attrs) + assert not without_attrs.identical(with_attrs) + assert not with_attrs.identical(without_attrs) + + class TestSubset: def test_match(self): # TODO is this example going to cause problems with case sensitivity? From 5eee5fde7ef14240044247d56a46e44216e758e0 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 9 Sep 2024 19:22:10 -0700 Subject: [PATCH 3/3] Remove from_root --- xarray/core/datatree.py | 31 ++++++++++++--------------- xarray/testing/assertions.py | 28 +++++------------------- xarray/tests/test_datatree_mapping.py | 4 ++-- 3 files changed, 21 insertions(+), 42 deletions(-) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index 3d3559acaeb..7ff3916e317 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -1241,31 +1241,25 @@ def isomorphic( def _matching( self, other: DataTree, - from_root: bool, nodes_match: Callable[[DataTree, DataTree], bool], ) -> bool: - if not self.isomorphic(other, from_root=from_root, strict_names=True): + if not self.isomorphic(other, from_root=True, strict_names=True): return False return all( nodes_match(node, other_node) - for node, other_node in zip(self.subtree, other.subtree) + for node, other_node in zip(self.descendants, other.descendants) ) - def equals(self, other: DataTree, from_root: bool = True) -> bool: + def equals(self, other: DataTree) -> bool: """ Two DataTrees are equal if they have isomorphic node structures, with matching node names, and if they have matching variables and coordinates, all of which are equal. - By default this method will check the whole tree above the given node. - Parameters ---------- other : DataTree The other tree object to compare to. - from_root : bool, optional, default is True - Whether or not to first traverse to the root of the two trees before checking for isomorphism. - If neither tree has a parent then this has no effect. See Also -------- @@ -1275,23 +1269,20 @@ def equals(self, other: DataTree, from_root: bool = True) -> bool: """ to_ds = lambda x: x._to_dataset_view(inherited=True, rebuild_dims=False) matcher = lambda x, y: to_ds(x).equals(to_ds(y)) - return self._matching(other, from_root, nodes_match=matcher) + if not matcher(self, other): + return False + return self._matching(other, nodes_match=matcher) - def identical(self, other: DataTree, from_root=True) -> bool: + def identical(self, other: DataTree) -> bool: """ Like equals, but will also check all dataset attributes and the attributes on all variables and coordinates, and requires the coordinates are defined at the exact same levels of the DataTree hierarchy. - By default this method will check the whole tree above the given node. - Parameters ---------- other : DataTree The other tree object to compare to. - from_root : bool, optional, default is True - Whether or not to first traverse to the root of the two trees before checking for isomorphism. - If neither tree has a parent then this has no effect. See Also -------- @@ -1299,9 +1290,15 @@ def identical(self, other: DataTree, from_root=True) -> bool: DataTree.isomorphic DataTree.equals """ + # include inherited coordinates only at the root; otherwise require that + # everything is also defined at the same level + self_view = self._to_dataset_view(inherited=True, rebuild_dims=False) + other_view = other._to_dataset_view(inherited=True, rebuild_dims=False) + if not self_view.identical(other_view): + return False to_ds = lambda x: x._to_dataset_view(inherited=False, rebuild_dims=False) matcher = lambda x, y: to_ds(x).identical(to_ds(y)) - return self._matching(other, from_root, nodes_match=matcher) + return self._matching(other, nodes_match=matcher) def filter(self: DataTree, filterfunc: Callable[[DataTree], bool]) -> DataTree: """ diff --git a/xarray/testing/assertions.py b/xarray/testing/assertions.py index 3a0cc96b20d..dd38dd59b98 100644 --- a/xarray/testing/assertions.py +++ b/xarray/testing/assertions.py @@ -116,7 +116,7 @@ def assert_equal(a: DataTree, b: DataTree, from_root: bool = True): ... @ensure_warnings -def assert_equal(a, b, from_root=True, check_dim_order: bool = True): +def assert_equal(a, b, check_dim_order: bool = True): """Like :py:func:`numpy.testing.assert_array_equal`, but for xarray objects. @@ -135,10 +135,6 @@ def assert_equal(a, b, from_root=True, check_dim_order: bool = True): or xarray.core.datatree.DataTree. The first object to compare. b : xarray.Dataset, xarray.DataArray, xarray.Variable, xarray.Coordinates or xarray.core.datatree.DataTree. The second object to compare. - from_root : bool, optional, default is True - Only used when comparing DataTree objects. Indicates whether or not to - first traverse to the root of the trees before checking for isomorphism. - If a & b have no parents then this has no effect. check_dim_order : bool, optional, default is True Whether dimensions must be in the same order. @@ -159,11 +155,7 @@ def assert_equal(a, b, from_root=True, check_dim_order: bool = True): elif isinstance(a, Coordinates): assert a.equals(b), formatting.diff_coords_repr(a, b, "equals") elif isinstance(a, DataTree): - if from_root: - a = a.root - b = b.root - - assert a.equals(b, from_root=from_root), diff_datatree_repr(a, b, "equals") + assert a.equals(b), diff_datatree_repr(a, b, "equals") else: raise TypeError(f"{type(a)} not supported by assertion comparison") @@ -173,11 +165,11 @@ def assert_identical(a, b): ... @overload -def assert_identical(a: DataTree, b: DataTree, from_root: bool = True): ... +def assert_identical(a: DataTree, b: DataTree): ... @ensure_warnings -def assert_identical(a, b, from_root=True): +def assert_identical(a, b): """Like :py:func:`xarray.testing.assert_equal`, but also matches the objects' names and attributes. @@ -193,10 +185,6 @@ def assert_identical(a, b, from_root=True): The first object to compare. b : xarray.Dataset, xarray.DataArray, xarray.Variable or xarray.Coordinates The second object to compare. - from_root : bool, optional, default is True - Only used when comparing DataTree objects. Indicates whether or not to - first traverse to the root of the trees before checking for isomorphism. - If a & b have no parents then this has no effect. check_dim_order : bool, optional, default is True Whether dimensions must be in the same order. @@ -220,13 +208,7 @@ def assert_identical(a, b, from_root=True): elif isinstance(a, Coordinates): assert a.identical(b), formatting.diff_coords_repr(a, b, "identical") elif isinstance(a, DataTree): - if from_root: - a = a.root - b = b.root - - assert a.identical(b, from_root=from_root), diff_datatree_repr( - a, b, "identical" - ) + assert a.identical(b), diff_datatree_repr(a, b, "identical") else: raise TypeError(f"{type(a)} not supported by assertion comparison") diff --git a/xarray/tests/test_datatree_mapping.py b/xarray/tests/test_datatree_mapping.py index 1d2595cd013..26484b2b131 100644 --- a/xarray/tests/test_datatree_mapping.py +++ b/xarray/tests/test_datatree_mapping.py @@ -251,9 +251,9 @@ def test_discard_ancestry(self, create_test_datatree): def times_ten(ds): return 10.0 * ds - expected = create_test_datatree(modify=lambda ds: 10.0 * ds)["set1"] + expected = create_test_datatree(modify=lambda ds: 10.0 * ds)["set1"].copy() result_tree = times_ten(subtree) - assert_equal(result_tree, expected, from_root=False) + assert_equal(result_tree, expected) def test_skip_empty_nodes_with_attrs(self, create_test_datatree): # inspired by xarray-datatree GH262