Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add functionalities for main function inputs #681

Draft
wants to merge 21 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
d7b74e6
Fixes for the generation of entry code, fixes of hints parsing
MaksymMalicki Nov 15, 2024
dfd9ec8
Add modifications to the runner
MaksymMalicki Nov 21, 2024
a7814a1
Merge branch 'main' into add_code_entry
MaksymMalicki Dec 2, 2024
83872d8
Add fixes for the entrycode generation
MaksymMalicki Dec 3, 2024
cd04099
Refactor main CLI, offset the hints indexes by entry code size, load …
MaksymMalicki Dec 10, 2024
5eb27ab
Add available gas and user args (#677)
MaksymMalicki Dec 25, 2024
8812042
Fix the integration tests
MaksymMalicki Dec 26, 2024
78cf788
Fixes in the runner
MaksymMalicki Dec 27, 2024
af47a3d
Fixes in the runner
MaksymMalicki Dec 28, 2024
73f6dd5
Fix the unit tests, uncomment pythonVm execution in integration tests…
MaksymMalicki Dec 28, 2024
17b17df
Add writing tokens gas cost to memory
MaksymMalicki Jan 7, 2025
688b3d2
Proper builtins initialization for cairo mode
MaksymMalicki Jan 8, 2025
f105506
Address comments in the PR
MaksymMalicki Jan 9, 2025
cc82938
Fix bugs regarding dicts
MaksymMalicki Jan 10, 2025
a32add7
Remove prints
MaksymMalicki Jan 10, 2025
07f4016
Merge branch 'main' into missing_dict_functionalities
MaksymMalicki Jan 10, 2025
53d4190
Fixes of the last tests for the dicts
MaksymMalicki Jan 11, 2025
31aeeed
Add dict_non_squashed dir to the integration tests
MaksymMalicki Jan 11, 2025
086011a
Add checks for the matching args size, rename files, modify the integ…
MaksymMalicki Jan 11, 2025
fc09b66
Almost all pass
MaksymMalicki Jan 13, 2025
72b3c14
Fix lint and unit tests
MaksymMalicki Jan 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 5 additions & 14 deletions cmd/cli/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"github.com/NethermindEth/cairo-vm-go/pkg/parsers/starknet"
zero "github.com/NethermindEth/cairo-vm-go/pkg/parsers/zero"
"github.com/NethermindEth/cairo-vm-go/pkg/runner"
"github.com/consensys/gnark-crypto/ecc/stark-curve/fp"
"github.com/urfave/cli/v2"
)

Expand Down Expand Up @@ -217,26 +216,18 @@ func main() {
if err != nil {
return fmt.Errorf("cannot load program: %w", err)
}
program, hints, err := runner.AssembleProgram(cairoProgram)
userArgs, err := starknet.ParseCairoProgramArgs(args)
if err != nil {
return fmt.Errorf("cannot parse args: %w", err)
}
program, hints, userArgs, err := runner.AssembleProgram(cairoProgram, userArgs, availableGas)
if err != nil {
return fmt.Errorf("cannot assemble program: %w", err)
}
runnerMode := runner.ExecutionModeCairo
if proofmode {
runnerMode = runner.ProofModeCairo
}
userArgs, err := starknet.ParseCairoProgramArgs(args)
if err != nil {
return fmt.Errorf("cannot parse args: %w", err)
}
if availableGas > 0 {
// The first argument is the available gas
availableGasArg := starknet.CairoFuncArgs{
Single: new(fp.Element).SetUint64(availableGas),
Array: nil,
}
userArgs = append([]starknet.CairoFuncArgs{availableGasArg}, userArgs...)
}
return runVM(program, proofmode, maxsteps, entrypointOffset, collectTrace, traceLocation, buildMemory, memoryLocation, layoutName, airPublicInputLocation, airPrivateInputLocation, hints, runnerMode, userArgs)
},
},
Expand Down
2 changes: 1 addition & 1 deletion integration_tests/.env
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# Set to run some specific file tests (ex. fib.cairo,alloc.cairo)
INTEGRATION_TESTS_FILTERS=
INTEGRATION_TESTS_FILTERS=poseidon_pedersen__starknet.cairo
51 changes: 32 additions & 19 deletions integration_tests/cairo_vm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func (f *Filter) filtered(testFile string) bool {
return false
}

func runAndTestFile(t *testing.T, path string, name string, benchmarkMap map[string][3]int, benchmark bool, errorExpected bool, zero bool) {
func runAndTestFile(t *testing.T, path string, name string, benchmarkMap map[string][3]int, benchmark bool, errorExpected bool, zero bool, inputArgs string) {
t.Logf("testing: %s\n", path)
compiledOutput, err := compileCairoCode(path, zero)
if err != nil {
Expand All @@ -75,7 +75,7 @@ func runAndTestFile(t *testing.T, path string, name string, benchmarkMap map[str
}
layout := getLayoutFromFileName(path)

elapsedGo, traceFile, memoryFile, _, err := runVm(compiledOutput, layout, zero)
elapsedGo, traceFile, memoryFile, _, err := runVm(compiledOutput, layout, zero, inputArgs)
if errorExpected {
assert.Error(t, err, path)
writeToFile(path)
Expand All @@ -92,7 +92,7 @@ func runAndTestFile(t *testing.T, path string, name string, benchmarkMap map[str
if zero {
rustVmFilePath = compiledOutput
}
elapsedRs, rsTraceFile, rsMemoryFile, err := runRustVm(rustVmFilePath, layout, zero)
elapsedRs, rsTraceFile, rsMemoryFile, err := runRustVm(rustVmFilePath, layout, zero, inputArgs)
if errorExpected {
// we let the code go on so that we can check if the go vm also raises an error
assert.Error(t, err, path)
Expand Down Expand Up @@ -171,15 +171,27 @@ func TestCairoFiles(t *testing.T) {
if err != nil {
panic(fmt.Errorf("failed to open file: %w", err))
}

file.Close()
roots := []struct {
type TestCase struct {
path string
zero bool
}{
{"./cairo_zero_hint_tests/", true},
{"./cairo_zero_file_tests/", true},
{"./builtin_tests/", true},
// {"./cairo_1_programs/", false},
}
roots := []TestCase{
// {"./cairo_zero_hint_tests/", true},
// {"./cairo_zero_file_tests/", true},
// {"./builtin_tests/", true},
{"./cairo_1_programs/", false},
{"./cairo_1_programs/dict_non_squashed", false},
{"./cairo_1_programs/with_input", false},
}

inputArgsMap := map[string]string{
"cairo_1_programs/with_input/array_input_sum__small.cairo": "2 [111 222 333] 1 [444 555 666 777]",
"cairo_1_programs/with_input/array_length__small.cairo": "[1 2 3 4 5 6] [7 8 9 10]",
"cairo_1_programs/with_input/branching.cairo": "123",
"cairo_1_programs/with_input/dict_with_input__small.cairo": "[1 2 3 4]",
"cairo_1_programs/with_input/tensor__small.cairo": "[1 4] [1 5]",
}

// filter is for debugging purposes
Expand Down Expand Up @@ -210,22 +222,19 @@ func TestCairoFiles(t *testing.T) {
if !filter.filtered(name) {
continue
}

inputArgs := inputArgsMap[path]
// we run tests concurrently if we don't need benchmarks
if !*zerobench {
sem <- struct{}{} // acquire a semaphore slot
wg.Add(1)

go func(path, name string, root struct {
path string
zero bool
}) {
go func(path, name string, root TestCase, inputArgs string) {
defer wg.Done()
defer func() { <-sem }() // release the semaphore slot when done
runAndTestFile(t, path, name, benchmarkMap, *zerobench, errorExpected, root.zero)
}(path, name, root)
runAndTestFile(t, path, name, benchmarkMap, *zerobench, errorExpected, root.zero, inputArgs)
}(path, name, root, inputArgs)
} else {
runAndTestFile(t, path, name, benchmarkMap, *zerobench, errorExpected, root.zero)
runAndTestFile(t, path, name, benchmarkMap, *zerobench, errorExpected, root.zero, inputArgs)
}
}
}
Expand Down Expand Up @@ -399,7 +408,7 @@ func runPythonVm(path, layout string) (time.Duration, string, string, error) {

// given a path to a compiled cairo zero file, execute it using the
// rust vm and return the trace and memory files location
func runRustVm(path, layout string, zero bool) (time.Duration, string, string, error) {
func runRustVm(path, layout string, zero bool, inputArgs string) (time.Duration, string, string, error) {
traceOutput := swapExtenstion(path, rsTraceSuffix)
memoryOutput := swapExtenstion(path, rsMemorySuffix)

Expand All @@ -411,6 +420,8 @@ func runRustVm(path, layout string, zero bool) (time.Duration, string, string, e
memoryOutput,
"--layout",
layout,
"--args",
inputArgs,
}

if zero {
Expand Down Expand Up @@ -440,7 +451,7 @@ func runRustVm(path, layout string, zero bool) (time.Duration, string, string, e

// given a path to a compiled cairo zero file, execute
// it using our vm
func runVm(path, layout string, zero bool) (time.Duration, string, string, string, error) {
func runVm(path, layout string, zero bool, inputArgs string) (time.Duration, string, string, string, error) {
traceOutput := swapExtenstion(path, traceSuffix)
memoryOutput := swapExtenstion(path, memorySuffix)

Expand Down Expand Up @@ -470,6 +481,8 @@ func runVm(path, layout string, zero bool) (time.Duration, string, string, strin
layout,
"--available_gas",
"9999999",
"--args",
inputArgs,
}
}
args = append(args, path)
Expand Down
56 changes: 42 additions & 14 deletions pkg/hintrunner/core/hint.go
Original file line number Diff line number Diff line change
Expand Up @@ -1143,7 +1143,8 @@ func (hint *Felt252DictEntryInit) Execute(vm *VM.VirtualMachine, ctx *hinter.Hin

prevValue, err := ctx.DictionaryManager.At(dictPtr, key)
if err != nil {
return fmt.Errorf("get dictionary entry: %w", err)
mv := mem.MemoryValueFromFieldElement(&utils.FeltZero)
prevValue = &mv
}
if prevValue == nil {
mv := mem.EmptyMemoryValueAsFelt()
Expand Down Expand Up @@ -1237,7 +1238,8 @@ func (hint *InitSquashData) Execute(vm *VM.VirtualMachine, ctx *hinter.HintRunne
// todo(rodro): Don't know if it could be called multiple times, or
err := hinter.InitializeSquashedDictionaryManager(ctx)
if err != nil {
return err
ctx.SquashedDictionaryManager = hinter.SquashedDictionaryManager{}
_ = hinter.InitializeSquashedDictionaryManager(ctx)
}

dictAccessPtr, err := hinter.ResolveAsAddress(vm, hint.DictAccesses)
Expand Down Expand Up @@ -1272,7 +1274,7 @@ func (hint *InitSquashData) Execute(vm *VM.VirtualMachine, ctx *hinter.HintRunne

// sort the keys in descending order
sort.Slice(ctx.SquashedDictionaryManager.Keys, func(i, j int) bool {
return ctx.SquashedDictionaryManager.Keys[i].Cmp(&ctx.SquashedDictionaryManager.Keys[j]) < 0
return ctx.SquashedDictionaryManager.Keys[i].Cmp(&ctx.SquashedDictionaryManager.Keys[j]) > 0
})

// if the first key is bigger than 2^128, signal it
Expand Down Expand Up @@ -1401,11 +1403,15 @@ func (hint *ShouldContinueSquashLoop) Execute(vm *VM.VirtualMachine, ctx *hinter
}

var shouldContinueLoop f.Element
if lastIndices, err := ctx.SquashedDictionaryManager.LastIndices(); err == nil && len(lastIndices) <= 1 {
shouldContinueLoop.SetOne()
} else if err != nil {
lastIndices, err := ctx.SquashedDictionaryManager.LastIndices()
if err != nil {
return fmt.Errorf("get last indices: %w", err)
}
if len(lastIndices) > 1 {
shouldContinueLoop.SetOne()
} else {
shouldContinueLoop.SetZero()
}

mv := mem.MemoryValueFromFieldElement(&shouldContinueLoop)
return vm.Memory.WriteToAddress(&shouldContinuePtr, &mv)
Expand All @@ -1425,11 +1431,15 @@ func (hint *GetNextDictKey) Execute(vm *VM.VirtualMachine, ctx *hinter.HintRunne
return fmt.Errorf("get next key address: %w", err)
}

nextKey, err := ctx.SquashedDictionaryManager.PopKey()
_, err = ctx.SquashedDictionaryManager.PopKey()
if err != nil {
return fmt.Errorf("pop key: %w", err)
}

nextKey, err := ctx.SquashedDictionaryManager.LastKey()
if err != nil {
return fmt.Errorf("get last key: %w", err)
}
mv := mem.MemoryValueFromFieldElement(&nextKey)
return vm.Memory.WriteToAddress(&nextKeyAddr, &mv)
}
Expand Down Expand Up @@ -1928,29 +1938,47 @@ func (hint *ExternalWriteArgsToMemory) String() string {
}

func (hint *ExternalWriteArgsToMemory) Execute(vm *VM.VirtualMachine, ctx *hinter.HintRunnerContext) error {
userArgsVar, err := ctx.ScopeManager.GetVariableValue("userArgs")
userArgs, err := hinter.GetVariableAs[[]starknet.CairoFuncArgs](&ctx.ScopeManager, "userArgs")
if err != nil {
return fmt.Errorf("get user args: %v", err)
}
userArgs, ok := userArgsVar.([]starknet.CairoFuncArgs)
if !ok {
return fmt.Errorf("expected user args to be a list of CairoFuncArgs")
apOffset, err := hinter.GetVariableAs[uint64](&ctx.ScopeManager, "apOffset")
if err != nil {
return fmt.Errorf("get ap offset: %v", err)
}
fmt.Println("apOffset", apOffset, "ap.offset", vm.Context.Ap)
apOffset += vm.Context.Ap
for _, arg := range userArgs {
if arg.Single != nil {
mv := mem.MemoryValueFromFieldElement(arg.Single)
err := vm.Memory.Write(1, vm.Context.Ap, &mv)
err := vm.Memory.Write(1, apOffset, &mv)
if err != nil {
return fmt.Errorf("write single arg: %v", err)
}
apOffset++
} else if arg.Array != nil {
arrayBase := vm.Memory.AllocateEmptySegment()
mv := mem.MemoryValueFromMemoryAddress(&arrayBase)
err := vm.Memory.Write(1, vm.Context.Ap, &mv)
err := vm.Memory.Write(1, apOffset, &mv)
if err != nil {
return fmt.Errorf("write array base: %v", err)
}
// TODO: Implement array writing
apOffset++
arrayEnd := arrayBase
for _, val := range arg.Array {
arrayEnd.Offset += 1
mv := mem.MemoryValueFromFieldElement(&val)
err := vm.Memory.Write(arrayEnd.SegmentIndex, arrayEnd.Offset, &mv)
if err != nil {
return fmt.Errorf("write array element: %v", err)
}
}
mv = mem.MemoryValueFromMemoryAddress(&arrayEnd)
err = vm.Memory.Write(1, apOffset, &mv)
if err != nil {
return fmt.Errorf("write array end: %v", err)
}
apOffset++
}
}
return nil
Expand Down
7 changes: 5 additions & 2 deletions pkg/hintrunner/hintrunner.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@ type HintRunner struct {
hints map[uint64][]h.Hinter
}

func NewHintRunner(hints map[uint64][]h.Hinter, userArgs []starknet.CairoFuncArgs) HintRunner {
func NewHintRunner(hints map[uint64][]h.Hinter, userArgs []starknet.CairoFuncArgs, writeApOffset uint64) HintRunner {
context := *h.InitializeDefaultContext()
if userArgs != nil {
err := context.ScopeManager.AssignVariable("userArgs", userArgs)
err := context.ScopeManager.AssignVariables(map[string]any{
"userArgs": userArgs,
"apOffset": writeApOffset,
})
// Error handling: this condition should never be true, since the context was initialized above
if err != nil {
panic(fmt.Errorf("assign userArgs: %v", err))
Expand Down
4 changes: 2 additions & 2 deletions pkg/hintrunner/hintrunner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func TestExistingHint(t *testing.T) {

hr := NewHintRunner(map[uint64][]hinter.Hinter{
10: {&allocHint},
}, nil)
}, nil, 0)

vm.Context.Pc = memory.MemoryAddress{
SegmentIndex: 0,
Expand All @@ -44,7 +44,7 @@ func TestNoHint(t *testing.T) {

hr := NewHintRunner(map[uint64][]hinter.Hinter{
10: {&allocHint},
}, nil)
}, nil, 0)

vm.Context.Pc = memory.MemoryAddress{
SegmentIndex: 0,
Expand Down
2 changes: 1 addition & 1 deletion pkg/runner/gas.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func gasInitialization(memory *mem.Memory) error {
return err
}

preCostTokenTypes := []TokenGasCost{PedersenToken, BitwiseToken, EcOpToken, PoseidonToken, AddModToken, MulModToken}
preCostTokenTypes := []TokenGasCost{PedersenToken, PoseidonToken, BitwiseToken, EcOpToken, AddModToken, MulModToken}

for _, token := range preCostTokenTypes {
cost, err := getTokenGasCost(token)
Expand Down
Loading
Loading