diff --git a/src/terralib.lua b/src/terralib.lua index 351238da..22abe503 100644 --- a/src/terralib.lua +++ b/src/terralib.lua @@ -3158,7 +3158,7 @@ function typecheck(topexp,luaenv,simultaneousdefinitions) return newobject(s,T.fornum,variable,initial,limit,step,body) elseif s:is "forlist" then local iterator = checkexp(s.iterator) - + local typ = iterator.type if typ:ispointertostruct() then typ,iterator = typ.type, insertdereference(iterator) @@ -3167,8 +3167,16 @@ function typecheck(topexp,luaenv,simultaneousdefinitions) diag:reporterror(iterator,"expected a struct with a __for metamethod but found ",typ) return s end + local itersym = terra.newsymbol(typ, "__for_iter") + local itervar = newobject(s, T.allocvar, "__for_iter", itersym) + local iterref = newobject(s, T.var, "__for_iter", itersym) + iterref.type = typ + local iterAssign = asterraexpression( + s, + createassignment(s, List {itervar}, List {asterraexpression(s, iterator)}), + "statement") local generator = typ.metamethods.__for - + local function bodycallback(...) local exps = List() for i = 1,select("#",...) do @@ -3183,9 +3191,12 @@ function typecheck(topexp,luaenv,simultaneousdefinitions) local stats = createstatementlist(s, List { assign, body }) return terra.newquote(stats) end - - local value = invokeuserfunction(s, "invoking __for", false ,generator,terra.newquote(iterator), bodycallback) - return asterraexpression(s,value,"statement") + + local value = asterraexpression( + s, + invokeuserfunction(s, "invoking __for", false, generator, iterref, bodycallback), + "statement") + return asterraexpression(s,createstatementlist(s, List { iterAssign, value }),"statement") elseif s:is "ifstat" then local br = s.branches:map(checkcondbranch) local els = (s.orelse and checkblock(s.orelse)) diff --git a/tests/forlist3.t b/tests/forlist3.t new file mode 100644 index 00000000..ca5a48e5 --- /dev/null +++ b/tests/forlist3.t @@ -0,0 +1,35 @@ + +local callcount = 0 + +struct iter {n: int} + +iter.metamethods.__for = function(self, body) + return quote + [ body(`self.n) ] + [ body(`self.n) ] + end +end + +local callinfo = {n=0} +terra this_should_be_called_once(n: int) + [terralib.cast({} -> {}, function() callinfo.n = callinfo.n + 1 end)]() + return iter{n} +end + +local checkcalls = {n = 0, expect = {5, 5}} +local spy = terralib.cast({int} -> {}, function(val) + checkcalls.n = checkcalls.n + 1 + assert(checkcalls.expect[checkcalls.n] == val, "spy called with incorrect value") + end +) + +terra test() + for x in this_should_be_called_once(5) do + spy(x) + end +end + +test() + +assert(checkcalls.n == #checkcalls.expect, "spy called the wrong number of times") +assert(callinfo.n == 1, "body expansion called the wrong number of times.")