From de86aabea5f049707b382bc245dbd67c0618a576 Mon Sep 17 00:00:00 2001 From: cedoor Date: Thu, 14 Dec 2023 16:03:53 +0100 Subject: [PATCH] fix(imt.sol): update correct side nodes --- .../contracts/internal/InternalLeanIMT.sol | 63 ++++++++----------- packages/imt.sol/test/LeanIMT.ts | 52 ++++++++++----- 2 files changed, 61 insertions(+), 54 deletions(-) diff --git a/packages/imt.sol/contracts/internal/InternalLeanIMT.sol b/packages/imt.sol/contracts/internal/InternalLeanIMT.sol index 063e7e685..1524c975c 100644 --- a/packages/imt.sol/contracts/internal/InternalLeanIMT.sol +++ b/packages/imt.sol/contracts/internal/InternalLeanIMT.sol @@ -4,6 +4,8 @@ pragma solidity ^0.8.4; import {PoseidonT3} from "poseidon-solidity/PoseidonT3.sol"; import {SNARK_SCALAR_FIELD} from "../Constants.sol"; +import "hardhat/console.sol"; + struct LeanIMTData { // Tracks the current number of leaves in the tree. uint256 size; @@ -92,61 +94,48 @@ library InternalLeanIMT { revert LeafDoesNotExist(); } else if (newLeaf != 0 && _has(self, newLeaf)) { revert LeafAlreadyExists(); + } else if (siblingNodes.length == 0) { + revert WrongSiblingNodes(); } uint256 index = _indexOf(self, oldLeaf); uint256 node = newLeaf; uint256 oldRoot = oldLeaf; - // A counter that adjusts the level at which sibling nodes are - // accessed and updated during the tree's update process. - // It ensures that the update function correctly navigates and - // modifies the tree's nodes at the appropriate levels, accounting - // for situations where not every level of the tree requires an - // update or a hash calculation. - uint256 s = 0; - - // The number of siblings of a proof can be less than - // the depth of the tree, because in some levels it might not - // be necessary to hash any value. - for (uint256 i = 0; i < siblingNodes.length; ) { - if (siblingNodes[i] >= SNARK_SCALAR_FIELD) { - revert LeafGreaterThanSnarkScalarField(); - } + uint256 lastIndex = self.size - 1; + uint256 i = 0; - uint256 level = i + s; + for (uint256 level = 0; level < self.depth; ) { + if ((index >> level) & 1 == 1) { + if (siblingNodes[i] >= SNARK_SCALAR_FIELD) { + revert LeafGreaterThanSnarkScalarField(); + } - if (oldRoot == self.sideNodes[level]) { - self.sideNodes[level] = node; + node = PoseidonT3.hash([siblingNodes[i], node]); + oldRoot = PoseidonT3.hash([siblingNodes[i], oldRoot]); - if (oldRoot == self.sideNodes[level + 1]) { - s += 1; + unchecked { + ++i; } + } else { + if (index >> level != lastIndex >> level) { + if (siblingNodes[i] >= SNARK_SCALAR_FIELD) { + revert LeafGreaterThanSnarkScalarField(); + } - uint256 j = 0; - - while (oldRoot == self.sideNodes[level + j + 1]) { - self.sideNodes[level + j + 1] = node; + node = PoseidonT3.hash([node, siblingNodes[i]]); + oldRoot = PoseidonT3.hash([oldRoot, siblingNodes[i]]); unchecked { - ++s; - ++j; + ++i; } + } else { + self.sideNodes[i] = node; } - - level = i + s; - } - - if ((index >> level) & 1 != 0) { - node = PoseidonT3.hash([siblingNodes[i], node]); - oldRoot = PoseidonT3.hash([siblingNodes[i], oldRoot]); - } else { - node = PoseidonT3.hash([node, siblingNodes[i]]); - oldRoot = PoseidonT3.hash([oldRoot, siblingNodes[i]]); } unchecked { - ++i; + ++level; } } diff --git a/packages/imt.sol/test/LeanIMT.ts b/packages/imt.sol/test/LeanIMT.ts index 80c49370d..ca30bb24b 100644 --- a/packages/imt.sol/test/LeanIMT.ts +++ b/packages/imt.sol/test/LeanIMT.ts @@ -77,7 +77,7 @@ describe("LeanIMT", () => { await expect(transaction).to.be.revertedWithCustomError(leanIMT, "LeafGreaterThanSnarkScalarField") }) - it("Should update a leaf", async () => { + it("Should not update a leaf if there are no sibling nodes", async () => { await leanIMTTest.insert(1) jsLeanIMT.insert(BigInt(1)) @@ -85,33 +85,49 @@ describe("LeanIMT", () => { const { siblings } = jsLeanIMT.generateProof(0) - await leanIMTTest.update(1, 2, siblings) + const transaction = leanIMTTest.update(1, 2, siblings) + + await expect(transaction).to.be.revertedWithCustomError(leanIMT, "WrongSiblingNodes") + }) + + it("Should update a leaf", async () => { + await leanIMTTest.insert(1) + await leanIMTTest.insert(2) + + jsLeanIMT.insertMany([BigInt(1), BigInt(2)]) + jsLeanIMT.update(0, BigInt(3)) + + const { siblings } = jsLeanIMT.generateProof(0) + + await leanIMTTest.update(1, 3, siblings) const root = await leanIMTTest.root() expect(root).to.equal(jsLeanIMT.root) }) - it("Should not update a leaf if the value of at least one leaf is > SNARK_SCALAR_FIELD", async () => { + it("Should not update a leaf if the value of at least one sibling node is > SNARK_SCALAR_FIELD", async () => { await leanIMTTest.insert(1) + await leanIMTTest.insert(2) - jsLeanIMT.insert(BigInt(1)) - jsLeanIMT.update(0, BigInt(2)) + jsLeanIMT.insertMany([BigInt(1), BigInt(2)]) + jsLeanIMT.update(0, BigInt(3)) const { siblings } = jsLeanIMT.generateProof(0) siblings[0] = SNARK_SCALAR_FIELD - const transaction = leanIMTTest.update(1, 2, siblings) + const transaction = leanIMTTest.update(1, 3, siblings) await expect(transaction).to.be.revertedWithCustomError(leanIMT, "LeafGreaterThanSnarkScalarField") }) it("Should not update a leaf if the siblings are wrong", async () => { await leanIMTTest.insert(1) + await leanIMTTest.insert(2) - jsLeanIMT.insert(BigInt(1)) - jsLeanIMT.update(0, BigInt(2)) + jsLeanIMT.insertMany([BigInt(1), BigInt(2)]) + jsLeanIMT.update(0, BigInt(3)) const { siblings } = jsLeanIMT.generateProof(0) @@ -122,19 +138,19 @@ describe("LeanIMT", () => { await expect(transaction).to.be.revertedWithCustomError(leanIMT, "WrongSiblingNodes") }) - it("Should update 10 leaves", async () => { - for (let i = 0; i < 10; i += 1) { + it("Should update 6 leaves", async () => { + for (let i = 0; i < 6; i += 1) { jsLeanIMT.insert(BigInt(i + 1)) await leanIMTTest.insert(i + 1) } - for (let i = 0; i < 10; i += 1) { - jsLeanIMT.update(i, BigInt(i + 11)) + for (let i = 0; i < 6; i += 1) { + jsLeanIMT.update(i, BigInt(i + 7)) const { siblings } = jsLeanIMT.generateProof(i) - await leanIMTTest.update(i + 1, i + 11, siblings) + await leanIMTTest.update(i + 1, i + 7, siblings) const root = await leanIMTTest.root() @@ -146,13 +162,15 @@ describe("LeanIMT", () => { describe("# remove", () => { it("Should remove a leaf", async () => { await leanIMTTest.insert(1) + await leanIMTTest.insert(2) + await leanIMTTest.insert(3) - jsLeanIMT.insert(BigInt(1)) - jsLeanIMT.update(0, BigInt(0)) + jsLeanIMT.insertMany([BigInt(1), BigInt(2), BigInt(3)]) + jsLeanIMT.update(2, BigInt(0)) - const { siblings } = jsLeanIMT.generateProof(0) + const { siblings } = jsLeanIMT.generateProof(2) - await leanIMTTest.remove(1, siblings) + await leanIMTTest.remove(3, siblings) const root = await leanIMTTest.root()