From 9c6b839284e5d980ca8d65197c037649b7d58bba Mon Sep 17 00:00:00 2001 From: KimiWu Date: Wed, 18 Oct 2023 16:18:12 +0800 Subject: [PATCH] feat: impl. sig_circuit --- specs/sig-proof.md | 2 +- specs/tables.md | 2 +- src/zkevm_specs/sig_circuit.py | 120 +++++++++++++++++++++++++ src/zkevm_specs/util/__init__.py | 2 + src/zkevm_specs/util/ec.py | 113 ++++++++++++++++++++++++ src/zkevm_specs/util/tables.py | 33 +++++++ tests/test_sig_circuit.py | 146 +++++++++++++++++++++++++++++++ 7 files changed, 416 insertions(+), 2 deletions(-) create mode 100644 src/zkevm_specs/sig_circuit.py create mode 100644 src/zkevm_specs/util/ec.py create mode 100644 src/zkevm_specs/util/tables.py create mode 100644 tests/test_sig_circuit.py diff --git a/specs/sig-proof.md b/specs/sig-proof.md index 5e1b7beb1..dc5cf8f05 100644 --- a/specs/sig-proof.md +++ b/specs/sig-proof.md @@ -23,7 +23,7 @@ Constraints on the shape of the table is like: | 0 msg_hash | 1 sig_v | 2 sig_r | 3 sig_s | 4 recovered_addr | 5 is_valid | | ------------- | ------ | ------------- | ------------- | ---------------- | ---------- | -| $value{Lo,Hi} | bool | $value{Lo,Hi} | $value{Lo,Hi} | $value{Lo,Hi} | bool | +| $value{Lo,Hi} | 0/1 | $value{Lo,Hi} | $value{Lo,Hi} | $value{Lo,Hi} | bool | The Sig Circuit aims at proving the correctness of SigTable. This mainly includes the following type of constraints: diff --git a/specs/tables.md b/specs/tables.md index a7f6334e6..677d68a8f 100644 --- a/specs/tables.md +++ b/specs/tables.md @@ -375,7 +375,7 @@ The circuit verifies the correctness of signatures. | 0 msg_hash | 1 sig_v | 2 sig_r | 3 sig_s | 4 recovered_addr | 5 is_valid | | ------------- | ------ | ------------- | ------------- | ---------------- | ---------- | -| $value{Lo,Hi} | bool | $value{Lo,Hi} | $value{Lo,Hi} | $value{Lo,Hi} | bool | +| $value{Lo,Hi} | 0/1 | $value{Lo,Hi} | $value{Lo,Hi} | $value{Lo,Hi} | bool | NOTE: - `sig_v` is either 0 or 1 so boolean type is used here. diff --git a/src/zkevm_specs/sig_circuit.py b/src/zkevm_specs/sig_circuit.py new file mode 100644 index 000000000..73e217f22 --- /dev/null +++ b/src/zkevm_specs/sig_circuit.py @@ -0,0 +1,120 @@ +from typing import List, NamedTuple +from .util import FQ, RLC, Word, linear_combine_bytes, ECDSAVerifyChip, KeccakTable, is_circuit_code +from eth_keys import KeyAPI # type: ignore +from eth_utils import keccak + + +class Row: + """ + Signature circuit + Verify a message hash is signed by an Ethereum Address. + """ + + msg_hash: Word + sig_v: FQ + sig_r: Word + sig_s: Word + recovered_addr: FQ + is_valid: FQ + + ecdsa_chip: ECDSAVerifyChip + pub_key_hash: bytes + pub_key_x_bytes: bytes + pub_key_y_bytes: bytes + msg_hash_bytes: bytes + + def __init__( + self, + pub_key_hash: bytes, + address: FQ, + msg_hash: Word, + ecdsa_chip: ECDSAVerifyChip, + is_valid: bool = True, + ) -> None: + self.ecdsa_chip = ecdsa_chip + self.pub_key_x_bytes = ecdsa_chip.pub_key_x_bytes + self.pub_key_y_bytes = ecdsa_chip.pub_key_y_bytes + self.msg_hash_bytes = ecdsa_chip.msg_hash_bytes + + # table + self.msg_hash = msg_hash + self.sig_v = FQ(int.from_bytes(self.ecdsa_chip.sig_v.le_bytes, "little")) + self.sig_r = Word(int.from_bytes(self.ecdsa_chip.sig_r.le_bytes, "little")) + self.sig_s = Word(int.from_bytes(self.ecdsa_chip.sig_s.le_bytes, "little")) + self.recovered_addr = address + self.is_valid = is_valid + + self.pub_key_hash = pub_key_hash + + @classmethod + def assign( + cls, + signature: KeyAPI.Signature, + pub_key: KeyAPI.PublicKey, + msg_hash: bytes, + is_valid: bool = True, + ): + pub_key_hash = keccak(pub_key.to_bytes()) + self_pub_key_hash = pub_key_hash + self_address = FQ(int.from_bytes(pub_key_hash[-20:], "big")) + self_msg_hash = Word(int.from_bytes(msg_hash, "big")) + self_ecdsa_chip = ECDSAVerifyChip.assign(signature, pub_key, msg_hash) + return cls(self_pub_key_hash, self_address, self_msg_hash, self_ecdsa_chip, is_valid) + + def verify(self, keccak_table: KeccakTable, keccak_randomness: FQ, assert_msg: str): + # 0. Copy constraints between pub_key, msg_hash and signature of this chip + # and the ones in ECDSA chip + assert self.pub_key_x_bytes == self.ecdsa_chip.pub_key_x_bytes + assert self.pub_key_y_bytes == self.ecdsa_chip.pub_key_y_bytes + assert self.msg_hash_bytes == self.ecdsa_chip.msg_hash_bytes + assert self.sig_r.int_value() == int.from_bytes(self.ecdsa_chip.sig_r.le_bytes, "little") + assert self.sig_s.int_value() == int.from_bytes(self.ecdsa_chip.sig_s.le_bytes, "little") + + # 1. Constrain v to be equal 0 or 1 + assert self.sig_v == 0 or self.sig_v == 1 + + # 2. Verify that keccak(pub_key_bytes) = pub_key_hash by keccak table + # lookup, where pub_key_bytes is built from the pub_key in the + # ecdsa_chip + pub_key_bytes = self.pub_key_x_bytes + self.pub_key_y_bytes + keccak_table.lookup( + True, + RLC(pub_key_bytes, keccak_randomness, n_bytes=64).expr(), + FQ(64), + Word(self.pub_key_hash), + assert_msg, + ) + + # 3. Verify that the first 20 bytes of the pub_key_hash equals `recovered_addr` + addr_expr = linear_combine_bytes(list(reversed(self.pub_key_hash[-20:])), FQ(2**8)) + assert ( + addr_expr == self.recovered_addr + ), f"{assert_msg}: {hex(addr_expr.n)} != {hex(self.recovered_addr.n)}" + + # 4. Verify that the signed message in the ecdsa_chip equals `msg_hash` + msg_hash = Word(self.msg_hash_bytes) + assert ( + msg_hash == self.msg_hash + ), f"{assert_msg}: {hex(msg_hash.int_value())} != {hex(self.msg_hash.int_value())}" + + # 5. Verify the ECDSA signature + is_valid = self.ecdsa_chip.verify() + assert is_valid == self.is_valid, f"{assert_msg}: {is_valid} != {self.is_valid}" + + +class Witness(NamedTuple): + rows: List[Row] # Transaction table rows + keccak_table: KeccakTable + + +@is_circuit_code +def verify_circuit( + witness: Witness, + keccak_randomness: FQ, +) -> None: + """ + Entry level circuit verification function + """ + for i, row in enumerate(witness.rows): + assert_msg = f"Constraints failed at row = {i}" + row.verify(witness.keccak_table, keccak_randomness, assert_msg) diff --git a/src/zkevm_specs/util/__init__.py b/src/zkevm_specs/util/__init__.py index 8f4940a05..c361f55bc 100644 --- a/src/zkevm_specs/util/__init__.py +++ b/src/zkevm_specs/util/__init__.py @@ -3,3 +3,5 @@ from .hash import * from .param import * from .typing import * +from .ec import * +from .tables import * diff --git a/src/zkevm_specs/util/ec.py b/src/zkevm_specs/util/ec.py new file mode 100644 index 000000000..9c520badd --- /dev/null +++ b/src/zkevm_specs/util/ec.py @@ -0,0 +1,113 @@ +from typing import Tuple +from .arithmetic import FQ +from eth_keys import KeyAPI # type: ignore + + +class WrongFieldInteger: + """ + Wrong Field arithmetic Integer, representing the implementation at + https://github.com/privacy-scaling-explorations/halo2wrong/blob/master/integer/src/integer.rs + """ + + limbs: Tuple[FQ, FQ, FQ, FQ] # Little-Endian limbs of [72, 72, 72, 40] bits + le_bytes: bytes # Little-Endian bytes + + def __init__(self, value: int) -> None: + mask = (1 << 72) - 1 + l0 = (value >> 0 * 72) & mask + l1 = (value >> 1 * 72) & mask + l2 = (value >> 2 * 72) & mask + l3 = (value >> 3 * 72) & mask + self.limbs = (FQ(l0), FQ(l1), FQ(l2), FQ(l3)) + self.le_bytes = value.to_bytes(32, "little") + + def to_le_bytes(self) -> bytes: + (l0, l1, l2, l3) = self.limbs + val = l0.n + (l1.n << 1 * 72) + (l2.n << 2 * 72) + (l3.n << 3 * 72) + return val.to_bytes(32, "little") + + def to_be_bytes(self) -> bytes: + (l0, l1, l2, l3) = self.limbs + val = l0.n + (l1.n << 1 * 72) + (l2.n << 2 * 72) + (l3.n << 3 * 72) + return val.to_bytes(32, "big") + + +class Secp256k1BaseField(WrongFieldInteger): + """ + Secp256k1 Base Field. + """ + + def __init__(self, value: int) -> None: + WrongFieldInteger.__init__(self, value) + + +class Secp256k1ScalarField(WrongFieldInteger): + """ + Secp256k1 Scalar Field. + """ + + def __init__(self, value: int) -> None: + WrongFieldInteger.__init__(self, value) + + +# TODO: There is another one used in tx_circuit, try to merge into one. +# Reminder: endianness of public key is differ with the one in tx_circuit +class ECDSAVerifyChip: + """ + ECDSA Signature Verification Chip. This represents an ECDSA signature + verification Chip as implemented in + https://github.com/privacy-scaling-explorations/halo2wrong/blob/master/ecdsa/src/ecdsa.rs + """ + + sig_v: Secp256k1ScalarField + sig_r: Secp256k1ScalarField + sig_s: Secp256k1ScalarField + pub_key: Tuple[Secp256k1BaseField, Secp256k1BaseField] + pub_key_x_bytes: bytes + pub_key_y_bytes: bytes + msg_hash: Secp256k1ScalarField + msg_hash_bytes: bytes + + def __init__( + self, + signature: Tuple[Secp256k1ScalarField, Secp256k1ScalarField, Secp256k1ScalarField], + pub_key: Tuple[Secp256k1BaseField, Secp256k1BaseField], + msg_hash: Secp256k1ScalarField, + ) -> None: + self.sig_v = signature[0] + self.sig_r = signature[1] + self.sig_s = signature[2] + self.pub_key = pub_key + self.msg_hash = msg_hash + self.pub_key_x_bytes = pub_key[0].to_be_bytes() + self.pub_key_y_bytes = pub_key[1].to_be_bytes() + self.msg_hash_bytes = msg_hash.to_be_bytes() + # NOTE: The circuit must constrain that all elements in the `*_bytes` + # parameters are in range 0..255 and that they represent the same + # value as their corresponding WrongFieldInteger limbs. + + @classmethod + def assign(cls, signature: KeyAPI.Signature, pub_key: KeyAPI.PublicKey, msg_hash: bytes): + # signature + self_sig_v = Secp256k1ScalarField(signature.v) + self_sig_r = Secp256k1ScalarField(signature.r) + self_sig_s = Secp256k1ScalarField(signature.s) + # public key + pub_key_bytes = pub_key.to_bytes() + pub_key_bytes_x, pub_key_bytes_y = pub_key_bytes[:32], pub_key_bytes[32:] + pub_key_x = int.from_bytes(pub_key_bytes_x, "big") + pub_key_y = int.from_bytes(pub_key_bytes_y, "big") + self_pub_key = (Secp256k1BaseField(pub_key_x), Secp256k1BaseField(pub_key_y)) + # message hash + self_msg_hash = Secp256k1ScalarField(int.from_bytes(msg_hash, "big")) + return cls((self_sig_v, self_sig_r, self_sig_s), self_pub_key, self_msg_hash) + + def verify(self) -> bool: + sig_v = int.from_bytes(self.sig_v.to_le_bytes(), "little") + sig_r = int.from_bytes(self.sig_r.to_le_bytes(), "little") + sig_s = int.from_bytes(self.sig_s.to_le_bytes(), "little") + signature = KeyAPI.Signature(vrs=[sig_v, sig_r, sig_s]) + + msg_hash = bytes(self.msg_hash.to_be_bytes()) + public_key = KeyAPI.PublicKey(self.pub_key[0].to_be_bytes() + self.pub_key[1].to_be_bytes()) + return KeyAPI().ecdsa_verify(msg_hash, signature, public_key) diff --git a/src/zkevm_specs/util/tables.py b/src/zkevm_specs/util/tables.py new file mode 100644 index 000000000..161787bdb --- /dev/null +++ b/src/zkevm_specs/util/tables.py @@ -0,0 +1,33 @@ +from typing import Tuple, Set +from .arithmetic import ( + FQ, + RLC, + Word, +) +from eth_utils import keccak + + +class KeccakTable: + # The columns are: (is_enabled, input_rlc, input_len, output) + table: Set[Tuple[FQ, FQ, FQ, Word]] + + def __init__(self): + self.table = set() + self.table.add((FQ(0), FQ(0), FQ(0), Word(0))) # Add all 0s row + + def add(self, input: bytes, keccak_randomness: FQ): + output = keccak(input) + self.table.add( + ( + FQ(1), + RLC(input, keccak_randomness, n_bytes=64).expr(), + FQ(len(input)), + Word(output), + ) + ) + + def lookup(self, is_enabled: FQ, input_rlc: FQ, input_len: FQ, output: Word, assert_msg: str): + assert (is_enabled, input_rlc, input_len, output) in self.table, ( + f"{assert_msg}: {(is_enabled, input_rlc, input_len, output)} " + + "not found in the lookup table" + ) diff --git a/tests/test_sig_circuit.py b/tests/test_sig_circuit.py new file mode 100644 index 000000000..61d2c843f --- /dev/null +++ b/tests/test_sig_circuit.py @@ -0,0 +1,146 @@ +from typing import NamedTuple, List +from eth_keys import keys # type: ignore +from eth_utils import keccak +from zkevm_specs.sig_circuit import * +from zkevm_specs.util import FQ +from common import rand_fq +from zkevm_specs.util import ( + FQ, + Word, + U160, + U256, +) + +keccak_randomness = rand_fq() +r = keccak_randomness + + +class SignedData(NamedTuple): + msg_hash: bytes + sig_v: U256 + sig_r: U256 + sig_s: U256 + addr: U160 + is_valid: bool + + +def sign_msg(sk: keys.PrivateKey, msg: bytes, valid: bool = True) -> SignedData: + """ + Return a copy of the signed data + """ + + msg_hash = keccak(msg) + sig = sk.sign_msg_hash(msg_hash) + sig_v = sig.v + sig_r = sig.r if valid else U256(1) + sig_s = sig.s if valid else U256(1) + return SignedData(msg_hash, sig_v, sig_r, sig_s, int(sk.public_key.to_address(), 16), valid) + + +def signedData2witness( + signed_data: List[SignedData], + keccak_randomness: FQ, +) -> Witness: + """ + Generate the complete witness of a list of signed data. + """ + + rows: List[Row] = [] + keccak_table = KeccakTable() + for i, data in enumerate(signed_data): + sig = KeyAPI.Signature(vrs=(data.sig_v, data.sig_r, data.sig_s)) + pk = sig.recover_public_key_from_msg_hash(data.msg_hash) + ecdsa_chip = ECDSAVerifyChip.assign(sig, pk, data.msg_hash) + + pk_bytes = pk.to_bytes() + keccak_table.add(pk_bytes, keccak_randomness) + pk_hash = keccak(pk_bytes) + rows.append( + Row( + pk_hash, + FQ(data.addr), + Word(data.msg_hash), + ecdsa_chip, + ) + ) + + return Witness(rows, keccak_table) + + +def gen_witness(num: int = 10, valid: bool = True) -> Witness: + sks = [keys.PrivateKey(bytes([byte + 1]) * 32) for byte in range(num)] + + list: List[SignedData] = [] + for sk in sks: + signed_msg = sign_msg(sk, bytes("Message", "utf-8"), valid) + list.append(signed_msg) + + witness = signedData2witness(list, r) + return witness + + +def verify( + witness: Witness, + keccak_randomness: FQ, + success: bool = True, +): + """ + Verify the circuit with the assigned witness (or the witness calculated + from the transactions). If `success` is False, expect the verification to + fail. + """ + + exception = None + try: + verify_circuit( + witness, + keccak_randomness, + ) + except AssertionError as e: + exception = e + + if success: + if exception: + raise exception + assert exception is None + else: + assert exception is not None + + +def test_ecdsa_verify_chip(): + sk = keys.PrivateKey(b"\x02" * 32) + pk = sk.public_key + msg_hash = b"\xae" * 32 + sig = sk.sign_msg_hash(msg_hash) + + ecdsa_chip = ECDSAVerifyChip.assign(sig, pk, msg_hash) + assert ecdsa_chip.verify() == True + + +def test_sig_verify(): + witness = gen_witness() + verify(witness, r) + + +def test_sig_bad_keccak(): + witness = gen_witness() + # Set empty keccak lookup table + witness = Witness(witness.rows, KeccakTable()) + verify(witness, r, success=False) + + +def test_sig_bad_signature(): + witness = gen_witness(10, False) + verify(witness, r, success=False) + + +def test_sig_bad_msg_hash(): + witness = gen_witness(1) + witness.rows[0].msg_hash = Word(1) + verify(witness, r, success=False) + + +def test_sig_bad_address(): + witness = gen_witness(1) + witness.rows[0].recovered_addr = FQ(1) + verify(witness, r, success=False)