diff --git a/runtime/interpreter/value.go b/runtime/interpreter/value.go index f813b80e6a..8f92804ff8 100644 --- a/runtime/interpreter/value.go +++ b/runtime/interpreter/value.go @@ -1465,7 +1465,11 @@ func (v *StringValue) GetMember(interpreter *Interpreter, locationRange Location panic(errors.NewUnreachableError()) } - return v.Split(invocation.Interpreter, invocation.LocationRange, separator.Str) + return v.Split( + invocation.Interpreter, + invocation.LocationRange, + separator, + ) }, ) @@ -1545,16 +1549,17 @@ func (v *StringValue) ToLower(interpreter *Interpreter) *StringValue { ) } -func (v *StringValue) Split(inter *Interpreter, _ LocationRange, separator string) Value { +func (v *StringValue) Split(inter *Interpreter, locationRange LocationRange, separator *StringValue) *ArrayValue { - // Meter computation as if the string was iterated. - // i.e: linear search to find the split points. This is an estimate. - inter.ReportComputation(common.ComputationKindLoop, uint(len(v.Str))) + if len(separator.Str) == 0 { + return v.Explode(inter, locationRange) + } - split := strings.Split(v.Str, separator) + count := v.count(inter, locationRange, separator) + 1 - var index int - count := len(split) + partIndex := 0 + + remaining := v return NewArrayValueWithIterator( inter, @@ -1562,12 +1567,66 @@ func (v *StringValue) Split(inter *Interpreter, _ LocationRange, separator strin common.ZeroAddress, uint64(count), func() Value { - if index >= count { + + inter.ReportComputation(common.ComputationKindLoop, 1) + + if partIndex >= count { return nil } - str := split[index] - index++ + // Set the remainder as the last part + if partIndex == count-1 { + partIndex++ + return remaining + } + + separatorCharacterIndex := remaining.indexOf(inter, separator) + if separatorCharacterIndex < 0 { + return nil + } + + partIndex++ + + part := remaining.slice( + 0, + separatorCharacterIndex, + locationRange, + ) + + remaining = remaining.slice( + separatorCharacterIndex+separator.Length(), + remaining.Length(), + locationRange, + ) + + return part + }, + ) +} + +// Explode returns a Cadence array of type [String], where each element is a single character of the string +func (v *StringValue) Explode(inter *Interpreter, locationRange LocationRange) *ArrayValue { + + iterator := v.Iterator(inter, locationRange) + + return NewArrayValueWithIterator( + inter, + VarSizedArrayOfStringType, + common.ZeroAddress, + uint64(v.Length()), + func() Value { + value := iterator.Next(inter, locationRange) + if value == nil { + return nil + } + + character, ok := value.(CharacterValue) + if !ok { + panic(errors.NewUnreachableError()) + } + + str := character.Str + return NewStringValue( inter, common.NewStringMemoryUsage(len(str)), diff --git a/runtime/tests/interpreter/metering_test.go b/runtime/tests/interpreter/metering_test.go index 9ca2d40026..776dcecc94 100644 --- a/runtime/tests/interpreter/metering_test.go +++ b/runtime/tests/interpreter/metering_test.go @@ -657,6 +657,6 @@ func TestInterpretStdlibComputationMetering(t *testing.T) { _, err = inter.Invoke("main") require.NoError(t, err) - assert.Equal(t, uint(10), computationMeteredValues[common.ComputationKindLoop]) + assert.Equal(t, uint(58), computationMeteredValues[common.ComputationKindLoop]) }) } diff --git a/runtime/tests/interpreter/string_test.go b/runtime/tests/interpreter/string_test.go index 6a276d245c..d6250b8631 100644 --- a/runtime/tests/interpreter/string_test.go +++ b/runtime/tests/interpreter/string_test.go @@ -504,97 +504,114 @@ func TestInterpretStringSplit(t *testing.T) { t.Parallel() - inter := parseCheckAndInterpret(t, ` - fun split(): [String] { - return "👪////❤️".split(separator: "////") - } - fun splitBySpace(): [String] { - return "👪 ❤️ Abc6 ;123".split(separator: " ") - } - fun splitWithUnicodeEquivalence(): [String] { - return "Caf\u{65}\u{301}ABc".split(separator: "\u{e9}") - } - fun testEmptyString(): [String] { - return "".split(separator: "//") - } - fun testNoMatch(): [String] { - return "pqrS;asdf".split(separator: ";;") - } - `) + type test struct { + str string + sep string + result []string + } - testCase := func(t *testing.T, funcName string, expected *interpreter.ArrayValue) { - t.Run(funcName, func(t *testing.T) { - result, err := inter.Invoke(funcName) - require.NoError(t, err) + var abcd = "abcd" + var faces = "☺☻☹" + var commas = "1,2,3,4" + var dots = "1....2....3....4" - RequireValuesEqual( - t, - inter, - expected, - result, + tests := []test{ + {"", "", []string{}}, + {abcd, "", []string{"a", "b", "c", "d"}}, + {faces, "", []string{"☺", "☻", "☹"}}, + {"☺�☹", "", []string{"☺", "�", "☹"}}, + {abcd, "a", []string{"", "bcd"}}, + {abcd, "z", []string{"abcd"}}, + {commas, ",", []string{"1", "2", "3", "4"}}, + {dots, "...", []string{"1", ".2", ".3", ".4"}}, + {faces, "☹", []string{"☺☻", ""}}, + {faces, "~", []string{faces}}, + { + "\\u{1F46A}////\\u{2764}\\u{FE0F}", + "////", + []string{"\U0001F46A", "\u2764\uFE0F"}, + }, + { + "\\u{1F46A} \\u{2764}\\u{FE0F} Abc6 ;123", + " ", + []string{"\U0001F46A", "\u2764\uFE0F", "Abc6", ";123"}, + }, + { + "Caf\\u{65}\\u{301}ABc", + "\\u{e9}", + []string{"Caf", "ABc"}, + }, + { + "", + "//", + []string{""}, + }, + { + "pqrS;asdf", + ";;", + []string{"pqrS;asdf"}, + }, + { + // U+1F476 U+1F3FB is 👶🏻 + " \\u{1F476}\\u{1F3FB} ascii \\u{D}\\u{A}", + " ", + []string{"", "\U0001F476\U0001F3FB", "ascii", "\u000D\u000A"}, + }, + // 🇪🇸🇸🇪🇪🇪 is "ES", "SE", "EE" + { + "\\u{1F1EA}\\u{1F1F8}\\u{1F1F8}\\u{1F1EA}\\u{1F1EA}\\u{1F1EA}", + "\\u{1F1F8}\\u{1F1EA}", + []string{"\U0001F1EA\U0001F1F8", "\U0001F1EA\U0001F1EA"}, + }, + } + + runTest := func(test test) { + + name := fmt.Sprintf("%s, %s", test.str, test.sep) + + t.Run(name, func(t *testing.T) { + + t.Parallel() + + inter := parseCheckAndInterpret(t, + fmt.Sprintf( + ` + fun test(): [String] { + let s = "%s" + return s.split(separator: "%s") + } + `, + test.str, + test.sep, + ), ) + + value, err := inter.Invoke("test") + require.NoError(t, err) + + require.IsType(t, &interpreter.ArrayValue{}, value) + actual := value.(*interpreter.ArrayValue) + + require.Equal(t, len(test.result), actual.Count()) + + for partIndex, expected := range test.result { + actualPart := actual.Get( + inter, + interpreter.EmptyLocationRange, + partIndex, + ) + + require.IsType(t, &interpreter.StringValue{}, actualPart) + actualPartString := actualPart.(*interpreter.StringValue) + + require.Equal(t, expected, actualPartString.Str) + } }) } - varSizedStringType := &interpreter.VariableSizedStaticType{ - Type: interpreter.PrimitiveStaticTypeString, + for _, test := range tests { + runTest(test) } - - testCase(t, - "split", - interpreter.NewArrayValue( - inter, - interpreter.EmptyLocationRange, - varSizedStringType, - common.ZeroAddress, - interpreter.NewUnmeteredStringValue("👪"), - interpreter.NewUnmeteredStringValue("❤️"), - ), - ) - testCase(t, - "splitBySpace", - interpreter.NewArrayValue( - inter, - interpreter.EmptyLocationRange, - varSizedStringType, - common.ZeroAddress, - interpreter.NewUnmeteredStringValue("👪"), - interpreter.NewUnmeteredStringValue("❤️"), - interpreter.NewUnmeteredStringValue("Abc6"), - interpreter.NewUnmeteredStringValue(";123"), - ), - ) - testCase(t, - "splitWithUnicodeEquivalence", - interpreter.NewArrayValue( - inter, - interpreter.EmptyLocationRange, - varSizedStringType, - common.ZeroAddress, - interpreter.NewUnmeteredStringValue("Caf"), - interpreter.NewUnmeteredStringValue("ABc"), - ), - ) - testCase(t, - "testEmptyString", - interpreter.NewArrayValue( - inter, - interpreter.EmptyLocationRange, - varSizedStringType, - common.ZeroAddress, - interpreter.NewUnmeteredStringValue(""), - ), - ) - testCase(t, - "testNoMatch", - interpreter.NewArrayValue( - inter, - interpreter.EmptyLocationRange, - varSizedStringType, - common.ZeroAddress, - interpreter.NewUnmeteredStringValue("pqrS;asdf"), - ), - ) } func TestInterpretStringReplaceAll(t *testing.T) {