From 8138bde87493a0959315cda2cfd81cb1e637cfe8 Mon Sep 17 00:00:00 2001 From: David Buchanan Date: Mon, 21 Oct 2024 01:18:37 +0100 Subject: [PATCH] fix _put_recursive tree grow case --- src/atmst/mst/node_wrangler.py | 6 +++++- tests/test_mst_diff.py | 13 +++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/src/atmst/mst/node_wrangler.py b/src/atmst/mst/node_wrangler.py index 2a13031..fcd354a 100644 --- a/src/atmst/mst/node_wrangler.py +++ b/src/atmst/mst/node_wrangler.py @@ -73,7 +73,11 @@ def _put_here(self, node: MSTNode, key: str, val: CID) -> MSTNode: def _put_recursive(self, node: MSTNode, key: str, val: CID, key_height: int, tree_height: int) -> MSTNode: if key_height > tree_height: # we need to grow the tree return self.ns.stored_node(self._put_recursive( - MSTNode.empty_root(), + self.ns.stored_node(MSTNode( + keys=(), + vals=(), + subtrees=(node.cid,) + )), key, val, key_height, tree_height + 1 )) diff --git a/tests/test_mst_diff.py b/tests/test_mst_diff.py index 83658e8..7b5a58d 100644 --- a/tests/test_mst_diff.py +++ b/tests/test_mst_diff.py @@ -5,6 +5,13 @@ from atmst.mst.node import MSTNode from cbrrr import CID +def dump_mst(ns: NodeStore, cid: CID, lvl=0): + node = ns.get_node(cid) + print(" "*lvl + "-", node) + for subtree in node.subtrees: + if subtree: + dump_mst(ns, subtree, lvl+1) + class MSTDiffTestCase(unittest.TestCase): def setUp(self): keys = [] @@ -60,6 +67,12 @@ def test_insertion_order_independent(self): for k in keys: mst_c = wrangler.put_record(mst_c, k, CID.cidv1_dag_cbor_sha256_32_from(k.encode())) + #print() + #dump_mst(self.ns, mst_a) + + #print() + #dump_mst(self.ns, mst_b) + self.assertEqual(mst_a, mst_b) self.assertEqual(mst_a, mst_c)