Skip to content

Commit

Permalink
coroc: don't register type information for functions with type parame…
Browse files Browse the repository at this point in the history
…ters (#124)

This partially fixes
#123.

This PR updates the compiler to ignore generic functions when
registering type information for the runtime and its serialization layer
(via `types.RegisterFunc` and `types.RegisterClosure`).

This expands the range of Go programs with generics that are supported.
If the input program never serializes a reference to a generic function
or a closure created within a generic function, the compiler output is
valid. For cases where the input attempts to serialize a generic
function or nested closure, a runtime check has been implemented. This
check triggers a panic if type information is unavailable for the
particular function.
  • Loading branch information
chriso authored Dec 12, 2023
2 parents a1ca737 + 069cac4 commit 8e8f3bf
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 14 deletions.
41 changes: 31 additions & 10 deletions compiler/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ func Compile(path string, options ...Option) error {
type Option func(*compiler)

type compiler struct {
prog *ssa.Program
generics map[*ssa.Function][]*ssa.Function
coroutinePkg *packages.Package

fset *token.FileSet
Expand Down Expand Up @@ -112,13 +114,26 @@ func (c *compiler) compile(path string) error {
}

log.Printf("building SSA program")
prog, _ := ssautil.AllPackages(pkgs, ssa.InstantiateGenerics|ssa.GlobalDebug)
prog.Build()
c.prog, _ = ssautil.AllPackages(pkgs, ssa.InstantiateGenerics|ssa.GlobalDebug)
c.prog.Build()

log.Printf("building call graph")
cg := vta.CallGraph(ssautil.AllFunctions(prog), cha.CallGraph(prog))
cg := vta.CallGraph(ssautil.AllFunctions(c.prog), cha.CallGraph(c.prog))

log.Printf("collecting generic instances")
c.generics = map[*ssa.Function][]*ssa.Function{}
for fn := range ssautil.AllFunctions(c.prog) {
if fn.Signature.TypeParams() != nil {
if _, ok := c.generics[fn]; !ok {
c.generics[fn] = nil
}
}
if origin := fn.Origin(); origin != nil {
c.generics[origin] = append(c.generics[origin], fn)
}
}

log.Printf("finding generic yield instantiations")
log.Printf("finding yield points")
packages.Visit(pkgs, func(p *packages.Package) bool {
if p.PkgPath == coroutinePackage {
c.coroutinePkg = p
Expand All @@ -129,10 +144,10 @@ func (c *compiler) compile(path string) error {
log.Printf("%s not imported by the module. Nothing to do", coroutinePackage)
return nil
}
yieldFunc := prog.FuncValue(c.coroutinePkg.Types.Scope().Lookup("Yield").(*types.Func))
yieldFunc := c.prog.FuncValue(c.coroutinePkg.Types.Scope().Lookup("Yield").(*types.Func))
yieldInstances := functionColors{}
for fn := range ssautil.AllFunctions(prog) {
if fn.Origin() == yieldFunc {
if fns, ok := c.generics[yieldFunc]; ok {
for _, fn := range fns {
yieldInstances[fn] = fn.Signature
}
}
Expand All @@ -149,11 +164,17 @@ func (c *compiler) compile(path string) error {
}, nil)
colorsByPkg := map[*packages.Package]functionColors{}
for fn, color := range colors {
if fn.Pkg == nil {
pkg := fn.Pkg
if pkg == nil {
if origin := fn.Origin(); origin != nil {
pkg = origin.Pkg
}
}
if pkg == nil {
return fmt.Errorf("unsupported yield function %s (Pkg is nil)", fn)
}

p := pkgsByTypes[fn.Pkg.Pkg]
p := pkgsByTypes[pkg.Pkg]
pkgColors := colorsByPkg[p]
if pkgColors == nil {
pkgColors = functionColors{}
Expand Down Expand Up @@ -306,7 +327,7 @@ func (c *compiler) compilePackage(p *packages.Package, colors functionColors) er
}
}

generateFunctypes(p, gen, colorsByFunc)
c.generateFunctypes(p, gen, colorsByFunc)

// Find all the required imports for this file.
gen = addImports(p, gen)
Expand Down
18 changes: 18 additions & 0 deletions compiler/coroutine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,24 @@ func TestCoroutineYield(t *testing.T) {
yields: []int{11},
result: 42,
},

{
name: "identity generic",
coro: func() { IdentityGeneric[int](11) },
yields: []int{11},
},

{
name: "identity generic (2)",
coro: func() { IdentityGenericInt(11) },
yields: []int{11},
},

{
name: "identity generic (3)",
coro: func() { IdentityGenericStructInt(11) },
yields: []int{11},
},
}

// This emulates the installation of function type information by the
Expand Down
17 changes: 13 additions & 4 deletions compiler/function.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"go/ast"
"go/token"
"go/types"
"log"
"maps"
"slices"
"strconv"
Expand Down Expand Up @@ -207,15 +208,23 @@ func functionPath(p *packages.Package, f *ast.FuncDecl) string {
return packagePath(p) + "." + f.Name.Name
}

func generateFunctypes(p *packages.Package, f *ast.File, colors map[ast.Node]*types.Signature) {
func (c *compiler) generateFunctypes(p *packages.Package, f *ast.File, colors map[ast.Node]*types.Signature) {
functypes := map[string]functype{}

for _, decl := range f.Decls {
switch d := decl.(type) {
case *ast.FuncDecl:
scope := &funcscope{vars: map[string]*funcvar{}}
name := functionPath(p, d)
collectFunctypes(p, name, d, scope, colors, functypes)
obj := p.TypesInfo.ObjectOf(d.Name).(*types.Func)
fn := c.prog.FuncValue(obj)
if fn.TypeParams() != nil {
// TODO: support generics. Generate type func/closure type information for each instance from: instances := c.generics[fn]
log.Printf("warning: cannot register runtime type information for generic function %s", fn)
continue
} else {
scope := &funcscope{vars: map[string]*funcvar{}}
name := functionPath(p, d)
collectFunctypes(p, name, d, scope, colors, functypes)
}
}
}

Expand Down
20 changes: 20 additions & 0 deletions compiler/testdata/coroutine.go
Original file line number Diff line number Diff line change
Expand Up @@ -557,3 +557,23 @@ func ReturnNamedValue() (out int) {
out = 42
return
}

func IdentityGeneric[T any](n T) {
coroutine.Yield[T, any](n)
}

type IdentityGenericStruct[T any] struct {
n T
}

func (i *IdentityGenericStruct[T]) Run() {
coroutine.Yield[T, any](i.n)
}

func IdentityGenericInt(n int) {
IdentityGeneric[int](n)
}

func IdentityGenericStructInt(n int) {
(&IdentityGenericStruct[int]{n: n}).Run()
}
18 changes: 18 additions & 0 deletions compiler/testdata/coroutine_durable.go
Original file line number Diff line number Diff line change
Expand Up @@ -3236,12 +3236,30 @@ func ReturnNamedValue() (_fn0 int) {
}
panic("unreachable")
}

//go:noinline
func IdentityGeneric[T any](n T) { coroutine.Yield[T, any](n) }

type IdentityGenericStruct[T any] struct {
n T
}

//go:noinline
func (i *IdentityGenericStruct[T]) Run() { coroutine.Yield[T, any](i.n) }

//go:noinline
func IdentityGenericInt(n int) { IdentityGeneric[int](n) }

//go:noinline
func IdentityGenericStructInt(n int) { (&IdentityGenericStruct[int]{n: n}).Run() }
func init() {
_types.RegisterFunc[func(n int)]("github.com/stealthrocket/coroutine/compiler/testdata.Double")
_types.RegisterFunc[func(_fn0 int)]("github.com/stealthrocket/coroutine/compiler/testdata.EvenSquareGenerator")
_types.RegisterFunc[func(_fn0 int)]("github.com/stealthrocket/coroutine/compiler/testdata.FizzBuzzIfGenerator")
_types.RegisterFunc[func(_fn0 int)]("github.com/stealthrocket/coroutine/compiler/testdata.FizzBuzzSwitchGenerator")
_types.RegisterFunc[func(n int)]("github.com/stealthrocket/coroutine/compiler/testdata.Identity")
_types.RegisterFunc[func(n int)]("github.com/stealthrocket/coroutine/compiler/testdata.IdentityGenericInt")
_types.RegisterFunc[func(n int)]("github.com/stealthrocket/coroutine/compiler/testdata.IdentityGenericStructInt")
_types.RegisterFunc[func(_ int)]("github.com/stealthrocket/coroutine/compiler/testdata.LoopBreakAndContinue")
_types.RegisterFunc[func(_fn1 int)]("github.com/stealthrocket/coroutine/compiler/testdata.MethodGenerator")
_types.RegisterFunc[func(_fn0 int) (_ int)]("github.com/stealthrocket/coroutine/compiler/testdata.NestedLoops")
Expand Down
6 changes: 6 additions & 0 deletions compiler/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@ func typeExpr(p *packages.Package, typ types.Type) ast.Expr {
c.Dir = ast.RECV
}
return c

case *types.TypeParam:
obj := t.Obj()
ident := ast.NewIdent(obj.Name())
p.TypesInfo.Defs[ident] = obj
return ident
}
panic(fmt.Sprintf("not implemented: %T", typ))
}
Expand Down
3 changes: 3 additions & 0 deletions types/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,9 @@ func (m *funcmap) RegisterAddr(addr unsafe.Pointer) (id funcid, closureType refl
if f == nil {
panic(fmt.Sprintf("function not found at address %v", addr))
}
if f.Type == nil {
panic(fmt.Sprintf("type information not registered for function %s (%p)", f.Name, addr))
}

var closureTypeID typeid
if f.Closure != nil {
Expand Down

0 comments on commit 8e8f3bf

Please sign in to comment.