Skip to content

Commit

Permalink
fix(imt.sol): update correct side nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
cedoor committed Dec 14, 2023
1 parent 039b58c commit de86aab
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 54 deletions.
63 changes: 26 additions & 37 deletions packages/imt.sol/contracts/internal/InternalLeanIMT.sol
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ pragma solidity ^0.8.4;
import {PoseidonT3} from "poseidon-solidity/PoseidonT3.sol";
import {SNARK_SCALAR_FIELD} from "../Constants.sol";

import "hardhat/console.sol";

struct LeanIMTData {
// Tracks the current number of leaves in the tree.
uint256 size;
Expand Down Expand Up @@ -92,61 +94,48 @@ library InternalLeanIMT {
revert LeafDoesNotExist();
} else if (newLeaf != 0 && _has(self, newLeaf)) {
revert LeafAlreadyExists();
} else if (siblingNodes.length == 0) {
revert WrongSiblingNodes();
}

uint256 index = _indexOf(self, oldLeaf);
uint256 node = newLeaf;
uint256 oldRoot = oldLeaf;

// A counter that adjusts the level at which sibling nodes are
// accessed and updated during the tree's update process.
// It ensures that the update function correctly navigates and
// modifies the tree's nodes at the appropriate levels, accounting
// for situations where not every level of the tree requires an
// update or a hash calculation.
uint256 s = 0;

// The number of siblings of a proof can be less than
// the depth of the tree, because in some levels it might not
// be necessary to hash any value.
for (uint256 i = 0; i < siblingNodes.length; ) {
if (siblingNodes[i] >= SNARK_SCALAR_FIELD) {
revert LeafGreaterThanSnarkScalarField();
}
uint256 lastIndex = self.size - 1;
uint256 i = 0;

uint256 level = i + s;
for (uint256 level = 0; level < self.depth; ) {
if ((index >> level) & 1 == 1) {
if (siblingNodes[i] >= SNARK_SCALAR_FIELD) {
revert LeafGreaterThanSnarkScalarField();
}

if (oldRoot == self.sideNodes[level]) {
self.sideNodes[level] = node;
node = PoseidonT3.hash([siblingNodes[i], node]);
oldRoot = PoseidonT3.hash([siblingNodes[i], oldRoot]);

if (oldRoot == self.sideNodes[level + 1]) {
s += 1;
unchecked {
++i;
}
} else {
if (index >> level != lastIndex >> level) {
if (siblingNodes[i] >= SNARK_SCALAR_FIELD) {
revert LeafGreaterThanSnarkScalarField();
}

uint256 j = 0;

while (oldRoot == self.sideNodes[level + j + 1]) {
self.sideNodes[level + j + 1] = node;
node = PoseidonT3.hash([node, siblingNodes[i]]);
oldRoot = PoseidonT3.hash([oldRoot, siblingNodes[i]]);

unchecked {
++s;
++j;
++i;
}
} else {
self.sideNodes[i] = node;
}

level = i + s;
}

if ((index >> level) & 1 != 0) {
node = PoseidonT3.hash([siblingNodes[i], node]);
oldRoot = PoseidonT3.hash([siblingNodes[i], oldRoot]);
} else {
node = PoseidonT3.hash([node, siblingNodes[i]]);
oldRoot = PoseidonT3.hash([oldRoot, siblingNodes[i]]);
}

unchecked {
++i;
++level;
}
}

Expand Down
52 changes: 35 additions & 17 deletions packages/imt.sol/test/LeanIMT.ts
Original file line number Diff line number Diff line change
Expand Up @@ -77,41 +77,57 @@ describe("LeanIMT", () => {
await expect(transaction).to.be.revertedWithCustomError(leanIMT, "LeafGreaterThanSnarkScalarField")
})

it("Should update a leaf", async () => {
it("Should not update a leaf if there are no sibling nodes", async () => {
await leanIMTTest.insert(1)

jsLeanIMT.insert(BigInt(1))
jsLeanIMT.update(0, BigInt(2))

const { siblings } = jsLeanIMT.generateProof(0)

await leanIMTTest.update(1, 2, siblings)
const transaction = leanIMTTest.update(1, 2, siblings)

await expect(transaction).to.be.revertedWithCustomError(leanIMT, "WrongSiblingNodes")
})

it("Should update a leaf", async () => {
await leanIMTTest.insert(1)
await leanIMTTest.insert(2)

jsLeanIMT.insertMany([BigInt(1), BigInt(2)])
jsLeanIMT.update(0, BigInt(3))

const { siblings } = jsLeanIMT.generateProof(0)

await leanIMTTest.update(1, 3, siblings)

const root = await leanIMTTest.root()

expect(root).to.equal(jsLeanIMT.root)
})

it("Should not update a leaf if the value of at least one leaf is > SNARK_SCALAR_FIELD", async () => {
it("Should not update a leaf if the value of at least one sibling node is > SNARK_SCALAR_FIELD", async () => {
await leanIMTTest.insert(1)
await leanIMTTest.insert(2)

jsLeanIMT.insert(BigInt(1))
jsLeanIMT.update(0, BigInt(2))
jsLeanIMT.insertMany([BigInt(1), BigInt(2)])
jsLeanIMT.update(0, BigInt(3))

const { siblings } = jsLeanIMT.generateProof(0)

siblings[0] = SNARK_SCALAR_FIELD

const transaction = leanIMTTest.update(1, 2, siblings)
const transaction = leanIMTTest.update(1, 3, siblings)

await expect(transaction).to.be.revertedWithCustomError(leanIMT, "LeafGreaterThanSnarkScalarField")
})

it("Should not update a leaf if the siblings are wrong", async () => {
await leanIMTTest.insert(1)
await leanIMTTest.insert(2)

jsLeanIMT.insert(BigInt(1))
jsLeanIMT.update(0, BigInt(2))
jsLeanIMT.insertMany([BigInt(1), BigInt(2)])
jsLeanIMT.update(0, BigInt(3))

const { siblings } = jsLeanIMT.generateProof(0)

Expand All @@ -122,19 +138,19 @@ describe("LeanIMT", () => {
await expect(transaction).to.be.revertedWithCustomError(leanIMT, "WrongSiblingNodes")
})

it("Should update 10 leaves", async () => {
for (let i = 0; i < 10; i += 1) {
it("Should update 6 leaves", async () => {
for (let i = 0; i < 6; i += 1) {
jsLeanIMT.insert(BigInt(i + 1))

await leanIMTTest.insert(i + 1)
}

for (let i = 0; i < 10; i += 1) {
jsLeanIMT.update(i, BigInt(i + 11))
for (let i = 0; i < 6; i += 1) {
jsLeanIMT.update(i, BigInt(i + 7))

const { siblings } = jsLeanIMT.generateProof(i)

await leanIMTTest.update(i + 1, i + 11, siblings)
await leanIMTTest.update(i + 1, i + 7, siblings)

const root = await leanIMTTest.root()

Expand All @@ -146,13 +162,15 @@ describe("LeanIMT", () => {
describe("# remove", () => {
it("Should remove a leaf", async () => {
await leanIMTTest.insert(1)
await leanIMTTest.insert(2)
await leanIMTTest.insert(3)

jsLeanIMT.insert(BigInt(1))
jsLeanIMT.update(0, BigInt(0))
jsLeanIMT.insertMany([BigInt(1), BigInt(2), BigInt(3)])
jsLeanIMT.update(2, BigInt(0))

const { siblings } = jsLeanIMT.generateProof(0)
const { siblings } = jsLeanIMT.generateProof(2)

await leanIMTTest.remove(1, siblings)
await leanIMTTest.remove(3, siblings)

const root = await leanIMTTest.root()

Expand Down

0 comments on commit de86aab

Please sign in to comment.