diff --git a/src/atmst/mst/node.py b/src/atmst/mst/node.py index 1691f02..fba9dae 100644 --- a/src/atmst/mst/node.py +++ b/src/atmst/mst/node.py @@ -114,7 +114,7 @@ def deserialise(cls, data: bytes) -> "Self": keys=tuple(keys), vals=tuple(vals) ) - + def is_empty(self) -> bool: return self.subtrees == (None,) @@ -128,18 +128,25 @@ def _to_optional(self) -> Optional[CID]: @cached_property - def height(self) -> int: + def maybe_height(self) -> Optional[int]: # if there are keys at this level, query one directly if self.keys: return self.key_height(self.keys[0]) - + # we're an empty tree if self.subtrees[0] is None: return 0 - - # this should only happen for non-root nodes with no keys - raise Exception("cannot determine node height") - + + # this should only happen for non-root nodes with no keys (aka an empty intermediate node) + return None + # NOTE: a Node class cannot see what's below it. You'll need to track + # state externally (like NodeWrangler does) if you want to find out + + def definitely_height(self) -> int: + if self.maybe_height is None: + raise ValueError("indeterminate node height") + return self.maybe_height + def gte_index(self, key: str) -> int: """ find the index of the first key greater than or equal to the specified key diff --git a/src/atmst/mst/node_store.py b/src/atmst/mst/node_store.py index 2ec5d73..ebdbead 100644 --- a/src/atmst/mst/node_store.py +++ b/src/atmst/mst/node_store.py @@ -52,7 +52,7 @@ def pretty(self, node_cid: Optional[CID]) -> str: if node_cid is None: return "" node = self.get_node(node_cid) - res = f"MSTNode(\n{indent(self.pretty(node.subtrees[0]))},\n" + res = f"MSTNode(\n{indent(self.pretty(node.subtrees[0]))},\n" for k, v, t in zip(node.keys, node.vals, node.subtrees[1:]): res += f" {k!r} ({MSTNode.key_height(k)}) -> {v.encode('base32')},\n" res += indent(self.pretty(t)) + ",\n" diff --git a/src/atmst/mst/node_walker.py b/src/atmst/mst/node_walker.py index 4c9dcb8..ec05f9d 100644 --- a/src/atmst/mst/node_walker.py +++ b/src/atmst/mst/node_walker.py @@ -37,19 +37,38 @@ class StackFrame: ns: NodeStore stack: List[StackFrame] + root_height: int - def __init__(self, ns: NodeStore, root_cid: Optional[CID], lpath: Optional[str]=PATH_MIN, rpath: Optional[str]=PATH_MAX, trusted: Optional[bool]=False) -> None: + def __init__(self, + ns: NodeStore, + root_cid: Optional[CID], + lpath: str=PATH_MIN, + rpath: str=PATH_MAX, + trusted: bool=False, + root_height: Optional[int]=None + ) -> None: self.ns = ns self.trusted = trusted + node = MSTNode.empty_root() if root_cid is None else self.ns.get_node(root_cid) + self.root_height = node.maybe_height if root_height is None else root_height + if self.root_height is None: + raise ValueError("indeterminate node height - pass it in if you know it") self.stack = [self.StackFrame( - node=MSTNode.empty_root() if root_cid is None else self.ns.get_node(root_cid), + node=node, lpath=lpath, rpath=rpath, idx=0 )] def subtree_walker(self) -> "Self": - return NodeWalker(self.ns, self.subtree, self.lpath, self.rpath, self.trusted) + return NodeWalker( + self.ns, + self.subtree, + self.lpath, + self.rpath, + self.trusted, + root_height=self.height - 1 + ) @property def frame(self) -> StackFrame: @@ -57,7 +76,7 @@ def frame(self) -> StackFrame: @property def height(self) -> int: - return self.frame.node.height + return self.root_height - (len(self.stack) - 1) @property def lpath(self) -> str: @@ -91,7 +110,7 @@ def can_go_right(self) -> bool: def right_or_up(self) -> None: if not self.can_go_right: # we reached the end of this node, go up a level - self.stack.pop() + self.stack.pop() # TODO: check before pop - make empty-stack an unreachable state if not self.stack: raise StopIteration # you probably want to check .final instead of hitting this return self.right_or_up() # we need to recurse, to skip over empty intermediates on the way back up @@ -110,8 +129,9 @@ def down(self) -> None: subtree_node = self.ns.get_node(subtree) if not self.trusted: # if we "trust" the source we can elide this check - if subtree_node.height != self.height - 1: - raise ValueError("inconsistent subtree height") + # the "None" case occurs for empty intermediate nodes + if subtree_node.maybe_height is not None and subtree_node.maybe_height != self.height - 1: + raise ValueError(f"inconsistent subtree height ({subtree_node.maybe_height}, expected {self.height - 1})") self.stack.append(self.StackFrame( node=subtree_node, diff --git a/src/atmst/mst/node_wrangler.py b/src/atmst/mst/node_wrangler.py index fcd354a..906abf1 100644 --- a/src/atmst/mst/node_wrangler.py +++ b/src/atmst/mst/node_wrangler.py @@ -37,7 +37,7 @@ def put_record(self, root_cid: CID, key: str, val: CID) -> CID: root = self.ns.get_node(root_cid) if root.is_empty(): # special case for empty tree return self._put_here(root, key, val).cid - return self._put_recursive(root, key, val, MSTNode.key_height(key), root.height).cid + return self._put_recursive(root, key, val, MSTNode.key_height(key), root.definitely_height()).cid def del_record(self, root_cid: CID, key: str) -> CID: root = self.ns.get_node(root_cid) @@ -45,7 +45,11 @@ def del_record(self, root_cid: CID, key: str) -> CID: # Note: the seemingly redundant outer .get().cid is required to transform # a None cid into the cid representing an empty node (we could maybe find a more elegant # way of doing this...) - return self.ns.get_node(self._squash_top(self._delete_recursive(root, key, MSTNode.key_height(key), root.height))).cid + return self.ns.get_node(self._squash_top( + self._delete_recursive( + root, key, MSTNode.key_height(key), root.definitely_height() + ) + )).cid diff --git a/tests/test_mst_diff.py b/tests/test_mst_diff.py index 31f71bf..e5d8545 100644 --- a/tests/test_mst_diff.py +++ b/tests/test_mst_diff.py @@ -43,8 +43,11 @@ def setUp(self): self.trees.append(root) def test_diff_all_pairs(self): - for a in self.trees: - for b in self.trees: + for ai, a in enumerate(self.trees): + for bi, b in enumerate(self.trees): + #print(ai, bi) + #print(self.ns.pretty(a)) + #print(self.ns.pretty(b)) reference_created, reference_deleted = very_slow_mst_diff(self.ns, a, b) created, deleted = mst_diff(self.ns, a, b) self.assertEqual(created, reference_created)