diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index bc818eb5671..7ff3916e317 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -1238,20 +1238,28 @@ def isomorphic( except (TypeError, TreeIsomorphismError): return False - def equals(self, other: DataTree, from_root: bool = True) -> bool: + def _matching( + self, + other: DataTree, + nodes_match: Callable[[DataTree, DataTree], bool], + ) -> bool: + if not self.isomorphic(other, from_root=True, strict_names=True): + return False + + return all( + nodes_match(node, other_node) + for node, other_node in zip(self.descendants, other.descendants) + ) + + def equals(self, other: DataTree) -> bool: """ Two DataTrees are equal if they have isomorphic node structures, with matching node names, and if they have matching variables and coordinates, all of which are equal. - By default this method will check the whole tree above the given node. - Parameters ---------- other : DataTree The other tree object to compare to. - from_root : bool, optional, default is True - Whether or not to first traverse to the root of the two trees before checking for isomorphism. - If neither tree has a parent then this has no effect. See Also -------- @@ -1259,30 +1267,22 @@ def equals(self, other: DataTree, from_root: bool = True) -> bool: DataTree.isomorphic DataTree.identical """ - if not self.isomorphic(other, from_root=from_root, strict_names=True): + to_ds = lambda x: x._to_dataset_view(inherited=True, rebuild_dims=False) + matcher = lambda x, y: to_ds(x).equals(to_ds(y)) + if not matcher(self, other): return False + return self._matching(other, nodes_match=matcher) - return all( - [ - node.ds.equals(other_node.ds) - for node, other_node in zip(self.subtree, other.subtree) - ] - ) - - def identical(self, other: DataTree, from_root=True) -> bool: + def identical(self, other: DataTree) -> bool: """ Like equals, but will also check all dataset attributes and the attributes on - all variables and coordinates. - - By default this method will check the whole tree above the given node. + all variables and coordinates, and requires the coordinates are defined at the + exact same levels of the DataTree hierarchy. Parameters ---------- other : DataTree The other tree object to compare to. - from_root : bool, optional, default is True - Whether or not to first traverse to the root of the two trees before checking for isomorphism. - If neither tree has a parent then this has no effect. See Also -------- @@ -1290,13 +1290,15 @@ def identical(self, other: DataTree, from_root=True) -> bool: DataTree.isomorphic DataTree.equals """ - if not self.isomorphic(other, from_root=from_root, strict_names=True): + # include inherited coordinates only at the root; otherwise require that + # everything is also defined at the same level + self_view = self._to_dataset_view(inherited=True, rebuild_dims=False) + other_view = other._to_dataset_view(inherited=True, rebuild_dims=False) + if not self_view.identical(other_view): return False - - return all( - node.ds.identical(other_node.ds) - for node, other_node in zip(self.subtree, other.subtree) - ) + to_ds = lambda x: x._to_dataset_view(inherited=False, rebuild_dims=False) + matcher = lambda x, y: to_ds(x).identical(to_ds(y)) + return self._matching(other, nodes_match=matcher) def filter(self: DataTree, filterfunc: Callable[[DataTree], bool]) -> DataTree: """ diff --git a/xarray/testing/assertions.py b/xarray/testing/assertions.py index 3a0cc96b20d..dd38dd59b98 100644 --- a/xarray/testing/assertions.py +++ b/xarray/testing/assertions.py @@ -116,7 +116,7 @@ def assert_equal(a: DataTree, b: DataTree, from_root: bool = True): ... @ensure_warnings -def assert_equal(a, b, from_root=True, check_dim_order: bool = True): +def assert_equal(a, b, check_dim_order: bool = True): """Like :py:func:`numpy.testing.assert_array_equal`, but for xarray objects. @@ -135,10 +135,6 @@ def assert_equal(a, b, from_root=True, check_dim_order: bool = True): or xarray.core.datatree.DataTree. The first object to compare. b : xarray.Dataset, xarray.DataArray, xarray.Variable, xarray.Coordinates or xarray.core.datatree.DataTree. The second object to compare. - from_root : bool, optional, default is True - Only used when comparing DataTree objects. Indicates whether or not to - first traverse to the root of the trees before checking for isomorphism. - If a & b have no parents then this has no effect. check_dim_order : bool, optional, default is True Whether dimensions must be in the same order. @@ -159,11 +155,7 @@ def assert_equal(a, b, from_root=True, check_dim_order: bool = True): elif isinstance(a, Coordinates): assert a.equals(b), formatting.diff_coords_repr(a, b, "equals") elif isinstance(a, DataTree): - if from_root: - a = a.root - b = b.root - - assert a.equals(b, from_root=from_root), diff_datatree_repr(a, b, "equals") + assert a.equals(b), diff_datatree_repr(a, b, "equals") else: raise TypeError(f"{type(a)} not supported by assertion comparison") @@ -173,11 +165,11 @@ def assert_identical(a, b): ... @overload -def assert_identical(a: DataTree, b: DataTree, from_root: bool = True): ... +def assert_identical(a: DataTree, b: DataTree): ... @ensure_warnings -def assert_identical(a, b, from_root=True): +def assert_identical(a, b): """Like :py:func:`xarray.testing.assert_equal`, but also matches the objects' names and attributes. @@ -193,10 +185,6 @@ def assert_identical(a, b, from_root=True): The first object to compare. b : xarray.Dataset, xarray.DataArray, xarray.Variable or xarray.Coordinates The second object to compare. - from_root : bool, optional, default is True - Only used when comparing DataTree objects. Indicates whether or not to - first traverse to the root of the trees before checking for isomorphism. - If a & b have no parents then this has no effect. check_dim_order : bool, optional, default is True Whether dimensions must be in the same order. @@ -220,13 +208,7 @@ def assert_identical(a, b, from_root=True): elif isinstance(a, Coordinates): assert a.identical(b), formatting.diff_coords_repr(a, b, "identical") elif isinstance(a, DataTree): - if from_root: - a = a.root - b = b.root - - assert a.identical(b, from_root=from_root), diff_datatree_repr( - a, b, "identical" - ) + assert a.identical(b), diff_datatree_repr(a, b, "identical") else: raise TypeError(f"{type(a)} not supported by assertion comparison") diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index f1f74d240f0..98360c4a339 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -1091,6 +1091,62 @@ def f(x, tree, y): assert actual is dt and actual.attrs == attrs +class TestEqualsAndIdentical: + def test_basics(self, create_test_datatree): + dt = create_test_datatree() + assert dt.equals(dt) + assert dt.identical(dt) + + other = DataTree() + assert not dt.equals(other) + assert not dt.identical(other) + assert not other.equals(dt) + assert not other.identical(dt) + + def test_isomormorphic(self, create_test_datatree): + dt = create_test_datatree() + + dt_new_node = dt.copy() + dt_new_node["/something_new"] = DataTree() + assert not dt.equals(dt_new_node) + assert not dt.identical(dt_new_node) + assert not dt_new_node.equals(dt) + assert not dt_new_node.identical(dt) + + dt_new_array = dt.copy() + dt_new_array["/something_else"] = xr.DataArray(1234) + assert not dt.equals(dt_new_array) + assert not dt.identical(dt_new_array) + assert not dt_new_array.equals(dt) + assert not dt_new_array.identical(dt) + + def test_equal_but_not_identical_inheritance(self): + duplicated_coords = DataTree.from_dict( + { + "/": Dataset(coords={"x": [1]}), + "/sub": Dataset(coords={"x": [1]}), + } + ) + inherited_coords = DataTree.from_dict( + { + "/": Dataset(coords={"x": [1]}), + "/sub": Dataset(), + } + ) + assert duplicated_coords.equals(inherited_coords) + assert inherited_coords.equals(duplicated_coords) + assert not duplicated_coords.identical(inherited_coords) + assert not inherited_coords.identical(duplicated_coords) + + def test_equal_but_not_identical_attrs(self): + without_attrs = DataTree() + with_attrs = DataTree(Dataset(attrs={"foo": "bar"})) + assert without_attrs.equals(with_attrs) + assert with_attrs.equals(without_attrs) + assert not without_attrs.identical(with_attrs) + assert not with_attrs.identical(without_attrs) + + class TestSubset: def test_match(self): # TODO is this example going to cause problems with case sensitivity? diff --git a/xarray/tests/test_datatree_mapping.py b/xarray/tests/test_datatree_mapping.py index 1d2595cd013..26484b2b131 100644 --- a/xarray/tests/test_datatree_mapping.py +++ b/xarray/tests/test_datatree_mapping.py @@ -251,9 +251,9 @@ def test_discard_ancestry(self, create_test_datatree): def times_ten(ds): return 10.0 * ds - expected = create_test_datatree(modify=lambda ds: 10.0 * ds)["set1"] + expected = create_test_datatree(modify=lambda ds: 10.0 * ds)["set1"].copy() result_tree = times_ten(subtree) - assert_equal(result_tree, expected, from_root=False) + assert_equal(result_tree, expected) def test_skip_empty_nodes_with_attrs(self, create_test_datatree): # inspired by xarray-datatree GH262