diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml
index 88a50ee..38c15bb 100644
--- a/.github/workflows/build.yaml
+++ b/.github/workflows/build.yaml
@@ -35,7 +35,7 @@ jobs:
- name: Run mypy
run: mypy binarytree
- name: Run pytest
- run: py.test --cov=./ --cov-report=xml
+ run: py.test --cov=binarytree --cov-report=xml
- name: Run Sphinx doctest
run: python -m sphinx -b doctest docs docs/_build
- name: Run Sphinx HTML
diff --git a/.gitignore b/.gitignore
index f957230..1079769 100644
--- a/.gitignore
+++ b/.gitignore
@@ -93,3 +93,6 @@ ENV/
# setuptools-scm
binarytree/version.py
+
+# Jupyter Notebook
+*.ipynb
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 2b0e4c1..73805d8 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -24,7 +24,7 @@ repos:
rev: v0.790
hooks:
- id: mypy
- args: [ binarytree ]
+ files: ^binarytree/
- repo: https://gitlab.com/pycqa/flake8
rev: 3.8.4
hooks:
diff --git a/binarytree/__init__.py b/binarytree/__init__.py
index 50b93b5..0695713 100644
--- a/binarytree/__init__.py
+++ b/binarytree/__init__.py
@@ -2,13 +2,15 @@
import heapq
import random
+from collections import deque
from dataclasses import dataclass
-from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
+from subprocess import SubprocessError
+from typing import Any, Deque, Dict, Iterator, List, Optional, Tuple, Union
+from graphviz import Digraph, ExecutableNotFound, nohtml
from pkg_resources import get_distribution
from binarytree.exceptions import (
- GraphvizImportError,
NodeIndexError,
NodeModifyError,
NodeNotFoundError,
@@ -18,21 +20,31 @@
TreeHeightError,
)
-try:
- from graphviz import Digraph, nohtml
-
- GRAPHVIZ_INSTALLED = True
-except ImportError:
- GRAPHVIZ_INSTALLED = False
- Digraph = Any
- from binarytree.layout import generate_svg
-
__version__ = get_distribution("binarytree").version
-LEFT_FIELD = "left"
-RIGHT_FIELD = "right"
-VAL_FIELD = "val"
-VALUE_FIELD = "value"
+_ATTR_LEFT = "left"
+_ATTR_RIGHT = "right"
+_ATTR_VAL = "val"
+_ATTR_VALUE = "value"
+_SVG_XML_TEMPLATE = """
+
+"""
NodeValue = Union[float, int]
@@ -82,15 +94,6 @@ def __init__(
self.left = left
self.right = right
- if not isinstance(value, (float, int)):
- raise NodeValueError("node value must be a float or int")
-
- if left is not None and not isinstance(left, Node):
- raise NodeTypeError("left child must be a Node instance")
-
- if right is not None and not isinstance(right, Node):
- raise NodeTypeError("right child must be a Node instance")
-
def __repr__(self) -> str:
"""Return the string representation of the current node.
@@ -180,20 +183,23 @@ def __setattr__(self, attr: str, obj: Any) -> None:
...
NodeValueError: node value must be a float or int
"""
- if attr == LEFT_FIELD:
+ if attr == _ATTR_LEFT:
if obj is not None and not isinstance(obj, Node):
raise NodeTypeError("left child must be a Node instance")
- elif attr == RIGHT_FIELD:
+
+ elif attr == _ATTR_RIGHT:
if obj is not None and not isinstance(obj, Node):
raise NodeTypeError("right child must be a Node instance")
- elif attr == VALUE_FIELD:
+
+ elif attr == _ATTR_VALUE:
if not isinstance(obj, (float, int)):
raise NodeValueError("node value must be a float or int")
- object.__setattr__(self, VAL_FIELD, obj)
- elif attr == VAL_FIELD:
+ object.__setattr__(self, _ATTR_VAL, obj)
+
+ elif attr == _ATTR_VAL:
if not isinstance(obj, (float, int)):
raise NodeValueError("node value must be a float or int")
- object.__setattr__(self, VALUE_FIELD, obj)
+ object.__setattr__(self, _ATTR_VALUE, obj)
object.__setattr__(self, attr, obj)
@@ -229,17 +235,20 @@ def __iter__(self) -> Iterator["Node"]:
>>> list(root)
[Node(1), Node(2), Node(3), Node(4), Node(5)]
"""
- current_level = [self]
+ current_nodes = [self]
- while len(current_level) > 0:
- next_level = []
- for node in current_level:
+ while len(current_nodes) > 0:
+ next_nodes = []
+
+ for node in current_nodes:
yield node
+
if node.left is not None:
- next_level.append(node.left)
+ next_nodes.append(node.left)
if node.right is not None:
- next_level.append(node.right)
- current_level = next_level
+ next_nodes.append(node.right)
+
+ current_nodes = next_nodes
def __len__(self) -> int:
"""Return the total number of nodes in the binary tree.
@@ -302,15 +311,15 @@ def __getitem__(self, index: int) -> "Node":
if not isinstance(index, int) or index < 0:
raise NodeIndexError("node index must be a non-negative int")
- current_level: List[Optional[Node]] = [self]
+ current_nodes: List[Optional[Node]] = [self]
current_index = 0
has_more_nodes = True
while has_more_nodes:
has_more_nodes = False
- next_level: List[Optional[Node]] = []
+ next_nodes: List[Optional[Node]] = []
- for node in current_level:
+ for node in current_nodes:
if current_index == index:
if node is None:
break
@@ -319,15 +328,15 @@ def __getitem__(self, index: int) -> "Node":
current_index += 1
if node is None:
- next_level.append(None)
- next_level.append(None)
+ next_nodes.append(None)
+ next_nodes.append(None)
continue
- next_level.append(node.left)
- next_level.append(node.right)
+ next_nodes.append(node.left)
+ next_nodes.append(node.right)
if node.left is not None or node.right is not None:
has_more_nodes = True
- current_level = next_level
+ current_nodes = next_nodes
raise NodeNotFoundError("node missing at index {}".format(index))
@@ -403,7 +412,7 @@ def __setitem__(self, index: int, node: "Node") -> None:
"parent node missing at index {}".format(parent_index)
)
- setattr(parent, LEFT_FIELD if index % 2 else RIGHT_FIELD, node)
+ setattr(parent, _ATTR_LEFT if index % 2 else _ATTR_RIGHT, node)
def __delitem__(self, index: int) -> None:
"""Remove the node (or subtree) at the given level-order_ index.
@@ -461,29 +470,102 @@ def __delitem__(self, index: int) -> None:
except NodeNotFoundError:
raise NodeNotFoundError("no node to delete at index {}".format(index))
- child_attr = LEFT_FIELD if index % 2 == 1 else RIGHT_FIELD
+ child_attr = _ATTR_LEFT if index % 2 == 1 else _ATTR_RIGHT
if getattr(parent, child_attr) is None:
raise NodeNotFoundError("no node to delete at index {}".format(index))
setattr(parent, child_attr, None)
- def _repr_svg_(self) -> str:
+ def _repr_svg_(self) -> str: # pragma: no cover
"""Display the binary tree using Graphviz (used for `Jupyter notebooks`_).
.. _Jupyter notebooks: https://jupyter.org
"""
- if GRAPHVIZ_INSTALLED:
+ try:
# noinspection PyProtectedMember
return str(self.graphviz()._repr_svg_())
- else:
- return generate_svg(self.values) # pragma: no cover
- def graphviz(self, *args: Any, **kwargs: Any) -> Digraph:
+ except (SubprocessError, ExecutableNotFound, FileNotFoundError):
+ return self.svg()
+
+ def svg(self, node_radius: int = 16) -> str:
+ """Generate SVG XML.
+
+ :param node_radius: Node radius in pixels (default: 16).
+ :type node_radius: int
+ :return: Raw SVG XML.
+ :rtype: str
+ """
+ tree_height = self.height
+ scale = node_radius * 3
+ xml: Deque[str] = deque()
+
+ def scale_x(x: int, y: int) -> float:
+ diff = tree_height - y
+ x = 2 ** (diff + 1) * x + 2 ** diff - 1
+ return 1 + node_radius + scale * x / 2
+
+ def scale_y(y: int) -> float:
+ return scale * (1 + y)
+
+ def add_edge(parent_x: int, parent_y: int, node_x: int, node_y: int) -> None:
+ xml.appendleft(
+ ''.format(
+ x1=scale_x(parent_x, parent_y),
+ y1=scale_y(parent_y),
+ x2=scale_x(node_x, node_y),
+ y2=scale_y(node_y),
+ )
+ )
+
+ def add_node(node_x: int, node_y: int, node_value: NodeValue) -> None:
+ x, y = scale_x(node_x, node_y), scale_y(node_y)
+ xml.append(f'')
+ xml.append(f'{node_value}')
+
+ current_nodes = [self.left, self.right]
+ has_more_nodes = True
+ y = 1
+
+ add_node(0, 0, self.value)
+
+ while has_more_nodes:
+
+ has_more_nodes = False
+ next_nodes: List[Optional[Node]] = []
+
+ for x, node in enumerate(current_nodes):
+ if node is None:
+ next_nodes.append(None)
+ next_nodes.append(None)
+ else:
+ if node.left is not None or node.right is not None:
+ has_more_nodes = True
+
+ add_edge(x // 2, y - 1, x, y)
+ add_node(x, y, node.value)
+
+ next_nodes.append(node.left)
+ next_nodes.append(node.right)
+
+ current_nodes = next_nodes
+ y += 1
+
+ return _SVG_XML_TEMPLATE.format(
+ width=scale * (2 ** tree_height),
+ height=scale * (2 + tree_height),
+ body="\n".join(xml),
+ )
+
+ def graphviz(self, *args: Any, **kwargs: Any) -> Digraph: # pragma: no cover
"""Return a graphviz.Digraph_ object representing the binary tree.
+
This method's positional and keyword arguments are passed directly into the
the Digraph's **__init__** method.
+
:return: graphviz.Digraph_ object representing the binary tree.
:raise binarytree.exceptions.GraphvizImportError: If graphviz is not installed
+
.. code-block:: python
>>> from binarytree import tree
>>>
@@ -492,12 +574,9 @@ def graphviz(self, *args: Any, **kwargs: Any) -> Digraph:
>>> graph = t.graphviz() # Generate a graphviz object
>>> graph.body # Get the DOT body
>>> graph.render() # Render the graph
+
.. _graphviz.Digraph: https://graphviz.readthedocs.io/en/stable/api.html#digraph
"""
- if not GRAPHVIZ_INSTALLED:
- raise GraphvizImportError(
- "Can't use graphviz method if graphviz module is not installed"
- )
if "node_attr" not in kwargs:
kwargs["node_attr"] = {
"shape": "record",
@@ -507,13 +586,18 @@ def graphviz(self, *args: Any, **kwargs: Any) -> Digraph:
"fontcolor": "black",
}
digraph = Digraph(*args, **kwargs)
+
for node in self:
node_id = str(id(node))
+
digraph.node(node_id, nohtml(f"| {node.value}|"))
+
if node.left is not None:
digraph.edge(f"{node_id}:l", f"{id(node.left)}:v")
+
if node.right is not None:
digraph.edge(f"{node_id}:r", f"{id(node.right)}:v")
+
return digraph
def pprint(self, index: bool = False, delimiter: str = "-") -> None:
@@ -590,46 +674,47 @@ def validate(self) -> None:
NodeReferenceError: cyclic node reference at index 0
"""
has_more_nodes = True
- visited = set()
- to_visit: List[Optional[Node]] = [self]
- index = 0
+ nodes_seen = set()
+ current_nodes: List[Optional[Node]] = [self]
+ node_index = 0 # level-order index
while has_more_nodes:
+
has_more_nodes = False
- next_level: List[Optional[Node]] = []
+ next_nodes: List[Optional[Node]] = []
- for node in to_visit:
+ for node in current_nodes:
if node is None:
- next_level.append(None)
- next_level.append(None)
+ next_nodes.append(None)
+ next_nodes.append(None)
else:
- if node in visited:
+ if node in nodes_seen:
raise NodeReferenceError(
f"cyclic reference at Node({node.val}) "
- + f"(level-order index {index})"
+ + f"(level-order index {node_index})"
)
if not isinstance(node, Node):
raise NodeTypeError(
- "invalid node instance at index {}".format(index)
+ "invalid node instance at index {}".format(node_index)
)
if not isinstance(node.val, (float, int)):
raise NodeValueError(
- "invalid node value at index {}".format(index)
+ "invalid node value at index {}".format(node_index)
)
- if not isinstance(node.value, (float, int)):
+ if not isinstance(node.value, (float, int)): # pragma: no cover
raise NodeValueError(
- "invalid node value at index {}".format(index)
+ "invalid node value at index {}".format(node_index)
)
if node.left is not None or node.right is not None:
has_more_nodes = True
- visited.add(node)
- next_level.append(node.left)
- next_level.append(node.right)
+ nodes_seen.add(node)
+ next_nodes.append(node.left)
+ next_nodes.append(node.right)
- index += 1
+ node_index += 1
- to_visit = next_level
+ current_nodes = next_nodes
@property
def values(self) -> List[Optional[NodeValue]]:
@@ -660,35 +745,34 @@ def values(self) -> List[Optional[NodeValue]]:
>>> root.values
[1, 2, 3, None, 4]
"""
- current_level: List[Optional[Node]] = [self]
+ current_nodes: List[Optional[Node]] = [self]
has_more_nodes = True
- values: List[Optional[NodeValue]] = []
+ node_values: List[Optional[NodeValue]] = []
while has_more_nodes:
has_more_nodes = False
- next_level: List[Optional[Node]] = []
+ next_nodes: List[Optional[Node]] = []
- for node in current_level:
+ for node in current_nodes:
if node is None:
- values.append(None)
- next_level.append(None)
- next_level.append(None)
- continue
-
- if node.left is not None or node.right is not None:
- has_more_nodes = True
+ node_values.append(None)
+ next_nodes.append(None)
+ next_nodes.append(None)
+ else:
+ if node.left is not None or node.right is not None:
+ has_more_nodes = True
- values.append(node.val)
- next_level.append(node.left)
- next_level.append(node.right)
+ node_values.append(node.val)
+ next_nodes.append(node.left)
+ next_nodes.append(node.right)
- current_level = next_level
+ current_nodes = next_nodes
# Get rid of trailing None values
- while values and values[-1] is None:
- values.pop()
+ while node_values and node_values[-1] is None:
+ node_values.pop()
- return values
+ return node_values
@property
def leaves(self) -> List["Node"]:
@@ -721,20 +805,20 @@ def leaves(self) -> List["Node"]:
>>> root.leaves
[Node(3), Node(4)]
"""
- current_level = [self]
+ current_nodes = [self]
leaves = []
- while len(current_level) > 0:
- next_level = []
- for node in current_level:
+ while len(current_nodes) > 0:
+ next_nodes = []
+ for node in current_nodes:
if node.left is None and node.right is None:
leaves.append(node)
continue
if node.left is not None:
- next_level.append(node.left)
+ next_nodes.append(node.left)
if node.right is not None:
- next_level.append(node.right)
- current_level = next_level
+ next_nodes.append(node.right)
+ current_nodes = next_nodes
return leaves
@property
@@ -767,18 +851,21 @@ def levels(self) -> List[List["Node"]]:
>>> root.levels
[[Node(1)], [Node(2), Node(3)], [Node(4)]]
"""
- current_level = [self]
+ current_nodes = [self]
levels = []
- while len(current_level) > 0:
- next_level = []
- for node in current_level:
+ while len(current_nodes) > 0:
+ next_nodes = []
+
+ for node in current_nodes:
if node.left is not None:
- next_level.append(node.left)
+ next_nodes.append(node.left)
if node.right is not None:
- next_level.append(node.right)
- levels.append(current_level)
- current_level = next_level
+ next_nodes.append(node.right)
+
+ levels.append(current_nodes)
+ current_nodes = next_nodes
+
return levels
@property
@@ -1489,18 +1576,18 @@ def levelorder(self) -> List["Node"]:
>>> root.levelorder
[Node(1), Node(2), Node(3), Node(4), Node(5)]
"""
- current_level = [self]
+ current_nodes = [self]
result = []
- while len(current_level) > 0:
- next_level = []
- for node in current_level:
+ while len(current_nodes) > 0:
+ next_nodes = []
+ for node in current_nodes:
result.append(node)
if node.left is not None:
- next_level.append(node.left)
+ next_nodes.append(node.left)
if node.right is not None:
- next_level.append(node.right)
- current_level = next_level
+ next_nodes.append(node.right)
+ current_nodes = next_nodes
return result
@@ -1645,7 +1732,10 @@ def _generate_random_node_values(height: int) -> List[int]:
def _build_tree_string(
- root: Optional[Node], curr_index: int, index: bool = False, delimiter: str = "-"
+ root: Optional[Node],
+ curr_index: int,
+ include_index: bool = False,
+ delimiter: str = "-",
) -> Tuple[List[str], int, int, int]:
"""Recursively walk down the binary tree and build a pretty-print string.
@@ -1660,9 +1750,9 @@ def _build_tree_string(
:type root: binarytree.Node | None
:param curr_index: Level-order_ index of the current node (root node is 0).
:type curr_index: int
- :param index: If set to True, include the level-order_ node indexes using
+ :param include_index: If set to True, include the level-order_ node indexes using
the following format: ``{index}{delimiter}{value}`` (default: False).
- :type index: bool
+ :type include_index: bool
:param delimiter: Delimiter character between the node index and the node
value (default: '-').
:type delimiter:
@@ -1679,7 +1769,7 @@ def _build_tree_string(
line1 = []
line2 = []
- if index:
+ if include_index:
node_repr = "{}{}{}".format(curr_index, delimiter, root.val)
else:
node_repr = str(root.val)
@@ -1688,10 +1778,10 @@ def _build_tree_string(
# Get the left and right sub-boxes, their widths, and root repr positions
l_box, l_box_width, l_root_start, l_root_end = _build_tree_string(
- root.left, 2 * curr_index + 1, index, delimiter
+ root.left, 2 * curr_index + 1, include_index, delimiter
)
r_box, r_box_width, r_root_start, r_root_end = _build_tree_string(
- root.right, 2 * curr_index + 2, index, delimiter
+ root.right, 2 * curr_index + 2, include_index, delimiter
)
# Draw the branch connecting the current root node to the left sub-box
@@ -1752,14 +1842,14 @@ def _get_tree_properties(root: Node) -> NodeProperties:
max_leaf_depth = -1
is_strict = True
is_complete = True
- current_level = [root]
+ current_nodes = [root]
non_full_node_seen = False
- while len(current_level) > 0:
+ while len(current_nodes) > 0:
max_leaf_depth += 1
- next_level = []
+ next_nodes = []
- for node in current_level:
+ for node in current_nodes:
size += 1
val = node.val
min_node_value = min(val, min_node_value)
@@ -1776,7 +1866,7 @@ def _get_tree_properties(root: Node) -> NodeProperties:
is_descending = False
elif node.left.val < val:
is_ascending = False
- next_level.append(node.left)
+ next_nodes.append(node.left)
is_complete = not non_full_node_seen
else:
non_full_node_seen = True
@@ -1786,7 +1876,7 @@ def _get_tree_properties(root: Node) -> NodeProperties:
is_descending = False
elif node.right.val < val:
is_ascending = False
- next_level.append(node.right)
+ next_nodes.append(node.right)
is_complete = not non_full_node_seen
else:
non_full_node_seen = True
@@ -1794,7 +1884,7 @@ def _get_tree_properties(root: Node) -> NodeProperties:
# If we see a node with only one child, it is not strict
is_strict &= (node.left is None) == (node.right is None)
- current_level = next_level
+ current_nodes = next_nodes
return NodeProperties(
height=max_leaf_depth,
@@ -1918,7 +2008,7 @@ def build(values: List[int]) -> Optional[Node]:
raise NodeNotFoundError(
"parent node missing at index {}".format(parent_index)
)
- setattr(parent, LEFT_FIELD if index % 2 else RIGHT_FIELD, node)
+ setattr(parent, _ATTR_LEFT if index % 2 else _ATTR_RIGHT, node)
return nodes[0] if nodes else None
@@ -1982,7 +2072,7 @@ def tree(height: int = 3, is_perfect: bool = False) -> Optional[Node]:
inserted = False
while depth < height and not inserted:
- attr = random.choice((LEFT_FIELD, RIGHT_FIELD))
+ attr = random.choice((_ATTR_LEFT, _ATTR_RIGHT))
if getattr(node, attr) is None:
setattr(node, attr, Node(value))
inserted = True
@@ -2048,7 +2138,7 @@ def bst(height: int = 3, is_perfect: bool = False) -> Optional[Node]:
inserted = False
while depth < height and not inserted:
- attr = LEFT_FIELD if node.val > value else RIGHT_FIELD
+ attr = _ATTR_LEFT if node.val > value else _ATTR_RIGHT
if getattr(node, attr) is None:
setattr(node, attr, Node(value))
inserted = True
diff --git a/binarytree/exceptions.py b/binarytree/exceptions.py
index e5b22ca..0676e24 100644
--- a/binarytree/exceptions.py
+++ b/binarytree/exceptions.py
@@ -28,7 +28,3 @@ class NodeValueError(BinaryTreeError):
class TreeHeightError(BinaryTreeError):
"""Tree height was invalid."""
-
-
-class GraphvizImportError(BinaryTreeError):
- """graphviz module is not installed"""
diff --git a/binarytree/layout.py b/binarytree/layout.py
deleted file mode 100644
index 7d200d4..0000000
--- a/binarytree/layout.py
+++ /dev/null
@@ -1,118 +0,0 @@
-""" Module containing layout related algorithms."""
-from typing import List, Tuple, Union
-
-
-def _get_coords(
- values: List[Union[float, int, None]]
-) -> Tuple[
- List[Tuple[int, int, Union[float, int, None]]], List[Tuple[int, int, int, int]]
-]:
- """Generate the coordinates used for rendering the nodes and edges.
-
- node and edges are stored as tuples in the form node: (x, y, label) and
- edge: (x1, y1, x2, y2)
-
- Each coordinate is relative y is the depth, x is the position of the node
- on a level from left to right 0 to 2**depth -1
-
- :param values: Values of the binary tree.
- :type values: list of ints
- :return: nodes and edges list
- :rtype: two lists of tuples
-
- """
- x = 0
- y = 0
- nodes = []
- edges = []
-
- # root node
- nodes.append((x, y, values[0]))
- # append other nodes and their edges
- y += 1
- for value in values[1:]:
- if value is not None:
- nodes.append((x, y, value))
- edges.append((x // 2, y - 1, x, y))
- x += 1
- # check if level is full
- if x == 2 ** y:
- x = 0
- y += 1
- return nodes, edges
-
-
-def generate_svg(values: List[Union[float, int, None]]) -> str:
- """Generate a svg image from a binary tree
-
- A simple layout is used based on a perfect tree of same height in which all
- leaves would be regularly spaced.
-
- :param values: Values of the binary tree.
- :type values: list of ints
- :return: the svg image of the tree.
- :rtype: str
- """
- node_size = 16.0
- stroke_width = 1.5
- gutter = 0.5
- x_scale = (2 + gutter) * node_size
- y_scale = 3.0 * node_size
-
- # retrieve relative coordinates
- nodes, edges = _get_coords(values)
- y_min = min([n[1] for n in nodes])
- y_max = max([n[1] for n in nodes])
-
- # generate the svg string
- svg = f"""
- "
- return svg
diff --git a/pyproject.toml b/pyproject.toml
index 3b3a65f..04adc11 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -7,10 +7,7 @@ requires = [
build-backend = "setuptools.build_meta"
[tool.coverage.run]
-omit = [
- "binarytree/version.py",
- "setup.py"
-]
+omit = ["binarytree/version.py"]
[tool.isort]
profile = "black"
diff --git a/tests/test_layout.py b/tests/test_layout.py
deleted file mode 100644
index 0a92258..0000000
--- a/tests/test_layout.py
+++ /dev/null
@@ -1,17 +0,0 @@
-import xml.etree.ElementTree as ET
-
-from binarytree.layout import _get_coords, generate_svg
-
-
-def test_get_coords():
- values = [0, 6, 5, None, 1, 4, 2]
- assert _get_coords(values) == (
- [(0, 0, 0), (0, 1, 6), (1, 1, 5), (1, 2, 1), (2, 2, 4), (3, 2, 2)],
- [(0, 0, 0, 1), (0, 0, 1, 1), (0, 1, 1, 2), (1, 1, 2, 2), (1, 1, 3, 2)],
- )
-
-
-def test_svg():
- svg = generate_svg([0, 1, 2])
- svg_tree = ET.fromstring(svg)
- assert svg_tree.tag == "{http://www.w3.org/2000/svg}svg"
diff --git a/tests/test_tree.py b/tests/test_tree.py
index 765169b..119bda6 100644
--- a/tests/test_tree.py
+++ b/tests/test_tree.py
@@ -19,7 +19,62 @@
REPETITIONS = 20
-
+EXPECTED_SVG_XML_SINGLE_NODE = """
+
+"""
+
+EXPECTED_SVG_XML_MULTIPLE_NODES = """
+
+"""
+
+
+# noinspection PyTypeChecker
def test_node_set_attributes():
root = Node(1)
assert root.left is None
@@ -869,6 +924,7 @@ def test_heap_float_values():
assert root.size == root_copy.size
+# noinspection PyTypeChecker
def test_get_parent():
root = Node(0)
root.left = Node(1)
@@ -884,3 +940,14 @@ def test_get_parent():
assert get_parent(root, Node(5)) is None
assert get_parent(None, root.left) is None
assert get_parent(root, None) is None
+
+
+def test_svg_generation():
+ root = Node(0)
+ assert root.svg() == EXPECTED_SVG_XML_SINGLE_NODE
+
+ root.left = Node(1)
+ root.right = Node(2)
+ root.left.left = Node(3)
+ root.right.right = Node(4)
+ assert root.svg() == EXPECTED_SVG_XML_MULTIPLE_NODES