Skip to content
This repository has been archived by the owner on Feb 21, 2024. It is now read-only.

Commit

Permalink
Optimize num_bytes and hex_prefix_rlp (0xPolygonZero#1384)
Browse files Browse the repository at this point in the history
* Compute num_bytes non-deterministically

* Optimize hex_prefix_rlp

* Clean code

* Clippy

* Apply suggestions

Co-authored-by: Robin Salen <[email protected]>

* Clean

* Add endline

* Change 1^256 to U256_MAX

* Apply suggestions from code review

Co-authored-by: Robin Salen <[email protected]>

---------

Co-authored-by: Robin Salen <[email protected]>
  • Loading branch information
4l0n50 and Nashtare authored Nov 29, 2023
1 parent 64cc100 commit 471ff68
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 134 deletions.
156 changes: 89 additions & 67 deletions evm/src/cpu/kernel/asm/mpt/hex_prefix.asm
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,14 @@
//
// Pre stack: rlp_start_pos, num_nibbles, packed_nibbles, terminated, retdest
// Post stack: rlp_end_pos

global hex_prefix_rlp:
// stack: rlp_pos, num_nibbles, packed_nibbles, terminated, retdest
// We will iterate backwards, from i = num_nibbles / 2 to i = 0, so that we
// can take nibbles from the least-significant end of packed_nibbles.
PUSH 2 DUP3 DIV // i = num_nibbles / 2
// stack: i, rlp_pos, num_nibbles, packed_nibbles, terminated, retdest

DUP2 %assert_lt_const(65)
PUSH 2 DUP3 DIV
// Compute the length of the hex-prefix string, in bytes:
// hp_len = num_nibbles / 2 + 1 = i + 1
DUP1 %increment
// stack: hp_len, i, rlp_pos, num_nibbles, packed_nibbles, terminated, retdest
%increment
// stack: hp_len, rlp_pos, num_nibbles, packed_nibbles, terminated, retdest

// Write the RLP header.
DUP1 %gt_const(55) %jumpi(rlp_header_large)
Expand All @@ -25,80 +21,106 @@ global hex_prefix_rlp:
// The hex-prefix is a single byte. It must be <= 127, since its first
// nibble only has two bits. So this is the "small" RLP string case, where
// the byte is its own RLP encoding.
// stack: hp_len, i, rlp_pos, num_nibbles, packed_nibbles, terminated, retdest
%jump(start_loop)
// stack: hp_len, rlp_pos, num_nibbles, packed_nibbles, terminated, retdest
POP
first_byte:
// stack: rlp_pos, num_nibbles, packed_nibbles, terminated, retdest
// get the first nibble, if num_nibbles is odd, or zero otherwise
SWAP2
// stack: packed_nibbles, num_nibbbles, rlp_pos, terminated, retdest
DUP2 DUP1
%mod_const(2)
// stack: parity, num_nibbles, packed_nibbles, num_nibbles, rlp_pos, terminated, retdest
SWAP1 SUB
%mul_const(4)
SHR
// stack: first_nibble_or_zero, num_nibbles, rlp_pos, terminated, retdest
SWAP2
// stack: rlp_pos, num_nibbles, first_nibble_or_zero, terminated, retdest
SWAP3
// stack: terminated, num_nibbles, first_nibble_or_zero, rlp_pos, retdest
%mul_const(2)
// stack: terminated * 2, num_nibbles, first_nibble_or_zero, rlp_pos, retdest
SWAP1
// stack: num_nibbles, terminated * 2, first_nibble_or_zero, rlp_pos, retdest
%mod_const(2) // parity
ADD
// stack: parity + terminated * 2, first_nibble_or_zero, rlp_pos, retdest
%mul_const(16)
ADD
// stack: first_byte, rlp_pos, retdest
DUP2
%mstore_rlp
%increment
// stack: rlp_pos', retdest
SWAP1
JUMP
remaining_bytes:
// stack: rlp_pos, num_nibbles, packed_nibbles, retdest
SWAP2
PUSH @U256_MAX
// stack: U256_MAX, packed_nibbles, num_nibbles, rlp_pos, ret_dest
SWAP1 SWAP2 DUP1
%mod_const(2)
// stack: parity, num_nibbles, U256_MAX, packed_nibbles, rlp_pos, ret_dest
SWAP1 SUB DUP1
// stack: num_nibbles - parity, num_nibbles - parity, U256_MAX, packed_nibbles, rlp_pos, ret_dest
%div_const(2)
// stack: remaining_bytes, num_nibbles - parity, U256_MAX, packed_nibbles, rlp_pos, ret_dest
SWAP2 SWAP1
// stack: num_nibbles - parity, U256_MAX, remaining_bytes, packed_nibbles, rlp_pos, ret_dest
%mul_const(4)
// stack: 4*(num_nibbles - parity), U256_MAX, remaining_bytes, packed_nibbles, rlp_pos, ret_dest
PUSH 256 SUB
// stack: 256 - 4*(num_nibbles - parity), U256_MAX, remaining_bytes, packed_nibbles, rlp_pos, ret_dest
SHR
// stack: mask, remaining_bytes, packed_nibbles, rlp_pos, ret_dest
SWAP1 SWAP2
AND
%stack
(remaining_nibbles, remaining_bytes, rlp_pos) ->
(rlp_pos, remaining_nibbles, remaining_bytes)
%mstore_unpacking_rlp
SWAP1
JUMP


rlp_header_medium:
// stack: hp_len, i, rlp_pos, num_nibbles, packed_nibbles, terminated, retdest
DUP1 %add_const(0x80) // value = 0x80 + hp_len
DUP4 // offset = rlp_pos
// stack: hp_len, rlp_pos, num_nibbles, packed_nibbles, terminated, retdest
%add_const(0x80) // value = 0x80 + hp_len
DUP2 // offset = rlp_pos
%mstore_rlp

// stack: rlp_pos, num_nibbles, packed_nibbles, terminated, retdest
// rlp_pos += 1
SWAP2 %increment SWAP2
%increment

%stack
(rlp_pos, num_nibbles, packed_nibbles, terminated, retdest) ->
(rlp_pos, num_nibbles, packed_nibbles, terminated, remaining_bytes, num_nibbles, packed_nibbles, retdest)

%jump(start_loop)
%jump(first_byte)

rlp_header_large:
// stack: hp_len, i, rlp_pos, num_nibbles, packed_nibbles, terminated, retdest
// stack: hp_len, rlp_pos, num_nibbles, packed_nibbles, terminated, retdest
// In practice hex-prefix length will never exceed 256, so the length of the
// length will always be 1 byte in this case.

PUSH 0xb8 // value = 0xb7 + len_of_len = 0xb8
DUP4 // offset = rlp_pos
DUP3 // offset = rlp_pos
%mstore_rlp

DUP1 // value = hp_len
DUP4 %increment // offset = rlp_pos + 1
// stack: hp_len, rlp_pos, num_nibbles, packed_nibbles, terminated, retdest
DUP2 %increment // offset = rlp_pos + 1
%mstore_rlp

// stack: rlp_pos, num_nibbles, packed_nibbles, terminated, retdest
// rlp_pos += 2
SWAP2 %add_const(2) SWAP2
%add_const(2)

start_loop:
// stack: hp_len, i, rlp_pos, num_nibbles, packed_nibbles, terminated, retdest
SWAP1
%stack
(rlp_pos, num_nibbles, packed_nibbles, terminated, retdest) ->
(rlp_pos, num_nibbles, packed_nibbles, terminated, remaining_bytes, num_nibbles, packed_nibbles, retdest)

loop:
// stack: i, hp_len, rlp_pos, num_nibbles, packed_nibbles, terminated, retdest
// If i == 0, break to first_byte.
DUP1 ISZERO %jumpi(first_byte)
%jump(first_byte)

// stack: i, hp_len, rlp_pos, num_nibbles, packed_nibbles, terminated, retdest
DUP5 // packed_nibbles
%and_const(0xFF)
// stack: byte_i, i, hp_len, rlp_pos, num_nibbles, packed_nibbles, terminated, retdest
DUP4 // rlp_pos
DUP3 // i
ADD // We'll write to offset rlp_pos + i
%mstore_rlp

// stack: i, hp_len, rlp_pos, num_nibbles, packed_nibbles, terminated, retdest
%decrement
SWAP4 %shr_const(8) SWAP4 // packed_nibbles >>= 8
%jump(loop)

first_byte:
// stack: 0, hp_len, rlp_pos, num_nibbles, first_nibble_or_zero, terminated, retdest
POP
// stack: hp_len, rlp_pos, num_nibbles, first_nibble_or_zero, terminated, retdest
DUP2 ADD
// stack: rlp_end_pos, rlp_pos, num_nibbles, first_nibble_or_zero, terminated, retdest
SWAP4
// stack: terminated, rlp_pos, num_nibbles, first_nibble_or_zero, rlp_end_pos, retdest
%mul_const(2)
// stack: terminated * 2, rlp_pos, num_nibbles, first_nibble_or_zero, rlp_end_pos, retdest
%stack (terminated_x2, rlp_pos, num_nibbles, first_nibble_or_zero)
-> (num_nibbles, terminated_x2, first_nibble_or_zero, rlp_pos)
// stack: num_nibbles, terminated * 2, first_nibble_or_zero, rlp_pos, rlp_end_pos, retdest
%mod_const(2) // parity
ADD
// stack: parity + terminated * 2, first_nibble_or_zero, rlp_pos, rlp_end_pos, retdest
%mul_const(16)
ADD
// stack: first_byte, rlp_pos, rlp_end_pos, retdest
SWAP1
%mstore_rlp
// stack: rlp_end_pos, retdest
SWAP1
JUMP
6 changes: 6 additions & 0 deletions evm/src/cpu/kernel/asm/rlp/encode.asm
Original file line number Diff line number Diff line change
Expand Up @@ -283,3 +283,9 @@ global mstore_unpacking_rlp:
PUSH @SEGMENT_RLP_RAW
PUSH 0 // context
%jump(mstore_unpacking)

%macro mstore_unpacking_rlp
%stack (offset, value, len) -> (offset, value, len, %%after)
%jump(mstore_unpacking_rlp)
%%after:
%endmacro
79 changes: 12 additions & 67 deletions evm/src/cpu/kernel/asm/rlp/num_bytes.asm
Original file line number Diff line number Diff line change
@@ -1,78 +1,23 @@
// Get the number of bytes required to represent the given scalar.
// Note that we define num_bytes(0) to be 1.

global num_bytes:
// stack: x, retdest
DUP1 PUSH 0 BYTE %jumpi(return_32)
DUP1 PUSH 1 BYTE %jumpi(return_31)
DUP1 PUSH 2 BYTE %jumpi(return_30)
DUP1 PUSH 3 BYTE %jumpi(return_29)
DUP1 PUSH 4 BYTE %jumpi(return_28)
DUP1 PUSH 5 BYTE %jumpi(return_27)
DUP1 PUSH 6 BYTE %jumpi(return_26)
DUP1 PUSH 7 BYTE %jumpi(return_25)
DUP1 PUSH 8 BYTE %jumpi(return_24)
DUP1 PUSH 9 BYTE %jumpi(return_23)
DUP1 PUSH 10 BYTE %jumpi(return_22)
DUP1 PUSH 11 BYTE %jumpi(return_21)
DUP1 PUSH 12 BYTE %jumpi(return_20)
DUP1 PUSH 13 BYTE %jumpi(return_19)
DUP1 PUSH 14 BYTE %jumpi(return_18)
DUP1 PUSH 15 BYTE %jumpi(return_17)
DUP1 PUSH 16 BYTE %jumpi(return_16)
DUP1 PUSH 17 BYTE %jumpi(return_15)
DUP1 PUSH 18 BYTE %jumpi(return_14)
DUP1 PUSH 19 BYTE %jumpi(return_13)
DUP1 PUSH 20 BYTE %jumpi(return_12)
DUP1 PUSH 21 BYTE %jumpi(return_11)
DUP1 PUSH 22 BYTE %jumpi(return_10)
DUP1 PUSH 23 BYTE %jumpi(return_9)
DUP1 PUSH 24 BYTE %jumpi(return_8)
DUP1 PUSH 25 BYTE %jumpi(return_7)
DUP1 PUSH 26 BYTE %jumpi(return_6)
DUP1 PUSH 27 BYTE %jumpi(return_5)
DUP1 PUSH 28 BYTE %jumpi(return_4)
DUP1 PUSH 29 BYTE %jumpi(return_3)
PUSH 30 BYTE %jumpi(return_2)
DUP1 ISZERO %jumpi(return_1)
// Non-deterministically guess the number of bits
PROVER_INPUT(num_bits)
%stack (num_bits, x) -> (num_bits, x, num_bits)
%decrement
SHR
// stack: 1, num_bits
%assert_eq_const(1)
// convert number of bits to number of bytes
%add_const(7)
%shr_const(3)

// If we got all the way here, each byte was zero, except possibly the least
// significant byte, which we didn't check. Either way, the result is 1.
// stack: retdest
PUSH 1
SWAP1
JUMP

return_2: PUSH 2 SWAP1 JUMP
return_3: POP PUSH 3 SWAP1 JUMP
return_4: POP PUSH 4 SWAP1 JUMP
return_5: POP PUSH 5 SWAP1 JUMP
return_6: POP PUSH 6 SWAP1 JUMP
return_7: POP PUSH 7 SWAP1 JUMP
return_8: POP PUSH 8 SWAP1 JUMP
return_9: POP PUSH 9 SWAP1 JUMP
return_10: POP PUSH 10 SWAP1 JUMP
return_11: POP PUSH 11 SWAP1 JUMP
return_12: POP PUSH 12 SWAP1 JUMP
return_13: POP PUSH 13 SWAP1 JUMP
return_14: POP PUSH 14 SWAP1 JUMP
return_15: POP PUSH 15 SWAP1 JUMP
return_16: POP PUSH 16 SWAP1 JUMP
return_17: POP PUSH 17 SWAP1 JUMP
return_18: POP PUSH 18 SWAP1 JUMP
return_19: POP PUSH 19 SWAP1 JUMP
return_20: POP PUSH 20 SWAP1 JUMP
return_21: POP PUSH 21 SWAP1 JUMP
return_22: POP PUSH 22 SWAP1 JUMP
return_23: POP PUSH 23 SWAP1 JUMP
return_24: POP PUSH 24 SWAP1 JUMP
return_25: POP PUSH 25 SWAP1 JUMP
return_26: POP PUSH 26 SWAP1 JUMP
return_27: POP PUSH 27 SWAP1 JUMP
return_28: POP PUSH 28 SWAP1 JUMP
return_29: POP PUSH 29 SWAP1 JUMP
return_30: POP PUSH 30 SWAP1 JUMP
return_31: POP PUSH 31 SWAP1 JUMP
return_32: POP PUSH 32 SWAP1 JUMP
return_1: POP PUSH 1 SWAP1 JUMP

// Convenience macro to call num_bytes and return where we left off.
%macro num_bytes
Expand Down
13 changes: 13 additions & 0 deletions evm/src/generation/prover_input.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ impl<F: Field> GenerationState<F> {
"account_code" => self.run_account_code(input_fn),
"bignum_modmul" => self.run_bignum_modmul(),
"withdrawal" => self.run_withdrawal(),
"num_bits" => self.run_num_bits(),
_ => Err(ProgramError::ProverInputError(InvalidFunction)),
}
}
Expand Down Expand Up @@ -221,6 +222,18 @@ impl<F: Field> GenerationState<F> {
.pop()
.ok_or(ProgramError::ProverInputError(OutOfWithdrawalData))
}

/// Return the number of bits of the top of the stack or an error if
/// the top of the stack is zero or empty.
fn run_num_bits(&mut self) -> Result<U256, ProgramError> {
let value = stack_peek(self, 0)?;
if value.is_zero() {
Err(ProgramError::ProverInputError(NumBitsError))
} else {
let num_bits = value.bits();
Ok(num_bits.into())
}
}
}

enum EvmField {
Expand Down
1 change: 1 addition & 0 deletions evm/src/witness/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,5 @@ pub enum ProverInputError {
InvalidMptInput,
InvalidInput,
InvalidFunction,
NumBitsError,
}

0 comments on commit 471ff68

Please sign in to comment.