Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[KGA-34] feat: update base fee starting from next block only #1606

Merged
merged 4 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/cairo-zero-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ jobs:
uses: actions/checkout@v4
with:
repository: kkrt-labs/ef-tests
ref: feat/update-base-fee-state
- name: Checkout local skip file
uses: actions/checkout@v4
with:
Expand Down
32 changes: 30 additions & 2 deletions cairo_zero/backend/starknet.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ from starkware.cairo.common.bool import FALSE, TRUE
from starkware.cairo.common.cairo_builtins import HashBuiltin
from starkware.cairo.common.dict_access import DictAccess
from starkware.cairo.common.uint256 import Uint256
from starkware.cairo.common.math_cmp import is_nn
from starkware.cairo.common.math_cmp import is_nn, is_le
from starkware.cairo.common.memset import memset
from starkware.starknet.common.syscalls import (
emit_event,
Expand All @@ -23,6 +23,7 @@ from kakarot.precompiles.precompiles_helpers import PrecompilesHelpers
from kakarot.constants import Constants
from kakarot.interfaces.interfaces import IERC20, IAccount
from kakarot.model import model
from utils.utils import Helpers
from kakarot.state import State
from kakarot.storages import (
Kakarot_native_token_address,
Expand Down Expand Up @@ -110,7 +111,7 @@ namespace Starknet {
let (block_number) = get_block_number();
let (block_timestamp) = get_block_timestamp();
let (coinbase) = Kakarot_coinbase.read();
let (base_fee) = Kakarot_base_fee.read();
let (base_fee) = get_base_fee();
let (block_gas_limit) = Kakarot_block_gas_limit.read();
let (prev_randao) = Kakarot_prev_randao.read();

Expand All @@ -128,6 +129,33 @@ namespace Starknet {
base_fee=base_fee,
);
}

// @notice Get the block base fee.
// @dev Implemented here, used in `library.cairo` to avoid a circular dependency issue.
// @dev If the block_number of the existing "next_block" entry is greater or equal to the current block_number,
// then we return the base fee of index 'next_block' and update the one of 'current_block' to be the same.
// Otherwise, we return the base fee of index 'current_block'.
// @return base_fee The current block base fee.
func get_base_fee{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}() -> (
base_fee: felt
) {
alloc_locals;
let (res_next_block) = Kakarot_base_fee.read('next_block');
let next_base_fee = res_next_block[0];
let next_start_block_number = res_next_block[1];
let (block_number) = get_block_number();
let is_next_fee_ready = is_le(next_start_block_number, block_number);

if (next_start_block_number != 0 and is_next_fee_ready != FALSE) {
// update current_block storage and return next_block value
Kakarot_base_fee.write('current_block', (next_base_fee, block_number));
Kakarot_base_fee.write('next_block', (0, 0));
return (next_base_fee,);
}

let (res_current_block) = Kakarot_base_fee.read('current_block');
return (res_current_block[0],);
}
}

namespace Internals {
Expand Down
27 changes: 22 additions & 5 deletions cairo_zero/kakarot/library.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from openzeppelin.security.reentrancyguard.library import ReentrancyGuard
from starkware.cairo.common.bool import FALSE, TRUE
from starkware.cairo.common.cairo_builtins import HashBuiltin, BitwiseBuiltin
from starkware.starknet.common.syscalls import get_caller_address, get_tx_info
from starkware.cairo.common.math_cmp import is_not_zero
from starkware.starknet.common.syscalls import get_caller_address, get_tx_info, get_block_number
from starkware.cairo.common.math_cmp import is_not_zero, is_le
from starkware.cairo.common.alloc import alloc
from starkware.cairo.common.uint256 import Uint256
from starkware.cairo.common.math import split_felt
Expand Down Expand Up @@ -160,21 +160,38 @@
}

// @notice Set the block base fee.
// @dev There can only be one base fee for a given block. Thus, we use an index to manage two different base fees:
// - The base fee to use for the current block (index: 'current_block')
// - The base fee to use for the next block (index: 'next_block')
// @param base_fee The new base fee.
func set_base_fee{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}(
base_fee: felt
) {
Kakarot_base_fee.write(base_fee);
alloc_locals;
let (block_number) = get_block_number();
let (res_next_block) = Kakarot_base_fee.read('next_block');
let next_base_fee = res_next_block[0];
let starting_block = res_next_block[1];
Kakarot_base_fee.write('next_block', (base_fee, block_number + 1));

let is_next_fee_ready = is_le(starting_block, block_number);
if (is_next_fee_ready == FALSE) {
return ();

Check warning on line 179 in cairo_zero/kakarot/library.cairo

View check run for this annotation

Codecov / codecov/patch

cairo_zero/kakarot/library.cairo#L179

Added line #L179 was not covered by tests
}

Kakarot_base_fee.write('current_block', (next_base_fee, starting_block));
return ();
}

// @notice Get the block base fee.
// @dev If the block_number of the existing "next_block" entry is greater or equal to the current block_number,
// then we return the base fee of index 'next_block' and update the one of 'current_block' to be the same.
// Otherwise, we return the base fee of index 'current_block'.
// @return base_fee The current block base fee.
func get_base_fee{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}() -> (
base_fee: felt
) {
let (base_fee) = Kakarot_base_fee.read();
return (base_fee,);
return Starknet.get_base_fee();
}

// @notice Set the coinbase address.
Expand Down
6 changes: 5 additions & 1 deletion cairo_zero/kakarot/storages.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,12 @@ func Kakarot_evm_to_starknet_address(evm_address: felt) -> (starknet_address: fe
func Kakarot_coinbase() -> (res: felt) {
}

// @notice The base fee set for kakarot
// @dev There can only be one base fee for a given block. Thus, we use an index to manage two different base fees:
// - The base fee to use for the current block (index: 'current_block')
// - The base fee to applicable starting next block (index: 'next_block')
@storage_var
func Kakarot_base_fee() -> (res: felt) {
func Kakarot_base_fee(index: felt) -> ((base_fee: felt, block_number: felt),) {
}

@storage_var
Expand Down
5 changes: 5 additions & 0 deletions cairo_zero/tests/src/kakarot/test_kakarot.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,11 @@ func test__set_base_fee{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_ch
return ();
}

func test__get_base_fee{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}() -> felt {
let (base_fee) = Kakarot.get_base_fee();
return base_fee;
}

func test__set_prev_randao{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}() {
tempvar prev_randao;

Expand Down
141 changes: 138 additions & 3 deletions cairo_zero/tests/src/kakarot/test_kakarot.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,133 @@ def test_should_assert_only_owner(self, cairo_run):
cairo_run("test__set_base_fee", base_fee=0xABC)

@SyscallHandler.patch("Ownable_owner", SyscallHandler.caller_address)
def test_should_set_base_fee(self, cairo_run):
base_fee = 0x100
@patch.object(SyscallHandler, "block_number", 0x100)
def test_set_base_fee_should_set_next_block_fee(self, cairo_run):
base_fee = 1
cairo_run("test__set_base_fee", base_fee=base_fee)
SyscallHandler.mock_storage.assert_any_call(
address=get_storage_var_address("Kakarot_base_fee"), value=base_fee
address=get_storage_var_address(
"Kakarot_base_fee", int.from_bytes(b"next_block", "big")
),
value=base_fee,
)
SyscallHandler.mock_storage.assert_any_call(
address=get_storage_var_address(
"Kakarot_base_fee", int.from_bytes(b"next_block", "big")
)
+ 1,
value=0x101,
)

@SyscallHandler.patch("Ownable_owner", SyscallHandler.caller_address)
@SyscallHandler.patch(
get_storage_var_address(
"Kakarot_base_fee", int.from_bytes(b"next_block", "big")
),
value=[1, 0x101],
)
@patch.object(SyscallHandler, "block_number", 0x102)
def test_set_base_fee_should_overwrite_current_block_fee_if_next_block_is_applicable(
self, cairo_run
):
# Because block_number == 102, the mocked 'next_block', available
# since block 101, should be moved to 'current_block'.
cairo_run("test__set_base_fee", base_fee=2)
SyscallHandler.mock_storage.assert_any_call(
address=get_storage_var_address(
"Kakarot_base_fee", int.from_bytes(b"current_block", "big")
),
value=1,
)
SyscallHandler.mock_storage.assert_any_call(
address=get_storage_var_address(
"Kakarot_base_fee", int.from_bytes(b"current_block", "big")
)
+ 1,
value=0x101,
)
SyscallHandler.mock_storage.assert_any_call(
address=get_storage_var_address(
"Kakarot_base_fee", int.from_bytes(b"next_block", "big")
),
value=2,
)
SyscallHandler.mock_storage.assert_any_call(
address=get_storage_var_address(
"Kakarot_base_fee", int.from_bytes(b"next_block", "big")
)
+ 1,
value=0x103,
)

@SyscallHandler.patch("Ownable_owner", SyscallHandler.caller_address)
@SyscallHandler.patch(
get_storage_var_address(
"Kakarot_base_fee", int.from_bytes(b"current_block", "big")
),
value=[1, 0x100],
)
@SyscallHandler.patch(
get_storage_var_address(
"Kakarot_base_fee", int.from_bytes(b"next_block", "big")
),
value=[2, 0x101],
)
@patch.object(SyscallHandler, "block_number", 0x100)
def test_get_base_fee_should_return_current_block_fee_if_next_block_is_not_applicable(
self, cairo_run
):
base_fee = cairo_run("test__get_base_fee")
assert base_fee == 1

@SyscallHandler.patch("Ownable_owner", SyscallHandler.caller_address)
@SyscallHandler.patch(
get_storage_var_address(
"Kakarot_base_fee", int.from_bytes(b"current_block", "big")
),
value=[1, 0x100],
)
@SyscallHandler.patch(
get_storage_var_address(
"Kakarot_base_fee", int.from_bytes(b"next_block", "big")
),
value=[2, 0x101],
)
@patch.object(SyscallHandler, "block_number", 0x101)
def test_get_base_fee_should_return_next_block_fee_if_applicable_and_update_current_block(
self, cairo_run
):
base_fee = cairo_run("test__get_base_fee")
assert base_fee == 2

# Should update current block base fee
SyscallHandler.mock_storage.assert_any_call(
address=get_storage_var_address(
"Kakarot_base_fee", int.from_bytes(b"current_block", "big")
),
value=base_fee,
)
SyscallHandler.mock_storage.assert_any_call(
address=get_storage_var_address(
"Kakarot_base_fee", int.from_bytes(b"current_block", "big")
)
+ 1,
value=0x101,
)

# Should nullify the value from next_block
SyscallHandler.mock_storage.assert_any_call(
address=get_storage_var_address(
"Kakarot_base_fee", int.from_bytes(b"next_block", "big")
),
value=0,
)
SyscallHandler.mock_storage.assert_any_call(
address=get_storage_var_address(
"Kakarot_base_fee", int.from_bytes(b"next_block", "big")
)
+ 1,
value=0,
)

class TestCoinbase:
Expand Down Expand Up @@ -676,6 +798,19 @@ def test_raise_transaction_gas_limit_too_high(self, cairo_run, tx):
)

@SyscallHandler.patch("Kakarot_block_gas_limit", TRANSACTION_GAS_LIMIT)
@SyscallHandler.patch(
get_storage_var_address(
"Kakarot_base_fee", int.from_bytes(b"next_block", "big")
),
value=TRANSACTION_GAS_LIMIT * 10**10,
)
@SyscallHandler.patch(
get_storage_var_address(
"Kakarot_base_fee", int.from_bytes(b"next_block", "big")
)
+ 1,
value=0x100,
)
@SyscallHandler.patch("Kakarot_base_fee", TRANSACTION_GAS_LIMIT * 10**10)
@SyscallHandler.patch("Kakarot_chain_id", CHAIN_ID)
@pytest.mark.parametrize("tx", TRANSACTIONS)
Expand Down
10 changes: 8 additions & 2 deletions tests/utils/syscall_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from contextlib import contextmanager
from dataclasses import dataclass
from hashlib import sha256
from typing import Optional, Union
from typing import Iterable, Optional, Union
from unittest import mock

import ecdsa
Expand Down Expand Up @@ -609,7 +609,13 @@ def patch(
selector_if_storage = get_storage_var_address(target, *args)
else:
selector_if_storage = target
cls.patches[selector_if_storage] = value

if isinstance(value, Iterable):
for i, v in enumerate(value):
cls.patches[selector_if_storage + i] = v
else:
cls.patches[selector_if_storage] = value

except AssertionError:
pass

Expand Down
Loading