Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes #148

Merged
merged 9 commits into from
Jun 17, 2024
Merged

Fixes #148

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions compiler/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ func (c *compiler) compilePackage(p *packages.Package, colors functionColors) er
c.generateFunctypes(p, gen, colorsByFunc)

// Find all the required imports for this file.
gen = addImports(p, gen)
gen = addImports(p, f, gen)

outputPath := strings.TrimSuffix(p.GoFiles[i], ".go")
outputPath += "_durable.go"
Expand Down Expand Up @@ -400,7 +400,7 @@ func containsColoredFuncLit(decl *ast.FuncDecl, colorsByFunc map[ast.Node]*types
return
}

func addImports(p *packages.Package, gen *ast.File) *ast.File {
func addImports(p *packages.Package, f *ast.File, gen *ast.File) *ast.File {
imports := map[string]string{}

ast.Inspect(gen, func(n ast.Node) bool {
Expand Down Expand Up @@ -438,6 +438,15 @@ func addImports(p *packages.Package, gen *ast.File) *ast.File {
}

importspecs := make([]ast.Spec, 0, len(imports))

// Preserve underscore (side effect) imports.
for _, imp := range f.Imports {
if imp.Name != nil && imp.Name.Name == "_" {
importspecs = append(importspecs, imp)
}
}

// Add imports for all packages used in the file.
for name, path := range imports {
importspecs = append(importspecs, &ast.ImportSpec{
Name: ast.NewIdent(name),
Expand Down Expand Up @@ -526,6 +535,8 @@ func (scope *scope) compileFuncLit(p *packages.Package, fn *ast.FuncLit, color *
Body: scope.compileFuncBody(p, fn.Type, fn.Body, nil, color),
}

p.TypesInfo.Types[gen] = types.TypeAndValue{Type: p.TypesInfo.TypeOf(fn)}

if !isExpr(gen.Body) {
scope.colors[gen] = color
}
Expand Down
27 changes: 27 additions & 0 deletions compiler/coroutine_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package compiler

import (
"math"
"reflect"
"slices"
"testing"

Expand Down Expand Up @@ -220,6 +222,11 @@ func TestCoroutineYield(t *testing.T) {
coro: func() { StructClosure(3) },
yields: []int{10, 100, 1000, 11, 101, 1000, 12, 102, 1000},
},
{
name: "generic closure capturing receiver and param",
coro: func() { StructGenericClosure(3) },
yields: []int{10, 100, 1000, 11, 101, 1000, 12, 102, 1000},
},
{
name: "generic function",
coro: func() { IdentityGenericInt(11) },
Expand Down Expand Up @@ -255,6 +262,26 @@ func TestCoroutineYield(t *testing.T) {
coro: func() { RangeOverInt(3) },
yields: []int{0, 1, 2},
},

{
name: "reflect type",
coro: func() {
ReflectType(reflect.TypeFor[uint8](), reflect.TypeFor[uint16]())
},
yields: []int{math.MaxUint8, math.MaxUint16},
},

{
name: "ellipsis closure",
coro: func() { EllipsisClosure(3) },
yields: []int{-1, 0, 1, 2},
},

{
name: "interface embedded",
coro: func() { InterfaceEmbedded() },
yields: []int{1, 1, 1},
},
}

// This emulates the installation of function type information by the
Expand Down
7 changes: 6 additions & 1 deletion compiler/function.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func collectFunctypes(p *packages.Package, name string, fn ast.Node, scope *func
typeArg = g.typeArgOf
}

signature := copyFunctionType(functionTypeOf(fn))
signature := copyFunctionType(funcTypeWithNamedResults(p, fn))
signature.TypeParams = nil

recv := copyFieldList(functionRecvOf(fn))
Expand Down Expand Up @@ -182,6 +182,11 @@ func collectFunctypes(p *packages.Package, name string, fn ast.Node, scope *func
fieldName := ast.NewIdent(fmt.Sprintf("X%d", i))
fieldType := freeVar.typ

// Convert ellipsis into slice (...X => []X).
if e, ok := fieldType.(*ast.Ellipsis); ok {
fieldType = &ast.ArrayType{Elt: e.Elt}
}

// The Go compiler uses a more advanced mechanism to determine if a
// free variable should be captured by pointer or by value: it looks
// at whether the variable is reassigned, its address taken, and if
Expand Down
79 changes: 79 additions & 0 deletions compiler/testdata/coroutine.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
package testdata

import (
"math"
"reflect"
"time"
"unsafe"

Expand Down Expand Up @@ -586,6 +588,34 @@ func StructClosure(n int) {
}
}

type GenericBox[T integer] struct {
x T
}

func (b *GenericBox[T]) YieldAndInc() {
coroutine.Yield[T, any](b.x)
b.x++
}

func (b *GenericBox[T]) Closure(y T) func(T) {
return func(z T) {
coroutine.Yield[T, any](b.x)
coroutine.Yield[T, any](y)
coroutine.Yield[T, any](z)
b.x++
y++
z++ // mutation is lost
}
}

func StructGenericClosure(n int) {
box := GenericBox[int]{10}
fn := box.Closure(100)
for i := 0; i < n; i++ {
fn(1000)
}
}

func IdentityGeneric[T any](n T) {
coroutine.Yield[T, any](n)
}
Expand Down Expand Up @@ -662,3 +692,52 @@ func RangeOverInt(n int) {
coroutine.Yield[int, any](i)
}
}

func ReflectType(types ...reflect.Type) {
for _, t := range types {
v := reflect.New(t).Elem()
if !v.CanUint() {
panic("expected uint type")
}
v.SetUint(math.MaxUint64)
coroutine.Yield[int, any](int(v.Uint()))
}
}

func MakeEllipsisClosure(ints ...int) func() {
return func() {
x := ints
for _, v := range x {
coroutine.Yield[int, any](v)
}
}
}

func EllipsisClosure(n int) {
ints := make([]int, n)
for i := range ints {
ints[i] = i
}
c := MakeEllipsisClosure(ints...)
coroutine.Yield[int, any](-1)
c()
}

type innerInterface interface {
Value() int
}

type innerInterfaceImpl int

func (i innerInterfaceImpl) Value() int { return int(i) }

type outerInterface interface {
innerInterface
}

func InterfaceEmbedded() {
var x interface{ outerInterface } = innerInterfaceImpl(1)
coroutine.Yield[int, any](x.Value())
coroutine.Yield[int, any](x.Value())
coroutine.Yield[int, any](x.Value())
}
Loading
Loading