From 416b5c711aac80913ae600c50d1dba8f251a7467 Mon Sep 17 00:00:00 2001 From: KimiWu Date: Thu, 26 Oct 2023 13:19:10 +0800 Subject: [PATCH] feat: impl ecc_circuit and ecc_table --- specs/ecc-proof.md | 52 +++++++++ specs/precompile/06ecAdd.md | 0 specs/tables.md | 11 ++ src/zkevm_specs/ecc_circuit.py | 165 +++++++++++++++++++++++++++ src/zkevm_specs/evm_circuit/table.py | 27 +++++ src/zkevm_specs/util/ec.py | 61 ++++++++-- tests/test_ecc_circuit.py | 110 ++++++++++++++++++ 7 files changed, 419 insertions(+), 7 deletions(-) create mode 100644 specs/ecc-proof.md create mode 100644 specs/precompile/06ecAdd.md create mode 100644 src/zkevm_specs/ecc_circuit.py create mode 100644 tests/test_ecc_circuit.py diff --git a/specs/ecc-proof.md b/specs/ecc-proof.md new file mode 100644 index 000000000..83fe10036 --- /dev/null +++ b/specs/ecc-proof.md @@ -0,0 +1,52 @@ +# Ecc Proof + +[Elliptic Curve Digital Signature Algorithm]: https://en.wikipedia.org/wiki/Elliptic_Curve_Digital_Signature_Algorithm + +According to the [Elliptic Curve Digital Signature Algorithm] (ECDSA), the signatures `(r,s)` are calculated via ECDSA from `msg_hash` and a `public_key` using the formula + +`(r,s)=ecdsa(msg_hash, public_key)` + +The `public_key` is obtained from `private_key` by mapping the latter to an elliptic curve (EC) point. The `r` is the x-component of an EC point, and the same EC point's y-component will be used to determine the recovery id `v = y%2` (the parity of y). Given the signature `(v, r, s)`, the `public_key` can be recovered from `(v, r, s)` and `msg_hash` using `ecrecover`. + + +## Circuit behavior + +SigTable built inside zkevm-circuits is used to verify signatures. It has the following columns: +- `msg_hash`: Advice Column, the Keccak256 hash of the message that's signed; +- `sig_v`: Advice Column, the recovery id, either 0 or 1, it should be the parity of y; +- `sig_r`: Advice Column, the signature's `r` component; +- `sig_s`: Advice Column, the signature's `s` component; +- `recovered_addr`: Advice Column, the recovered address, i.e. the 20-bytes address that must have signed the message; +- `is_valid`: Advice Column, indicates whether or not the signature is valid or not upon signature verification. + +Constraints on the shape of the table is like: + +| 0 op_type | 1 input_1 | 2 input_2 | 3 output | 4 is_valid | +| --------- | ------------- | ------------- | ------------- | ---------- | +| $tag | $value{Lo,Hi} | $value{Lo,Hi} | $value{Lo,Hi} | bool | + +* tag: ADD, MUL and PAIRING + +The Sig Circuit aims at proving the correctness of SigTable. This mainly includes the following type of constraints: +- checking that the signature is obtained correctly. This is done by the ECDSA chip, and the correctness of `v` is checked separately; +- checking that `msg_hash` is obtained correctly from Keccak hash function. This is done by lookup to Keccak table; + + +## Constraints + +`assign_ecdsa` method takes the signature data and uses ECDSA chip to verify its correctness. The verification result `sig_is_valid` will be returned. The recovery id `v` value will be computed and verified. + +`sign_data_decomposition` method takes the signature data and the return values of `assign_ecdsa`, and returns the cells for byte decomposition of the keys and messages in the form of `SignDataDecomposed`. The latter consists of the following contents: +- `SignDataDecomposed` + - `pk_hash_cells`: byte cells for keccak256 hash of public key; + - `msg_hash_cells`: byte cells for `msg_hash`; + - `pk_cells`: byte cells for the EC coordinates of public key; + - `address`: RLC of `pk_hash` last 20 bytes; + - `is_address_zero`: check if address is zero; + - `r_cells`, `s_cells`: byte cells for signatures `r` and `s`. + +The decomposed sign data are sent to `assign_sign_verify` method to compute and verify their RLC values and perform Keccak lookup checks. + +## Code + +Please refer to `src/zkevm-specs/sig_circuit.py` \ No newline at end of file diff --git a/specs/precompile/06ecAdd.md b/specs/precompile/06ecAdd.md new file mode 100644 index 000000000..e69de29bb diff --git a/specs/tables.md b/specs/tables.md index 677d68a8f..4109b938e 100644 --- a/specs/tables.md +++ b/specs/tables.md @@ -379,3 +379,14 @@ The circuit verifies the correctness of signatures. NOTE: - `sig_v` is either 0 or 1 so boolean type is used here. + + +## Elliptic Curve Table + +Proved by the Elliptic Curve circuit. + +| 0 op_type | 1 input_a | 2 input_b | 3 output | 4 is_valid | +| --------- | ------------- | ------------- | ------------- | ---------- | +| $tag | $value{Lo,Hi} | $value{Lo,Hi} | $value{Lo,Hi} | bool | + +- **tag**: supports `Add`, `Mul` and `Pairing` diff --git a/src/zkevm_specs/ecc_circuit.py b/src/zkevm_specs/ecc_circuit.py new file mode 100644 index 000000000..a63d802cb --- /dev/null +++ b/src/zkevm_specs/ecc_circuit.py @@ -0,0 +1,165 @@ +from __future__ import annotations +from typing import List, Sequence, Tuple +from py_ecc.bn128.bn128_curve import is_inf, is_on_curve, b +from .evm_circuit import EccTableRow +from .util import ConstraintSystem, FQ, Word, ECCVerifyChip +from zkevm_specs.evm_circuit.table import EccOpTag + + +class EccCircuitRow: + """ + ECC circuit + """ + + row: EccTableRow + + ecc_chip: ECCVerifyChip + + def __init__(self, row: EccTableRow, ecc_chip: ECCVerifyChip) -> None: + self.row = row + self.ecc_chip = ecc_chip + + @classmethod + def check_fq(cls, value: int) -> bool: + return value < int(FQ.field_modulus) + + @classmethod + def assign( + cls, op_type: EccOpTag, p0: Tuple[Word, Word], p1: Tuple[Word, Word], out: Tuple[Word, Word] + ): + if op_type == EccOpTag.Add: + return cls.assign_add(p0, p1, out) + elif op_type == EccOpTag.Mul: + return cls.assign_add(p0, p1, out) + elif op_type == EccOpTag.Pairing: + return cls.assign_add(p0, p1, out) + else: + raise TypeError(f"Not supported type: {op_type}") + + @classmethod + def assign_add(cls, p0: Tuple[Word, Word], p1: Tuple[Word, Word], out: Tuple[Word, Word]): + # 1. verify validity of input points p0 and p1 + precheck_p0x = cls.check_fq(p0[0].int_value()) + precheck_p0y = cls.check_fq(p0[1].int_value()) + precheck_p1x = cls.check_fq(p1[0].int_value()) + precheck_p1y = cls.check_fq(p1[1].int_value()) + + point0 = (FQ(p0[0].int_value()), FQ(p0[1].int_value())) + point1 = (FQ(p1[0].int_value()), FQ(p1[1].int_value())) + is_valid_p0 = is_on_curve(point0, b) + is_valid_p1 = is_on_curve(point1, b) + is_infinite_p0 = is_inf(point0) + is_infinite_p1 = is_inf(point1) + is_valid_points = (is_valid_p0 or is_infinite_p0) and (is_valid_p1 or is_infinite_p1) + + is_valid = ( + precheck_p0x and precheck_p0y and precheck_p1x and precheck_p1y and is_valid_points + ) + + self_p0_x = p0[0].int_value() + self_p0_y = p0[1].int_value() + self_p1_x = p1[0].int_value() + self_p1_y = p1[1].int_value() + self_output_x = out[0].int_value() + self_output_y = out[1].int_value() + + ecc_chip = ECCVerifyChip.assign( + p0=(FQ(self_p0_x), FQ(self_p0_y)), + p1=(FQ(self_p1_x), FQ(self_p1_y)), + output=(FQ(self_output_x), FQ(self_output_y)), + ) + ecc_table = EccTableRow( + FQ(EccOpTag.Add), + Word(self_p0_x), + Word(self_p0_y), + Word(self_p1_x), + Word(self_p1_y), + Word(self_output_x), + Word(self_output_y), + FQ(is_valid), + ) + + return cls(ecc_table, ecc_chip) + + @classmethod + def assign_mul(cls, p0: Tuple[Word, Word], p1: Tuple[Word, Word], out: Tuple[Word, Word]): + raise NotImplementedError("assign_mul is not supported yet") + + @classmethod + def assign_pairing(cls, p0: Tuple[Word, Word], p1: Tuple[Word, Word], out: Tuple[Word, Word]): + raise NotImplementedError("assign_pairing is not supported yet") + + def verify( + self, cs: ConstraintSystem, max_add_ops: int, max_mul_ops: int, max_pairing_ops: int + ): + # Copy constraints between EccTable and ECCVerifyChip + cs.constrain_equal_word(Word(self.ecc_chip.p0[0].n), self.row.px) + cs.constrain_equal_word(Word(self.ecc_chip.p0[1].n), self.row.py) + cs.constrain_equal_word(Word(self.ecc_chip.p1[0].n), self.row.qx) + cs.constrain_equal_word(Word(self.ecc_chip.p1[1].n), self.row.qy) + cs.constrain_equal_word(Word(self.ecc_chip.output[0].n), self.row.out_x) + cs.constrain_equal_word(Word(self.ecc_chip.output[1].n), self.row.out_y) + + is_add = cs.is_equal(self.row.op_type, FQ(EccOpTag.Add)) + is_mul = cs.is_equal(self.row.op_type, FQ(EccOpTag.Mul)) + is_pairing = cs.is_equal(self.row.op_type, FQ(EccOpTag.Pairing)) + # Must be one of above operations + cs.constrain_equal(is_add + is_mul + is_pairing, FQ(1)) + + num_add = 0 + num_mul = 0 + num_pairing = 0 + if is_add == FQ(1): + num_add += 1 + assert ( + num_add <= max_add_ops + ), f"exceeds max number of add operation, max_add_ops: {max_add_ops}" + cs.constrain_equal(FQ(self.ecc_chip.verify_add()), self.row.is_valid) + + if is_mul == FQ(1): + num_mul += 1 + assert ( + num_mul <= max_mul_ops + ), f"exceeds max number of mul operation, max_mul_ops: {max_mul_ops}" + cs.constrain_equal(FQ(self.ecc_chip.verify_mul()), self.row.is_valid) + + if is_pairing == FQ(1): + num_pairing += 1 + assert ( + num_pairing <= max_pairing_ops + ), f"exceeds max number of pairing operation, max_pairing_ops: {max_pairing_ops}" + cs.constrain_equal(FQ(self.ecc_chip.verify_pairing()), self.row.is_valid) + + +class EccCircuit: + rows: List[EccCircuitRow] + max_add_ops: int + max_mul_ops: int + max_pairing_ops: int + + def __init__( + self, + max_add_ops: int, + max_mul_ops: int, + max_pairing_ops: int, + ) -> None: + self.rows = [] + self.max_add_ops = max_add_ops + self.max_mul_ops = max_mul_ops + self.max_pairing_ops = max_pairing_ops + + def table(self) -> Sequence[EccCircuitRow]: + return self.rows + + def add(self, row: EccCircuitRow) -> EccCircuit: + self.rows.append(row) + return self + + +def verify_circuit(circuit: EccCircuit) -> None: + """ + Entry level circuit verification function + """ + cs = ConstraintSystem() + for row in circuit.table(): + row.verify(cs, circuit.max_add_ops, circuit.max_mul_ops, circuit.max_pairing_ops) diff --git a/src/zkevm_specs/evm_circuit/table.py b/src/zkevm_specs/evm_circuit/table.py index a53fba839..abe3003e3 100644 --- a/src/zkevm_specs/evm_circuit/table.py +++ b/src/zkevm_specs/evm_circuit/table.py @@ -350,6 +350,16 @@ def from_account_field_tag(field_tag: AccountFieldTag) -> MPTProofType: raise Exception("Unexpected AccountFieldTag value") +class EccOpTag(IntEnum): + """ + Tag for EccTable that specifies the operation over ECC + """ + + Add = auto() # addition of two EC points + Mul = auto() # multiplication of two EC points + Pairing = auto() # pairing of two EC points + + class WrongQueryKey(Exception): def __init__(self, table_name: str, diff: Set[str]) -> None: self.message = f"Lookup {table_name} with invalid keys {diff}" @@ -545,6 +555,19 @@ class SigTableRow(TableRow): sig_r: Word sig_s: Word recovered_addr: FQ + + +@dataclass(frozen=True) +class EccTableRow(TableRow): + op_type: FQ + px: Word + py: Word + # qx is the scalar and qy must be zero if op_type is multiple + qx: Word + qy: Word + + out_x: Word + out_y: Word is_valid: FQ @@ -563,6 +586,7 @@ class Tables: keccak_table: Set[KeccakTableRow] exp_table: Set[ExpTableRow] sig_table: Set[SigTableRow] + ecc_table: Set[EccTableRow] def __init__( self, @@ -575,6 +599,7 @@ def __init__( keccak_table: Optional[Sequence[KeccakTableRow]] = None, exp_circuit: Optional[Sequence[ExpCircuitRow]] = None, sig_table: Optional[Sequence[SigTableRow]] = None, + ecc_table: Optional[Sequence[EccTableRow]] = None, ) -> None: self.block_table = block_table self.tx_table = tx_table @@ -592,6 +617,8 @@ def __init__( self.exp_table = self._convert_exp_circuit_to_table(exp_circuit) if sig_table is not None: self.sig_table = set(sig_table) + if ecc_table is not None: + self.ecc_table = set(ecc_table) def _convert_copy_circuit_to_table(self, copy_circuit: Sequence[CopyCircuitRow]): rows: List[CopyTableRow] = [] diff --git a/src/zkevm_specs/util/ec.py b/src/zkevm_specs/util/ec.py index eea102109..be1d93d81 100644 --- a/src/zkevm_specs/util/ec.py +++ b/src/zkevm_specs/util/ec.py @@ -1,6 +1,8 @@ +from __future__ import annotations from typing import Tuple -from .arithmetic import FQ from eth_keys import KeyAPI # type: ignore +from .arithmetic import FQ +from py_ecc.bn128.bn128_curve import add, eq class WrongFieldInteger: @@ -21,15 +23,19 @@ def __init__(self, value: int) -> None: self.limbs = (FQ(l0), FQ(l1), FQ(l2), FQ(l3)) self.le_bytes = value.to_bytes(32, "little") - def to_le_bytes(self) -> bytes: + def to_int_value(self) -> int: (l0, l1, l2, l3) = self.limbs - val = l0.n + (l1.n << 1 * 72) + (l2.n << 2 * 72) + (l3.n << 3 * 72) - return val.to_bytes(32, "little") + return l0.n + (l1.n << 1 * 72) + (l2.n << 2 * 72) + (l3.n << 3 * 72) + + def to_le_bytes(self) -> bytes: + return self.to_int_value().to_bytes(32, "little") def to_be_bytes(self) -> bytes: - (l0, l1, l2, l3) = self.limbs - val = l0.n + (l1.n << 1 * 72) + (l2.n << 2 * 72) + (l3.n << 3 * 72) - return val.to_bytes(32, "big") + return self.to_int_value().to_bytes(32, "big") + + def add(self, other: WrongFieldInteger) -> int: + # Python will extend the size if it exceeds 32bytes. So, don't need to take care overflow here. + return self.to_int_value() + other.to_int_value() class Secp256k1BaseField(WrongFieldInteger): @@ -111,3 +117,44 @@ def verify(self) -> bool: msg_hash = bytes(self.msg_hash.to_be_bytes()) public_key = KeyAPI.PublicKey(self.pub_key[0].to_be_bytes() + self.pub_key[1].to_be_bytes()) return KeyAPI().ecdsa_verify(msg_hash, signature, public_key) + + +class ECCVerifyChip: + """ + ECC Verification Chip. This represents an ECC verification Chip as implemented in + https://github.com/privacy-scaling-explorations/halo2wrong/blob/master/ecc/src/general_field_ecc.rs + """ + + p0: Tuple[FQ, FQ] + p1: Tuple[FQ, FQ] + output: Tuple[FQ, FQ] + + def __init__( + self, + p0: Tuple[FQ, FQ], + p1: Tuple[FQ, FQ], + output: Tuple[FQ, FQ], + ) -> None: + self.p0 = p0 + self.p1 = p1 + self.output = output + + @classmethod + def assign( + cls, + p0: Tuple[FQ, FQ], + p1: Tuple[FQ, FQ], + output: Tuple[FQ, FQ], + ): + return cls((p0[0], p0[1]), (p1[0], p1[1]), (output[0], output[1])) + + def verify_add(self) -> bool: + a = add(self.p0, self.p1) + print(f"{a[0].n}, {a[1].n}") + return eq(add(self.p0, self.p1), self.output) + + def verify_mul(self) -> bool: + raise NotImplementedError("verify_mul is not supported yet") + + def verify_pairing(self) -> bool: + raise NotImplementedError("verify_pairing is not supported yet") diff --git a/tests/test_ecc_circuit.py b/tests/test_ecc_circuit.py new file mode 100644 index 000000000..57332fdf3 --- /dev/null +++ b/tests/test_ecc_circuit.py @@ -0,0 +1,110 @@ +import pytest +from typing import Tuple, NamedTuple +from zkevm_specs.ecc_circuit import EccCircuitRow, verify_circuit, EccCircuit +from zkevm_specs.evm_circuit.table import EccOpTag +from zkevm_specs.util import Word + + +class EccOps(NamedTuple): + op_type: EccOpTag + p: Tuple[int, int] + q: Tuple[int, int] + out: Tuple[int, int] + + +def verify( + circuit: EccCircuit, + success: bool = True, +): + """ + Verify the circuit with the assigned witness. + If `success` is False, expect the verification to fail. + """ + + exception = None + try: + verify_circuit(circuit) + except Exception as e: + exception = e + if success: + if exception: + raise exception + assert exception is None + else: + assert exception is not None + + +def gen_ecAdd_testing_data(): + op = EccOpTag.Add + + normal = ( + EccOps( + op, + p=(1, 2), + q=(1, 2), + out=( + 0x030644E72E131A029B85045B68181585D97816A916871CA8D3C208C16D87CFD3, + 0x15ED738C0E0A7C92E7845F96B2AE9C0A68A6A449E3538FC7FF3EBF7A5A18A2C4, + ), + ), + True, + ) + # p is not on the curve + invalid_p = ( + EccOps( + op, + p=(2, 3), + q=(1, 2), + out=(0, 0), + ), + True, + ) + incorrect_out = ( + EccOps( + op, + p=(1, 2), + q=(1, 2), + out=( + 0x030644E72E131A029B85045B68181585D97816A916871CA8D3C208C16D87CFD0, + 0x15ED738C0E0A7C92E7845F96B2AE9C0A68A6A449E3538FC7FF3EBF7A5A18A2C4, + ), + ), + False, + ) + # q = -p + # py_ecc doesn't support this case, it returns (0x30644E72E131A029B85045B68181585D97816A916871CA8D3C208C16D87CFD45, 1) + # p_plus_neg_p = ( + # EccOps( + # op, + # p=(1, 2), + # q=(1, 0x30644E72E131A029B85045B68181585D97816A916871CA8D3C208C16D87CFD45), + # out=(0, 0), + # ), + # True, + # ) + return [normal, invalid_p, incorrect_out] + + +TESTING_DATA = gen_ecAdd_testing_data() + + +@pytest.mark.parametrize( + "ecc_ops, success", + TESTING_DATA, +) +def test_ecc_add(ecc_ops: EccOps, success: bool): + MAX_ECADD_OPS = 5 + MAX_ECMUL_OPS = 0 + MAX_ECPAIRING_OPS = 0 + + circuit = EccCircuit(MAX_ECADD_OPS, MAX_ECMUL_OPS, MAX_ECPAIRING_OPS) + ecc_ops = gen_ecAdd_testing_data() + for ec_op, success in ecc_ops: + row = EccCircuitRow.assign( + ec_op.op_type, + (Word(ec_op.p[0]), Word(ec_op.p[1])), + (Word(ec_op.q[0]), Word(ec_op.q[1])), + (Word(ec_op.out[0]), Word(ec_op.out[1])), + ) + circuit.add(row) + verify(circuit, success)