diff --git a/compiler/compile.go b/compiler/compile.go index f7194e8..abd821f 100644 --- a/compiler/compile.go +++ b/compiler/compile.go @@ -47,6 +47,8 @@ func Compile(path string, options ...Option) error { type Option func(*compiler) type compiler struct { + prog *ssa.Program + generics map[*ssa.Function][]*ssa.Function coroutinePkg *packages.Package fset *token.FileSet @@ -112,13 +114,26 @@ func (c *compiler) compile(path string) error { } log.Printf("building SSA program") - prog, _ := ssautil.AllPackages(pkgs, ssa.InstantiateGenerics|ssa.GlobalDebug) - prog.Build() + c.prog, _ = ssautil.AllPackages(pkgs, ssa.InstantiateGenerics|ssa.GlobalDebug) + c.prog.Build() log.Printf("building call graph") - cg := vta.CallGraph(ssautil.AllFunctions(prog), cha.CallGraph(prog)) + cg := vta.CallGraph(ssautil.AllFunctions(c.prog), cha.CallGraph(c.prog)) + + log.Printf("collecting generic instances") + c.generics = map[*ssa.Function][]*ssa.Function{} + for fn := range ssautil.AllFunctions(c.prog) { + if fn.Signature.TypeParams() != nil { + if _, ok := c.generics[fn]; !ok { + c.generics[fn] = nil + } + } + if origin := fn.Origin(); origin != nil { + c.generics[origin] = append(c.generics[origin], fn) + } + } - log.Printf("finding generic yield instantiations") + log.Printf("finding yield points") packages.Visit(pkgs, func(p *packages.Package) bool { if p.PkgPath == coroutinePackage { c.coroutinePkg = p @@ -129,10 +144,10 @@ func (c *compiler) compile(path string) error { log.Printf("%s not imported by the module. Nothing to do", coroutinePackage) return nil } - yieldFunc := prog.FuncValue(c.coroutinePkg.Types.Scope().Lookup("Yield").(*types.Func)) + yieldFunc := c.prog.FuncValue(c.coroutinePkg.Types.Scope().Lookup("Yield").(*types.Func)) yieldInstances := functionColors{} - for fn := range ssautil.AllFunctions(prog) { - if fn.Origin() == yieldFunc { + if fns, ok := c.generics[yieldFunc]; ok { + for _, fn := range fns { yieldInstances[fn] = fn.Signature } } @@ -149,11 +164,17 @@ func (c *compiler) compile(path string) error { }, nil) colorsByPkg := map[*packages.Package]functionColors{} for fn, color := range colors { - if fn.Pkg == nil { + pkg := fn.Pkg + if pkg == nil { + if origin := fn.Origin(); origin != nil { + pkg = origin.Pkg + } + } + if pkg == nil { return fmt.Errorf("unsupported yield function %s (Pkg is nil)", fn) } - p := pkgsByTypes[fn.Pkg.Pkg] + p := pkgsByTypes[pkg.Pkg] pkgColors := colorsByPkg[p] if pkgColors == nil { pkgColors = functionColors{} @@ -306,7 +327,7 @@ func (c *compiler) compilePackage(p *packages.Package, colors functionColors) er } } - generateFunctypes(p, gen, colorsByFunc) + c.generateFunctypes(p, gen, colorsByFunc) // Find all the required imports for this file. gen = addImports(p, gen) diff --git a/compiler/coroutine_test.go b/compiler/coroutine_test.go index 6282738..d6e92c8 100644 --- a/compiler/coroutine_test.go +++ b/compiler/coroutine_test.go @@ -214,6 +214,24 @@ func TestCoroutineYield(t *testing.T) { yields: []int{11}, result: 42, }, + + { + name: "identity generic", + coro: func() { IdentityGeneric[int](11) }, + yields: []int{11}, + }, + + { + name: "identity generic (2)", + coro: func() { IdentityGenericInt(11) }, + yields: []int{11}, + }, + + { + name: "identity generic (3)", + coro: func() { IdentityGenericStructInt(11) }, + yields: []int{11}, + }, } // This emulates the installation of function type information by the diff --git a/compiler/function.go b/compiler/function.go index 65acdc9..e468bd0 100644 --- a/compiler/function.go +++ b/compiler/function.go @@ -5,6 +5,7 @@ import ( "go/ast" "go/token" "go/types" + "log" "maps" "slices" "strconv" @@ -207,15 +208,23 @@ func functionPath(p *packages.Package, f *ast.FuncDecl) string { return packagePath(p) + "." + f.Name.Name } -func generateFunctypes(p *packages.Package, f *ast.File, colors map[ast.Node]*types.Signature) { +func (c *compiler) generateFunctypes(p *packages.Package, f *ast.File, colors map[ast.Node]*types.Signature) { functypes := map[string]functype{} for _, decl := range f.Decls { switch d := decl.(type) { case *ast.FuncDecl: - scope := &funcscope{vars: map[string]*funcvar{}} - name := functionPath(p, d) - collectFunctypes(p, name, d, scope, colors, functypes) + obj := p.TypesInfo.ObjectOf(d.Name).(*types.Func) + fn := c.prog.FuncValue(obj) + if fn.TypeParams() != nil { + // TODO: support generics. Generate type func/closure type information for each instance from: instances := c.generics[fn] + log.Printf("warning: cannot register runtime type information for generic function %s", fn) + continue + } else { + scope := &funcscope{vars: map[string]*funcvar{}} + name := functionPath(p, d) + collectFunctypes(p, name, d, scope, colors, functypes) + } } } diff --git a/compiler/testdata/coroutine.go b/compiler/testdata/coroutine.go index f3e53e2..92d7218 100644 --- a/compiler/testdata/coroutine.go +++ b/compiler/testdata/coroutine.go @@ -557,3 +557,23 @@ func ReturnNamedValue() (out int) { out = 42 return } + +func IdentityGeneric[T any](n T) { + coroutine.Yield[T, any](n) +} + +type IdentityGenericStruct[T any] struct { + n T +} + +func (i *IdentityGenericStruct[T]) Run() { + coroutine.Yield[T, any](i.n) +} + +func IdentityGenericInt(n int) { + IdentityGeneric[int](n) +} + +func IdentityGenericStructInt(n int) { + (&IdentityGenericStruct[int]{n: n}).Run() +} diff --git a/compiler/testdata/coroutine_durable.go b/compiler/testdata/coroutine_durable.go index 5203438..f13ce47 100644 --- a/compiler/testdata/coroutine_durable.go +++ b/compiler/testdata/coroutine_durable.go @@ -3236,12 +3236,30 @@ func ReturnNamedValue() (_fn0 int) { } panic("unreachable") } + +//go:noinline +func IdentityGeneric[T any](n T) { coroutine.Yield[T, any](n) } + +type IdentityGenericStruct[T any] struct { + n T +} + +//go:noinline +func (i *IdentityGenericStruct[T]) Run() { coroutine.Yield[T, any](i.n) } + +//go:noinline +func IdentityGenericInt(n int) { IdentityGeneric[int](n) } + +//go:noinline +func IdentityGenericStructInt(n int) { (&IdentityGenericStruct[int]{n: n}).Run() } func init() { _types.RegisterFunc[func(n int)]("github.com/stealthrocket/coroutine/compiler/testdata.Double") _types.RegisterFunc[func(_fn0 int)]("github.com/stealthrocket/coroutine/compiler/testdata.EvenSquareGenerator") _types.RegisterFunc[func(_fn0 int)]("github.com/stealthrocket/coroutine/compiler/testdata.FizzBuzzIfGenerator") _types.RegisterFunc[func(_fn0 int)]("github.com/stealthrocket/coroutine/compiler/testdata.FizzBuzzSwitchGenerator") _types.RegisterFunc[func(n int)]("github.com/stealthrocket/coroutine/compiler/testdata.Identity") + _types.RegisterFunc[func(n int)]("github.com/stealthrocket/coroutine/compiler/testdata.IdentityGenericInt") + _types.RegisterFunc[func(n int)]("github.com/stealthrocket/coroutine/compiler/testdata.IdentityGenericStructInt") _types.RegisterFunc[func(_ int)]("github.com/stealthrocket/coroutine/compiler/testdata.LoopBreakAndContinue") _types.RegisterFunc[func(_fn1 int)]("github.com/stealthrocket/coroutine/compiler/testdata.MethodGenerator") _types.RegisterFunc[func(_fn0 int) (_ int)]("github.com/stealthrocket/coroutine/compiler/testdata.NestedLoops") diff --git a/compiler/types.go b/compiler/types.go index 4e4ec5f..eae8442 100644 --- a/compiler/types.go +++ b/compiler/types.go @@ -92,6 +92,12 @@ func typeExpr(p *packages.Package, typ types.Type) ast.Expr { c.Dir = ast.RECV } return c + + case *types.TypeParam: + obj := t.Obj() + ident := ast.NewIdent(obj.Name()) + p.TypesInfo.Defs[ident] = obj + return ident } panic(fmt.Sprintf("not implemented: %T", typ)) } diff --git a/types/types.go b/types/types.go index 2ddc28c..443537f 100644 --- a/types/types.go +++ b/types/types.go @@ -399,6 +399,9 @@ func (m *funcmap) RegisterAddr(addr unsafe.Pointer) (id funcid, closureType refl if f == nil { panic(fmt.Sprintf("function not found at address %v", addr)) } + if f.Type == nil { + panic(fmt.Sprintf("type information not registered for function %s (%p)", f.Name, addr)) + } var closureTypeID typeid if f.Closure != nil {