From 9d982609d098e3701eec31f6d1cfe9b73e2e6f90 Mon Sep 17 00:00:00 2001 From: Harikrishnan Shaji Date: Wed, 12 Jun 2024 18:35:21 +0530 Subject: [PATCH] Implement SetAdd hint (#402) * Implement SetAdd * Fix incorrect method usage * Fix incorrect type used in scope * Add more tests * Fix names * More efficient resolution of operand * Temporarily comment SetAdd integration test * Fix tests --- integration_tests/cairo_files/set_add.cairo | 35 ++++ pkg/hintrunner/zero/hintcode.go | 1 + pkg/hintrunner/zero/zerohint.go | 2 + pkg/hintrunner/zero/zerohint_others.go | 122 ++++++++++++++ pkg/hintrunner/zero/zerohint_others_test.go | 168 +++++++++++++++++++- 5 files changed, 327 insertions(+), 1 deletion(-) create mode 100644 integration_tests/cairo_files/set_add.cairo diff --git a/integration_tests/cairo_files/set_add.cairo b/integration_tests/cairo_files/set_add.cairo new file mode 100644 index 000000000..4a2d3682f --- /dev/null +++ b/integration_tests/cairo_files/set_add.cairo @@ -0,0 +1,35 @@ +// %builtins range_check +// +// from starkware.cairo.common.alloc import alloc +// from starkware.cairo.common.set import set_add +// +// struct MyStruct { +// a: felt, +// b: felt, +// } +// +// func main{range_check_ptr}() { +// alloc_locals; +// +// // An array containing two structs. +// let (local my_list: MyStruct*) = alloc(); +// assert my_list[0] = MyStruct(a=1, b=3); +// assert my_list[1] = MyStruct(a=5, b=7); +// +// // Suppose that we want to add the element +// // MyStruct(a=2, b=3) to my_list, but only if it is not already +// // present (for the purpose of the example the contents of the +// // array are known, but this doesn't have to be the case) +// let list_end: felt* = &my_list[2]; +// let (new_elm: MyStruct*) = alloc(); +// assert new_elm[0] = MyStruct(a=2, b=3); +// +// set_add{set_end_ptr=list_end}(set_ptr=my_list, elm_size=MyStruct.SIZE, elm_ptr=new_elm); +// assert my_list[2] = MyStruct(a=2, b=3); +// return (); +// } +// + +func main() { + return (); +} diff --git a/pkg/hintrunner/zero/hintcode.go b/pkg/hintrunner/zero/hintcode.go index 5b35642eb..bacf114db 100644 --- a/pkg/hintrunner/zero/hintcode.go +++ b/pkg/hintrunner/zero/hintcode.go @@ -142,4 +142,5 @@ const ( memsetEnterScopeCode string = "vm_enter_scope({'n': ids.n})" vmEnterScopeCode string = "vm_enter_scope()" vmExitScopeCode string = "vm_exit_scope()" + setAddCode string = "assert ids.elm_size > 0\nassert ids.set_ptr <= ids.set_end_ptr\nelm_list = memory.get_range(ids.elm_ptr, ids.elm_size)\nfor i in range(0, ids.set_end_ptr - ids.set_ptr, ids.elm_size):\n if memory.get_range(ids.set_ptr + i, ids.elm_size) == elm_list:\n ids.index = i // ids.elm_size\n ids.is_elm_in_set = 1\n break\nelse:\n ids.is_elm_in_set = 0" ) diff --git a/pkg/hintrunner/zero/zerohint.go b/pkg/hintrunner/zero/zerohint.go index ae20ce92e..9f093be43 100644 --- a/pkg/hintrunner/zero/zerohint.go +++ b/pkg/hintrunner/zero/zerohint.go @@ -204,6 +204,8 @@ func GetHintFromCode(program *zero.ZeroProgram, rawHint zero.Hint, hintPC uint64 return createMemEnterScopeHinter(resolver, true) case vmExitScopeCode: return createVMExitScopeHinter() + case setAddCode: + return createSetAddHinter(resolver) case testAssignCode: return createTestAssignHinter(resolver) default: diff --git a/pkg/hintrunner/zero/zerohint_others.go b/pkg/hintrunner/zero/zerohint_others.go index b4ad2e36d..d3c94e616 100644 --- a/pkg/hintrunner/zero/zerohint_others.go +++ b/pkg/hintrunner/zero/zerohint_others.go @@ -2,6 +2,7 @@ package zero import ( "fmt" + "reflect" "github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/core" "github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/hinter" @@ -150,3 +151,124 @@ func createMemEnterScopeHinter(resolver hintReferenceResolver, memset bool) (hin } return newMemEnterScopeHint(value, memset), nil } + +func newSetAddHint(elmSize, elmPtr, setPtr, setEndPtr, index, isElmInSet hinter.ResOperander) hinter.Hinter { + return &GenericZeroHinter{ + Name: "SetAdd", + Op: func(vm *VM.VirtualMachine, ctx *hinter.HintRunnerContext) error { + //> assert ids.elm_size > 0 + //> assert ids.set_ptr <= ids.set_end_ptr + //> elm_list = memory.get_range(ids.elm_ptr, ids.elm_size) + //> for i in range(0, ids.set_end_ptr - ids.set_ptr, ids.elm_size): + //> if memory.get_range(ids.set_ptr + i, ids.elm_size) == elm_list: + //> ids.index = i // ids.elm_size + //> ids.is_elm_in_set = 1 + //> break + //> else: + //> ids.is_elm_in_set = 0 + + elmSize, err := hinter.ResolveAsUint64(vm, elmSize) + if err != nil { + return err + } + elmPtr, err := hinter.ResolveAsAddress(vm, elmPtr) + if err != nil { + return err + } + setPtr, err := hinter.ResolveAsAddress(vm, setPtr) + if err != nil { + return err + } + setEndPtr, err := hinter.ResolveAsAddress(vm, setEndPtr) + if err != nil { + return err + } + indexAddr, err := index.GetAddress(vm) + if err != nil { + return err + } + isElmInSetAddr, err := isElmInSet.GetAddress(vm) + if err != nil { + return err + } + + //> assert ids.elm_size > 0 + if elmSize == 0 { + return fmt.Errorf("assert ids.elm_size > 0 failed") + } + + //> assert ids.set_ptr <= ids.set_end_ptr + if setPtr.Offset > setEndPtr.Offset { + return fmt.Errorf("assert ids.set_ptr <= ids.set_end_ptr failed") + } + + //> elm_list = memory.get_range(ids.elm_ptr, ids.elm_size) + elmList, err := vm.Memory.GetConsecutiveMemoryValues(*elmPtr, int16(elmSize)) + if err != nil { + return err + } + + //> for i in range(0, ids.set_end_ptr - ids.set_ptr, ids.elm_size): + //> if memory.get_range(ids.set_ptr + i, ids.elm_size) == elm_list: + //> ids.index = i // ids.elm_size + //> ids.is_elm_in_set = 1 + //> break + //> else: + //> ids.is_elm_in_set = 0 + isElmInSetFelt := utils.FeltZero + totalSetLength := setEndPtr.Offset - setPtr.Offset + for i := uint64(0); i < totalSetLength; i += elmSize { + memoryElmList, err := vm.Memory.GetConsecutiveMemoryValues(*setPtr, int16(elmSize)) + if err != nil { + return err + } + *setPtr, err = setPtr.AddOffset(int16(elmSize)) + if err != nil { + return err + } + if reflect.DeepEqual(memoryElmList, elmList) { + indexFelt := fp.NewElement(i / elmSize) + indexMv := memory.MemoryValueFromFieldElement(&indexFelt) + err := vm.Memory.WriteToAddress(&indexAddr, &indexMv) + if err != nil { + return err + } + isElmInSetFelt = utils.FeltOne + break + } + } + + mv := memory.MemoryValueFromFieldElement(&isElmInSetFelt) + return vm.Memory.WriteToAddress(&isElmInSetAddr, &mv) + }, + } +} + +func createSetAddHinter(resolver hintReferenceResolver) (hinter.Hinter, error) { + elmSize, err := resolver.GetResOperander("elm_size") + if err != nil { + return nil, err + } + elmPtr, err := resolver.GetResOperander("elm_ptr") + if err != nil { + return nil, err + } + setPtr, err := resolver.GetResOperander("set_ptr") + if err != nil { + return nil, err + } + setEndPtr, err := resolver.GetResOperander("set_end_ptr") + if err != nil { + return nil, err + } + index, err := resolver.GetResOperander("index") + if err != nil { + return nil, err + } + isElmInSet, err := resolver.GetResOperander("is_elm_in_set") + if err != nil { + return nil, err + } + + return newSetAddHint(elmSize, elmPtr, setPtr, setEndPtr, index, isElmInSet), nil +} diff --git a/pkg/hintrunner/zero/zerohint_others_test.go b/pkg/hintrunner/zero/zerohint_others_test.go index a37c01afa..11f99f198 100644 --- a/pkg/hintrunner/zero/zerohint_others_test.go +++ b/pkg/hintrunner/zero/zerohint_others_test.go @@ -4,9 +4,10 @@ import ( "testing" "github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/hinter" + "github.com/consensys/gnark-crypto/ecc/stark-curve/fp" ) -func TestZeroHintMemcpy(t *testing.T) { +func TestZeroHintOthers(t *testing.T) { runHinterTests(t, map[string][]hintTestCase{ "MemcpyContinueCopying": { { @@ -57,5 +58,170 @@ func TestZeroHintMemcpy(t *testing.T) { check: varValueInScopeEquals("n", *feltUint64(1)), }, }, + "SetAdd": { + { + operanders: []*hintOperander{ + {Name: "elm_size", Kind: apRelative, Value: feltUint64(0)}, + {Name: "elm_ptr", Kind: apRelative, Value: addrWithSegment(1, 0)}, + {Name: "set_ptr", Kind: apRelative, Value: addrWithSegment(1, 0)}, + {Name: "set_end_ptr", Kind: apRelative, Value: addrWithSegment(1, 0)}, + {Name: "index", Kind: uninitialized}, + {Name: "is_elm_in_set", Kind: uninitialized}, + }, + makeHinter: func(ctx *hintTestContext) hinter.Hinter { + return newSetAddHint( + ctx.operanders["elm_size"], + ctx.operanders["elm_ptr"], + ctx.operanders["set_ptr"], + ctx.operanders["set_end_ptr"], + ctx.operanders["index"], + ctx.operanders["is_elm_in_set"], + ) + }, + errCheck: errorTextContains("assert ids.elm_size > 0 failed"), + }, + { + operanders: []*hintOperander{ + {Name: "elm_size", Kind: apRelative, Value: feltUint64(1)}, + {Name: "elm_ptr", Kind: apRelative, Value: addrWithSegment(1, 0)}, + {Name: "set_ptr", Kind: apRelative, Value: addrWithSegment(1, 1)}, + {Name: "set_end_ptr", Kind: apRelative, Value: addrWithSegment(1, 0)}, + {Name: "index", Kind: uninitialized}, + {Name: "is_elm_in_set", Kind: uninitialized}, + }, + makeHinter: func(ctx *hintTestContext) hinter.Hinter { + return newSetAddHint( + ctx.operanders["elm_size"], + ctx.operanders["elm_ptr"], + ctx.operanders["set_ptr"], + ctx.operanders["set_end_ptr"], + ctx.operanders["index"], + ctx.operanders["is_elm_in_set"], + ) + }, + errCheck: errorTextContains("assert ids.set_ptr <= ids.set_end_ptr failed"), + }, + { + operanders: []*hintOperander{ + {Name: "elm.1", Kind: apRelative, Value: feltUint64(1)}, + {Name: "elm.2", Kind: apRelative, Value: feltUint64(2)}, + {Name: "elm.3", Kind: apRelative, Value: feltUint64(3)}, + {Name: "elm.4", Kind: apRelative, Value: feltUint64(4)}, + {Name: "set.1", Kind: apRelative, Value: feltUint64(5)}, + {Name: "set.2", Kind: apRelative, Value: feltUint64(6)}, + {Name: "set.3", Kind: apRelative, Value: feltUint64(7)}, + {Name: "set.4", Kind: apRelative, Value: feltUint64(8)}, + {Name: "set.5", Kind: apRelative, Value: feltUint64(9)}, + {Name: "set.6", Kind: apRelative, Value: feltUint64(10)}, + {Name: "set.7", Kind: apRelative, Value: feltUint64(11)}, + {Name: "set.8", Kind: apRelative, Value: feltUint64(12)}, + {Name: "set.9", Kind: apRelative, Value: feltUint64(1)}, + {Name: "set.10", Kind: apRelative, Value: feltUint64(2)}, + {Name: "set.11", Kind: apRelative, Value: feltUint64(3)}, + {Name: "set.12", Kind: apRelative, Value: feltUint64(4)}, + {Name: "set.13", Kind: apRelative, Value: feltUint64(13)}, + {Name: "set.14", Kind: apRelative, Value: feltUint64(14)}, + {Name: "set.15", Kind: apRelative, Value: feltUint64(15)}, + {Name: "set.16", Kind: apRelative, Value: feltUint64(16)}, + {Name: "elm_size", Kind: apRelative, Value: feltUint64(4)}, + {Name: "elm_ptr", Kind: apRelative, Value: addrWithSegment(1, 4)}, + {Name: "set_ptr", Kind: apRelative, Value: addrWithSegment(1, 8)}, + {Name: "set_end_ptr", Kind: apRelative, Value: addrWithSegment(1, 24)}, + {Name: "index", Kind: uninitialized}, + {Name: "is_elm_in_set", Kind: uninitialized}, + }, + makeHinter: func(ctx *hintTestContext) hinter.Hinter { + return newSetAddHint( + ctx.operanders["elm_size"], + ctx.operanders["elm_ptr"], + ctx.operanders["set_ptr"], + ctx.operanders["set_end_ptr"], + ctx.operanders["index"], + ctx.operanders["is_elm_in_set"], + ) + }, + check: allVarValueEquals(map[string]*fp.Element{ + "index": feltUint64(2), + "is_elm_in_set": feltUint64(1), + }), + }, + { + operanders: []*hintOperander{ + {Name: "elm.1", Kind: apRelative, Value: feltUint64(1)}, + {Name: "elm.2", Kind: apRelative, Value: feltUint64(2)}, + {Name: "elm.3", Kind: apRelative, Value: feltUint64(3)}, + {Name: "elm.4", Kind: apRelative, Value: feltUint64(4)}, + {Name: "set.1", Kind: apRelative, Value: feltUint64(5)}, + {Name: "set.2", Kind: apRelative, Value: feltUint64(6)}, + {Name: "set.3", Kind: apRelative, Value: feltUint64(7)}, + {Name: "set.4", Kind: apRelative, Value: feltUint64(8)}, + {Name: "set.5", Kind: apRelative, Value: feltUint64(9)}, + {Name: "set.6", Kind: apRelative, Value: feltUint64(10)}, + {Name: "set.7", Kind: apRelative, Value: feltUint64(11)}, + {Name: "set.8", Kind: apRelative, Value: feltUint64(12)}, + {Name: "set.9", Kind: apRelative, Value: feltUint64(13)}, + {Name: "set.10", Kind: apRelative, Value: feltUint64(14)}, + {Name: "set.11", Kind: apRelative, Value: feltUint64(15)}, + {Name: "set.12", Kind: apRelative, Value: feltUint64(16)}, + {Name: "set.13", Kind: apRelative, Value: feltUint64(17)}, + {Name: "set.14", Kind: apRelative, Value: feltUint64(18)}, + {Name: "set.15", Kind: apRelative, Value: feltUint64(19)}, + {Name: "set.16", Kind: apRelative, Value: feltUint64(20)}, + {Name: "elm_size", Kind: apRelative, Value: feltUint64(4)}, + {Name: "elm_ptr", Kind: apRelative, Value: addrWithSegment(1, 4)}, + {Name: "set_ptr", Kind: apRelative, Value: addrWithSegment(1, 8)}, + {Name: "set_end_ptr", Kind: apRelative, Value: addrWithSegment(1, 24)}, + {Name: "index", Kind: uninitialized}, + {Name: "is_elm_in_set", Kind: uninitialized}, + }, + makeHinter: func(ctx *hintTestContext) hinter.Hinter { + return newSetAddHint( + ctx.operanders["elm_size"], + ctx.operanders["elm_ptr"], + ctx.operanders["set_ptr"], + ctx.operanders["set_end_ptr"], + ctx.operanders["index"], + ctx.operanders["is_elm_in_set"], + ) + }, + check: allVarValueEquals(map[string]*fp.Element{ + "is_elm_in_set": feltUint64(0), + }), + }, + { + operanders: []*hintOperander{ + {Name: "elm.1", Kind: apRelative, Value: feltUint64(1)}, + {Name: "elm.2", Kind: apRelative, Value: feltUint64(2)}, + {Name: "elm.3", Kind: apRelative, Value: feltUint64(3)}, + {Name: "elm.4", Kind: apRelative, Value: feltUint64(4)}, + {Name: "elm.5", Kind: apRelative, Value: feltUint64(5)}, + {Name: "set.1", Kind: apRelative, Value: feltUint64(1)}, + {Name: "set.2", Kind: apRelative, Value: feltUint64(2)}, + {Name: "set.3", Kind: apRelative, Value: feltUint64(3)}, + {Name: "set.4", Kind: apRelative, Value: feltUint64(4)}, + {Name: "set.5", Kind: apRelative, Value: feltUint64(5)}, + {Name: "elm_size", Kind: apRelative, Value: feltUint64(5)}, + {Name: "elm_ptr", Kind: apRelative, Value: addrWithSegment(1, 4)}, + {Name: "set_ptr", Kind: apRelative, Value: addrWithSegment(1, 9)}, + {Name: "set_end_ptr", Kind: apRelative, Value: addrWithSegment(1, 14)}, + {Name: "index", Kind: uninitialized}, + {Name: "is_elm_in_set", Kind: uninitialized}, + }, + makeHinter: func(ctx *hintTestContext) hinter.Hinter { + return newSetAddHint( + ctx.operanders["elm_size"], + ctx.operanders["elm_ptr"], + ctx.operanders["set_ptr"], + ctx.operanders["set_end_ptr"], + ctx.operanders["index"], + ctx.operanders["is_elm_in_set"], + ) + }, + check: allVarValueEquals(map[string]*fp.Element{ + "index": feltUint64(0), + "is_elm_in_set": feltUint64(1), + }), + }, + }, }) }