From d264ec9b693e84d499d52d43f762135832c61c23 Mon Sep 17 00:00:00 2001 From: tdakkota Date: Sun, 16 May 2021 11:41:17 +0300 Subject: [PATCH] fix(map): create iterators in map to save state --- assert/assert.star | 1 + itertools.go | 6 +++--- map.go | 29 ++++++++++++++++++++--------- starlark_test.go | 7 +++++++ sun_test.go | 2 +- testdata/map.star | 11 +++++++++++ 6 files changed, 43 insertions(+), 13 deletions(-) diff --git a/assert/assert.star b/assert/assert.star index c6e480f..4cacc98 100644 --- a/assert/assert.star +++ b/assert/assert.star @@ -48,4 +48,5 @@ assert = module( lt = _lt, contains = _contains, fails = _fails, + gc = _gc, ) diff --git a/itertools.go b/itertools.go index 9bdf58b..6bf9b03 100644 --- a/itertools.go +++ b/itertools.go @@ -15,13 +15,13 @@ type floatOrInt struct { } // Unpacker for floatOrInt. -func (p *floatOrInt) Unpack(v starlark.Value) error { +func (f *floatOrInt) Unpack(v starlark.Value) error { switch v := v.(type) { case starlark.Int: - p.value = v + f.value = v return nil case starlark.Float: - p.value = v + f.value = v return nil } return fmt.Errorf("got %s, want float or int", v.Type()) diff --git a/map.go b/map.go index 2b07460..b69bfd2 100644 --- a/map.go +++ b/map.go @@ -2,6 +2,7 @@ package sun import ( "fmt" + "runtime" "go.starlark.net/starlark" ) @@ -43,6 +44,7 @@ type mapObject struct { thread *starlark.Thread function starlark.Callable iterables []starlark.Iterable + iterators []starlark.Iterator } func (f mapObject) String() string { @@ -69,17 +71,12 @@ func (f mapObject) Hash() (uint32, error) { } func (f mapObject) Iterate() starlark.Iterator { - iterators := make([]starlark.Iterator, len(f.iterables)) - for i := range iterators { - iterators[i] = f.iterables[i].Iterate() - } - // TODO(tdakkota): specialize iterator if there is only one iterable. return &mapIter{ thread: f.thread, function: f.function, - iterators: iterators, - buf: make([]starlark.Value, 0, len(iterators)), + iterators: f.iterators, + buf: make([]starlark.Value, 0, len(f.iterators)), } } @@ -103,9 +100,23 @@ func map_( return nil, err } - return &mapObject{ + iterators := make([]starlark.Iterator, len(iterables)) + for i := range iterators { + iterators[i] = iterables[i].Iterate() + } + + obj := &mapObject{ thread: thread, function: function, iterables: iterables, - }, nil + iterators: iterators, + } + // TODO(tdakkota): find better way to release iterators + runtime.SetFinalizer(obj, func(m *mapObject) { + for _, iterator := range m.iterators { + iterator.Done() + } + }) + + return obj, nil } diff --git a/starlark_test.go b/starlark_test.go index cd0d20c..64d51f1 100644 --- a/starlark_test.go +++ b/starlark_test.go @@ -19,6 +19,7 @@ import ( _ "embed" "fmt" "regexp" + "runtime" "strings" "sync" @@ -70,6 +71,7 @@ func LoadAssertModule() (starlark.StringDict, error) { "matches": starlark.NewBuiltin("matches", matches), "module": starlark.NewBuiltin("module", starlarkstruct.MakeModule), "_freeze": starlark.NewBuiltin("freeze", freeze), + "_gc": starlark.NewBuiltin("gc", gc), } thread := new(starlark.Thread) assert, assertErr = starlark.ExecFile(thread, "assert.star", assertFile, predeclared) @@ -77,6 +79,11 @@ func LoadAssertModule() (starlark.StringDict, error) { return assert, assertErr } +func gc(*starlark.Thread, *starlark.Builtin, starlark.Tuple, []starlark.Tuple) (starlark.Value, error) { + runtime.GC() + return starlark.None, nil +} + // catch(f) evaluates f() and returns its evaluation error message // if it failed or None if it succeeded. func catch(thread *starlark.Thread, _ *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) { diff --git a/sun_test.go b/sun_test.go index 63d07a1..08eff30 100644 --- a/sun_test.go +++ b/sun_test.go @@ -9,7 +9,7 @@ import ( ) // load implements the 'load' operation as used in the evaluator tests. -func load(thread *starlark.Thread, module string) (starlark.StringDict, error) { +func load(_ *starlark.Thread, module string) (starlark.StringDict, error) { if module == "assert.star" { return LoadAssertModule() } diff --git a/testdata/map.star b/testdata/map.star index 001016d..7865f64 100644 --- a/testdata/map.star +++ b/testdata/map.star @@ -29,4 +29,15 @@ def test_main(): [(2,2,2), (3,3,3), (4,4,4), (5,5,5),], ) + def f(_a): + return _a + + m = map(f, [1, 2, 3]) + assert.eq(next(m), 1) + assert.eq(next(m), 2) + assert.eq(next(m), 3) + assert.fails(lambda: next(m), "iteration done") + m = None + assert.gc() + test_main() \ No newline at end of file