Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes (part 2) #149

Merged
merged 8 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 66 additions & 30 deletions compiler/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,25 @@ func (c *compiler) compile(path string) error {
pkgColors[fn] = color
}

// Add all packages from the module. Although these packages don't contain
// yield points, they may return closures that need to be serialized. For
// this to work, certain functions need to be marked as noinline and function
// literal types need to be registered.
//
// TODO: improve this by scanning dependencies to see if they need to be included
packages.Visit(pkgs, func(p *packages.Package) bool {
if p.Module == nil || p.Module.Dir != moduleDir {
return true
}
if p.PkgPath == coroutinePackage {
return true
}
if _, ok := colorsByPkg[p]; !ok {
colorsByPkg[p] = functionColors{}
}
return true
}, nil)

if c.onlyListFiles {
cwd, _ := os.Getwd()
for pkg := range colorsByPkg {
Expand Down Expand Up @@ -356,16 +375,43 @@ func (c *compiler) compilePackage(p *packages.Package, colors functionColors) er

case *ast.FuncDecl:
color := colorsByFunc[decl]
if color == nil && !containsColoredFuncLit(decl, colorsByFunc) {
gen.Decls = append(gen.Decls, decl)
continue

compiled := false
if color != nil || containsColoredFuncLit(decl, colorsByFunc) {
// Reject certain language features for now.
if err := unsupported(decl, p.TypesInfo); err != nil {
return err
}
scope := &scope{compiler: c, colors: colorsByFunc}
decl = scope.compileFuncDecl(p, decl, color)
compiled = true
}
// Reject certain language features for now.
if err := unsupported(decl, p.TypesInfo); err != nil {
return err

if compiled || containsFuncLit(decl) {
// If the function declaration contains function literals, we have to
// add the //go:noinline copmiler directive to prevent inlining or the
// resulting symbol name generated by the linker wouldn't match the
// predictions made in generateFunctypes.
//
// When functions are inlined, the linker creates a unique name
// combining the symbol name of the calling function and the symbol name
// of the closure. Knowing which functions will be inlined is difficult
// considering the score-base mechansim that Go uses and alterations
// like PGO, therefore we take the simple approach of disabling inlining
// instead.
//
// Note that we only need to do this for single-expression functions as
// otherwise the presence of a defer statement to unwind the coroutine
// already prevents inlining, however, it's simpler to always add the
// compiler directive.
if decl.Doc == nil {
decl.Doc = &ast.CommentGroup{}
}
decl.Doc.List = appendCommentGroup(decl.Doc.List, decl.Doc)
decl.Doc.List = appendComment(decl.Doc.List, "//go:noinline\n")
}
scope := &scope{compiler: c, colors: colorsByFunc}
gen.Decls = append(gen.Decls, scope.compileFuncDecl(p, decl, color))

gen.Decls = append(gen.Decls, decl)
}
}

Expand Down Expand Up @@ -400,6 +446,17 @@ func containsColoredFuncLit(decl *ast.FuncDecl, colorsByFunc map[ast.Node]*types
return
}

func containsFuncLit(decl *ast.FuncDecl) (yes bool) {
ast.Inspect(decl, func(n ast.Node) bool {
if _, ok := n.(*ast.FuncLit); ok {
yes = true
return false
Comment on lines +452 to +453
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I understand this statement, yes is the named return, which we set to true but then immediately override by returning false?

Copy link
Contributor Author

@chriso chriso Jun 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ast.Inspect is an iterator that expects a function to return whether it wants to continue iteration in that part of the AST. Once we find a function literal, we return false to terminate iteration.

I copied the helper function from above and changed it slightly, since we're looking for any function literal, not just colored function literals.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😓 my bad, I misread this, thanks for the precision.

}
return true
})
return
}

func addImports(p *packages.Package, f *ast.File, gen *ast.File) *ast.File {
imports := map[string]string{}

Expand Down Expand Up @@ -488,7 +545,7 @@ type scope struct {
}

func (scope *scope) compileFuncDecl(p *packages.Package, fn *ast.FuncDecl, color *types.Signature) *ast.FuncDecl {
log.Printf("compiling function %s %s", p.Name, fn.Name)
log.Printf("compiling function %s.%s", p.Name, fn.Name)

// Generate the coroutine function. At this stage, use the same name
// as the source function (and require that the caller use build tags
Expand All @@ -502,34 +559,13 @@ func (scope *scope) compileFuncDecl(p *packages.Package, fn *ast.FuncDecl, color
Body: scope.compileFuncBody(p, fnType, fn.Body, fn.Recv, color),
}

// If the function declaration contains function literals, we have to
// add the //go:noinline copmiler directive to prevent inlining or the
// resulting symbol name generated by the linker wouldn't match the
// predictions made in generateFunctypes.
//
// When functions are inlined, the linker creates a unique name
// combining the symbol name of the calling function and the symbol name
// of the closure. Knowing which functions will be inlined is difficult
// considering the score-base mechansim that Go uses and alterations
// like PGO, therefore we take the simple approach of disabling inlining
// instead.
//
// Note that we only need to do this for single-expression functions as
// otherwise the presence of a defer statement to unwind the coroutine
// already prevents inlining, however, it's simpler to always add the
// compiler directive.
gen.Doc.List = appendCommentGroup(gen.Doc.List, fn.Doc)
gen.Doc.List = appendComment(gen.Doc.List, "//go:noinline\n")

if color != nil && !isExpr(gen.Body) {
scope.colors[gen] = color
}
return gen
}

func (scope *scope) compileFuncLit(p *packages.Package, fn *ast.FuncLit, color *types.Signature) *ast.FuncLit {
log.Printf("compiling function literal %s", p.Name)

gen := &ast.FuncLit{
Type: funcTypeWithNamedResults(p, fn),
Body: scope.compileFuncBody(p, fn.Type, fn.Body, nil, color),
Expand Down
12 changes: 12 additions & 0 deletions compiler/coroutine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,18 @@ func TestCoroutineYield(t *testing.T) {
coro: func() { InterfaceEmbedded() },
yields: []int{1, 1, 1},
},

{
name: "closure in separate package",
coro: func() { ClosureInSeparatePackage(3) },
yields: []int{3, 4, 5},
},

{
name: "closure via generic with struct type param",
coro: func() { GenericStructClosure(3) },
yields: []int{3, 5, 7},
},
}

// This emulates the installation of function type information by the
Expand Down
35 changes: 19 additions & 16 deletions compiler/function.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"go/ast"
"go/token"
"go/types"
"log"
"maps"
"slices"
"strconv"
Expand Down Expand Up @@ -283,7 +282,6 @@ func (c *compiler) generateFunctypes(p *packages.Package, f *ast.File, colors ma
if len(instances) == 0 {
// This can occur when a generic function is never instantiated/used,
// or when it's instantiated in a package not known to the compiler.
log.Printf("warning: cannot register runtime type information for generic function %s", fn)
continue
}
for _, instance := range instances {
Expand Down Expand Up @@ -489,20 +487,7 @@ func (g *genericInstance) typeArgOf(param *types.TypeParam) types.Type {
}

func (g *genericInstance) partial() bool {
sig := g.instance.Signature
params := sig.Params()
for i := 0; i < params.Len(); i++ {
if _, ok := params.At(i).Type().(*types.TypeParam); ok {
return true
}
}
results := sig.Results()
for i := 0; i < results.Len(); i++ {
if _, ok := results.At(i).Type().(*types.TypeParam); ok {
return true
}
}
return false
return containsTypeParam(g.instance.Signature)
}

func (g *genericInstance) scanRecvTypeArgs(fn func(*types.TypeParam, int, types.Type)) {
Expand Down Expand Up @@ -585,6 +570,24 @@ func writeGoShape(b *strings.Builder, tt types.Type) {
} else {
panic(fmt.Sprintf("not implemented: %#v (%T)", tt, t))
}
case *types.Struct:
b.WriteString("struct { ")
for i := 0; i < t.NumFields(); i++ {
if i > 0 {
b.WriteString("; ")
}
f := t.Field(i)

if f.Embedded() {
panic(fmt.Sprintf("not implemented: %#v (%T)", tt, t))
}
b.WriteString(f.Pkg().Path())
b.WriteByte('.')
b.WriteString(f.Name())
b.WriteByte(' ')
b.WriteString(f.Type().String())
}
b.WriteString(" }")
default:
panic(fmt.Sprintf("not implemented: %#v (%T)", tt, t))
}
Expand Down
36 changes: 36 additions & 0 deletions compiler/testdata/coroutine.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"unsafe"

"github.com/dispatchrun/coroutine"
"github.com/dispatchrun/coroutine/compiler/testdata/subpkg"
)

//go:generate coroc
Expand Down Expand Up @@ -741,3 +742,38 @@ func InterfaceEmbedded() {
coroutine.Yield[int, any](x.Value())
coroutine.Yield[int, any](x.Value())
}

func ClosureInSeparatePackage(n int) {
adder := subpkg.Adder(n)
for i := 0; i < n; i++ {
coroutine.Yield[int, any](adder(i))
}
}

func GenericStructClosure(n int) {
impl := AdderImpl{base: n, mul: 2}

boxed := &GenericAdder[AdderImpl]{adder: impl}
for i := 0; i < n; i++ {
coroutine.Yield[int, any](boxed.Add(i))
}
}

type adder interface {
Add(int) int
}

type AdderImpl struct {
base int
mul int
}

func (a AdderImpl) Add(n int) int { return a.base + n*a.mul }

var _ adder = AdderImpl{}

type GenericAdder[A adder] struct{ adder A }

func (b *GenericAdder[A]) Add(n int) int {
return b.adder.Add(n)
}
Loading
Loading