Skip to content

Commit

Permalink
Implement SetAdd hint (#402)
Browse files Browse the repository at this point in the history
* Implement SetAdd

* Fix incorrect method usage

* Fix incorrect type used in scope

* Add more tests

* Fix names

* More efficient resolution of operand

* Temporarily comment SetAdd integration test

* Fix tests
  • Loading branch information
har777 authored Jun 12, 2024
1 parent 373f0d2 commit 9d98260
Show file tree
Hide file tree
Showing 5 changed files with 327 additions and 1 deletion.
35 changes: 35 additions & 0 deletions integration_tests/cairo_files/set_add.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// %builtins range_check
//
// from starkware.cairo.common.alloc import alloc
// from starkware.cairo.common.set import set_add
//
// struct MyStruct {
// a: felt,
// b: felt,
// }
//
// func main{range_check_ptr}() {
// alloc_locals;
//
// // An array containing two structs.
// let (local my_list: MyStruct*) = alloc();
// assert my_list[0] = MyStruct(a=1, b=3);
// assert my_list[1] = MyStruct(a=5, b=7);
//
// // Suppose that we want to add the element
// // MyStruct(a=2, b=3) to my_list, but only if it is not already
// // present (for the purpose of the example the contents of the
// // array are known, but this doesn't have to be the case)
// let list_end: felt* = &my_list[2];
// let (new_elm: MyStruct*) = alloc();
// assert new_elm[0] = MyStruct(a=2, b=3);
//
// set_add{set_end_ptr=list_end}(set_ptr=my_list, elm_size=MyStruct.SIZE, elm_ptr=new_elm);
// assert my_list[2] = MyStruct(a=2, b=3);
// return ();
// }
//

func main() {
return ();
}
1 change: 1 addition & 0 deletions pkg/hintrunner/zero/hintcode.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,4 +142,5 @@ const (
memsetEnterScopeCode string = "vm_enter_scope({'n': ids.n})"
vmEnterScopeCode string = "vm_enter_scope()"
vmExitScopeCode string = "vm_exit_scope()"
setAddCode string = "assert ids.elm_size > 0\nassert ids.set_ptr <= ids.set_end_ptr\nelm_list = memory.get_range(ids.elm_ptr, ids.elm_size)\nfor i in range(0, ids.set_end_ptr - ids.set_ptr, ids.elm_size):\n if memory.get_range(ids.set_ptr + i, ids.elm_size) == elm_list:\n ids.index = i // ids.elm_size\n ids.is_elm_in_set = 1\n break\nelse:\n ids.is_elm_in_set = 0"
)
2 changes: 2 additions & 0 deletions pkg/hintrunner/zero/zerohint.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,8 @@ func GetHintFromCode(program *zero.ZeroProgram, rawHint zero.Hint, hintPC uint64
return createMemEnterScopeHinter(resolver, true)
case vmExitScopeCode:
return createVMExitScopeHinter()
case setAddCode:
return createSetAddHinter(resolver)
case testAssignCode:
return createTestAssignHinter(resolver)
default:
Expand Down
122 changes: 122 additions & 0 deletions pkg/hintrunner/zero/zerohint_others.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package zero

import (
"fmt"
"reflect"

"github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/core"
"github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/hinter"
Expand Down Expand Up @@ -150,3 +151,124 @@ func createMemEnterScopeHinter(resolver hintReferenceResolver, memset bool) (hin
}
return newMemEnterScopeHint(value, memset), nil
}

func newSetAddHint(elmSize, elmPtr, setPtr, setEndPtr, index, isElmInSet hinter.ResOperander) hinter.Hinter {
return &GenericZeroHinter{
Name: "SetAdd",
Op: func(vm *VM.VirtualMachine, ctx *hinter.HintRunnerContext) error {
//> assert ids.elm_size > 0
//> assert ids.set_ptr <= ids.set_end_ptr
//> elm_list = memory.get_range(ids.elm_ptr, ids.elm_size)
//> for i in range(0, ids.set_end_ptr - ids.set_ptr, ids.elm_size):
//> if memory.get_range(ids.set_ptr + i, ids.elm_size) == elm_list:
//> ids.index = i // ids.elm_size
//> ids.is_elm_in_set = 1
//> break
//> else:
//> ids.is_elm_in_set = 0

elmSize, err := hinter.ResolveAsUint64(vm, elmSize)
if err != nil {
return err
}
elmPtr, err := hinter.ResolveAsAddress(vm, elmPtr)
if err != nil {
return err
}
setPtr, err := hinter.ResolveAsAddress(vm, setPtr)
if err != nil {
return err
}
setEndPtr, err := hinter.ResolveAsAddress(vm, setEndPtr)
if err != nil {
return err
}
indexAddr, err := index.GetAddress(vm)
if err != nil {
return err
}
isElmInSetAddr, err := isElmInSet.GetAddress(vm)
if err != nil {
return err
}

//> assert ids.elm_size > 0
if elmSize == 0 {
return fmt.Errorf("assert ids.elm_size > 0 failed")
}

//> assert ids.set_ptr <= ids.set_end_ptr
if setPtr.Offset > setEndPtr.Offset {
return fmt.Errorf("assert ids.set_ptr <= ids.set_end_ptr failed")
}

//> elm_list = memory.get_range(ids.elm_ptr, ids.elm_size)
elmList, err := vm.Memory.GetConsecutiveMemoryValues(*elmPtr, int16(elmSize))
if err != nil {
return err
}

//> for i in range(0, ids.set_end_ptr - ids.set_ptr, ids.elm_size):
//> if memory.get_range(ids.set_ptr + i, ids.elm_size) == elm_list:
//> ids.index = i // ids.elm_size
//> ids.is_elm_in_set = 1
//> break
//> else:
//> ids.is_elm_in_set = 0
isElmInSetFelt := utils.FeltZero
totalSetLength := setEndPtr.Offset - setPtr.Offset
for i := uint64(0); i < totalSetLength; i += elmSize {
memoryElmList, err := vm.Memory.GetConsecutiveMemoryValues(*setPtr, int16(elmSize))
if err != nil {
return err
}
*setPtr, err = setPtr.AddOffset(int16(elmSize))
if err != nil {
return err
}
if reflect.DeepEqual(memoryElmList, elmList) {
indexFelt := fp.NewElement(i / elmSize)
indexMv := memory.MemoryValueFromFieldElement(&indexFelt)
err := vm.Memory.WriteToAddress(&indexAddr, &indexMv)
if err != nil {
return err
}
isElmInSetFelt = utils.FeltOne
break
}
}

mv := memory.MemoryValueFromFieldElement(&isElmInSetFelt)
return vm.Memory.WriteToAddress(&isElmInSetAddr, &mv)
},
}
}

func createSetAddHinter(resolver hintReferenceResolver) (hinter.Hinter, error) {
elmSize, err := resolver.GetResOperander("elm_size")
if err != nil {
return nil, err
}
elmPtr, err := resolver.GetResOperander("elm_ptr")
if err != nil {
return nil, err
}
setPtr, err := resolver.GetResOperander("set_ptr")
if err != nil {
return nil, err
}
setEndPtr, err := resolver.GetResOperander("set_end_ptr")
if err != nil {
return nil, err
}
index, err := resolver.GetResOperander("index")
if err != nil {
return nil, err
}
isElmInSet, err := resolver.GetResOperander("is_elm_in_set")
if err != nil {
return nil, err
}

return newSetAddHint(elmSize, elmPtr, setPtr, setEndPtr, index, isElmInSet), nil
}
168 changes: 167 additions & 1 deletion pkg/hintrunner/zero/zerohint_others_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@ import (
"testing"

"github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/hinter"
"github.com/consensys/gnark-crypto/ecc/stark-curve/fp"
)

func TestZeroHintMemcpy(t *testing.T) {
func TestZeroHintOthers(t *testing.T) {
runHinterTests(t, map[string][]hintTestCase{
"MemcpyContinueCopying": {
{
Expand Down Expand Up @@ -57,5 +58,170 @@ func TestZeroHintMemcpy(t *testing.T) {
check: varValueInScopeEquals("n", *feltUint64(1)),
},
},
"SetAdd": {
{
operanders: []*hintOperander{
{Name: "elm_size", Kind: apRelative, Value: feltUint64(0)},
{Name: "elm_ptr", Kind: apRelative, Value: addrWithSegment(1, 0)},
{Name: "set_ptr", Kind: apRelative, Value: addrWithSegment(1, 0)},
{Name: "set_end_ptr", Kind: apRelative, Value: addrWithSegment(1, 0)},
{Name: "index", Kind: uninitialized},
{Name: "is_elm_in_set", Kind: uninitialized},
},
makeHinter: func(ctx *hintTestContext) hinter.Hinter {
return newSetAddHint(
ctx.operanders["elm_size"],
ctx.operanders["elm_ptr"],
ctx.operanders["set_ptr"],
ctx.operanders["set_end_ptr"],
ctx.operanders["index"],
ctx.operanders["is_elm_in_set"],
)
},
errCheck: errorTextContains("assert ids.elm_size > 0 failed"),
},
{
operanders: []*hintOperander{
{Name: "elm_size", Kind: apRelative, Value: feltUint64(1)},
{Name: "elm_ptr", Kind: apRelative, Value: addrWithSegment(1, 0)},
{Name: "set_ptr", Kind: apRelative, Value: addrWithSegment(1, 1)},
{Name: "set_end_ptr", Kind: apRelative, Value: addrWithSegment(1, 0)},
{Name: "index", Kind: uninitialized},
{Name: "is_elm_in_set", Kind: uninitialized},
},
makeHinter: func(ctx *hintTestContext) hinter.Hinter {
return newSetAddHint(
ctx.operanders["elm_size"],
ctx.operanders["elm_ptr"],
ctx.operanders["set_ptr"],
ctx.operanders["set_end_ptr"],
ctx.operanders["index"],
ctx.operanders["is_elm_in_set"],
)
},
errCheck: errorTextContains("assert ids.set_ptr <= ids.set_end_ptr failed"),
},
{
operanders: []*hintOperander{
{Name: "elm.1", Kind: apRelative, Value: feltUint64(1)},
{Name: "elm.2", Kind: apRelative, Value: feltUint64(2)},
{Name: "elm.3", Kind: apRelative, Value: feltUint64(3)},
{Name: "elm.4", Kind: apRelative, Value: feltUint64(4)},
{Name: "set.1", Kind: apRelative, Value: feltUint64(5)},
{Name: "set.2", Kind: apRelative, Value: feltUint64(6)},
{Name: "set.3", Kind: apRelative, Value: feltUint64(7)},
{Name: "set.4", Kind: apRelative, Value: feltUint64(8)},
{Name: "set.5", Kind: apRelative, Value: feltUint64(9)},
{Name: "set.6", Kind: apRelative, Value: feltUint64(10)},
{Name: "set.7", Kind: apRelative, Value: feltUint64(11)},
{Name: "set.8", Kind: apRelative, Value: feltUint64(12)},
{Name: "set.9", Kind: apRelative, Value: feltUint64(1)},
{Name: "set.10", Kind: apRelative, Value: feltUint64(2)},
{Name: "set.11", Kind: apRelative, Value: feltUint64(3)},
{Name: "set.12", Kind: apRelative, Value: feltUint64(4)},
{Name: "set.13", Kind: apRelative, Value: feltUint64(13)},
{Name: "set.14", Kind: apRelative, Value: feltUint64(14)},
{Name: "set.15", Kind: apRelative, Value: feltUint64(15)},
{Name: "set.16", Kind: apRelative, Value: feltUint64(16)},
{Name: "elm_size", Kind: apRelative, Value: feltUint64(4)},
{Name: "elm_ptr", Kind: apRelative, Value: addrWithSegment(1, 4)},
{Name: "set_ptr", Kind: apRelative, Value: addrWithSegment(1, 8)},
{Name: "set_end_ptr", Kind: apRelative, Value: addrWithSegment(1, 24)},
{Name: "index", Kind: uninitialized},
{Name: "is_elm_in_set", Kind: uninitialized},
},
makeHinter: func(ctx *hintTestContext) hinter.Hinter {
return newSetAddHint(
ctx.operanders["elm_size"],
ctx.operanders["elm_ptr"],
ctx.operanders["set_ptr"],
ctx.operanders["set_end_ptr"],
ctx.operanders["index"],
ctx.operanders["is_elm_in_set"],
)
},
check: allVarValueEquals(map[string]*fp.Element{
"index": feltUint64(2),
"is_elm_in_set": feltUint64(1),
}),
},
{
operanders: []*hintOperander{
{Name: "elm.1", Kind: apRelative, Value: feltUint64(1)},
{Name: "elm.2", Kind: apRelative, Value: feltUint64(2)},
{Name: "elm.3", Kind: apRelative, Value: feltUint64(3)},
{Name: "elm.4", Kind: apRelative, Value: feltUint64(4)},
{Name: "set.1", Kind: apRelative, Value: feltUint64(5)},
{Name: "set.2", Kind: apRelative, Value: feltUint64(6)},
{Name: "set.3", Kind: apRelative, Value: feltUint64(7)},
{Name: "set.4", Kind: apRelative, Value: feltUint64(8)},
{Name: "set.5", Kind: apRelative, Value: feltUint64(9)},
{Name: "set.6", Kind: apRelative, Value: feltUint64(10)},
{Name: "set.7", Kind: apRelative, Value: feltUint64(11)},
{Name: "set.8", Kind: apRelative, Value: feltUint64(12)},
{Name: "set.9", Kind: apRelative, Value: feltUint64(13)},
{Name: "set.10", Kind: apRelative, Value: feltUint64(14)},
{Name: "set.11", Kind: apRelative, Value: feltUint64(15)},
{Name: "set.12", Kind: apRelative, Value: feltUint64(16)},
{Name: "set.13", Kind: apRelative, Value: feltUint64(17)},
{Name: "set.14", Kind: apRelative, Value: feltUint64(18)},
{Name: "set.15", Kind: apRelative, Value: feltUint64(19)},
{Name: "set.16", Kind: apRelative, Value: feltUint64(20)},
{Name: "elm_size", Kind: apRelative, Value: feltUint64(4)},
{Name: "elm_ptr", Kind: apRelative, Value: addrWithSegment(1, 4)},
{Name: "set_ptr", Kind: apRelative, Value: addrWithSegment(1, 8)},
{Name: "set_end_ptr", Kind: apRelative, Value: addrWithSegment(1, 24)},
{Name: "index", Kind: uninitialized},
{Name: "is_elm_in_set", Kind: uninitialized},
},
makeHinter: func(ctx *hintTestContext) hinter.Hinter {
return newSetAddHint(
ctx.operanders["elm_size"],
ctx.operanders["elm_ptr"],
ctx.operanders["set_ptr"],
ctx.operanders["set_end_ptr"],
ctx.operanders["index"],
ctx.operanders["is_elm_in_set"],
)
},
check: allVarValueEquals(map[string]*fp.Element{
"is_elm_in_set": feltUint64(0),
}),
},
{
operanders: []*hintOperander{
{Name: "elm.1", Kind: apRelative, Value: feltUint64(1)},
{Name: "elm.2", Kind: apRelative, Value: feltUint64(2)},
{Name: "elm.3", Kind: apRelative, Value: feltUint64(3)},
{Name: "elm.4", Kind: apRelative, Value: feltUint64(4)},
{Name: "elm.5", Kind: apRelative, Value: feltUint64(5)},
{Name: "set.1", Kind: apRelative, Value: feltUint64(1)},
{Name: "set.2", Kind: apRelative, Value: feltUint64(2)},
{Name: "set.3", Kind: apRelative, Value: feltUint64(3)},
{Name: "set.4", Kind: apRelative, Value: feltUint64(4)},
{Name: "set.5", Kind: apRelative, Value: feltUint64(5)},
{Name: "elm_size", Kind: apRelative, Value: feltUint64(5)},
{Name: "elm_ptr", Kind: apRelative, Value: addrWithSegment(1, 4)},
{Name: "set_ptr", Kind: apRelative, Value: addrWithSegment(1, 9)},
{Name: "set_end_ptr", Kind: apRelative, Value: addrWithSegment(1, 14)},
{Name: "index", Kind: uninitialized},
{Name: "is_elm_in_set", Kind: uninitialized},
},
makeHinter: func(ctx *hintTestContext) hinter.Hinter {
return newSetAddHint(
ctx.operanders["elm_size"],
ctx.operanders["elm_ptr"],
ctx.operanders["set_ptr"],
ctx.operanders["set_end_ptr"],
ctx.operanders["index"],
ctx.operanders["is_elm_in_set"],
)
},
check: allVarValueEquals(map[string]*fp.Element{
"index": feltUint64(0),
"is_elm_in_set": feltUint64(1),
}),
},
},
})
}

0 comments on commit 9d98260

Please sign in to comment.