From c4eeea6d2204ade5b4a3378fae4de4ec71f8121e Mon Sep 17 00:00:00 2001 From: MatthisCl <76970409+MatthisCl@users.noreply.github.com> Date: Tue, 9 Jul 2024 16:25:45 +0200 Subject: [PATCH] Add new add_nx_graphs method to skeleton.py (#1130) * add new method to skeleton.py add_nx_graphs * refac add_nx_graph and add unit test * add snapshot nml and test to check old nml generation against new annotation.save after add_nx_graphs * fix typechecks * Merge branch 'master' into 475-add-from_nx_graphs-method-to-skeleton * update Changelog.md --- webknossos/Changelog.md | 1 + .../testdata/nmls/generate_nml_snapshot.nml | 39 ++++++++ webknossos/tests/test_skeleton.py | 94 +++++++++++++++++++ webknossos/webknossos/skeleton/skeleton.py | 42 ++++++++- 4 files changed, 175 insertions(+), 1 deletion(-) create mode 100644 webknossos/testdata/nmls/generate_nml_snapshot.nml diff --git a/webknossos/Changelog.md b/webknossos/Changelog.md index ab82a80fb..1b9ddca10 100644 --- a/webknossos/Changelog.md +++ b/webknossos/Changelog.md @@ -16,6 +16,7 @@ For upgrade instructions, please check the respective _Breaking Changes_ section ### Added - Added an implementation of padded_with_margins for NDBoundingBox class. [#1120](https://github.com/scalableminds/webknossos-libs/pull/1120) +- Added a new method add_nx_graphs to skeleton.py which supports to add nx.Graphs to the Skeleton object. [#1130](https://github.com/scalableminds/webknossos-libs/pull/1130) ### Changed - Removed additional logging messages during image conversion. [#1124](https://github.com/scalableminds/webknossos-libs/pull/1124) diff --git a/webknossos/testdata/nmls/generate_nml_snapshot.nml b/webknossos/testdata/nmls/generate_nml_snapshot.nml new file mode 100644 index 000000000..21503abf3 --- /dev/null +++ b/webknossos/testdata/nmls/generate_nml_snapshot.nml @@ -0,0 +1,39 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/webknossos/tests/test_skeleton.py b/webknossos/tests/test_skeleton.py index 0fe3ce437..2e5d13d14 100644 --- a/webknossos/tests/test_skeleton.py +++ b/webknossos/tests/test_skeleton.py @@ -3,6 +3,7 @@ from pathlib import Path from typing import List, Optional +import networkx as nx import pytest import webknossos as wk @@ -51,6 +52,15 @@ def create_dummy_skeleton() -> wk.Skeleton: return nml +def create_dummy_nx_graph() -> nx.Graph: + nx_graph = nx.Graph() + nx_graph.add_node(1, position=(0, 1, 2), comment="node 1 nx") + nx_graph.add_node(2, position=(3, 1, 2), comment="node 2 nx") + nx_graph.add_edge(1, 2) + + return nx_graph + + def test_doc_example() -> None: from webknossos import Annotation @@ -108,6 +118,90 @@ def test_skeleton_creation() -> None: assert grand_children[0].group == groups[0] +def test_add_nx_graph() -> None: + skeleton = create_dummy_skeleton() + node_count = skeleton.get_total_node_count() + tree_count = len(list(skeleton.flattened_trees())) + group_count = len(list(skeleton.flattened_groups())) + max_node_id = skeleton.get_max_node_id() + + nx_graph = create_dummy_nx_graph() + skeleton.add_nx_graphs( + {"first_group": [nx_graph, nx_graph], "second_group": [nx_graph]} + ) + + # check number of groups, nodes and trees + assert len(list(skeleton.flattened_groups())) == group_count + 2 + assert skeleton.get_total_node_count() == node_count + 6 + assert len(list(skeleton.flattened_trees())) == tree_count + 3 + + # check group names + for group in skeleton.flattened_groups(): + assert group.name in [ + "first_group", + "second_group", + "Example Group", + "Nested Group", + ] + + # check node attributes + max_node_id = skeleton.get_max_node_id() + assert skeleton.get_node_by_id(max_node_id).comment == "node 2 nx" + assert skeleton.get_node_by_id(max_node_id).position == (3, 1, 2) + assert skeleton.get_node_by_id(max_node_id - 1).comment == "node 1 nx" + assert skeleton.get_node_by_id(max_node_id - 1).position == (0, 1, 2) + + # check if edge was added + for edge in skeleton.get_tree_by_id(max_node_id - 2).edges: + assert (edge[0].id, edge[1].id) == (max_node_id - 1, max_node_id) + + +def test_nml_generation(tmp_path: Path) -> None: + OLD_NML_PATH = TESTDATA_DIR / "nmls" / "generate_nml_snapshot.nml" + + tree1 = create_dummy_nx_graph() + tree2 = create_dummy_nx_graph() + tree2.add_node(3, position=(3, 3, 3), comment="node 3 nx") + + tree_dict = {"first_group": [tree1], "second_group": [tree2]} + + # old_nml was generated with the old wknml library as follows: + # params_wknml = {"name": "MyDataset", "scale": (1, 1, 1), "zoomLevel": 0.4} + # old_nml = generate_nml(tree_dict=tree_dict, parameters=params_wknml) + # with open(tmp_path / "annotation_old.nml", "wb") as f: + # write_nml(f, old_nml) + + tree_dict = {"first_group": [tree1], "second_group": [tree2]} + + annotation = wk.Annotation( + name="MyAnnotation", + dataset_name="MyDataset", + voxel_size=(1, 1, 1), + zoom_level=0.4, + ) + + annotation.skeleton.add_nx_graphs(tree_dict) + + annotation.save(tmp_path / "annotation_new.nml") + + old_skeleton = wk.Skeleton.load(OLD_NML_PATH) + new_skeleton = wk.Skeleton.load(tmp_path / "annotation_new.nml") + + for old_group, new_group in zip( + old_skeleton.flattened_groups(), new_skeleton.flattened_groups() + ): + assert old_group.name == new_group.name + for old_child, new_child in zip(old_group.children, new_group.children): + if isinstance(old_child, wk.Tree) and isinstance(new_child, wk.Tree): + for old_node, new_node in zip(old_child.nodes, new_child.nodes): + assert old_node.comment == new_node.comment + assert old_node.position == new_node.position + assert old_node.radius == new_node.radius + for old_edge, new_edge in zip(old_child.edges, new_child.edges): + assert old_edge[0].position == new_edge[0].position + assert old_edge[1].position == new_edge[1].position + + def diff_lines(lines_a: List[str], lines_b: List[str]) -> List[str]: diff = list( difflib.unified_diff( diff --git a/webknossos/webknossos/skeleton/skeleton.py b/webknossos/webknossos/skeleton/skeleton.py index ddde84f78..c4e1c4336 100644 --- a/webknossos/webknossos/skeleton/skeleton.py +++ b/webknossos/webknossos/skeleton/skeleton.py @@ -1,9 +1,10 @@ import itertools from os import PathLike from pathlib import Path -from typing import Iterator, Optional, Tuple, Union +from typing import Dict, Iterator, List, Optional, Tuple, Union import attr +import networkx as nx from ..utils import warn_deprecated from .group import Group @@ -104,6 +105,45 @@ def save(self, out_path: Union[str, PathLike]) -> None: annotation = Annotation(name=out_path.stem, skeleton=self, time=None) annotation.save(out_path) + def add_nx_graphs( + self, tree_dict: Union[List[nx.Graph], Dict[str, List[nx.Graph]]] + ) -> None: + """ + A utility to add nx graphs [NetworkX graph object](https://networkx.org/) to a wk skeleton object. Accepts both a simple list of multiple skeletons/trees or a dictionary grouping skeleton inputs. + + Arguments: + tree_dict (Union[List[nx.Graph], Dict[str, List[nx.Graph]]]): A list of wK tree-like structures as NetworkX graphs or a dictionary of group names and same lists of NetworkX tree objects. + """ + + if not isinstance(tree_dict, dict): + tree_dict = {"main_group": tree_dict} + + for group_name, trees in tree_dict.items(): + group = self.add_group(group_name) + for tree in trees: + tree_name = tree.graph.get("name", f"tree_{len(list(group.trees))}") + wk_tree = group.add_tree(tree_name) + wk_tree.color = tree.graph.get("color", None) + id_node_dict = {} + for id_with_node in tree.nodes(data=True): + old_id, node = id_with_node + node = wk_tree.add_node( + position=node.get("position"), + comment=node.get("comment", None), + radius=node.get("radius", 1.0), + rotation=node.get("rotation", None), + inVp=node.get("inVp", None), + inMag=node.get("inMag", None), + bitDepth=node.get("bitDepth", None), + interpolation=node.get("interpolation", None), + time=node.get("time", None), + is_branchpoint=node.get("is_branchpoint", False), + branchpoint_time=node.get("branchpoint_time", None), + ) + id_node_dict[old_id] = node + for edge in tree.edges(): + wk_tree.add_edge(id_node_dict[edge[0]], id_node_dict[edge[1]]) + @staticmethod def from_path(file_path: Union[PathLike, str]) -> "Skeleton": """Deprecated. Use Skeleton.load instead."""