From f74a10e24a22fb000e793f27fc5f7c144e817903 Mon Sep 17 00:00:00 2001 From: Joohwan Oh Date: Tue, 9 Feb 2021 01:01:22 -0800 Subject: [PATCH] Add type annotations --- binarytree/__init__.py | 303 +++++++++++++++++++++-------------------- 1 file changed, 156 insertions(+), 147 deletions(-) diff --git a/binarytree/__init__.py b/binarytree/__init__.py index 43981c3..503b14c 100644 --- a/binarytree/__init__.py +++ b/binarytree/__init__.py @@ -2,7 +2,7 @@ import heapq import random -from typing import List, Optional, Union +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union from graphviz import Digraph, nohtml from pkg_resources import get_distribution @@ -19,61 +19,13 @@ __version__ = get_distribution("binarytree").version -LEFT = "left" -RIGHT = "right" -VAL = "val" -VALUE = "value" +LEFT_FIELD = "left" +RIGHT_FIELD = "right" +VAL_FIELD = "val" +VALUE_FIELD = "value" - -def get_parent(root, child): - """Search the binary tree and return the parent of given child. - - :param root: Root node of the binary tree. - :type: binarytree.Node - :param child: Child node. - :rtype: binarytree.Node - :return: Parent node, or None if missing. - :rtype: binarytree.Node - - **Example**: - - .. doctest:: - - >>> from binarytree import Node, get_parent - >>> - >>> root = Node(1) - >>> root.left = Node(2) - >>> root.right = Node(3) - >>> root.left.right = Node(4) - >>> - >>> print(root) - - __1 - / \\ - 2 3 - \\ - 4 - - >>> print(get_parent(root, root.left.right)) - - 2 - \\ - 4 - - """ - if child is None: - return None - - stack = [root] - while stack: - node = stack.pop() - if node: - if node.left is child or node.right is child: - return node - else: - stack.append(node.left) - stack.append(node.right) - return None +NodeValue = Union[float, int] +NodeProperty = Union[float, int, bool] class Node: @@ -98,10 +50,13 @@ class Node: def __init__( self, - value: Union[float, int], + value: NodeValue, left: Optional["Node"] = None, right: Optional["Node"] = None, - ): + ) -> None: + self.value = self.val = value + self.left = left + self.right = right if not isinstance(value, (float, int)): raise NodeValueError("node value must be a float or int") @@ -112,15 +67,11 @@ def __init__( if right is not None and not isinstance(right, Node): raise NodeTypeError("right child must be a Node instance") - self.value = self.val = value - self.left = left - self.right = right - - def __repr__(self): + def __repr__(self) -> str: """Return the string representation of the current node. :return: String representation. - :rtype: str | unicode + :rtype: str **Example**: @@ -133,11 +84,11 @@ def __repr__(self): """ return "Node({})".format(self.val) - def __str__(self): + def __str__(self) -> str: """Return the pretty-print string for the binary tree. :return: Pretty-print string. - :rtype: str | unicode + :rtype: str **Example**: @@ -169,13 +120,13 @@ def __str__(self): lines = _build_tree_string(self, 0, False, "-")[0] return "\n" + "\n".join((line.rstrip() for line in lines)) - def __setattr__(self, attr, obj): + def __setattr__(self, attr: str, obj: Any) -> None: """Modified version of ``__setattr__`` with extra sanity checking. Class attributes **left**, **right** and **value** are validated. :param attr: Name of the class attribute. - :type attr: str | unicode + :type attr: str :param obj: Object to set. :type obj: object :raise binarytree.exceptions.NodeTypeError: If left or right child is @@ -205,24 +156,24 @@ def __setattr__(self, attr, obj): ... NodeValueError: node value must be a float or int """ - if attr == LEFT: + if attr == LEFT_FIELD: if obj is not None and not isinstance(obj, Node): raise NodeTypeError("left child must be a Node instance") - elif attr == RIGHT: + elif attr == RIGHT_FIELD: if obj is not None and not isinstance(obj, Node): raise NodeTypeError("right child must be a Node instance") - elif attr == VALUE: + elif attr == VALUE_FIELD: if not isinstance(obj, (float, int)): raise NodeValueError("node value must be a float or int") - object.__setattr__(self, VAL, obj) - elif attr == VAL: + object.__setattr__(self, VAL_FIELD, obj) + elif attr == VAL_FIELD: if not isinstance(obj, (float, int)): raise NodeValueError("node value must be a float or int") - object.__setattr__(self, VALUE, obj) + object.__setattr__(self, VALUE_FIELD, obj) object.__setattr__(self, attr, obj) - def __iter__(self): + def __iter__(self) -> Iterator["Node"]: """Iterate through the nodes in the binary tree in level-order_. .. _level-order: @@ -266,7 +217,7 @@ def __iter__(self): next_level.append(node.right) current_level = next_level - def __len__(self): + def __len__(self) -> int: """Return the total number of nodes in the binary tree. :return: Total number of nodes. @@ -290,7 +241,7 @@ def __len__(self): """ return self.properties["size"] - def __getitem__(self, index): + def __getitem__(self, index: int) -> "Node": """Return the node (or subtree) at the given level-order_ index. .. _level-order: @@ -327,13 +278,13 @@ def __getitem__(self, index): if not isinstance(index, int) or index < 0: raise NodeIndexError("node index must be a non-negative int") - current_level = [self] + current_level: List[Optional[Node]] = [self] current_index = 0 has_more_nodes = True while has_more_nodes: has_more_nodes = False - next_level = [] + next_level: List[Optional[Node]] = [] for node in current_level: if current_index == index: @@ -356,7 +307,7 @@ def __getitem__(self, index): raise NodeNotFoundError("node missing at index {}".format(index)) - def __setitem__(self, index, node): + def __setitem__(self, index: int, node: "Node") -> None: """Insert a node (or subtree) at the given level-order_ index. * An exception is raised if the parent node is missing. @@ -428,9 +379,9 @@ def __setitem__(self, index, node): "parent node missing at index {}".format(parent_index) ) - setattr(parent, LEFT if index % 2 else RIGHT, node) + setattr(parent, LEFT_FIELD if index % 2 else RIGHT_FIELD, node) - def __delitem__(self, index): + def __delitem__(self, index: int) -> None: """Remove the node (or subtree) at the given level-order_ index. * An exception is raised if the target node is missing. @@ -486,13 +437,13 @@ def __delitem__(self, index): except NodeNotFoundError: raise NodeNotFoundError("no node to delete at index {}".format(index)) - child_attr = LEFT if index % 2 == 1 else RIGHT + child_attr = LEFT_FIELD if index % 2 == 1 else RIGHT_FIELD if getattr(parent, child_attr) is None: raise NodeNotFoundError("no node to delete at index {}".format(index)) setattr(parent, child_attr, None) - def _repr_svg_(self): + def _repr_svg_(self) -> str: """Display the binary tree using Graphviz (used for `Jupyter notebooks`_). .. _Jupyter notebooks: https://jupyter.org @@ -500,7 +451,7 @@ def _repr_svg_(self): # noinspection PyProtectedMember return self.graphviz()._repr_svg_() - def graphviz(self, *args, **kwargs) -> Digraph: + def graphviz(self, *args: Any, **kwargs: Any) -> Digraph: """Return a graphviz.Digraph_ object representing the binary tree. This method's positional and keyword arguments are passed directly into the @@ -544,7 +495,7 @@ def graphviz(self, *args, **kwargs) -> Digraph: return digraph - def pprint(self, index=False, delimiter="-"): + def pprint(self, index: bool = False, delimiter: str = "-") -> None: """Pretty-print the binary tree. :param index: If set to True (default: False), display level-order_ @@ -552,7 +503,7 @@ def pprint(self, index=False, delimiter="-"): :type index: bool :param delimiter: Delimiter character between the node index and the node value (default: '-'). - :type delimiter: str | unicode + :type delimiter: str **Example**: @@ -592,7 +543,7 @@ def pprint(self, index=False, delimiter="-"): lines = _build_tree_string(self, 0, index, delimiter)[0] print("\n" + "\n".join((line.rstrip() for line in lines))) - def validate(self): + def validate(self) -> None: """Check if the binary tree is malformed. :raise binarytree.exceptions.NodeReferenceError: If there is a @@ -619,12 +570,12 @@ def validate(self): """ has_more_nodes = True visited = set() - to_visit = [self] + to_visit: List[Optional[Node]] = [self] index = 0 while has_more_nodes: has_more_nodes = False - next_level = [] + next_level: List[Optional[Node]] = [] for node in to_visit: if node is None: @@ -660,7 +611,7 @@ def validate(self): to_visit = next_level @property - def values(self): + def values(self) -> List[Optional[NodeValue]]: """Return the `list representation`_ of the binary tree. .. _list representation: @@ -672,7 +623,7 @@ def values(self): right child at 2i + 2, and parent at index floor((i - 1) / 2). None indicates absence of a node at that index. See example below for an illustration. - :rtype: [int | float | None] + :rtype: [float | int | None] **Example**: @@ -688,13 +639,14 @@ def values(self): >>> root.values [1, 2, 3, None, 4] """ - current_level = [self] + current_level: List[Optional[Node]] = [self] has_more_nodes = True - values = [] + values: List[Optional[NodeValue]] = [] while has_more_nodes: has_more_nodes = False - next_level = [] + next_level: List[Optional[Node]] = [] + for node in current_level: if node is None: values.append(None) @@ -711,14 +663,14 @@ def values(self): current_level = next_level - # Get rid of trailing None's + # Get rid of trailing None values while values and values[-1] is None: values.pop() return values @property - def leaves(self): + def leaves(self) -> List["Node"]: """Return the leaf nodes of the binary tree. A leaf node is any node that does not have child nodes. @@ -765,7 +717,7 @@ def leaves(self): return leaves @property - def levels(self): + def levels(self) -> List[List["Node"]]: """Return the nodes in the binary tree level by level. :return: Lists of nodes level by level. @@ -809,7 +761,7 @@ def levels(self): return levels @property - def height(self): + def height(self) -> int: """Return the height of the binary tree. Height of a binary tree is the number of edges on the longest path @@ -846,7 +798,7 @@ def height(self): return _get_tree_properties(self)["height"] @property - def size(self): + def size(self) -> int: """Return the total number of nodes in the binary tree. :return: Total number of nodes. @@ -872,7 +824,7 @@ def size(self): return _get_tree_properties(self)["size"] @property - def leaf_count(self): + def leaf_count(self) -> int: """Return the total number of leaf nodes in the binary tree. A leaf node is a node with no child nodes. @@ -897,7 +849,7 @@ def leaf_count(self): return _get_tree_properties(self)["leaf_count"] @property - def is_balanced(self): + def is_balanced(self) -> bool: """Check if the binary tree is height-balanced. A binary tree is height-balanced if it meets the following criteria: @@ -935,7 +887,7 @@ def is_balanced(self): return _is_balanced(self) >= 0 @property - def is_bst(self): + def is_bst(self) -> bool: """Check if the binary tree is a BST_ (binary search tree). :return: True if the binary tree is a BST_, False otherwise. @@ -965,7 +917,7 @@ def is_bst(self): return _is_bst(self) @property - def is_symmetric(self): + def is_symmetric(self) -> bool: """Check if the binary tree is symmetric. A binary tree is symmetric if it meets the following criteria: @@ -1003,7 +955,7 @@ def is_symmetric(self): return _is_symmetric(self) @property - def is_max_heap(self): + def is_max_heap(self) -> bool: """Check if the binary tree is a `max heap`_. :return: True if the binary tree is a `max heap`_, False otherwise. @@ -1033,7 +985,7 @@ def is_max_heap(self): return _get_tree_properties(self)["is_max_heap"] @property - def is_min_heap(self): + def is_min_heap(self) -> bool: """Check if the binary tree is a `min heap`_. :return: True if the binary tree is a `min heap`_, False otherwise. @@ -1063,7 +1015,7 @@ def is_min_heap(self): return _get_tree_properties(self)["is_min_heap"] @property - def is_perfect(self): + def is_perfect(self) -> bool: """Check if the binary tree is perfect. A binary tree is perfect if all its levels are completely filled. See @@ -1100,7 +1052,7 @@ def is_perfect(self): return _get_tree_properties(self)["is_perfect"] @property - def is_strict(self): + def is_strict(self) -> bool: """Check if the binary tree is strict. A binary tree is strict if all its non-leaf nodes have both the left @@ -1135,7 +1087,7 @@ def is_strict(self): return _get_tree_properties(self)["is_strict"] @property - def is_complete(self): + def is_complete(self) -> bool: """Check if the binary tree is complete. A binary tree is complete if it meets the following criteria: @@ -1172,11 +1124,11 @@ def is_complete(self): return _get_tree_properties(self)["is_complete"] @property - def min_node_value(self): + def min_node_value(self) -> NodeValue: """Return the minimum node value of the binary tree. :return: Minimum node value. - :rtype: int + :rtype: float | int **Example**: @@ -1194,11 +1146,11 @@ def min_node_value(self): return _get_tree_properties(self)["min_node_value"] @property - def max_node_value(self): + def max_node_value(self) -> NodeValue: """Return the maximum node value of the binary tree. :return: Maximum node value. - :rtype: int + :rtype: float | int **Example**: @@ -1216,7 +1168,7 @@ def max_node_value(self): return _get_tree_properties(self)["max_node_value"] @property - def max_leaf_depth(self): + def max_leaf_depth(self) -> int: """Return the maximum leaf node depth of the binary tree. :return: Maximum leaf node depth. @@ -1250,7 +1202,7 @@ def max_leaf_depth(self): return _get_tree_properties(self)["max_leaf_depth"] @property - def min_leaf_depth(self): + def min_leaf_depth(self) -> int: """Return the minimum leaf node depth of the binary tree. :return: Minimum leaf node depth. @@ -1284,7 +1236,7 @@ def min_leaf_depth(self): return _get_tree_properties(self)["min_leaf_depth"] @property - def properties(self): + def properties(self) -> Dict[str, Any]: """Return various properties of the binary tree. :return: Binary tree properties. @@ -1345,7 +1297,7 @@ def properties(self): return properties @property - def inorder(self): + def inorder(self) -> List["Node"]: """Return the nodes in the binary tree using in-order_ traversal. An in-order_ traversal visits left subtree, root, then right subtree. @@ -1380,7 +1332,7 @@ def inorder(self): """ result: List[Node] = [] stack: List[Node] = [] - node = self + node: Optional[Node] = self while node or stack: while node: @@ -1394,7 +1346,7 @@ def inorder(self): return result @property - def preorder(self): + def preorder(self) -> List["Node"]: """Return the nodes in the binary tree using pre-order_ traversal. A pre-order_ traversal visits root, left subtree, then right subtree. @@ -1427,8 +1379,8 @@ def preorder(self): >>> root.preorder [Node(1), Node(2), Node(4), Node(5), Node(3)] """ - result = [] - stack = [self] + result: List[Node] = [] + stack: List[Optional[Node]] = [self] while stack: node = stack.pop() @@ -1440,7 +1392,7 @@ def preorder(self): return result @property - def postorder(self): + def postorder(self) -> List["Node"]: """Return the nodes in the binary tree using post-order_ traversal. A post-order_ traversal visits left subtree, right subtree, then root. @@ -1473,8 +1425,8 @@ def postorder(self): >>> root.postorder [Node(4), Node(5), Node(2), Node(3), Node(1)] """ - result = [] - stack = [self] + result: List[Node] = [] + stack: List[Optional[Node]] = [self] while stack: node = stack.pop() @@ -1486,7 +1438,7 @@ def postorder(self): return result[::-1] @property - def levelorder(self): + def levelorder(self) -> List["Node"]: """Return the nodes in the binary tree using level-order_ traversal. A level-order_ traversal visits nodes left to right, level by level. @@ -1536,7 +1488,7 @@ def levelorder(self): return result -def _is_balanced(root: Optional[Node]): +def _is_balanced(root: Optional[Node]) -> int: """Return the tree height + 1 if balanced, -1 otherwise. :param root: Root node of the binary tree. @@ -1555,7 +1507,7 @@ def _is_balanced(root: Optional[Node]): return -1 if abs(left - right) > 1 else max(left, right) + 1 -def _is_bst(root: Node): +def _is_bst(root: Optional[Node]) -> bool: """Check if the binary tree is a BST (binary search tree). :param root: Root node of the binary tree. @@ -1564,8 +1516,9 @@ def _is_bst(root: Node): :rtype: bool """ stack: List[Node] = [] - cur: Optional[Node] = root - pre: Optional[Node] = None + cur = root + pre = None + while stack or cur is not None: if cur is not None: stack.append(cur) @@ -1579,7 +1532,7 @@ def _is_bst(root: Node): return True -def _is_symmetric(root): +def _is_symmetric(root: Optional[Node]) -> bool: """Check if the binary tree is symmetric. :param root: Root node of the binary tree. @@ -1602,37 +1555,37 @@ def symmetric_helper(left_subtree, right_subtree): return symmetric_helper(root, root) -def _validate_tree_height(height): +def _validate_tree_height(height: int): """Check if the height of the binary tree is valid. :param height: Height of the binary tree (must be 0 - 9 inclusive). :type height: int :raise binarytree.exceptions.TreeHeightError: If height is invalid. """ - if not (isinstance(height, int) and 0 <= height <= 9): + if not (type(height) == int and 0 <= height <= 9): raise TreeHeightError("height must be an int between 0 - 9") -def _generate_perfect_bst(height): +def _generate_perfect_bst(height: int) -> Optional[Node]: """Generate a perfect BST (binary search tree) and return its root. :param height: Height of the BST. :type height: int :return: Root node of the BST. - :rtype: binarytree.Node + :rtype: binarytree.Node | None """ max_node_count = 2 ** (height + 1) - 1 node_values = list(range(max_node_count)) return _build_bst_from_sorted_values(node_values) -def _build_bst_from_sorted_values(sorted_values): +def _build_bst_from_sorted_values(sorted_values: List[int]) -> Optional[Node]: """Recursively build a perfect BST from odd number of sorted values. :param sorted_values: Odd number of sorted values. :type sorted_values: [int | float] :return: Root node of the BST. - :rtype: binarytree.Node + :rtype: binarytree.Node | None """ if len(sorted_values) == 0: return None @@ -1643,7 +1596,7 @@ def _build_bst_from_sorted_values(sorted_values): return root -def _generate_random_leaf_count(height): +def _generate_random_leaf_count(height: int) -> int: """Return a random leaf count for building binary trees. :param height: Height of the binary tree. @@ -1660,7 +1613,7 @@ def _generate_random_leaf_count(height): return roll_1 + roll_2 or half_leaf_count -def _generate_random_node_values(height): +def _generate_random_node_values(height: int) -> List[int]: """Return random node values for building binary trees. :param height: Height of the binary tree. @@ -1674,7 +1627,9 @@ def _generate_random_node_values(height): return node_values -def _build_tree_string(root, curr_index, index=False, delimiter="-"): +def _build_tree_string( + root: Optional[Node], curr_index: int, index: bool = False, delimiter: str = "-" +) -> Tuple[List[str], int, int, int]: """Recursively walk down the binary tree and build a pretty-print string. In each recursive call, a "box" of characters visually representing the @@ -1762,7 +1717,7 @@ def _build_tree_string(root, curr_index, index=False, delimiter="-"): return new_box, len(new_box[0]), new_root_start, new_root_end -def _get_tree_properties(root): +def _get_tree_properties(root: Node) -> Dict[str, Any]: """Inspect the binary tree and return its properties (e.g. height). :param root: Root node of the binary tree. @@ -1840,7 +1795,59 @@ def _get_tree_properties(root): } -def build(values): +def get_parent(root: Node, child: Node) -> Optional[Node]: + """Search the binary tree and return the parent of given child. + + :param root: Root node of the binary tree. + :type: binarytree.Node + :param child: Child node. + :rtype: binarytree.Node + :return: Parent node, or None if missing. + :rtype: binarytree.Node | None + + **Example**: + + .. doctest:: + + >>> from binarytree import Node, get_parent + >>> + >>> root = Node(1) + >>> root.left = Node(2) + >>> root.right = Node(3) + >>> root.left.right = Node(4) + >>> + >>> print(root) + + __1 + / \\ + 2 3 + \\ + 4 + + >>> print(get_parent(root, root.left.right)) + + 2 + \\ + 4 + + """ + if child is None: + return None + + stack: List[Optional[Node]] = [root] + + while stack: + node = stack.pop() + if node: + if node.left is child or node.right is child: + return node + else: + stack.append(node.left) + stack.append(node.right) + return None + + +def build(values: List) -> Optional[Node]: """Build a tree from `list representation`_ and return its root node. .. _list representation: @@ -1853,7 +1860,7 @@ def build(values): absence of a node at that index. See example below for an illustration. :type values: [int | float | None] :return: Root node of the binary tree. - :rtype: binarytree.Node + :rtype: binarytree.Node | None :raise binarytree.exceptions.NodeNotFoundError: If the list representation is malformed (e.g. a parent node is missing). @@ -1894,12 +1901,12 @@ def build(values): raise NodeNotFoundError( "parent node missing at index {}".format(parent_index) ) - setattr(parent, LEFT if index % 2 else RIGHT, node) + setattr(parent, LEFT_FIELD if index % 2 else RIGHT_FIELD, node) return nodes[0] if nodes else None -def tree(height=3, is_perfect=False): +def tree(height: int = 3, is_perfect: bool = False) -> Optional[Node]: """Generate a random binary tree and return its root node. :param height: Height of the tree (default: 3, range: 0 - 9 inclusive). @@ -1958,7 +1965,7 @@ def tree(height=3, is_perfect=False): inserted = False while depth < height and not inserted: - attr = random.choice((LEFT, RIGHT)) + attr = random.choice((LEFT_FIELD, RIGHT_FIELD)) if getattr(node, attr) is None: setattr(node, attr, Node(value)) inserted = True @@ -1973,7 +1980,7 @@ def tree(height=3, is_perfect=False): return root_node -def bst(height=3, is_perfect=False): +def bst(height: int = 3, is_perfect: bool = False) -> Optional[Node]: """Generate a random BST (binary search tree) and return its root node. :param height: Height of the BST (default: 3, range: 0 - 9 inclusive). @@ -2024,7 +2031,7 @@ def bst(height=3, is_perfect=False): inserted = False while depth < height and not inserted: - attr = LEFT if node.val > value else RIGHT + attr = LEFT_FIELD if node.val > value else RIGHT_FIELD if getattr(node, attr) is None: setattr(node, attr, Node(value)) inserted = True @@ -2039,7 +2046,9 @@ def bst(height=3, is_perfect=False): return root_node -def heap(height=3, is_max=True, is_perfect=False): +def heap( + height: int = 3, is_max: bool = True, is_perfect: bool = False +) -> Optional[Node]: """Generate a random heap and return its root node. :param height: Height of the heap (default: 3, range: 0 - 9 inclusive).