From 84565b34badb9ada989d8cb990343984995c4e1d Mon Sep 17 00:00:00 2001 From: Chris O'Hara Date: Wed, 13 Dec 2023 10:21:39 +1000 Subject: [PATCH] Unpack generic instance info --- compiler/function.go | 218 ++++++++++++++++++++++++- compiler/testdata/coroutine_durable.go | 60 ++----- 2 files changed, 229 insertions(+), 49 deletions(-) diff --git a/compiler/function.go b/compiler/function.go index e468bd0..3b438a6 100644 --- a/compiler/function.go +++ b/compiler/function.go @@ -9,9 +9,11 @@ import ( "maps" "slices" "strconv" + "strings" "golang.org/x/tools/go/ast/astutil" "golang.org/x/tools/go/packages" + "golang.org/x/tools/go/ssa" ) type functype struct { @@ -57,7 +59,7 @@ type funcvar struct { typ ast.Expr } -func collectFunctypes(p *packages.Package, name string, fn ast.Node, scope *funcscope, colors map[ast.Node]*types.Signature, functypes map[string]functype) { +func collectFunctypes(p *packages.Package, name string, fn ast.Node, scope *funcscope, colors map[ast.Node]*types.Signature, functypes map[string]functype, g *genericInstance) { type function struct { node ast.Node scope *funcscope @@ -191,7 +193,7 @@ func collectFunctypes(p *packages.Package, name string, fn ast.Node, scope *func for i, anonFunc := range anonFuncs[index:] { anonFuncName := anonFuncLinkName(name, index+i+1) - collectFunctypes(p, anonFuncName, anonFunc.node, anonFunc.scope, colors, functypes) + collectFunctypes(p, anonFuncName, anonFunc.node, anonFunc.scope, colors, functypes, g) } } } @@ -217,13 +219,23 @@ func (c *compiler) generateFunctypes(p *packages.Package, f *ast.File, colors ma obj := p.TypesInfo.ObjectOf(d.Name).(*types.Func) fn := c.prog.FuncValue(obj) if fn.TypeParams() != nil { - // TODO: support generics. Generate type func/closure type information for each instance from: instances := c.generics[fn] - log.Printf("warning: cannot register runtime type information for generic function %s", fn) - continue + instances := c.generics[fn] + if len(instances) == 0 { + // This can occur when a generic function is never instantiated/used, + // or when it's instantiated in a package not known to the compiler. + log.Printf("warning: cannot register runtime type information for generic function %s", fn) + continue + } + for _, instance := range instances { + g := newGenericInstance(fn, instance) + scope := &funcscope{vars: map[string]*funcvar{}} + name := g.gcshapePath() + collectFunctypes(p, name, instance.Syntax(), scope, colors, functypes, g) + } } else { scope := &funcscope{vars: map[string]*funcvar{}} name := functionPath(p, d) - collectFunctypes(p, name, d, scope, colors, functypes) + collectFunctypes(p, name, d, scope, colors, functypes, nil) } } } @@ -312,3 +324,197 @@ func functionBodyOf(fn ast.Node) *ast.BlockStmt { panic("node is neither *ast.FuncDecl or *ast.FuncLit") } } + +type genericInstance struct { + origin *ssa.Function + instance *ssa.Function + + recvPtr bool + recvType *types.Named + recvArgs map[types.Type]types.Type + + typeParams map[types.Type]*types.TypeParam + + types map[types.Type]types.Type +} + +func newGenericInstance(origin, instance *ssa.Function) *genericInstance { + g := &genericInstance{origin: origin, instance: instance} + + g.typeParams = map[types.Type]*types.TypeParam{} + typeParams := g.origin.Signature.TypeParams() + for i := 0; i < typeParams.Len(); i++ { + p := typeParams.At(i) + pt := p.Obj().Type() + g.typeParams[pt] = p + } + + if recv := g.instance.Signature.Recv(); recv != nil { + switch t := recv.Type().(type) { + case *types.Pointer: + g.recvPtr = true + switch pt := t.Elem().(type) { + case *types.Named: + g.recvType = pt + default: + panic(fmt.Sprintf("not implemented: %T", t)) + } + + case *types.Named: + g.recvType = t + default: + panic(fmt.Sprintf("not implemented: %T", t)) + } + } + + g.types = map[types.Type]types.Type{} + if g.recvType != nil { + g.scanRecvTypeArgs(func(p *types.TypeParam, _ int, arg types.Type) { + g.types[p.Obj().Type()] = arg + }) + } + g.scanTypeArgs(func(p *types.TypeParam, _ int, arg types.Type) { + g.types[p.Obj().Type()] = arg + }) + + return g +} + +func (g *genericInstance) typeOfParam(t types.Type) (types.Type, bool) { + v, ok := g.types[t] + return v, ok +} + +func (g *genericInstance) scanRecvTypeArgs(fn func(*types.TypeParam, int, types.Type)) { + typeParams := g.recvType.TypeParams() + typeArgs := g.recvType.TypeArgs() + for i := 0; i < typeArgs.Len(); i++ { + arg := typeArgs.At(i) + param := typeParams.At(i) + + fn(param, i, arg) + } +} + +func (g *genericInstance) scanTypeArgs(fn func(*types.TypeParam, int, types.Type)) { + knownTypes := map[*types.TypeParam]map[types.Type]struct{}{} + g.scanParams(func(p *types.TypeParam, i int, arg *types.Var) { + t, ok := knownTypes[p] + if !ok { + t = map[types.Type]struct{}{} + knownTypes[p] = t + } + t[arg.Type()] = struct{}{} + }) + g.scanResults(func(p *types.TypeParam, i int, res *types.Var) { + t, ok := knownTypes[p] + if !ok { + t = map[types.Type]struct{}{} + knownTypes[p] = t + } + t[res.Type()] = struct{}{} + }) + + typeParams := g.origin.Signature.TypeParams() + for i := 0; i < typeParams.Len(); i++ { + p := typeParams.At(i) + types := knownTypes[p] + + switch len(types) { + case 0: + panic(fmt.Sprintf("not implemented: no usage of type param %s in function %s instance %s", p, g.origin, g.instance)) + case 1: + for knownType := range types { + fn(p, i, knownType) + } + default: + panic(fmt.Sprintf("not implemented: more than one type registered for type param %s in function %s instance %s", p, g.origin, g.instance)) + } + } +} + +func (g *genericInstance) scanParams(fn func(*types.TypeParam, int, *types.Var)) { + originParams := g.origin.Signature.Params() + params := g.instance.Signature.Params() + for i := 0; i < params.Len(); i++ { + param := params.At(i) + op := originParams.At(i) + if tp, ok := g.typeParams[op.Type()]; ok { + fn(tp, i, param) + } + } +} + +func (g *genericInstance) scanResults(fn func(*types.TypeParam, int, *types.Var)) { + originResults := g.origin.Signature.Results() + results := g.instance.Signature.Results() + for i := 0; i < results.Len(); i++ { + result := results.At(i) + op := originResults.At(i) + if tp, ok := g.typeParams[op.Type()]; ok { + fn(tp, i, result) + } + } +} + +func (g *genericInstance) gcshapePath() string { + var path strings.Builder + + path.WriteString(g.origin.Pkg.Pkg.Path()) + + if g.recvType != nil { + path.WriteByte('.') + if g.recvPtr { + path.WriteString("(*") + } + path.WriteString(g.recvType.Obj().Name()) + + if g.recvType.TypeParams() != nil { + path.WriteByte('[') + g.scanRecvTypeArgs(func(_ *types.TypeParam, i int, arg types.Type) { + if i > 0 { + path.WriteString(",") + } + writeGoShape(&path, arg) + }) + path.WriteByte(']') + } + + if g.recvPtr { + path.WriteByte(')') + } + } + + path.WriteByte('.') + path.WriteString(g.instance.Object().(*types.Func).Name()) + + if g.origin.Signature.TypeParams() != nil { + path.WriteByte('[') + g.scanTypeArgs(func(_ *types.TypeParam, i int, arg types.Type) { + if i > 0 { + path.WriteString(",") + } + writeGoShape(&path, arg) + }) + path.WriteByte(']') + } + + return path.String() +} + +func writeGoShape(b *strings.Builder, tt types.Type) { + b.WriteString("go.shape.") + + switch t := tt.Underlying().(type) { + case *types.Basic: + b.WriteString(t.Name()) + case *types.Interface: + if t.Empty() { + b.WriteString("interface{}") + } else { + panic(fmt.Sprintf("not implemented: %#v (%T)", tt, t)) + } + default: + panic(fmt.Sprintf("not implemented: %#v (%T)", tt, t)) + } +} diff --git a/compiler/testdata/coroutine_durable.go b/compiler/testdata/coroutine_durable.go index 81f3314..941ffdc 100644 --- a/compiler/testdata/coroutine_durable.go +++ b/compiler/testdata/coroutine_durable.go @@ -4,17 +4,15 @@ package testdata import ( coroutine "github.com/stealthrocket/coroutine" - time "time" unsafe "unsafe" + time "time" ) import _types "github.com/stealthrocket/coroutine/types" func SomeFunctionThatShouldExistInTheCompiledFile() { } - //go:noinline func Identity(n int) { coroutine.Yield[int, any](n) } - //go:noinline func SquareGenerator(_fn0 int) { _c := coroutine.LoadContext[int, any]() @@ -50,7 +48,6 @@ func SquareGenerator(_fn0 int) { } } } - //go:noinline func SquareGeneratorTwice(_fn0 int) { _c := coroutine.LoadContext[int, any]() @@ -81,7 +78,6 @@ func SquareGeneratorTwice(_fn0 int) { SquareGenerator(_f0.X0) } } - //go:noinline func SquareGeneratorTwiceLoop(_fn0 int) { _c := coroutine.LoadContext[int, any]() @@ -117,7 +113,6 @@ func SquareGeneratorTwiceLoop(_fn0 int) { } } } - //go:noinline func EvenSquareGenerator(_fn0 int) { _c := coroutine.LoadContext[int, any]() @@ -165,7 +160,6 @@ func EvenSquareGenerator(_fn0 int) { } } } - //go:noinline func NestedLoops(_fn0 int) (_ int) { _c := coroutine.LoadContext[int, any]() @@ -247,7 +241,6 @@ func NestedLoops(_fn0 int) (_ int) { } panic("unreachable") } - //go:noinline func FizzBuzzIfGenerator(_fn0 int) { _c := coroutine.LoadContext[int, any]() @@ -308,7 +301,6 @@ func FizzBuzzIfGenerator(_fn0 int) { } } } - //go:noinline func FizzBuzzSwitchGenerator(_fn0 int) { _c := coroutine.LoadContext[int, any]() @@ -393,7 +385,6 @@ func FizzBuzzSwitchGenerator(_fn0 int) { } } } - //go:noinline func Shadowing(_ int) { _c := coroutine.LoadContext[int, any]() @@ -736,7 +727,6 @@ func Shadowing(_ int) { coroutine.Yield[int, any](_f0.X22) } } - //go:noinline func RangeSliceIndexGenerator(_ int) { _c := coroutine.LoadContext[int, any]() @@ -779,7 +769,6 @@ func RangeSliceIndexGenerator(_ int) { } } } - //go:noinline func RangeArrayIndexValueGenerator(_ int) { _c := coroutine.LoadContext[int, any]() @@ -836,7 +825,6 @@ func RangeArrayIndexValueGenerator(_ int) { } } } - //go:noinline func TypeSwitchingGenerator(_ int) { _c := coroutine.LoadContext[int, any]() @@ -911,7 +899,6 @@ func TypeSwitchingGenerator(_ int) { } } } - //go:noinline func LoopBreakAndContinue(_ int) { _c := coroutine.LoadContext[int, any]() @@ -1066,7 +1053,6 @@ func LoopBreakAndContinue(_ int) { } } } - //go:noinline func RangeOverMaps(_fn0 int) { _c := coroutine.LoadContext[int, any]() @@ -1401,7 +1387,6 @@ func RangeOverMaps(_fn0 int) { } } } - //go:noinline func Range(_fn0 int, _fn1 func(int)) { _c := coroutine.LoadContext[int, any]() @@ -1440,15 +1425,13 @@ func Range(_fn0 int, _fn1 func(int)) { } } } - //go:noinline func Double(n int) { coroutine.Yield[int, any](2 * n) } - //go:noinline func RangeTriple(n int) { - Range(n, func(i int) { coroutine.Yield[int, any](3 * i) }) + Range(n, func(i int) { coroutine.Yield[int, any](3 * i) }, + ) } - //go:noinline func RangeTripleFuncValue(_fn0 int) { _c := coroutine.LoadContext[int, any]() @@ -1483,7 +1466,6 @@ func RangeTripleFuncValue(_fn0 int) { Range(_f0.X0, _f0.X1) } } - //go:noinline func RangeReverseClosureCaptureByValue(_fn0 int) { _c := coroutine.LoadContext[int, any]() @@ -1533,7 +1515,6 @@ func RangeReverseClosureCaptureByValue(_fn0 int) { } } } - //go:noinline func Range10ClosureCapturingValues() { _c := coroutine.LoadContext[int, any]() @@ -1640,7 +1621,6 @@ func Range10ClosureCapturingValues() { } } } - //go:noinline func Range10ClosureCapturingPointers() { _c := coroutine.LoadContext[int, any]() @@ -1757,7 +1737,6 @@ func Range10ClosureCapturingPointers() { } } } - //go:noinline func Range10ClosureHeterogenousCapture() { _c := coroutine.LoadContext[int, any]() @@ -1977,7 +1956,6 @@ func Range10ClosureHeterogenousCapture() { } } } - //go:noinline func Range10Heterogenous() { _c := coroutine.LoadContext[int, any]() @@ -2089,7 +2067,6 @@ func Range10Heterogenous() { } } } - //go:noinline func Select(_fn0 int) { _c := coroutine.LoadContext[int, any]() @@ -2384,7 +2361,6 @@ func Select(_fn0 int) { } } } - //go:noinline func YieldingExpressionDesugaring() { _c := coroutine.LoadContext[int, any]() @@ -2807,7 +2783,6 @@ func YieldingExpressionDesugaring() { } } } - //go:noinline func a(_fn0 int) (_ int) { _c := coroutine.LoadContext[int, any]() @@ -2839,7 +2814,6 @@ func a(_fn0 int) (_ int) { } panic("unreachable") } - //go:noinline func b(_fn0 int) (_ int) { _c := coroutine.LoadContext[int, any]() @@ -2871,7 +2845,6 @@ func b(_fn0 int) (_ int) { } panic("unreachable") } - //go:noinline func YieldingDurations() { _c := coroutine.LoadContext[int, any]() @@ -2979,7 +2952,6 @@ func YieldingDurations() { } } } - //go:noinline func YieldAndDeferAssign(_fn0 *int, _fn1, _fn2 int) { _c := coroutine.LoadContext[int, any]() @@ -3024,7 +2996,6 @@ func YieldAndDeferAssign(_fn0 *int, _fn1, _fn2 int) { coroutine.Yield[int, any](_f0.X1) } } - //go:noinline func RangeYieldAndDeferAssign(_fn0 int) { _c := coroutine.LoadContext[int, any]() @@ -3062,7 +3033,6 @@ func RangeYieldAndDeferAssign(_fn0 int) { } type MethodGeneratorState struct{ i int } - //go:noinline func (_fn0 *MethodGeneratorState) MethodGenerator(_fn1 int) { _c := coroutine.LoadContext[int, any]() @@ -3099,7 +3069,6 @@ func (_fn0 *MethodGeneratorState) MethodGenerator(_fn1 int) { } } } - //go:noinline func VarArgs(_fn0 int) { _c := coroutine.LoadContext[int, any]() @@ -3139,7 +3108,6 @@ func VarArgs(_fn0 int) { varArgs(_f0.X1...) } } - //go:noinline func varArgs(_fn0 ...int) { _c := coroutine.LoadContext[int, any]() @@ -3196,7 +3164,6 @@ func varArgs(_fn0 ...int) { } } } - //go:noinline func ReturnNamedValue() (_fn0 int) { _c := coroutine.LoadContext[int, any]() @@ -3236,26 +3203,21 @@ func ReturnNamedValue() (_fn0 int) { } panic("unreachable") } - //go:noinline func IdentityGeneric[T any](n T) { coroutine.Yield[T, any](n) } type IdentityGenericStruct[T any] struct { n T } - //go:noinline func (i *IdentityGenericStruct[T]) Run() { coroutine.Yield[T, any](i.n) } - //go:noinline func IdentityGenericInt(n int) { IdentityGeneric[int](n) } - //go:noinline func IdentityGenericStructInt(n int) { (&IdentityGenericStruct[int]{n: n}).Run() } - //go:noinline func IdentityGenericClosure[T any](_fn0 T) { - _c := coroutine.LoadContext[int, any]() + _c := coroutine.LoadContext[T, any]() var _f0 *struct { IP int X0 T @@ -3299,18 +3261,20 @@ func buildClosure[T any](n T) func() { coroutine.Yield[T, any](n) } } - //go:noinline func IdentityGenericClosureInt(n int) { IdentityGenericClosure[int](n) } func init() { + _types.RegisterFunc[func()]("github.com/stealthrocket/coroutine/compiler/testdata.(*IdentityGenericStruct[go.shape.int]).Run") _types.RegisterFunc[func(n int)]("github.com/stealthrocket/coroutine/compiler/testdata.Double") _types.RegisterFunc[func(_fn0 int)]("github.com/stealthrocket/coroutine/compiler/testdata.EvenSquareGenerator") _types.RegisterFunc[func(_fn0 int)]("github.com/stealthrocket/coroutine/compiler/testdata.FizzBuzzIfGenerator") _types.RegisterFunc[func(_fn0 int)]("github.com/stealthrocket/coroutine/compiler/testdata.FizzBuzzSwitchGenerator") _types.RegisterFunc[func(n int)]("github.com/stealthrocket/coroutine/compiler/testdata.Identity") _types.RegisterFunc[func(n int)]("github.com/stealthrocket/coroutine/compiler/testdata.IdentityGenericClosureInt") + _types.RegisterFunc[func[T any](_fn0 T)]("github.com/stealthrocket/coroutine/compiler/testdata.IdentityGenericClosure[go.shape.int]") _types.RegisterFunc[func(n int)]("github.com/stealthrocket/coroutine/compiler/testdata.IdentityGenericInt") _types.RegisterFunc[func(n int)]("github.com/stealthrocket/coroutine/compiler/testdata.IdentityGenericStructInt") + _types.RegisterFunc[func[T any](n T)]("github.com/stealthrocket/coroutine/compiler/testdata.IdentityGeneric[go.shape.int]") _types.RegisterFunc[func(_ int)]("github.com/stealthrocket/coroutine/compiler/testdata.LoopBreakAndContinue") _types.RegisterFunc[func(_fn1 int)]("github.com/stealthrocket/coroutine/compiler/testdata.MethodGenerator") _types.RegisterFunc[func(_fn0 int) (_ int)]("github.com/stealthrocket/coroutine/compiler/testdata.NestedLoops") @@ -3435,5 +3399,15 @@ func init() { _types.RegisterFunc[func()]("github.com/stealthrocket/coroutine/compiler/testdata.YieldingExpressionDesugaring") _types.RegisterFunc[func(_fn0 int) (_ int)]("github.com/stealthrocket/coroutine/compiler/testdata.a") _types.RegisterFunc[func(_fn0 int) (_ int)]("github.com/stealthrocket/coroutine/compiler/testdata.b") + _types.RegisterFunc[func[T any](n T) func()]("github.com/stealthrocket/coroutine/compiler/testdata.buildClosure[go.shape.int]") + _types.RegisterClosure[func(), struct { + F uintptr + X0 T + }]("github.com/stealthrocket/coroutine/compiler/testdata.buildClosure[go.shape.int].func1") + _types.RegisterFunc[func[T any](n T) func()]("github.com/stealthrocket/coroutine/compiler/testdata.buildClosure[go.shape.interface{}]") + _types.RegisterClosure[func(), struct { + F uintptr + X0 T + }]("github.com/stealthrocket/coroutine/compiler/testdata.buildClosure[go.shape.interface{}].func1") _types.RegisterFunc[func(_fn0 ...int)]("github.com/stealthrocket/coroutine/compiler/testdata.varArgs") }