Skip to content

Commit

Permalink
Merge branch 'main' into documentation-intro-NethermindEth#328
Browse files Browse the repository at this point in the history
  • Loading branch information
TAdev0 authored Jun 12, 2024
2 parents 87b0acf + 1b62872 commit 1ceafca
Show file tree
Hide file tree
Showing 22 changed files with 1,408 additions and 48 deletions.
40 changes: 39 additions & 1 deletion integration_tests/cairo_files/dict.cairo
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// inspired from the dict.cairo integration test in the lambdaclass cairo-vm codebase

from starkware.cairo.common.default_dict import default_dict_new
from starkware.cairo.common.dict import dict_read
from starkware.cairo.common.dict import dict_read, dict_write
from starkware.cairo.common.dict_access import DictAccess

func test_default_dict() {
Expand All @@ -24,9 +24,47 @@ func test_read() {
return ();
}

func test_write() {
alloc_locals;
let (local my_dict: DictAccess*) = default_dict_new(123);
let (local val1: felt) = dict_read{dict_ptr=my_dict}(key=1);
assert val1 = 123;

let (local val2: felt) = dict_read{dict_ptr=my_dict}(key=2);
assert val2 = 123;

dict_write{dict_ptr=my_dict}(key=1, new_value=512);
let (local val3: felt) = dict_read{dict_ptr=my_dict}(key=1);
assert val3 = 512;

let (local val4: felt) = dict_read{dict_ptr=my_dict}(key=2);
assert val4 = 123;

dict_write{dict_ptr=my_dict}(key=1, new_value=1024);
let (local val5: felt) = dict_read{dict_ptr=my_dict}(key=1);
assert val5 = 1024;

let (local val6: felt) = dict_read{dict_ptr=my_dict}(key=2);
assert val6 = 123;

dict_write{dict_ptr=my_dict}(key=1, new_value=888);
dict_write{dict_ptr=my_dict}(key=2, new_value=999);
let (local val7: felt) = dict_read{dict_ptr=my_dict}(key=1);
assert val7 = 888;
let (local val8: felt) = dict_read{dict_ptr=my_dict}(key=2);
assert val8 = 999;
let (local val9: felt) = dict_read{dict_ptr=my_dict}(key=3);
assert val9 = 123;

return ();
}

func main() {
test_default_dict();
test_read();
test_write();
return ();
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// The content of this file has been partially borrowed from LambdaClass Cairo VM in Rust
// See https://github.com/lambdaclass/cairo-vm/blob/aecbb3f01dacb6d3f90256c808466c2c37606252/cairo_programs/keccak_alternative_hint.cairo#L20

%builtins output range_check bitwise

from starkware.cairo.common.cairo_keccak.keccak import (
_prepare_block,
KECCAK_FULL_RATE_IN_BYTES,
KECCAK_FULL_RATE_IN_WORDS,
KECCAK_STATE_SIZE_FELTS,
)
from starkware.cairo.common.math import assert_nn_le
from starkware.cairo.common.cairo_builtins import BitwiseBuiltin
from starkware.cairo.common.alloc import alloc
from starkware.cairo.common.serialize import serialize_word

func _block_permutation_cairo_keccak{output_ptr: felt*, keccak_ptr: felt*}() {
alloc_locals;
let output = output_ptr;
let keccak_ptr_start = keccak_ptr - KECCAK_STATE_SIZE_FELTS;
%{
from starkware.cairo.common.cairo_keccak.keccak_utils import keccak_func
_keccak_state_size_felts = int(ids.KECCAK_STATE_SIZE_FELTS)
assert 0 <= _keccak_state_size_felts < 100
output_values = keccak_func(memory.get_range(
ids.keccak_ptr_start, _keccak_state_size_felts))
segments.write_arg(ids.output, output_values)
%}
let keccak_ptr = keccak_ptr + KECCAK_STATE_SIZE_FELTS;

return ();
}

func run_cairo_keccak{output_ptr: felt*, range_check_ptr, bitwise_ptr: BitwiseBuiltin*}() {
alloc_locals;

let (output: felt*) = alloc();
let keccak_output = output;

let (inputs: felt*) = alloc();
let inputs_start = inputs;
fill_array(inputs, 9, 3, 0);

let (state: felt*) = alloc();
let state_start = state;
fill_array(state, 5, 25, 0);

let n_bytes = 24;

_prepare_block{keccak_ptr=output_ptr}(inputs=inputs, n_bytes=n_bytes, state=state);
_block_permutation_cairo_keccak{keccak_ptr=output_ptr}();

local full_word: felt;
%{ ids.full_word = int(ids.n_bytes >= 8) %}
assert full_word = 1;

let n_bytes = 8;
local full_word: felt;
%{ ids.full_word = int(ids.n_bytes >= 8) %}
assert full_word = 1;

let n_bytes = 7;
local full_word: felt;
%{ ids.full_word = int(ids.n_bytes >= 8) %}
assert full_word = 0;

return ();
}

func fill_array(array: felt*, base: felt, array_length: felt, iterator: felt) {
if (iterator == array_length) {
return ();
}

assert array[iterator] = base;

return fill_array(array, base, array_length, iterator + 1);
}

func main{output_ptr: felt*, range_check_ptr, bitwise_ptr: BitwiseBuiltin*}() {
run_cairo_keccak();

return ();
}
14 changes: 14 additions & 0 deletions pkg/hintrunner/hinter/scope.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,20 @@ func (sm *ScopeManager) GetVariableValueAsBigInt(name string) (*big.Int, error)
return valueBig, nil
}

func (sm *ScopeManager) GetVariableValueAsUint64(name string) (uint64, error) {
value, err := sm.GetVariableValue(name)
if err != nil {
return 0, err
}

valueUint, ok := value.(uint64)
if !ok {
return 0, fmt.Errorf("value: %s is not a uint64", value)
}

return valueUint, nil
}

func (sm *ScopeManager) getCurrentScope() (*map[string]any, error) {
if len(sm.scopes) == 0 {
return nil, fmt.Errorf("expected at least one existing scope")
Expand Down
32 changes: 17 additions & 15 deletions pkg/hintrunner/hinter/zero_dict.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,33 +10,33 @@ import (

// Used to keep track of all dictionaries data
type ZeroDictionary struct {
// The data contained on a dictionary
data map[f.Element]mem.MemoryValue
// The Data contained on a dictionary
Data map[f.Element]mem.MemoryValue
// Default value for key not present in the dictionary
defaultValue mem.MemoryValue
DefaultValue mem.MemoryValue
// first free offset in memory segment of dictionary
freeOffset uint64
FreeOffset *uint64
}

// Gets the memory value at certain key
func (d *ZeroDictionary) At(key f.Element) (mem.MemoryValue, error) {
if value, ok := d.data[key]; ok {
if value, ok := d.Data[key]; ok {
return value, nil
}
if d.defaultValue != mem.UnknownValue {
return d.defaultValue, nil
if d.DefaultValue != mem.UnknownValue {
return d.DefaultValue, nil
}
return mem.UnknownValue, fmt.Errorf("no value for key: %v", key)
}

// Given a key and a value, it sets the value at the given key
func (d *ZeroDictionary) Set(key f.Element, value mem.MemoryValue) {
d.data[key] = value
d.Data[key] = value
}

// Given a incrementBy value, it increments the freeOffset field of dictionary by it
func (d *ZeroDictionary) IncrementFreeOffset(freeOffset uint64) {
d.freeOffset += freeOffset
*d.FreeOffset += freeOffset
}

// Used to manage dictionaries creation
Expand All @@ -56,10 +56,11 @@ func NewZeroDictionaryManager() ZeroDictionaryManager {
// to the start of this segment
func (dm *ZeroDictionaryManager) NewDictionary(vm *VM.VirtualMachine) mem.MemoryAddress {
newDictAddr := vm.Memory.AllocateEmptySegment()
freeOffset := uint64(0)
dm.dictionaries[newDictAddr.SegmentIndex] = ZeroDictionary{
data: make(map[f.Element]mem.MemoryValue),
defaultValue: mem.UnknownValue,
freeOffset: 0,
Data: make(map[f.Element]mem.MemoryValue),
DefaultValue: mem.UnknownValue,
FreeOffset: &freeOffset,
}
return newDictAddr
}
Expand All @@ -70,10 +71,11 @@ func (dm *ZeroDictionaryManager) NewDictionary(vm *VM.VirtualMachine) mem.Memory
// querying the defaultValue will be returned instead.
func (dm *ZeroDictionaryManager) NewDefaultDictionary(vm *VM.VirtualMachine, defaultValue mem.MemoryValue) mem.MemoryAddress {
newDefaultDictAddr := vm.Memory.AllocateEmptySegment()
freeOffset := uint64(0)
dm.dictionaries[newDefaultDictAddr.SegmentIndex] = ZeroDictionary{
data: make(map[f.Element]mem.MemoryValue),
defaultValue: defaultValue,
freeOffset: 0,
Data: make(map[f.Element]mem.MemoryValue),
DefaultValue: defaultValue,
FreeOffset: &freeOffset,
}
return newDefaultDictAddr
}
Expand Down
17 changes: 15 additions & 2 deletions pkg/hintrunner/zero/hintcode.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ const (

unsignedDivRemCode string = "from starkware.cairo.common.math_utils import assert_integer\nassert_integer(ids.div)\nassert 0 < ids.div <= PRIME // range_check_builtin.bound, \\\n f'div={hex(ids.div)} is out of the valid range.'\nids.q, ids.r = divmod(ids.value, ids.div)"

signedPowCode string = "assert ids.base != 0, 'Cannot raise 0 to a negative power.'"

// split_felt() hints.
splitFeltCode string = "from starkware.cairo.common.math_utils import assert_integer\nassert ids.MAX_HIGH < 2**128 and ids.MAX_LOW < 2**128\nassert PRIME - 1 == ids.MAX_HIGH * 2**128 + ids.MAX_LOW\nassert_integer(ids.value)\nids.low = ids.value & ((1 << 128) - 1)\nids.high = ids.value >> 128"

Expand Down Expand Up @@ -106,14 +108,25 @@ const (
// ------ Blake Hash hints related code ------
blake2sAddUint256BigendCode string = "B = 32\nMASK = 2 ** 32 - 1\nsegments.write_arg(ids.data, [(ids.high >> (B * (3 - i))) & MASK for i in range(4)])\nsegments.write_arg(ids.data + 4, [(ids.low >> (B * (3 - i))) & MASK for i in range(4)])"
blake2sAddUint256Code string = "B = 32\nMASK = 2 ** 32 - 1\nsegments.write_arg(ids.data, [(ids.low >> (B * i)) & MASK for i in range(4)])\nsegments.write_arg(ids.data + 4, [(ids.high >> (B * i)) & MASK for i in range(4)])"
blake2sFinalizeCode string = "from starkware.cairo.common.cairo_blake2s.blake2s_utils import IV, blake2s_compress\n\n_n_packed_instances = int(ids.N_PACKED_INSTANCES)\nassert 0 <= _n_packed_instances < 20\n_blake2s_input_chunk_size_felts = int(ids.INPUT_BLOCK_FELTS)\nassert 0 <= _blake2s_input_chunk_size_felts < 100\n\nmessage = [0] * _blake2s_input_chunk_size_felts\nmodified_iv = [IV[0] ^ 0x01010020] + IV[1:]\noutput = blake2s_compress(\n message=message,\n h=modified_iv,\n t0=0,\n t1=0,\n f0=0xffffffff,\n f1=0,\n)\npadding = (modified_iv + message + [0, 0xffffffff] + output) * (_n_packed_instances - 1)\nsegments.write_arg(ids.blake2s_ptr_end, padding)"

// ------ Keccak hints related code ------

keccakWriteArgs string = "segments.write_arg(ids.inputs, [ids.low % 2 ** 64, ids.low // 2 ** 64])\nsegments.write_arg(ids.inputs + 2, [ids.high % 2 ** 64, ids.high // 2 ** 64])"
unsafeKeccakCode string = "from eth_hash.auto import keccak\n\ndata, length = ids.data, ids.length\n\nif '__keccak_max_size' in globals():\n assert length <= __keccak_max_size, \\\n f'unsafe_keccak() can only be used with length<={__keccak_max_size}. ' \\\n f'Got: length={length}.'\n\nkeccak_input = bytearray()\nfor word_i, byte_i in enumerate(range(0, length, 16)):\n word = memory[data + word_i]\n n_bytes = min(16, length - byte_i)\n assert 0 <= word < 2 ** (8 * n_bytes)\n keccak_input += word.to_bytes(n_bytes, 'big')\n\nhashed = keccak(keccak_input)\nids.high = int.from_bytes(hashed[:16], 'big')\nids.low = int.from_bytes(hashed[16:32], 'big')"
cairoKeccakFinalizeCode string = `# Add dummy pairs of input and output.
_keccak_state_size_felts = int(ids.KECCAK_STATE_SIZE_FELTS)
_block_size = int(ids.BLOCK_SIZE)
assert 0 <= _keccak_state_size_felts < 100
assert 0 <= _block_size < 10
inp = [0] * _keccak_state_size_felts
padding = (inp + keccak_func(inp)) * _block_size
segments.write_arg(ids.keccak_ptr_end, padding)`
keccakWriteArgsCode string = "segments.write_arg(ids.inputs, [ids.low % 2 ** 64, ids.low // 2 ** 64])\nsegments.write_arg(ids.inputs + 2, [ids.high % 2 ** 64, ids.high // 2 ** 64])"
blockPermutationCode string = "from starkware.cairo.common.keccak_utils.keccak_utils import keccak_func\n_keccak_state_size_felts = int(ids.KECCAK_STATE_SIZE_FELTS)\nassert 0 <= _keccak_state_size_felts < 100\noutput_values = keccak_func(memory.get_range(\nids.keccak_ptr - _keccak_state_size_felts, _keccak_state_size_felts))\nsegments.write_arg(ids.keccak_ptr, output_values)"

// ------ Dictionaries hints related code ------
defaultDictNewCode string = "if '__dict_manager' not in globals():\n from starkware.cairo.common.dict import DictManager\n __dict_manager = DictManager()\n\nmemory[ap] = __dict_manager.new_default_dict(segments, ids.default_value)"
dictReadCode string = "dict_tracker = __dict_manager.get_tracker(ids.dict_ptr)\ndict_tracker.current_ptr += ids.DictAccess.SIZE\nids.value = dict_tracker.data[ids.key]"
dictWriteCode string = "dict_tracker = __dict_manager.get_tracker(ids.dict_ptr)\ndict_tracker.current_ptr += ids.DictAccess.SIZE\nids.dict_ptr.prev_value = dict_tracker.data[ids.key]\ndict_tracker.data[ids.key] = ids.new_value"
squashDictInnerAssertLenKeys string = "assert len(keys) == 0"
squashDictInnerContinueLoop string = "ids.loop_temps.should_continue = 1 if current_access_indices else 0"
squashDictInnerSkipLoop string = "ids.should_skip_loop = 0 if current_access_indices else 1"
Expand Down
14 changes: 13 additions & 1 deletion pkg/hintrunner/zero/zerohint.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ func GetHintFromCode(program *zero.ZeroProgram, rawHint zero.Hint, hintPC uint64
return createSignedDivRemHinter(resolver)
case powCode:
return createPowHinter(resolver)
case signedPowCode:
return createSignedPowHinter(resolver)
case splitFeltCode:
return createSplitFeltHinter(resolver)
case sqrtCode:
Expand Down Expand Up @@ -146,9 +148,17 @@ func GetHintFromCode(program *zero.ZeroProgram, rawHint zero.Hint, hintPC uint64
return createBlake2sAddUint256Hinter(resolver, true)
case blake2sAddUint256Code:
return createBlake2sAddUint256Hinter(resolver, false)
case blake2sFinalizeCode:
return createBlake2sFinalizeHinter(resolver)
// Keccak hints
case keccakWriteArgs:
case keccakWriteArgsCode:
return createKeccakWriteArgsHinter(resolver)
case cairoKeccakFinalizeCode:
return createCairoKeccakFinalizeHinter(resolver)
case unsafeKeccakCode:
return createUnsafeKeccakHinter(resolver)
case blockPermutationCode:
return createBlockPermutationHinter(resolver)
// Usort hints
case usortEnterScopeCode:
return createUsortEnterScopeHinter()
Expand All @@ -165,6 +175,8 @@ func GetHintFromCode(program *zero.ZeroProgram, rawHint zero.Hint, hintPC uint64
return createDefaultDictNewHinter(resolver)
case dictReadCode:
return createDictReadHinter(resolver)
case dictWriteCode:
return createDictWriteHinter(resolver)
case squashDictInnerAssertLenKeys:
return createSquashDictInnerAssertLenKeysHinter()
case squashDictInnerContinueLoop:
Expand Down
Loading

0 comments on commit 1ceafca

Please sign in to comment.