Skip to content

Commit

Permalink
make whitelisted checks based on caller address (instead of code
Browse files Browse the repository at this point in the history
address) and only register succeeded cairo call in message
  • Loading branch information
enitrat committed Nov 12, 2024
1 parent 07c3164 commit f33f644
Show file tree
Hide file tree
Showing 9 changed files with 44 additions and 59 deletions.
17 changes: 7 additions & 10 deletions cairo_zero/kakarot/interpreter.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -67,20 +67,11 @@ namespace Interpreter {
let is_precompile = PrecompilesHelpers.is_precompile(evm.message.code_address.evm);
if (is_precompile != FALSE) {
let parent_context = evm.message.parent;
let is_parent_zero = Helpers.is_zero(cast(parent_context, felt));
if (is_parent_zero != FALSE) {
// Case A: The precompile is called straight from an EOA
tempvar caller_code_address = evm.message.caller;
} else {
// Case B: The precompile is called from a contract
tempvar caller_code_address = parent_context.evm.message.code_address.evm;
}
tempvar caller_address = evm.message.caller;
let (output_len, output, gas_used, revert_code) = Precompiles.exec_precompile(
evm.message.code_address.evm,
evm.message.calldata_len,
evm.message.calldata,
caller_code_address,
caller_address,
evm.message.address.evm,
);
Expand All @@ -101,9 +92,15 @@ namespace Interpreter {
}
let range_check_ptr = [ap - 2];
let evm = cast([ap - 1], model.EVM*);

// Only count the cairo precompile if it was executed and did not revert.
// If it did revert, we're ensured no state changes were made in the cairo call.
let is_cairo_precompile_called = PrecompilesHelpers.is_kakarot_precompile(
evm.message.code_address.evm
);
let is_cairo_precompile_executed = is_cairo_precompile_called * (
1 - precompile_reverted
);
tempvar message = new model.Message(
bytecode=evm.message.bytecode,
bytecode_len=evm.message.bytecode_len,
Expand All @@ -120,7 +117,7 @@ namespace Interpreter {
is_create=evm.message.is_create,
depth=evm.message.depth,
env=evm.message.env,
cairo_precompile_called=is_cairo_precompile_called,
cairo_precompile_called=is_cairo_precompile_executed,
);
tempvar evm = new model.EVM(
message=message,
Expand Down
4 changes: 1 addition & 3 deletions cairo_zero/kakarot/precompiles/precompiles.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ namespace Precompiles {
// @param precompile_address The precompile evm_address.
// @param input_len The length of the input array.
// @param input The input array.
// @param caller_code_address The address of the code of the contract that calls the precompile.
// @param caller_address The address of the caller of the precompile. Delegatecall rules apply.
// @param message_address The address being executed in the current message.
// @return output_len The output length.
Expand All @@ -48,7 +47,6 @@ namespace Precompiles {
precompile_address: felt,
input_len: felt,
input: felt*,
caller_code_address: felt,
caller_address: felt,
message_address: felt,
) -> (output_len: felt, output: felt*, gas_used: felt, reverted: felt) {
Expand Down Expand Up @@ -135,7 +133,7 @@ namespace Precompiles {

kakarot_precompile:
let is_call_authorized_ = PrecompilesHelpers.is_call_authorized(
precompile_address, caller_code_address, caller_address, message_address
precompile_address, caller_address, message_address
);
tempvar is_not_authorized = 1 - is_call_authorized_;
tempvar syscall_ptr = syscall_ptr;
Expand Down
10 changes: 3 additions & 7 deletions cairo_zero/kakarot/precompiles/precompiles_helpers.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -75,25 +75,21 @@ namespace PrecompilesHelpers {

// @notice Returns whether the call to the precompile is authorized.
// @dev A call is authorized if:
// a. The precompile requires a whitelist AND the CODE_ADDRESS of the caller is whitelisted
// a. The precompile requires a whitelist AND the ADDRESS of the caller is whitelisted
// b. The precompile is CAIRO_MULTICALL_PRECOMPILE and the precompile address is the same as the message address (NOT a DELEGATECALL / CALLCODE).
// @param precompile_address The address of the precompile.
// @param caller_code_address The code_address of the precompile caller.
// @param caller_address The address of the caller.
// @param message_address The address being executed in the current message.
// @return Whether the call is authorized.
func is_call_authorized{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}(
precompile_address: felt,
caller_code_address: felt,
caller_address: felt,
message_address: felt,
precompile_address: felt, caller_address: felt, message_address: felt
) -> felt {
alloc_locals;
let precompile_requires_whitelist = requires_whitelist(precompile_address);

// Ensure that calls to precompiles that require a whitelist are properly authorized.
if (precompile_requires_whitelist == TRUE) {
let is_whitelisted = is_caller_whitelisted(caller_code_address);
let is_whitelisted = is_caller_whitelisted(caller_address);
tempvar syscall_ptr = syscall_ptr;
tempvar pedersen_ptr = pedersen_ptr;
tempvar range_check_ptr = range_check_ptr;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,13 @@ func test__precompiles_run{
// Given
local address;
local input_len;
local caller_code_address;
local caller_address;
local message_address;
let (local input) = alloc();
%{
ids.address = program_input["address"]
ids.input_len = len(program_input["input"])
segments.write_arg(ids.input, program_input["input"])
ids.caller_code_address = program_input.get("caller_code_address", 0)
ids.caller_address = program_input.get("caller_address", 0)
ids.message_address = program_input.get("message_address", 0)
%}
Expand All @@ -43,7 +41,6 @@ func test__precompiles_run{
precompile_address=address,
input_len=input_len,
input=input,
caller_code_address=caller_code_address,
caller_address=caller_address,
message_address=message_address,
);
Expand Down
44 changes: 21 additions & 23 deletions cairo_zero/tests/src/kakarot/precompiles/test_precompiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
)
from tests.utils.syscall_handler import SyscallHandler

AUTHORIZED_CALLER_CODE = 0xA7071ED
UNAUTHORIZED_CALLER_CODE = 0xC0C0C0
AUTHORIZED_CALLER_ADDRESS = 0xA7071ED
UNAUTHORIZED_CALLER_ADDRESS = 0xC0C0C0
CALLER_ADDRESS = 0x123ABC432


Expand Down Expand Up @@ -80,7 +80,7 @@ def test__p256_verify_precompile(
class TestKakarotPrecompiles:
@SyscallHandler.patch(
"Kakarot_authorized_cairo_precompiles_callers",
AUTHORIZED_CALLER_CODE,
AUTHORIZED_CALLER_ADDRESS,
1,
)
@SyscallHandler.patch("deploy", lambda *_: [0])
Expand All @@ -103,8 +103,7 @@ def test_should_deploy_account_when_sender_starknet_address_zero(
+ f"{0x60:064x}" # data_offset
+ f"{0x00:064x}" # data_len
),
caller_code_address=AUTHORIZED_CALLER_CODE,
caller_address=CALLER_ADDRESS,
caller_address=AUTHORIZED_CALLER_ADDRESS,
message_address=0x75001,
)
assert not bool(reverted)
Expand All @@ -116,25 +115,26 @@ def test_should_deploy_account_when_sender_starknet_address_zero(

@SyscallHandler.patch(
"Kakarot_authorized_cairo_precompiles_callers",
AUTHORIZED_CALLER_CODE,
AUTHORIZED_CALLER_ADDRESS,
1,
)
@SyscallHandler.patch(
"Kakarot_evm_to_starknet_address", CALLER_ADDRESS, 0x1234
)
@SyscallHandler.patch_deploy(lambda class_hash, data: [0])
@pytest.mark.parametrize(
"address, caller_code_address, input_data, expected_return_data, expected_reverted",
"address, caller_address, input_data, expected_return_data, expected_reverted",
[
(
0x75001,
AUTHORIZED_CALLER_CODE,
AUTHORIZED_CALLER_ADDRESS,
bytes.fromhex("0abcdef0"),
b"Kakarot: OutOfBoundsRead",
True,
), # invalid input
(
0x75001,
AUTHORIZED_CALLER_CODE,
AUTHORIZED_CALLER_ADDRESS,
bytes.fromhex(
f"{0xc0de:064x}"
+ f"{get_selector_from_name('inc'):064x}"
Expand All @@ -146,7 +146,7 @@ def test_should_deploy_account_when_sender_starknet_address_zero(
), # call_contract
(
0x75001,
AUTHORIZED_CALLER_CODE,
AUTHORIZED_CALLER_ADDRESS,
bytes.fromhex(
f"{0xc0de:064x}"
+ f"{get_selector_from_name('get'):064x}"
Expand All @@ -159,7 +159,7 @@ def test_should_deploy_account_when_sender_starknet_address_zero(
), # call_contract with return data
(
0x75001,
UNAUTHORIZED_CALLER_CODE,
UNAUTHORIZED_CALLER_ADDRESS,
bytes.fromhex("0abcdef0"),
b"Kakarot: unauthorizedPrecompile",
True,
Expand All @@ -176,7 +176,7 @@ def test__cairo_precompiles(
self,
cairo_run,
address,
caller_code_address,
caller_address,
input_data,
expected_return_data,
expected_reverted,
Expand All @@ -197,35 +197,34 @@ def test__cairo_precompiles(
"test__precompiles_run",
address=address,
input=input_data,
caller_code_address=caller_code_address,
caller_address=CALLER_ADDRESS,
caller_address=caller_address,
message_address=address,
)
assert bool(reverted) == expected_reverted
assert bytes(return_data) == expected_return_data
assert gas_used == (
CAIRO_PRECOMPILE_GAS
if caller_code_address == AUTHORIZED_CALLER_CODE
if caller_address == AUTHORIZED_CALLER_ADDRESS
else 0
)
return

class TestKakarotMessaging:
@SyscallHandler.patch(
"Kakarot_authorized_cairo_precompiles_callers",
AUTHORIZED_CALLER_CODE,
AUTHORIZED_CALLER_ADDRESS,
1,
)
@SyscallHandler.patch(
"Kakarot_l1_messaging_contract_address",
0xC0DE,
)
@pytest.mark.parametrize(
"address, caller_code_address, input_data, to_address, expected_reverted_return_data, expected_reverted",
"address, caller_address, input_data, to_address, expected_reverted_return_data, expected_reverted",
[
(
0x75002,
AUTHORIZED_CALLER_CODE,
AUTHORIZED_CALLER_ADDRESS,
encode(
["uint160", "bytes"], [0xC0DE, encode(["uint128"], [0x2A])]
),
Expand All @@ -235,15 +234,15 @@ class TestKakarotMessaging:
),
(
0x75002,
AUTHORIZED_CALLER_CODE,
AUTHORIZED_CALLER_ADDRESS,
encode(["uint160", "bytes"], [0xC0DE, 0x2A.to_bytes(1, "big")]),
0xC0DE,
b"",
False,
),
(
0x75002,
UNAUTHORIZED_CALLER_CODE,
UNAUTHORIZED_CALLER_ADDRESS,
bytes.fromhex("0abcdef0"),
0xC0DE,
b"Kakarot: unauthorizedPrecompile",
Expand All @@ -258,7 +257,7 @@ class TestKakarotMessaging:
)
def test__cairo_message(
self,
caller_code_address,
caller_address,
cairo_run,
address,
input_data,
Expand All @@ -271,8 +270,7 @@ def test__cairo_message(
"test__precompiles_run",
address=address,
input=input_data,
caller_code_address=caller_code_address,
caller_address=CALLER_ADDRESS,
caller_address=caller_address,
message_address=address,
)
if expected_reverted:
Expand Down
19 changes: 9 additions & 10 deletions solidity_contracts/src/CairoPrecompiles/DualVmToken.sol
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@ pragma solidity 0.8.27;

import {WhitelistedCallCairoLib} from "./WhitelistedCallCairoLib.sol";
import {CairoLib} from "kakarot-lib/CairoLib.sol";
import {NoDelegateCall} from "../Security/NoDelegateCall.sol";

/// @notice EVM adapter into a Cairo ERC20 token
/// @dev This implementation is highly experimental
/// It relies on CairoLib to perform Cairo precompile calls
/// Events are emitted in this contract but also in the Starknet token contract
/// @dev External functions are noDelegateCall to prevent a user making an EVM call to a malicious contract,
/// @dev External functions are to prevent a user making an EVM call to a malicious contract,
/// with any calldata, that would be able to directly control on their behalf any quantity of any one of the ERC20
/// tokens held by the victim's account contract, with the sole condition that the ERC20 has an
/// authorized DualVmToken wrapper.
Expand Down Expand Up @@ -209,7 +208,7 @@ contract DualVmToken is NoDelegateCall {
//////////////////////////////////////////////////////////////*/

/// @dev Approve an evm account spender for a specific amount
function approve(address spender, uint256 amount) external noDelegateCall returns (bool) {
function approve(address spender, uint256 amount) external returns (bool) {
uint256[] memory spenderAddressCalldata = new uint256[](1);
spenderAddressCalldata[0] = uint256(uint160(spender));
uint256 spenderStarknetAddress =
Expand All @@ -225,7 +224,7 @@ contract DualVmToken is NoDelegateCall {
/// @param spender The starknet address to approve
/// @param amount The amount of tokens to approve
/// @return True if the approval was successful
function approve(uint256 spender, uint256 amount) external noDelegateCall returns (bool) {
function approve(uint256 spender, uint256 amount) external returns (bool) {
_approve(spender, amount);
emit Approval(msg.sender, spender, amount);
return true;
Expand All @@ -250,7 +249,7 @@ contract DualVmToken is NoDelegateCall {
/// @param to The evm address to transfer the tokens to
/// @param amount The amount of tokens to transfer
/// @return True if the transfer was successful
function transfer(address to, uint256 amount) external noDelegateCall returns (bool) {
function transfer(address to, uint256 amount) external returns (bool) {
uint256[] memory toAddressCalldata = new uint256[](1);
toAddressCalldata[0] = uint256(uint160(to));
uint256 toStarknetAddress =
Expand All @@ -265,7 +264,7 @@ contract DualVmToken is NoDelegateCall {
/// @param to The starknet address to transfer the tokens to
/// @param amount The amount of tokens to transfer
/// @return True if the transfer was successful
function transfer(uint256 to, uint256 amount) external noDelegateCall returns (bool) {
function transfer(uint256 to, uint256 amount) external returns (bool) {
_transfer(to, amount);
emit Transfer(msg.sender, to, amount);
return true;
Expand All @@ -292,7 +291,7 @@ contract DualVmToken is NoDelegateCall {
/// @param to The evm address to transfer the tokens to
/// @param amount The amount of tokens to transfer
/// @return True if the transfer was successful
function transferFrom(address from, address to, uint256 amount) external noDelegateCall returns (bool) {
function transferFrom(address from, address to, uint256 amount) external returns (bool) {
uint256[] memory toAddressCalldata = new uint256[](1);
toAddressCalldata[0] = uint256(uint160(to));
uint256 toStarknetAddress =
Expand All @@ -314,7 +313,7 @@ contract DualVmToken is NoDelegateCall {
/// @param to The evm address to transfer the tokens to
/// @param amount The amount of tokens to transfer
/// @return True if the transfer was successful
function transferFrom(uint256 from, address to, uint256 amount) external noDelegateCall returns (bool) {
function transferFrom(uint256 from, address to, uint256 amount) external returns (bool) {
uint256[] memory toAddressCalldata = new uint256[](1);
toAddressCalldata[0] = uint256(uint160(to));
uint256 toStarknetAddress =
Expand All @@ -331,7 +330,7 @@ contract DualVmToken is NoDelegateCall {
/// @param to The starknet address to transfer the tokens to
/// @param amount The amount of tokens to transfer
/// @return True if the transfer was successful
function transferFrom(address from, uint256 to, uint256 amount) external noDelegateCall returns (bool) {
function transferFrom(address from, uint256 to, uint256 amount) external returns (bool) {
uint256[] memory fromAddressCalldata = new uint256[](1);
fromAddressCalldata[0] = uint256(uint160(from));
uint256 fromStarknetAddress =
Expand All @@ -348,7 +347,7 @@ contract DualVmToken is NoDelegateCall {
/// @param to The starknet address to transfer the tokens to
/// @param amount The amount of tokens to transfer
/// @return True if the transfer was successful
function transferFrom(uint256 from, uint256 to, uint256 amount) external noDelegateCall returns (bool) {
function transferFrom(uint256 from, uint256 to, uint256 amount) external returns (bool) {
_transferFrom(from, to, amount);
emit Transfer(from, to, amount);
return true;
Expand Down
2 changes: 1 addition & 1 deletion solidity_contracts/src/Security/DualVmTokenHack.sol
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ contract DualVmTokenHack {
}

function tryApproveEvm() external returns (bool success) {
(success,) = target.delegatecall(
(success,) = target.delegatecall{gas: 30000}(
abi.encodeWithSelector(bytes4(keccak256("approve(address,uint256)")), address(this), AMOUNT)
);
}
Expand Down
2 changes: 1 addition & 1 deletion tests/end_to_end/Security/test_dual_vm_token_hack.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class TestActions:
async def test_malicious_approve_address_should_fail_nodelegatecaltest_malicious_approve_address_should_fail_nodelegatecall(
self, dual_vm_token, hack_vm_token, owner
):
result = await hack_vm_token.functions["tryApproveEvm()"]()
result = await hack_vm_token.functions["tryApproveEvm()"](gas_limit=1000000)
call_succeeded = int.from_bytes(bytes(result["response"]), "big")
assert call_succeeded == 0

Expand Down
Loading

0 comments on commit f33f644

Please sign in to comment.