diff --git a/packages/imt.sol/contracts/LazyIMT.sol b/packages/imt.sol/contracts/LazyIMT.sol index 660612f3f..82dce538e 100644 --- a/packages/imt.sol/contracts/LazyIMT.sol +++ b/packages/imt.sol/contracts/LazyIMT.sol @@ -78,29 +78,24 @@ library LazyIMT { } function root(LazyIMTData storage self) public view returns (uint256) { - // this will always short circuit if self.numberOfLeaves == 0 uint40 numberOfLeaves = self.numberOfLeaves; // dynamically determine a depth uint8 depth = 1; - while (uint8(2)**depth < numberOfLeaves) { + while (uint8(2) ** depth < numberOfLeaves) { depth++; } return _root(self, numberOfLeaves, depth); } - function root(LazyTreeData storage self, uint8 depth) public view returns (uint256) { + function root(LazyIMTData storage self, uint8 depth) public view returns (uint256) { uint40 numberOfLeaves = self.numberOfLeaves; - require(2**depth >= numberOfLeaves, "LazyMerkleTree: ambiguous depth"); + require(2 ** depth >= numberOfLeaves, "LazyIMT: ambiguous depth"); return _root(self, self.numberOfLeaves, depth); } - function _root( - LazyTreeData storage self, - uint40 numberOfLeaves, - uint8 depth - ) internal view returns (uint256) { - require(depth > 0, "LazyMerkleTree: depth must be > 0"); - require(depth <= MAX_DEPTH, "LazyMerkleTree: depth must be < MAX_DEPTH"); + function _root(LazyIMTData storage self, uint40 numberOfLeaves, uint8 depth) internal view returns (uint256) { + require(depth > 0, "LazyIMT: depth must be > 0"); + require(depth <= MAX_DEPTH, "LazyIMT: depth must be < MAX_DEPTH"); // this should always short circuit if self.numberOfLeaves == 0 if (numberOfLeaves == 0) return defaultZero(depth); uint40 index = numberOfLeaves - 1; diff --git a/packages/imt.sol/contracts/test/LazyIMTTest.sol b/packages/imt.sol/contracts/test/LazyIMTTest.sol index ad454a1aa..b1fc4b998 100644 --- a/packages/imt.sol/contracts/test/LazyIMTTest.sol +++ b/packages/imt.sol/contracts/test/LazyIMTTest.sol @@ -6,6 +6,7 @@ import {LazyIMT, LazyIMTData} from "../LazyIMT.sol"; contract LazyIMTTest { LazyIMTData public data; + uint256 _root; function init(uint8 depth) public { LazyIMT.init(data, depth); @@ -23,7 +24,20 @@ contract LazyIMTTest { LazyIMT.update(data, leaf, index); } + // for benchmarking the root cost + function benchmarkRoot() public { + _root = LazyIMT.root(data); + } + function root() public view returns (uint256) { return LazyIMT.root(data); } + + function dynamicRoot(uint8 depth) public view returns (uint256) { + return LazyIMT.root(data, depth); + } + + function staticRoot(uint8 depth) public view returns (uint256) { + return LazyIMT.root(data, depth); + } } diff --git a/packages/imt.sol/test/LazyIMT.ts b/packages/imt.sol/test/LazyIMT.ts index 4fb108d50..147b43ea4 100644 --- a/packages/imt.sol/test/LazyIMT.ts +++ b/packages/imt.sol/test/LazyIMT.ts @@ -79,22 +79,36 @@ describe("LazyIMT", () => { elements.push(e) // construct the tree - const targetDepth = Math.max(1, Math.ceil(Math.log2(elements.length))) - const merkleTree = new IMT(poseidon2, targetDepth, BigInt(0)) - - for (const _e of elements) { - merkleTree.insert(_e) + { + const targetDepth = Math.max(1, Math.ceil(Math.log2(elements.length))) + const merkleTree = new IMT(poseidon2, targetDepth, BigInt(0)) + for (const _e of elements) { + merkleTree.insert(_e) + } + await lazyIMTTest.insert(e) + await lazyIMTTest.benchmarkRoot().then((t) => t.wait()) + { + const root = await lazyIMTTest.root() + expect(root.toString()).to.equal(merkleTree.root.toString()) + } + { + const root = await lazyIMTTest.dynamicRoot(targetDepth) + expect(root.toString()).to.equal(merkleTree.root.toString()) + } } - await lazyIMTTest.insert(e) - - const root = await lazyIMTTest.root() - - expect(root.toString()).to.equal(merkleTree.root.toString()) - const treeData = await lazyIMTTest.data() expect(treeData.numberOfLeaves).to.equal(elements.length) + + for (let y = depth; y < 12; y += 1) { + const merkleTree = new IMT(poseidon2, y, BigInt(0)) + for (const _e of elements) { + merkleTree.insert(_e) + } + const root = await lazyIMTTest.staticRoot(y) + expect(root.toString()).to.equal(merkleTree.root.toString()) + } } }) } @@ -147,6 +161,8 @@ describe("LazyIMT", () => { await lazyIMTTest.insert(e) + await lazyIMTTest.benchmarkRoot().then((t) => t.wait()) + const root = await lazyIMTTest.root() expect(root.toString()).to.equal(merkleTree.root.toString()) @@ -256,4 +272,17 @@ describe("LazyIMT", () => { } }) }) + + it("Should fail to generate out of range static root", async () => { + await lazyIMTTest.init(10) + + const elements = [] + for (let x = 0; x < 20; x += 1) { + const e = random() + elements.push(e) + await lazyIMTTest.insert(e) + } + await expect(lazyIMTTest.staticRoot(4)).to.be.revertedWith("LazyIMT: ambiguous depth") + await expect(lazyIMTTest.staticRoot(33)).to.be.revertedWith("LazyIMT: depth must be < MAX_DEPTH") + }) })