Skip to content

Commit

Permalink
fix(map): create iterators in map to save state
Browse files Browse the repository at this point in the history
  • Loading branch information
tdakkota committed May 16, 2021
1 parent 8b89ed7 commit d264ec9
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 13 deletions.
1 change: 1 addition & 0 deletions assert/assert.star
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,5 @@ assert = module(
lt = _lt,
contains = _contains,
fails = _fails,
gc = _gc,
)
6 changes: 3 additions & 3 deletions itertools.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ type floatOrInt struct {
}

// Unpacker for floatOrInt.
func (p *floatOrInt) Unpack(v starlark.Value) error {
func (f *floatOrInt) Unpack(v starlark.Value) error {
switch v := v.(type) {
case starlark.Int:
p.value = v
f.value = v
return nil
case starlark.Float:
p.value = v
f.value = v
return nil
}
return fmt.Errorf("got %s, want float or int", v.Type())
Expand Down
29 changes: 20 additions & 9 deletions map.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package sun

import (
"fmt"
"runtime"

"go.starlark.net/starlark"
)
Expand Down Expand Up @@ -43,6 +44,7 @@ type mapObject struct {
thread *starlark.Thread
function starlark.Callable
iterables []starlark.Iterable
iterators []starlark.Iterator
}

func (f mapObject) String() string {
Expand All @@ -69,17 +71,12 @@ func (f mapObject) Hash() (uint32, error) {
}

func (f mapObject) Iterate() starlark.Iterator {
iterators := make([]starlark.Iterator, len(f.iterables))
for i := range iterators {
iterators[i] = f.iterables[i].Iterate()
}

// TODO(tdakkota): specialize iterator if there is only one iterable.
return &mapIter{
thread: f.thread,
function: f.function,
iterators: iterators,
buf: make([]starlark.Value, 0, len(iterators)),
iterators: f.iterators,
buf: make([]starlark.Value, 0, len(f.iterators)),
}
}

Expand All @@ -103,9 +100,23 @@ func map_(
return nil, err
}

return &mapObject{
iterators := make([]starlark.Iterator, len(iterables))
for i := range iterators {
iterators[i] = iterables[i].Iterate()
}

obj := &mapObject{
thread: thread,
function: function,
iterables: iterables,
}, nil
iterators: iterators,
}
// TODO(tdakkota): find better way to release iterators
runtime.SetFinalizer(obj, func(m *mapObject) {
for _, iterator := range m.iterators {
iterator.Done()
}
})

return obj, nil
}
7 changes: 7 additions & 0 deletions starlark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
_ "embed"
"fmt"
"regexp"
"runtime"
"strings"
"sync"

Expand Down Expand Up @@ -70,13 +71,19 @@ func LoadAssertModule() (starlark.StringDict, error) {
"matches": starlark.NewBuiltin("matches", matches),
"module": starlark.NewBuiltin("module", starlarkstruct.MakeModule),
"_freeze": starlark.NewBuiltin("freeze", freeze),
"_gc": starlark.NewBuiltin("gc", gc),
}
thread := new(starlark.Thread)
assert, assertErr = starlark.ExecFile(thread, "assert.star", assertFile, predeclared)
})
return assert, assertErr
}

func gc(*starlark.Thread, *starlark.Builtin, starlark.Tuple, []starlark.Tuple) (starlark.Value, error) {
runtime.GC()
return starlark.None, nil
}

// catch(f) evaluates f() and returns its evaluation error message
// if it failed or None if it succeeded.
func catch(thread *starlark.Thread, _ *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
Expand Down
2 changes: 1 addition & 1 deletion sun_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
)

// load implements the 'load' operation as used in the evaluator tests.
func load(thread *starlark.Thread, module string) (starlark.StringDict, error) {
func load(_ *starlark.Thread, module string) (starlark.StringDict, error) {
if module == "assert.star" {
return LoadAssertModule()
}
Expand Down
11 changes: 11 additions & 0 deletions testdata/map.star
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,15 @@ def test_main():
[(2,2,2), (3,3,3), (4,4,4), (5,5,5),],
)

def f(_a):
return _a

m = map(f, [1, 2, 3])
assert.eq(next(m), 1)
assert.eq(next(m), 2)
assert.eq(next(m), 3)
assert.fails(lambda: next(m), "iteration done")
m = None
assert.gc()

test_main()

0 comments on commit d264ec9

Please sign in to comment.