Skip to content

Commit

Permalink
Write a test where both the receiver and a param are captured
Browse files Browse the repository at this point in the history
  • Loading branch information
chriso committed Dec 14, 2023
1 parent 3ef4d26 commit ad6ce04
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 41 deletions.
4 changes: 2 additions & 2 deletions compiler/coroutine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,8 @@ func TestCoroutineYield(t *testing.T) {

{
name: "closure capturing receiver and param",
coro: func() { StructClosure(0, 3) },
yields: []int{-1, 0, 1, 2},
coro: func() { StructClosure(3) },
yields: []int{-1, 10, 100, 1000, 11, 101, 1000, 12, 102, 1000},
},

{
Expand Down
27 changes: 17 additions & 10 deletions compiler/testdata/coroutine.go
Original file line number Diff line number Diff line change
Expand Up @@ -579,21 +579,28 @@ func IdentityGenericStructInt(n int) {
}

type Box struct {
n int
x int
}

func (b *Box) Closure() func() {
func (b *Box) Closure(y int) func(int) {
// Force compilation of this method and the closure within.
// Remove once #84 is fixed.
coroutine.Yield[int, any](-1)
return func() {
coroutine.Yield[int, any](b.n)
b.n++

return func(z int) {
coroutine.Yield[int, any](b.x)
coroutine.Yield[int, any](y)
coroutine.Yield[int, any](z)
b.x++
y++
z++ // mutation is lost
}
}

func StructClosure(n, count int) {
box := Box{n}
fn := box.Closure()
for i := 0; i < count; i++ {
fn()
func StructClosure(n int) {
box := Box{10}
fn := box.Closure(100)
for i := 0; i < n; i++ {
fn(1000)
}
}
80 changes: 51 additions & 29 deletions compiler/testdata/coroutine_durable.go
Original file line number Diff line number Diff line change
Expand Up @@ -3254,24 +3254,27 @@ func IdentityGenericInt(n int) { IdentityGeneric[int](n) }
func IdentityGenericStructInt(n int) { (&IdentityGenericStruct[int]{n: n}).Run() }

type Box struct {
n int
x int
}

//go:noinline
func (_fn0 *Box) Closure() (_ func()) {
func (_fn0 *Box) Closure(_fn1 int) (_ func(int)) {
_c := coroutine.LoadContext[int, any]()
var _f1 *struct {
IP int
X0 *Box
X1 int
} = coroutine.Push[struct {
IP int
X0 *Box
X1 int
}](&_c.Stack)
if _f1.IP == 0 {
*_f1 = struct {
IP int
X0 *Box
}{X0: _fn0}
X1 int
}{X0: _fn0, X1: _fn1}
}
defer func() {
if !_c.Unwinding() {
Expand All @@ -3280,21 +3283,26 @@ func (_fn0 *Box) Closure() (_ func()) {
}()
switch {
case _f1.IP < 2:

coroutine.Yield[int, any](-1)
_f1.IP = 2
fallthrough
case _f1.IP < 3:
return func() {

return func(_fn0 int) {
_c := coroutine.LoadContext[int, any]()
var _f0 *struct {
IP int
X0 int
} = coroutine.Push[struct {
IP int
X0 int
}](&_c.Stack)
if _f0.IP == 0 {
*_f0 = struct {
IP int
}{}
X0 int
}{X0: _fn0}
}
defer func() {
if !_c.Unwinding() {
Expand All @@ -3303,45 +3311,58 @@ func (_fn0 *Box) Closure() (_ func()) {
}()
switch {
case _f0.IP < 2:
coroutine.Yield[int, any](_f1.X0.n)
coroutine.Yield[int, any](_f1.X0.x)
_f0.IP = 2
fallthrough
case _f0.IP < 3:
coroutine.Yield[int, any](_f1.X1)
_f0.IP = 3
fallthrough
case _f0.IP < 4:
coroutine.Yield[int, any](_f0.X0)
_f0.IP = 4
fallthrough
case _f0.IP < 5:
_f1.X0.
n++
x++
_f0.IP = 5
fallthrough
case _f0.IP < 6:
_f1.X1++
_f0.IP = 6
fallthrough
case _f0.IP < 7:
_f0.X0++
}
}
}
panic("unreachable")
}

//go:noinline
func StructClosure(_fn0, _fn1 int) {
func StructClosure(_fn0 int) {
_c := coroutine.LoadContext[int, any]()
var _f0 *struct {
IP int
X0 int
X1 int
X2 Box
X3 func()
X4 int
X1 Box
X2 func(int)
X3 int
} = coroutine.Push[struct {
IP int
X0 int
X1 int
X2 Box
X3 func()
X4 int
X1 Box
X2 func(int)
X3 int
}](&_c.Stack)
if _f0.IP == 0 {
*_f0 = struct {
IP int
X0 int
X1 int
X2 Box
X3 func()
X4 int
}{X0: _fn0, X1: _fn1}
X1 Box
X2 func(int)
X3 int
}{X0: _fn0}
}
defer func() {
if !_c.Unwinding() {
Expand All @@ -3350,33 +3371,34 @@ func StructClosure(_fn0, _fn1 int) {
}()
switch {
case _f0.IP < 2:
_f0.X2 = Box{_f0.X0}
_f0.X1 = Box{10}
_f0.IP = 2
fallthrough
case _f0.IP < 3:
_f0.X3 = _f0.X2.Closure()
_f0.X2 = _f0.X1.Closure(100)
_f0.IP = 3
fallthrough
case _f0.IP < 5:
switch {
case _f0.IP < 4:
_f0.X4 = 0
_f0.X3 = 0
_f0.IP = 4
fallthrough
case _f0.IP < 5:
for ; _f0.X4 < _f0.X1; _f0.X4, _f0.IP = _f0.X4+1, 4 {
_f0.X3()
for ; _f0.X3 < _f0.X0; _f0.X3, _f0.IP = _f0.X3+1, 4 {
_f0.X2(1000)
}
}
}
}
func init() {
_types.RegisterFunc[func() (_ func())]("github.com/stealthrocket/coroutine/compiler/testdata.(*Box).Closure")
_types.RegisterClosure[func(), struct {
_types.RegisterFunc[func(_fn1 int) (_ func(int))]("github.com/stealthrocket/coroutine/compiler/testdata.(*Box).Closure")
_types.RegisterClosure[func(_fn0 int), struct {
F uintptr
X0 *struct {
IP int
X0 *Box
X1 int
}
}]("github.com/stealthrocket/coroutine/compiler/testdata.(*Box).Closure.func2")
_types.RegisterFunc[func(_fn1 int)]("github.com/stealthrocket/coroutine/compiler/testdata.(*MethodGeneratorState).MethodGenerator")
Expand Down Expand Up @@ -3483,7 +3505,7 @@ func init() {
_types.RegisterFunc[func(_fn0 int)]("github.com/stealthrocket/coroutine/compiler/testdata.SquareGenerator")
_types.RegisterFunc[func(_fn0 int)]("github.com/stealthrocket/coroutine/compiler/testdata.SquareGeneratorTwice")
_types.RegisterFunc[func(_fn0 int)]("github.com/stealthrocket/coroutine/compiler/testdata.SquareGeneratorTwiceLoop")
_types.RegisterFunc[func(_fn0, _fn1 int)]("github.com/stealthrocket/coroutine/compiler/testdata.StructClosure")
_types.RegisterFunc[func(_fn0 int)]("github.com/stealthrocket/coroutine/compiler/testdata.StructClosure")
_types.RegisterFunc[func(_ int)]("github.com/stealthrocket/coroutine/compiler/testdata.TypeSwitchingGenerator")
_types.RegisterFunc[func(_fn0 int)]("github.com/stealthrocket/coroutine/compiler/testdata.VarArgs")
_types.RegisterFunc[func(_fn0 *int, _fn1, _fn2 int)]("github.com/stealthrocket/coroutine/compiler/testdata.YieldAndDeferAssign")
Expand Down

0 comments on commit ad6ce04

Please sign in to comment.