Skip to content

Commit

Permalink
fix tree height checks, make NodeWalker track heights properly
Browse files Browse the repository at this point in the history
  • Loading branch information
DavidBuchanan314 committed Dec 14, 2024
1 parent f2c0fd7 commit 613fb3c
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 19 deletions.
21 changes: 14 additions & 7 deletions src/atmst/mst/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def deserialise(cls, data: bytes) -> "Self":
keys=tuple(keys),
vals=tuple(vals)
)

def is_empty(self) -> bool:
return self.subtrees == (None,)

Expand All @@ -128,18 +128,25 @@ def _to_optional(self) -> Optional[CID]:


@cached_property
def height(self) -> int:
def maybe_height(self) -> Optional[int]:
# if there are keys at this level, query one directly
if self.keys:
return self.key_height(self.keys[0])

# we're an empty tree
if self.subtrees[0] is None:
return 0

# this should only happen for non-root nodes with no keys
raise Exception("cannot determine node height")


# this should only happen for non-root nodes with no keys (aka an empty intermediate node)
return None
# NOTE: a Node class cannot see what's below it. You'll need to track
# state externally (like NodeWrangler does) if you want to find out

def definitely_height(self) -> int:
if self.maybe_height is None:
raise ValueError("indeterminate node height")
return self.maybe_height

def gte_index(self, key: str) -> int:
"""
find the index of the first key greater than or equal to the specified key
Expand Down
2 changes: 1 addition & 1 deletion src/atmst/mst/node_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def pretty(self, node_cid: Optional[CID]) -> str:
if node_cid is None:
return "<empty>"
node = self.get_node(node_cid)
res = f"MSTNode<cid={node.cid.encode('base32')}>(\n{indent(self.pretty(node.subtrees[0]))},\n"
res = f"MSTNode<cid={node.cid.encode('base32')}, maybe_height={node.maybe_height}>(\n{indent(self.pretty(node.subtrees[0]))},\n"
for k, v, t in zip(node.keys, node.vals, node.subtrees[1:]):
res += f" {k!r} ({MSTNode.key_height(k)}) -> {v.encode('base32')},\n"
res += indent(self.pretty(t)) + ",\n"
Expand Down
34 changes: 27 additions & 7 deletions src/atmst/mst/node_walker.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,27 +37,46 @@ class StackFrame:

ns: NodeStore
stack: List[StackFrame]
root_height: int

def __init__(self, ns: NodeStore, root_cid: Optional[CID], lpath: Optional[str]=PATH_MIN, rpath: Optional[str]=PATH_MAX, trusted: Optional[bool]=False) -> None:
def __init__(self,
ns: NodeStore,
root_cid: Optional[CID],
lpath: str=PATH_MIN,
rpath: str=PATH_MAX,
trusted: bool=False,
root_height: Optional[int]=None
) -> None:
self.ns = ns
self.trusted = trusted
node = MSTNode.empty_root() if root_cid is None else self.ns.get_node(root_cid)
self.root_height = node.maybe_height if root_height is None else root_height
if self.root_height is None:
raise ValueError("indeterminate node height - pass it in if you know it")
self.stack = [self.StackFrame(
node=MSTNode.empty_root() if root_cid is None else self.ns.get_node(root_cid),
node=node,
lpath=lpath,
rpath=rpath,
idx=0
)]

def subtree_walker(self) -> "Self":
return NodeWalker(self.ns, self.subtree, self.lpath, self.rpath, self.trusted)
return NodeWalker(
self.ns,
self.subtree,
self.lpath,
self.rpath,
self.trusted,
root_height=self.height - 1
)

@property
def frame(self) -> StackFrame:
return self.stack[-1]

@property
def height(self) -> int:
return self.frame.node.height
return self.root_height - (len(self.stack) - 1)

@property
def lpath(self) -> str:
Expand Down Expand Up @@ -91,7 +110,7 @@ def can_go_right(self) -> bool:
def right_or_up(self) -> None:
if not self.can_go_right:
# we reached the end of this node, go up a level
self.stack.pop()
self.stack.pop() # TODO: check before pop - make empty-stack an unreachable state
if not self.stack:
raise StopIteration # you probably want to check .final instead of hitting this
return self.right_or_up() # we need to recurse, to skip over empty intermediates on the way back up
Expand All @@ -110,8 +129,9 @@ def down(self) -> None:
subtree_node = self.ns.get_node(subtree)

if not self.trusted: # if we "trust" the source we can elide this check
if subtree_node.height != self.height - 1:
raise ValueError("inconsistent subtree height")
# the "None" case occurs for empty intermediate nodes
if subtree_node.maybe_height is not None and subtree_node.maybe_height != self.height - 1:
raise ValueError(f"inconsistent subtree height ({subtree_node.maybe_height}, expected {self.height - 1})")

self.stack.append(self.StackFrame(
node=subtree_node,
Expand Down
8 changes: 6 additions & 2 deletions src/atmst/mst/node_wrangler.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,19 @@ def put_record(self, root_cid: CID, key: str, val: CID) -> CID:
root = self.ns.get_node(root_cid)
if root.is_empty(): # special case for empty tree
return self._put_here(root, key, val).cid
return self._put_recursive(root, key, val, MSTNode.key_height(key), root.height).cid
return self._put_recursive(root, key, val, MSTNode.key_height(key), root.definitely_height()).cid

def del_record(self, root_cid: CID, key: str) -> CID:
root = self.ns.get_node(root_cid)

# Note: the seemingly redundant outer .get().cid is required to transform
# a None cid into the cid representing an empty node (we could maybe find a more elegant
# way of doing this...)
return self.ns.get_node(self._squash_top(self._delete_recursive(root, key, MSTNode.key_height(key), root.height))).cid
return self.ns.get_node(self._squash_top(
self._delete_recursive(
root, key, MSTNode.key_height(key), root.definitely_height()
)
)).cid



Expand Down
7 changes: 5 additions & 2 deletions tests/test_mst_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,11 @@ def setUp(self):
self.trees.append(root)

def test_diff_all_pairs(self):
for a in self.trees:
for b in self.trees:
for ai, a in enumerate(self.trees):
for bi, b in enumerate(self.trees):
#print(ai, bi)
#print(self.ns.pretty(a))
#print(self.ns.pretty(b))
reference_created, reference_deleted = very_slow_mst_diff(self.ns, a, b)
created, deleted = mst_diff(self.ns, a, b)
self.assertEqual(created, reference_created)
Expand Down

0 comments on commit 613fb3c

Please sign in to comment.