diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 88a50ee..38c15bb 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -35,7 +35,7 @@ jobs: - name: Run mypy run: mypy binarytree - name: Run pytest - run: py.test --cov=./ --cov-report=xml + run: py.test --cov=binarytree --cov-report=xml - name: Run Sphinx doctest run: python -m sphinx -b doctest docs docs/_build - name: Run Sphinx HTML diff --git a/.gitignore b/.gitignore index f957230..1079769 100644 --- a/.gitignore +++ b/.gitignore @@ -93,3 +93,6 @@ ENV/ # setuptools-scm binarytree/version.py + +# Jupyter Notebook +*.ipynb diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2b0e4c1..73805d8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -24,7 +24,7 @@ repos: rev: v0.790 hooks: - id: mypy - args: [ binarytree ] + files: ^binarytree/ - repo: https://gitlab.com/pycqa/flake8 rev: 3.8.4 hooks: diff --git a/binarytree/__init__.py b/binarytree/__init__.py index 50b93b5..0695713 100644 --- a/binarytree/__init__.py +++ b/binarytree/__init__.py @@ -2,13 +2,15 @@ import heapq import random +from collections import deque from dataclasses import dataclass -from typing import Any, Dict, Iterator, List, Optional, Tuple, Union +from subprocess import SubprocessError +from typing import Any, Deque, Dict, Iterator, List, Optional, Tuple, Union +from graphviz import Digraph, ExecutableNotFound, nohtml from pkg_resources import get_distribution from binarytree.exceptions import ( - GraphvizImportError, NodeIndexError, NodeModifyError, NodeNotFoundError, @@ -18,21 +20,31 @@ TreeHeightError, ) -try: - from graphviz import Digraph, nohtml - - GRAPHVIZ_INSTALLED = True -except ImportError: - GRAPHVIZ_INSTALLED = False - Digraph = Any - from binarytree.layout import generate_svg - __version__ = get_distribution("binarytree").version -LEFT_FIELD = "left" -RIGHT_FIELD = "right" -VAL_FIELD = "val" -VALUE_FIELD = "value" +_ATTR_LEFT = "left" +_ATTR_RIGHT = "right" +_ATTR_VAL = "val" +_ATTR_VALUE = "value" +_SVG_XML_TEMPLATE = """ + + + +{body} + + +""" NodeValue = Union[float, int] @@ -82,15 +94,6 @@ def __init__( self.left = left self.right = right - if not isinstance(value, (float, int)): - raise NodeValueError("node value must be a float or int") - - if left is not None and not isinstance(left, Node): - raise NodeTypeError("left child must be a Node instance") - - if right is not None and not isinstance(right, Node): - raise NodeTypeError("right child must be a Node instance") - def __repr__(self) -> str: """Return the string representation of the current node. @@ -180,20 +183,23 @@ def __setattr__(self, attr: str, obj: Any) -> None: ... NodeValueError: node value must be a float or int """ - if attr == LEFT_FIELD: + if attr == _ATTR_LEFT: if obj is not None and not isinstance(obj, Node): raise NodeTypeError("left child must be a Node instance") - elif attr == RIGHT_FIELD: + + elif attr == _ATTR_RIGHT: if obj is not None and not isinstance(obj, Node): raise NodeTypeError("right child must be a Node instance") - elif attr == VALUE_FIELD: + + elif attr == _ATTR_VALUE: if not isinstance(obj, (float, int)): raise NodeValueError("node value must be a float or int") - object.__setattr__(self, VAL_FIELD, obj) - elif attr == VAL_FIELD: + object.__setattr__(self, _ATTR_VAL, obj) + + elif attr == _ATTR_VAL: if not isinstance(obj, (float, int)): raise NodeValueError("node value must be a float or int") - object.__setattr__(self, VALUE_FIELD, obj) + object.__setattr__(self, _ATTR_VALUE, obj) object.__setattr__(self, attr, obj) @@ -229,17 +235,20 @@ def __iter__(self) -> Iterator["Node"]: >>> list(root) [Node(1), Node(2), Node(3), Node(4), Node(5)] """ - current_level = [self] + current_nodes = [self] - while len(current_level) > 0: - next_level = [] - for node in current_level: + while len(current_nodes) > 0: + next_nodes = [] + + for node in current_nodes: yield node + if node.left is not None: - next_level.append(node.left) + next_nodes.append(node.left) if node.right is not None: - next_level.append(node.right) - current_level = next_level + next_nodes.append(node.right) + + current_nodes = next_nodes def __len__(self) -> int: """Return the total number of nodes in the binary tree. @@ -302,15 +311,15 @@ def __getitem__(self, index: int) -> "Node": if not isinstance(index, int) or index < 0: raise NodeIndexError("node index must be a non-negative int") - current_level: List[Optional[Node]] = [self] + current_nodes: List[Optional[Node]] = [self] current_index = 0 has_more_nodes = True while has_more_nodes: has_more_nodes = False - next_level: List[Optional[Node]] = [] + next_nodes: List[Optional[Node]] = [] - for node in current_level: + for node in current_nodes: if current_index == index: if node is None: break @@ -319,15 +328,15 @@ def __getitem__(self, index: int) -> "Node": current_index += 1 if node is None: - next_level.append(None) - next_level.append(None) + next_nodes.append(None) + next_nodes.append(None) continue - next_level.append(node.left) - next_level.append(node.right) + next_nodes.append(node.left) + next_nodes.append(node.right) if node.left is not None or node.right is not None: has_more_nodes = True - current_level = next_level + current_nodes = next_nodes raise NodeNotFoundError("node missing at index {}".format(index)) @@ -403,7 +412,7 @@ def __setitem__(self, index: int, node: "Node") -> None: "parent node missing at index {}".format(parent_index) ) - setattr(parent, LEFT_FIELD if index % 2 else RIGHT_FIELD, node) + setattr(parent, _ATTR_LEFT if index % 2 else _ATTR_RIGHT, node) def __delitem__(self, index: int) -> None: """Remove the node (or subtree) at the given level-order_ index. @@ -461,29 +470,102 @@ def __delitem__(self, index: int) -> None: except NodeNotFoundError: raise NodeNotFoundError("no node to delete at index {}".format(index)) - child_attr = LEFT_FIELD if index % 2 == 1 else RIGHT_FIELD + child_attr = _ATTR_LEFT if index % 2 == 1 else _ATTR_RIGHT if getattr(parent, child_attr) is None: raise NodeNotFoundError("no node to delete at index {}".format(index)) setattr(parent, child_attr, None) - def _repr_svg_(self) -> str: + def _repr_svg_(self) -> str: # pragma: no cover """Display the binary tree using Graphviz (used for `Jupyter notebooks`_). .. _Jupyter notebooks: https://jupyter.org """ - if GRAPHVIZ_INSTALLED: + try: # noinspection PyProtectedMember return str(self.graphviz()._repr_svg_()) - else: - return generate_svg(self.values) # pragma: no cover - def graphviz(self, *args: Any, **kwargs: Any) -> Digraph: + except (SubprocessError, ExecutableNotFound, FileNotFoundError): + return self.svg() + + def svg(self, node_radius: int = 16) -> str: + """Generate SVG XML. + + :param node_radius: Node radius in pixels (default: 16). + :type node_radius: int + :return: Raw SVG XML. + :rtype: str + """ + tree_height = self.height + scale = node_radius * 3 + xml: Deque[str] = deque() + + def scale_x(x: int, y: int) -> float: + diff = tree_height - y + x = 2 ** (diff + 1) * x + 2 ** diff - 1 + return 1 + node_radius + scale * x / 2 + + def scale_y(y: int) -> float: + return scale * (1 + y) + + def add_edge(parent_x: int, parent_y: int, node_x: int, node_y: int) -> None: + xml.appendleft( + ''.format( + x1=scale_x(parent_x, parent_y), + y1=scale_y(parent_y), + x2=scale_x(node_x, node_y), + y2=scale_y(node_y), + ) + ) + + def add_node(node_x: int, node_y: int, node_value: NodeValue) -> None: + x, y = scale_x(node_x, node_y), scale_y(node_y) + xml.append(f'') + xml.append(f'{node_value}') + + current_nodes = [self.left, self.right] + has_more_nodes = True + y = 1 + + add_node(0, 0, self.value) + + while has_more_nodes: + + has_more_nodes = False + next_nodes: List[Optional[Node]] = [] + + for x, node in enumerate(current_nodes): + if node is None: + next_nodes.append(None) + next_nodes.append(None) + else: + if node.left is not None or node.right is not None: + has_more_nodes = True + + add_edge(x // 2, y - 1, x, y) + add_node(x, y, node.value) + + next_nodes.append(node.left) + next_nodes.append(node.right) + + current_nodes = next_nodes + y += 1 + + return _SVG_XML_TEMPLATE.format( + width=scale * (2 ** tree_height), + height=scale * (2 + tree_height), + body="\n".join(xml), + ) + + def graphviz(self, *args: Any, **kwargs: Any) -> Digraph: # pragma: no cover """Return a graphviz.Digraph_ object representing the binary tree. + This method's positional and keyword arguments are passed directly into the the Digraph's **__init__** method. + :return: graphviz.Digraph_ object representing the binary tree. :raise binarytree.exceptions.GraphvizImportError: If graphviz is not installed + .. code-block:: python >>> from binarytree import tree >>> @@ -492,12 +574,9 @@ def graphviz(self, *args: Any, **kwargs: Any) -> Digraph: >>> graph = t.graphviz() # Generate a graphviz object >>> graph.body # Get the DOT body >>> graph.render() # Render the graph + .. _graphviz.Digraph: https://graphviz.readthedocs.io/en/stable/api.html#digraph """ - if not GRAPHVIZ_INSTALLED: - raise GraphvizImportError( - "Can't use graphviz method if graphviz module is not installed" - ) if "node_attr" not in kwargs: kwargs["node_attr"] = { "shape": "record", @@ -507,13 +586,18 @@ def graphviz(self, *args: Any, **kwargs: Any) -> Digraph: "fontcolor": "black", } digraph = Digraph(*args, **kwargs) + for node in self: node_id = str(id(node)) + digraph.node(node_id, nohtml(f"| {node.value}|")) + if node.left is not None: digraph.edge(f"{node_id}:l", f"{id(node.left)}:v") + if node.right is not None: digraph.edge(f"{node_id}:r", f"{id(node.right)}:v") + return digraph def pprint(self, index: bool = False, delimiter: str = "-") -> None: @@ -590,46 +674,47 @@ def validate(self) -> None: NodeReferenceError: cyclic node reference at index 0 """ has_more_nodes = True - visited = set() - to_visit: List[Optional[Node]] = [self] - index = 0 + nodes_seen = set() + current_nodes: List[Optional[Node]] = [self] + node_index = 0 # level-order index while has_more_nodes: + has_more_nodes = False - next_level: List[Optional[Node]] = [] + next_nodes: List[Optional[Node]] = [] - for node in to_visit: + for node in current_nodes: if node is None: - next_level.append(None) - next_level.append(None) + next_nodes.append(None) + next_nodes.append(None) else: - if node in visited: + if node in nodes_seen: raise NodeReferenceError( f"cyclic reference at Node({node.val}) " - + f"(level-order index {index})" + + f"(level-order index {node_index})" ) if not isinstance(node, Node): raise NodeTypeError( - "invalid node instance at index {}".format(index) + "invalid node instance at index {}".format(node_index) ) if not isinstance(node.val, (float, int)): raise NodeValueError( - "invalid node value at index {}".format(index) + "invalid node value at index {}".format(node_index) ) - if not isinstance(node.value, (float, int)): + if not isinstance(node.value, (float, int)): # pragma: no cover raise NodeValueError( - "invalid node value at index {}".format(index) + "invalid node value at index {}".format(node_index) ) if node.left is not None or node.right is not None: has_more_nodes = True - visited.add(node) - next_level.append(node.left) - next_level.append(node.right) + nodes_seen.add(node) + next_nodes.append(node.left) + next_nodes.append(node.right) - index += 1 + node_index += 1 - to_visit = next_level + current_nodes = next_nodes @property def values(self) -> List[Optional[NodeValue]]: @@ -660,35 +745,34 @@ def values(self) -> List[Optional[NodeValue]]: >>> root.values [1, 2, 3, None, 4] """ - current_level: List[Optional[Node]] = [self] + current_nodes: List[Optional[Node]] = [self] has_more_nodes = True - values: List[Optional[NodeValue]] = [] + node_values: List[Optional[NodeValue]] = [] while has_more_nodes: has_more_nodes = False - next_level: List[Optional[Node]] = [] + next_nodes: List[Optional[Node]] = [] - for node in current_level: + for node in current_nodes: if node is None: - values.append(None) - next_level.append(None) - next_level.append(None) - continue - - if node.left is not None or node.right is not None: - has_more_nodes = True + node_values.append(None) + next_nodes.append(None) + next_nodes.append(None) + else: + if node.left is not None or node.right is not None: + has_more_nodes = True - values.append(node.val) - next_level.append(node.left) - next_level.append(node.right) + node_values.append(node.val) + next_nodes.append(node.left) + next_nodes.append(node.right) - current_level = next_level + current_nodes = next_nodes # Get rid of trailing None values - while values and values[-1] is None: - values.pop() + while node_values and node_values[-1] is None: + node_values.pop() - return values + return node_values @property def leaves(self) -> List["Node"]: @@ -721,20 +805,20 @@ def leaves(self) -> List["Node"]: >>> root.leaves [Node(3), Node(4)] """ - current_level = [self] + current_nodes = [self] leaves = [] - while len(current_level) > 0: - next_level = [] - for node in current_level: + while len(current_nodes) > 0: + next_nodes = [] + for node in current_nodes: if node.left is None and node.right is None: leaves.append(node) continue if node.left is not None: - next_level.append(node.left) + next_nodes.append(node.left) if node.right is not None: - next_level.append(node.right) - current_level = next_level + next_nodes.append(node.right) + current_nodes = next_nodes return leaves @property @@ -767,18 +851,21 @@ def levels(self) -> List[List["Node"]]: >>> root.levels [[Node(1)], [Node(2), Node(3)], [Node(4)]] """ - current_level = [self] + current_nodes = [self] levels = [] - while len(current_level) > 0: - next_level = [] - for node in current_level: + while len(current_nodes) > 0: + next_nodes = [] + + for node in current_nodes: if node.left is not None: - next_level.append(node.left) + next_nodes.append(node.left) if node.right is not None: - next_level.append(node.right) - levels.append(current_level) - current_level = next_level + next_nodes.append(node.right) + + levels.append(current_nodes) + current_nodes = next_nodes + return levels @property @@ -1489,18 +1576,18 @@ def levelorder(self) -> List["Node"]: >>> root.levelorder [Node(1), Node(2), Node(3), Node(4), Node(5)] """ - current_level = [self] + current_nodes = [self] result = [] - while len(current_level) > 0: - next_level = [] - for node in current_level: + while len(current_nodes) > 0: + next_nodes = [] + for node in current_nodes: result.append(node) if node.left is not None: - next_level.append(node.left) + next_nodes.append(node.left) if node.right is not None: - next_level.append(node.right) - current_level = next_level + next_nodes.append(node.right) + current_nodes = next_nodes return result @@ -1645,7 +1732,10 @@ def _generate_random_node_values(height: int) -> List[int]: def _build_tree_string( - root: Optional[Node], curr_index: int, index: bool = False, delimiter: str = "-" + root: Optional[Node], + curr_index: int, + include_index: bool = False, + delimiter: str = "-", ) -> Tuple[List[str], int, int, int]: """Recursively walk down the binary tree and build a pretty-print string. @@ -1660,9 +1750,9 @@ def _build_tree_string( :type root: binarytree.Node | None :param curr_index: Level-order_ index of the current node (root node is 0). :type curr_index: int - :param index: If set to True, include the level-order_ node indexes using + :param include_index: If set to True, include the level-order_ node indexes using the following format: ``{index}{delimiter}{value}`` (default: False). - :type index: bool + :type include_index: bool :param delimiter: Delimiter character between the node index and the node value (default: '-'). :type delimiter: @@ -1679,7 +1769,7 @@ def _build_tree_string( line1 = [] line2 = [] - if index: + if include_index: node_repr = "{}{}{}".format(curr_index, delimiter, root.val) else: node_repr = str(root.val) @@ -1688,10 +1778,10 @@ def _build_tree_string( # Get the left and right sub-boxes, their widths, and root repr positions l_box, l_box_width, l_root_start, l_root_end = _build_tree_string( - root.left, 2 * curr_index + 1, index, delimiter + root.left, 2 * curr_index + 1, include_index, delimiter ) r_box, r_box_width, r_root_start, r_root_end = _build_tree_string( - root.right, 2 * curr_index + 2, index, delimiter + root.right, 2 * curr_index + 2, include_index, delimiter ) # Draw the branch connecting the current root node to the left sub-box @@ -1752,14 +1842,14 @@ def _get_tree_properties(root: Node) -> NodeProperties: max_leaf_depth = -1 is_strict = True is_complete = True - current_level = [root] + current_nodes = [root] non_full_node_seen = False - while len(current_level) > 0: + while len(current_nodes) > 0: max_leaf_depth += 1 - next_level = [] + next_nodes = [] - for node in current_level: + for node in current_nodes: size += 1 val = node.val min_node_value = min(val, min_node_value) @@ -1776,7 +1866,7 @@ def _get_tree_properties(root: Node) -> NodeProperties: is_descending = False elif node.left.val < val: is_ascending = False - next_level.append(node.left) + next_nodes.append(node.left) is_complete = not non_full_node_seen else: non_full_node_seen = True @@ -1786,7 +1876,7 @@ def _get_tree_properties(root: Node) -> NodeProperties: is_descending = False elif node.right.val < val: is_ascending = False - next_level.append(node.right) + next_nodes.append(node.right) is_complete = not non_full_node_seen else: non_full_node_seen = True @@ -1794,7 +1884,7 @@ def _get_tree_properties(root: Node) -> NodeProperties: # If we see a node with only one child, it is not strict is_strict &= (node.left is None) == (node.right is None) - current_level = next_level + current_nodes = next_nodes return NodeProperties( height=max_leaf_depth, @@ -1918,7 +2008,7 @@ def build(values: List[int]) -> Optional[Node]: raise NodeNotFoundError( "parent node missing at index {}".format(parent_index) ) - setattr(parent, LEFT_FIELD if index % 2 else RIGHT_FIELD, node) + setattr(parent, _ATTR_LEFT if index % 2 else _ATTR_RIGHT, node) return nodes[0] if nodes else None @@ -1982,7 +2072,7 @@ def tree(height: int = 3, is_perfect: bool = False) -> Optional[Node]: inserted = False while depth < height and not inserted: - attr = random.choice((LEFT_FIELD, RIGHT_FIELD)) + attr = random.choice((_ATTR_LEFT, _ATTR_RIGHT)) if getattr(node, attr) is None: setattr(node, attr, Node(value)) inserted = True @@ -2048,7 +2138,7 @@ def bst(height: int = 3, is_perfect: bool = False) -> Optional[Node]: inserted = False while depth < height and not inserted: - attr = LEFT_FIELD if node.val > value else RIGHT_FIELD + attr = _ATTR_LEFT if node.val > value else _ATTR_RIGHT if getattr(node, attr) is None: setattr(node, attr, Node(value)) inserted = True diff --git a/binarytree/exceptions.py b/binarytree/exceptions.py index e5b22ca..0676e24 100644 --- a/binarytree/exceptions.py +++ b/binarytree/exceptions.py @@ -28,7 +28,3 @@ class NodeValueError(BinaryTreeError): class TreeHeightError(BinaryTreeError): """Tree height was invalid.""" - - -class GraphvizImportError(BinaryTreeError): - """graphviz module is not installed""" diff --git a/binarytree/layout.py b/binarytree/layout.py deleted file mode 100644 index 7d200d4..0000000 --- a/binarytree/layout.py +++ /dev/null @@ -1,118 +0,0 @@ -""" Module containing layout related algorithms.""" -from typing import List, Tuple, Union - - -def _get_coords( - values: List[Union[float, int, None]] -) -> Tuple[ - List[Tuple[int, int, Union[float, int, None]]], List[Tuple[int, int, int, int]] -]: - """Generate the coordinates used for rendering the nodes and edges. - - node and edges are stored as tuples in the form node: (x, y, label) and - edge: (x1, y1, x2, y2) - - Each coordinate is relative y is the depth, x is the position of the node - on a level from left to right 0 to 2**depth -1 - - :param values: Values of the binary tree. - :type values: list of ints - :return: nodes and edges list - :rtype: two lists of tuples - - """ - x = 0 - y = 0 - nodes = [] - edges = [] - - # root node - nodes.append((x, y, values[0])) - # append other nodes and their edges - y += 1 - for value in values[1:]: - if value is not None: - nodes.append((x, y, value)) - edges.append((x // 2, y - 1, x, y)) - x += 1 - # check if level is full - if x == 2 ** y: - x = 0 - y += 1 - return nodes, edges - - -def generate_svg(values: List[Union[float, int, None]]) -> str: - """Generate a svg image from a binary tree - - A simple layout is used based on a perfect tree of same height in which all - leaves would be regularly spaced. - - :param values: Values of the binary tree. - :type values: list of ints - :return: the svg image of the tree. - :rtype: str - """ - node_size = 16.0 - stroke_width = 1.5 - gutter = 0.5 - x_scale = (2 + gutter) * node_size - y_scale = 3.0 * node_size - - # retrieve relative coordinates - nodes, edges = _get_coords(values) - y_min = min([n[1] for n in nodes]) - y_max = max([n[1] for n in nodes]) - - # generate the svg string - svg = f""" - - - - """ - # scales - - def scalex(x: int, y: int) -> float: - depth = y_max - y - # offset - x = 2 ** (depth + 1) * x + 2 ** depth - 1 - return 1 + node_size + x_scale * x / 2 - - def scaley(y: int) -> float: - return float(y_scale * (1 + y - y_min)) - - # edges - def svg_edge(x1: float, y1: float, x2: float, y2: float) -> str: - """Generate svg code for an edge""" - return f"""""" - - for a in edges: - x1, y1, x2, y2 = a - svg += svg_edge(scalex(x1, y1), scaley(y1), scalex(x2, y2), scaley(y2)) - - # nodes - def svg_node(x: float, y: float, label: str = "") -> str: - """Generate svg code for a node and his label""" - return f""" - - {label}""" - - for n in nodes: - x, y, label = n - svg += svg_node(scalex(x, y), scaley(y), str(label)) - - svg += "" - return svg diff --git a/pyproject.toml b/pyproject.toml index 3b3a65f..04adc11 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,10 +7,7 @@ requires = [ build-backend = "setuptools.build_meta" [tool.coverage.run] -omit = [ - "binarytree/version.py", - "setup.py" -] +omit = ["binarytree/version.py"] [tool.isort] profile = "black" diff --git a/tests/test_layout.py b/tests/test_layout.py deleted file mode 100644 index 0a92258..0000000 --- a/tests/test_layout.py +++ /dev/null @@ -1,17 +0,0 @@ -import xml.etree.ElementTree as ET - -from binarytree.layout import _get_coords, generate_svg - - -def test_get_coords(): - values = [0, 6, 5, None, 1, 4, 2] - assert _get_coords(values) == ( - [(0, 0, 0), (0, 1, 6), (1, 1, 5), (1, 2, 1), (2, 2, 4), (3, 2, 2)], - [(0, 0, 0, 1), (0, 0, 1, 1), (0, 1, 1, 2), (1, 1, 2, 2), (1, 1, 3, 2)], - ) - - -def test_svg(): - svg = generate_svg([0, 1, 2]) - svg_tree = ET.fromstring(svg) - assert svg_tree.tag == "{http://www.w3.org/2000/svg}svg" diff --git a/tests/test_tree.py b/tests/test_tree.py index 765169b..119bda6 100644 --- a/tests/test_tree.py +++ b/tests/test_tree.py @@ -19,7 +19,62 @@ REPETITIONS = 20 - +EXPECTED_SVG_XML_SINGLE_NODE = """ + + + + +0 + + +""" + +EXPECTED_SVG_XML_MULTIPLE_NODES = """ + + + + + + + + +0 + +1 + +2 + +3 + +4 + + +""" + + +# noinspection PyTypeChecker def test_node_set_attributes(): root = Node(1) assert root.left is None @@ -869,6 +924,7 @@ def test_heap_float_values(): assert root.size == root_copy.size +# noinspection PyTypeChecker def test_get_parent(): root = Node(0) root.left = Node(1) @@ -884,3 +940,14 @@ def test_get_parent(): assert get_parent(root, Node(5)) is None assert get_parent(None, root.left) is None assert get_parent(root, None) is None + + +def test_svg_generation(): + root = Node(0) + assert root.svg() == EXPECTED_SVG_XML_SINGLE_NODE + + root.left = Node(1) + root.right = Node(2) + root.left.left = Node(3) + root.right.right = Node(4) + assert root.svg() == EXPECTED_SVG_XML_MULTIPLE_NODES