Skip to content

Commit

Permalink
Merge pull request #2892 from darkdrag00nv2/range_type_for_loop
Browse files Browse the repository at this point in the history
  • Loading branch information
turbolent authored Nov 8, 2023
2 parents e523180 + b512835 commit 7ab492f
Show file tree
Hide file tree
Showing 6 changed files with 256 additions and 12 deletions.
4 changes: 2 additions & 2 deletions runtime/interpreter/interpreter_statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ func (interpreter *Interpreter) VisitForStatement(statement *ast.ForStatement) S
panic(errors.NewUnreachableError())
}

iterator := iterable.Iterator(interpreter)
iterator := iterable.Iterator(interpreter, locationRange)

var indexVariable *Variable
if statement.Index != nil {
Expand All @@ -342,7 +342,7 @@ func (interpreter *Interpreter) VisitForStatement(statement *ast.ForStatement) S
}

for {
value := iterator.Next(interpreter)
value := iterator.Next(interpreter, locationRange)
if value == nil {
return nil
}
Expand Down
81 changes: 75 additions & 6 deletions runtime/interpreter/value.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,13 +246,13 @@ type LinkValue interface {
// IterableValue is a value which can be iterated over, e.g. with a for-loop
type IterableValue interface {
Value
Iterator(interpreter *Interpreter) ValueIterator
Iterator(interpreter *Interpreter, locationRange LocationRange) ValueIterator
}

// ValueIterator is an iterator which returns values.
// When Next returns nil, it signals the end of the iterator.
type ValueIterator interface {
Next(interpreter *Interpreter) Value
Next(interpreter *Interpreter, locationRange LocationRange) Value
}

func safeAdd(a, b int, locationRange LocationRange) int {
Expand Down Expand Up @@ -1516,7 +1516,7 @@ func (v *StringValue) ConformsToStaticType(
return true
}

func (v *StringValue) Iterator(_ *Interpreter) ValueIterator {
func (v *StringValue) Iterator(_ *Interpreter, _ LocationRange) ValueIterator {
return StringValueIterator{
graphemes: uniseg.NewGraphemes(v.Str),
}
Expand All @@ -1528,7 +1528,7 @@ type StringValueIterator struct {

var _ ValueIterator = StringValueIterator{}

func (i StringValueIterator) Next(_ *Interpreter) Value {
func (i StringValueIterator) Next(_ *Interpreter, _ LocationRange) Value {
if !i.graphemes.Next() {
return nil
}
Expand All @@ -1550,7 +1550,7 @@ type ArrayValueIterator struct {
atreeIterator *atree.ArrayIterator
}

func (v *ArrayValue) Iterator(_ *Interpreter) ValueIterator {
func (v *ArrayValue) Iterator(_ *Interpreter, _ LocationRange) ValueIterator {
arrayIterator, err := v.array.Iterator()
if err != nil {
panic(errors.NewExternalError(err))
Expand All @@ -1562,7 +1562,7 @@ func (v *ArrayValue) Iterator(_ *Interpreter) ValueIterator {

var _ ValueIterator = ArrayValueIterator{}

func (i ArrayValueIterator) Next(interpreter *Interpreter) Value {
func (i ArrayValueIterator) Next(interpreter *Interpreter, _ LocationRange) Value {
atreeValue, err := i.atreeIterator.Next()
if err != nil {
panic(errors.NewExternalError(err))
Expand Down Expand Up @@ -16267,6 +16267,7 @@ type CompositeField struct {
const attachmentNamePrefix = "$"

var _ TypeIndexableValue = &CompositeValue{}
var _ IterableValue = &CompositeValue{}

func NewCompositeField(memoryGauge common.MemoryGauge, name string, value Value) CompositeField {
common.UseMemory(memoryGauge, common.CompositeFieldMemoryUsage)
Expand Down Expand Up @@ -17782,6 +17783,74 @@ func (v *CompositeValue) RemoveTypeKey(
return v.RemoveMember(interpreter, locationRange, attachmentMemberName(attachmentType))
}

func (v *CompositeValue) Iterator(interpreter *Interpreter, locationRange LocationRange) ValueIterator {
staticType := v.StaticType(interpreter)

switch typ := staticType.(type) {
case InclusiveRangeStaticType:
return NewInclusiveRangeIterator(interpreter, locationRange, v, typ)

default:
// Must be caught in the checker.
panic(errors.NewUnreachableError())
}
}

type InclusiveRangeIterator struct {
rangeValue *CompositeValue
next IntegerValue

// Cached values
stepNegative bool
step IntegerValue
end IntegerValue
}

var _ ValueIterator = &InclusiveRangeIterator{}

func NewInclusiveRangeIterator(
interpreter *Interpreter,
locationRange LocationRange,
v *CompositeValue,
typ InclusiveRangeStaticType,
) *InclusiveRangeIterator {
startValue := getFieldAsIntegerValue(interpreter, v, locationRange, sema.InclusiveRangeTypeStartFieldName)

zeroValue := GetSmallIntegerValue(0, typ.ElementType)
endValue := getFieldAsIntegerValue(interpreter, v, locationRange, sema.InclusiveRangeTypeEndFieldName)

stepValue := getFieldAsIntegerValue(interpreter, v, locationRange, sema.InclusiveRangeTypeStepFieldName)
stepNegative := stepValue.Less(interpreter, zeroValue, locationRange)

return &InclusiveRangeIterator{
rangeValue: v,
next: startValue,
stepNegative: bool(stepNegative),
step: stepValue,
end: endValue,
}
}

func (i *InclusiveRangeIterator) Next(interpreter *Interpreter, locationRange LocationRange) Value {
valueToReturn := i.next

// Ensure that valueToReturn is within the bounds.
if i.stepNegative && bool(valueToReturn.Less(interpreter, i.end, locationRange)) {
return nil
} else if !i.stepNegative && bool(valueToReturn.Greater(interpreter, i.end, locationRange)) {
return nil
}

// Update the next value.
nextValueToReturn, ok := valueToReturn.Plus(interpreter, i.step, locationRange).(IntegerValue)
if !ok {
panic(errors.NewUnreachableError())
}

i.next = nextValueToReturn
return valueToReturn
}

// DictionaryValue

type DictionaryValue struct {
Expand Down
8 changes: 4 additions & 4 deletions runtime/interpreter/value_range.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,9 @@ func rangeContains(
locationRange LocationRange,
needleValue IntegerValue,
) BoolValue {
start := getFieldAsIntegerValue(rangeValue, interpreter, locationRange, sema.InclusiveRangeTypeStartFieldName)
end := getFieldAsIntegerValue(rangeValue, interpreter, locationRange, sema.InclusiveRangeTypeEndFieldName)
step := getFieldAsIntegerValue(rangeValue, interpreter, locationRange, sema.InclusiveRangeTypeStepFieldName)
start := getFieldAsIntegerValue(interpreter, rangeValue, locationRange, sema.InclusiveRangeTypeStartFieldName)
end := getFieldAsIntegerValue(interpreter, rangeValue, locationRange, sema.InclusiveRangeTypeEndFieldName)
step := getFieldAsIntegerValue(interpreter, rangeValue, locationRange, sema.InclusiveRangeTypeStepFieldName)

result := start.Equal(interpreter, locationRange, needleValue) ||
end.Equal(interpreter, locationRange, needleValue)
Expand Down Expand Up @@ -220,8 +220,8 @@ func rangeContains(
}

func getFieldAsIntegerValue(
rangeValue *CompositeValue,
interpreter *Interpreter,
rangeValue *CompositeValue,
locationRange LocationRange,
name string,
) IntegerValue {
Expand Down
2 changes: 2 additions & 0 deletions runtime/sema/check_for.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ func (checker *Checker) VisitForStatement(statement *ast.ForStatement) (_ struct
elementType = arrayType.ElementType(false)
} else if valueType == StringType {
elementType = CharacterType
} else if inclusiveRangeType, ok := valueType.(*InclusiveRangeType); ok {
elementType = inclusiveRangeType.MemberType
} else {
checker.report(
&TypeMismatchWithDescriptionError{
Expand Down
35 changes: 35 additions & 0 deletions runtime/tests/checker/for_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@
package checker

import (
"fmt"
"testing"

"github.com/stretchr/testify/assert"

"github.com/onflow/cadence/runtime/sema"
"github.com/onflow/cadence/runtime/stdlib"
)

func TestCheckForVariableSized(t *testing.T) {
Expand Down Expand Up @@ -76,6 +78,39 @@ func TestCheckForString(t *testing.T) {
assert.NoError(t, err)
}

func TestCheckForInclusiveRange(t *testing.T) {

t.Parallel()

baseValueActivation := sema.NewVariableActivation(sema.BaseValueActivation)
baseValueActivation.DeclareValue(stdlib.InclusiveRangeConstructorFunction)

for _, typ := range sema.AllIntegerTypes {
code := fmt.Sprintf(`
fun test() {
let start : %[1]s = 1
let end : %[1]s = 2
let step : %[1]s = 1
let range: InclusiveRange<%[1]s> = InclusiveRange(start, end, step: step)
for value in range {
var typedValue: %[1]s = value
}
}
`, typ.String())

_, err := ParseAndCheckWithOptions(t, code,
ParseAndCheckOptions{
Config: &sema.Config{
BaseValueActivation: baseValueActivation,
},
},
)

assert.NoError(t, err)
}
}

func TestCheckForEmpty(t *testing.T) {

t.Parallel()
Expand Down
138 changes: 138 additions & 0 deletions runtime/tests/interpreter/for_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,16 @@
package interpreter_test

import (
"fmt"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/onflow/cadence/runtime/activations"
"github.com/onflow/cadence/runtime/common"
"github.com/onflow/cadence/runtime/sema"
"github.com/onflow/cadence/runtime/stdlib"
. "github.com/onflow/cadence/runtime/tests/utils"

"github.com/onflow/cadence/runtime/interpreter"
Expand Down Expand Up @@ -255,3 +260,136 @@ func TestInterpretForString(t *testing.T) {
value,
)
}

type inclusiveRangeForInLoopTest struct {
start, end, step int8
loopElements []int
}

func TestInclusiveRangeForInLoop(t *testing.T) {
t.Parallel()

baseValueActivation := sema.NewVariableActivation(sema.BaseValueActivation)
baseValueActivation.DeclareValue(stdlib.InclusiveRangeConstructorFunction)

baseActivation := activations.NewActivation(nil, interpreter.BaseActivation)
interpreter.Declare(baseActivation, stdlib.InclusiveRangeConstructorFunction)

unsignedTestCases := []inclusiveRangeForInLoopTest{
{
start: 0,
end: 10,
step: 1,
loopElements: []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
},
{
start: 0,
end: 10,
step: 2,
loopElements: []int{0, 2, 4, 6, 8, 10},
},
}

signedTestCases := []inclusiveRangeForInLoopTest{
{
start: 10,
end: -10,
step: -2,
loopElements: []int{10, 8, 6, 4, 2, 0, -2, -4, -6, -8, -10},
},
}

runTestCase := func(t *testing.T, typ sema.Type, testCase inclusiveRangeForInLoopTest) {
t.Run(typ.String(), func(t *testing.T) {
t.Parallel()

code := fmt.Sprintf(
`
fun test(): [%[1]s] {
let start : %[1]s = %[2]d
let end : %[1]s = %[3]d
let step : %[1]s = %[4]d
let range: InclusiveRange<%[1]s> = InclusiveRange(start, end, step: step)
var elements : [%[1]s] = []
for element in range {
elements.append(element)
}
return elements
}
`,
typ.String(),
testCase.start,
testCase.end,
testCase.step,
)

inter, err := parseCheckAndInterpretWithOptions(t, code,
ParseCheckAndInterpretOptions{
CheckerConfig: &sema.Config{
BaseValueActivation: baseValueActivation,
},
Config: &interpreter.Config{
BaseActivation: baseActivation,
},
},
)

require.NoError(t, err)
loopElements, err := inter.Invoke("test")
require.NoError(t, err)

integerStaticType := interpreter.ConvertSemaToStaticType(
nil,
typ,
)

count := 0
iterator := (loopElements).(*interpreter.ArrayValue).Iterator(inter, interpreter.EmptyLocationRange)
for {
elem := iterator.Next(inter, interpreter.EmptyLocationRange)
if elem == nil {
break
}

AssertValuesEqual(
t,
inter,
interpreter.GetSmallIntegerValue(
int8(testCase.loopElements[count]),
integerStaticType,
),
elem,
)

count += 1
}

assert.Equal(t, len(testCase.loopElements), count)
})
}

for _, typ := range sema.AllIntegerTypes {
// Only test leaf types
switch typ {
case sema.IntegerType, sema.SignedIntegerType:
continue
}

for _, testCase := range unsignedTestCases {
runTestCase(t, typ, testCase)
}
}

for _, typ := range sema.AllSignedIntegerTypes {
// Only test leaf types
switch typ {
case sema.SignedIntegerType:
continue
}

for _, testCase := range signedTestCases {
runTestCase(t, typ, testCase)
}
}
}

0 comments on commit 7ab492f

Please sign in to comment.