From 3219aea671e65ce3e2894b4f4dc43c7e244d4f1a Mon Sep 17 00:00:00 2001 From: Chris O'Hara Date: Wed, 13 Dec 2023 10:57:56 +1000 Subject: [PATCH] Convert types when registering generic functions/closures --- compiler/function.go | 26 +++++++++-- compiler/testdata/coroutine_durable.go | 63 +++++++++++++++++++++----- 2 files changed, 74 insertions(+), 15 deletions(-) diff --git a/compiler/function.go b/compiler/function.go index 3b438a6..49afe6a 100644 --- a/compiler/function.go +++ b/compiler/function.go @@ -88,11 +88,25 @@ func collectFunctypes(p *packages.Package, name string, fn ast.Node, scope *func if fields != nil { for _, field := range fields.List { for _, name := range field.Names { + typ := p.TypesInfo.TypeOf(name) + if g != nil { + if instanceType, ok := g.typeOfParam(typ); ok { + typ = instanceType + } + } + if typ != nil { + _, ellipsis := field.Type.(*ast.Ellipsis) + field.Type = typeExpr(p, typ) + if a, ok := field.Type.(*ast.ArrayType); ok && a.Len == nil && ellipsis { + field.Type = &ast.Ellipsis{Elt: a.Elt} + } + } scope.insert(name, field.Type) } } } } + signature.TypeParams = nil var inspect func(ast.Node) bool inspect = func(node ast.Node) bool { @@ -110,11 +124,17 @@ func collectFunctypes(p *packages.Package, name string, fn ast.Node, scope *func switch s := spec.(type) { case *ast.ValueSpec: for _, name := range s.Names { - typ := s.Type + typ := p.TypesInfo.TypeOf(name) + if g != nil { + if instanceType, ok := g.typeOfParam(typ); ok { + typ = instanceType + } + } if typ == nil { - typ = typeExpr(p, p.TypesInfo.TypeOf(name)) + scope.insert(name, s.Type) + } else { + scope.insert(name, typeExpr(p, typ)) } - scope.insert(name, typ) } } } diff --git a/compiler/testdata/coroutine_durable.go b/compiler/testdata/coroutine_durable.go index 941ffdc..3985528 100644 --- a/compiler/testdata/coroutine_durable.go +++ b/compiler/testdata/coroutine_durable.go @@ -4,15 +4,17 @@ package testdata import ( coroutine "github.com/stealthrocket/coroutine" - unsafe "unsafe" time "time" + unsafe "unsafe" ) import _types "github.com/stealthrocket/coroutine/types" func SomeFunctionThatShouldExistInTheCompiledFile() { } + //go:noinline func Identity(n int) { coroutine.Yield[int, any](n) } + //go:noinline func SquareGenerator(_fn0 int) { _c := coroutine.LoadContext[int, any]() @@ -48,6 +50,7 @@ func SquareGenerator(_fn0 int) { } } } + //go:noinline func SquareGeneratorTwice(_fn0 int) { _c := coroutine.LoadContext[int, any]() @@ -78,6 +81,7 @@ func SquareGeneratorTwice(_fn0 int) { SquareGenerator(_f0.X0) } } + //go:noinline func SquareGeneratorTwiceLoop(_fn0 int) { _c := coroutine.LoadContext[int, any]() @@ -113,6 +117,7 @@ func SquareGeneratorTwiceLoop(_fn0 int) { } } } + //go:noinline func EvenSquareGenerator(_fn0 int) { _c := coroutine.LoadContext[int, any]() @@ -160,6 +165,7 @@ func EvenSquareGenerator(_fn0 int) { } } } + //go:noinline func NestedLoops(_fn0 int) (_ int) { _c := coroutine.LoadContext[int, any]() @@ -241,6 +247,7 @@ func NestedLoops(_fn0 int) (_ int) { } panic("unreachable") } + //go:noinline func FizzBuzzIfGenerator(_fn0 int) { _c := coroutine.LoadContext[int, any]() @@ -301,6 +308,7 @@ func FizzBuzzIfGenerator(_fn0 int) { } } } + //go:noinline func FizzBuzzSwitchGenerator(_fn0 int) { _c := coroutine.LoadContext[int, any]() @@ -385,6 +393,7 @@ func FizzBuzzSwitchGenerator(_fn0 int) { } } } + //go:noinline func Shadowing(_ int) { _c := coroutine.LoadContext[int, any]() @@ -727,6 +736,7 @@ func Shadowing(_ int) { coroutine.Yield[int, any](_f0.X22) } } + //go:noinline func RangeSliceIndexGenerator(_ int) { _c := coroutine.LoadContext[int, any]() @@ -769,6 +779,7 @@ func RangeSliceIndexGenerator(_ int) { } } } + //go:noinline func RangeArrayIndexValueGenerator(_ int) { _c := coroutine.LoadContext[int, any]() @@ -825,6 +836,7 @@ func RangeArrayIndexValueGenerator(_ int) { } } } + //go:noinline func TypeSwitchingGenerator(_ int) { _c := coroutine.LoadContext[int, any]() @@ -899,6 +911,7 @@ func TypeSwitchingGenerator(_ int) { } } } + //go:noinline func LoopBreakAndContinue(_ int) { _c := coroutine.LoadContext[int, any]() @@ -1053,6 +1066,7 @@ func LoopBreakAndContinue(_ int) { } } } + //go:noinline func RangeOverMaps(_fn0 int) { _c := coroutine.LoadContext[int, any]() @@ -1387,6 +1401,7 @@ func RangeOverMaps(_fn0 int) { } } } + //go:noinline func Range(_fn0 int, _fn1 func(int)) { _c := coroutine.LoadContext[int, any]() @@ -1425,13 +1440,15 @@ func Range(_fn0 int, _fn1 func(int)) { } } } + //go:noinline func Double(n int) { coroutine.Yield[int, any](2 * n) } + //go:noinline func RangeTriple(n int) { - Range(n, func(i int) { coroutine.Yield[int, any](3 * i) }, - ) + Range(n, func(i int) { coroutine.Yield[int, any](3 * i) }) } + //go:noinline func RangeTripleFuncValue(_fn0 int) { _c := coroutine.LoadContext[int, any]() @@ -1466,6 +1483,7 @@ func RangeTripleFuncValue(_fn0 int) { Range(_f0.X0, _f0.X1) } } + //go:noinline func RangeReverseClosureCaptureByValue(_fn0 int) { _c := coroutine.LoadContext[int, any]() @@ -1515,6 +1533,7 @@ func RangeReverseClosureCaptureByValue(_fn0 int) { } } } + //go:noinline func Range10ClosureCapturingValues() { _c := coroutine.LoadContext[int, any]() @@ -1621,6 +1640,7 @@ func Range10ClosureCapturingValues() { } } } + //go:noinline func Range10ClosureCapturingPointers() { _c := coroutine.LoadContext[int, any]() @@ -1737,6 +1757,7 @@ func Range10ClosureCapturingPointers() { } } } + //go:noinline func Range10ClosureHeterogenousCapture() { _c := coroutine.LoadContext[int, any]() @@ -1956,6 +1977,7 @@ func Range10ClosureHeterogenousCapture() { } } } + //go:noinline func Range10Heterogenous() { _c := coroutine.LoadContext[int, any]() @@ -2067,6 +2089,7 @@ func Range10Heterogenous() { } } } + //go:noinline func Select(_fn0 int) { _c := coroutine.LoadContext[int, any]() @@ -2361,6 +2384,7 @@ func Select(_fn0 int) { } } } + //go:noinline func YieldingExpressionDesugaring() { _c := coroutine.LoadContext[int, any]() @@ -2783,6 +2807,7 @@ func YieldingExpressionDesugaring() { } } } + //go:noinline func a(_fn0 int) (_ int) { _c := coroutine.LoadContext[int, any]() @@ -2814,6 +2839,7 @@ func a(_fn0 int) (_ int) { } panic("unreachable") } + //go:noinline func b(_fn0 int) (_ int) { _c := coroutine.LoadContext[int, any]() @@ -2845,6 +2871,7 @@ func b(_fn0 int) (_ int) { } panic("unreachable") } + //go:noinline func YieldingDurations() { _c := coroutine.LoadContext[int, any]() @@ -2952,6 +2979,7 @@ func YieldingDurations() { } } } + //go:noinline func YieldAndDeferAssign(_fn0 *int, _fn1, _fn2 int) { _c := coroutine.LoadContext[int, any]() @@ -2996,6 +3024,7 @@ func YieldAndDeferAssign(_fn0 *int, _fn1, _fn2 int) { coroutine.Yield[int, any](_f0.X1) } } + //go:noinline func RangeYieldAndDeferAssign(_fn0 int) { _c := coroutine.LoadContext[int, any]() @@ -3033,6 +3062,7 @@ func RangeYieldAndDeferAssign(_fn0 int) { } type MethodGeneratorState struct{ i int } + //go:noinline func (_fn0 *MethodGeneratorState) MethodGenerator(_fn1 int) { _c := coroutine.LoadContext[int, any]() @@ -3069,6 +3099,7 @@ func (_fn0 *MethodGeneratorState) MethodGenerator(_fn1 int) { } } } + //go:noinline func VarArgs(_fn0 int) { _c := coroutine.LoadContext[int, any]() @@ -3108,6 +3139,7 @@ func VarArgs(_fn0 int) { varArgs(_f0.X1...) } } + //go:noinline func varArgs(_fn0 ...int) { _c := coroutine.LoadContext[int, any]() @@ -3164,6 +3196,7 @@ func varArgs(_fn0 ...int) { } } } + //go:noinline func ReturnNamedValue() (_fn0 int) { _c := coroutine.LoadContext[int, any]() @@ -3203,21 +3236,26 @@ func ReturnNamedValue() (_fn0 int) { } panic("unreachable") } + //go:noinline -func IdentityGeneric[T any](n T) { coroutine.Yield[T, any](n) } +func IdentityGeneric(n int) { coroutine.Yield[T, any](n) } type IdentityGenericStruct[T any] struct { n T } + //go:noinline func (i *IdentityGenericStruct[T]) Run() { coroutine.Yield[T, any](i.n) } + //go:noinline func IdentityGenericInt(n int) { IdentityGeneric[int](n) } + //go:noinline func IdentityGenericStructInt(n int) { (&IdentityGenericStruct[int]{n: n}).Run() } + //go:noinline -func IdentityGenericClosure[T any](_fn0 T) { - _c := coroutine.LoadContext[T, any]() +func IdentityGenericClosure(_fn0 int) { + _c := coroutine.LoadContext[int, any]() var _f0 *struct { IP int X0 T @@ -3256,11 +3294,12 @@ func IdentityGenericClosure[T any](_fn0 T) { // TODO: add this go:noinline directive automatically // //go:noinline -func buildClosure[T any](n T) func() { +func buildClosure(n T) func() { return func() { coroutine.Yield[T, any](n) } } + //go:noinline func IdentityGenericClosureInt(n int) { IdentityGenericClosure[int](n) } func init() { @@ -3271,10 +3310,10 @@ func init() { _types.RegisterFunc[func(_fn0 int)]("github.com/stealthrocket/coroutine/compiler/testdata.FizzBuzzSwitchGenerator") _types.RegisterFunc[func(n int)]("github.com/stealthrocket/coroutine/compiler/testdata.Identity") _types.RegisterFunc[func(n int)]("github.com/stealthrocket/coroutine/compiler/testdata.IdentityGenericClosureInt") - _types.RegisterFunc[func[T any](_fn0 T)]("github.com/stealthrocket/coroutine/compiler/testdata.IdentityGenericClosure[go.shape.int]") + _types.RegisterFunc[func(_fn0 int)]("github.com/stealthrocket/coroutine/compiler/testdata.IdentityGenericClosure[go.shape.int]") _types.RegisterFunc[func(n int)]("github.com/stealthrocket/coroutine/compiler/testdata.IdentityGenericInt") _types.RegisterFunc[func(n int)]("github.com/stealthrocket/coroutine/compiler/testdata.IdentityGenericStructInt") - _types.RegisterFunc[func[T any](n T)]("github.com/stealthrocket/coroutine/compiler/testdata.IdentityGeneric[go.shape.int]") + _types.RegisterFunc[func(n int)]("github.com/stealthrocket/coroutine/compiler/testdata.IdentityGeneric[go.shape.int]") _types.RegisterFunc[func(_ int)]("github.com/stealthrocket/coroutine/compiler/testdata.LoopBreakAndContinue") _types.RegisterFunc[func(_fn1 int)]("github.com/stealthrocket/coroutine/compiler/testdata.MethodGenerator") _types.RegisterFunc[func(_fn0 int) (_ int)]("github.com/stealthrocket/coroutine/compiler/testdata.NestedLoops") @@ -3399,12 +3438,12 @@ func init() { _types.RegisterFunc[func()]("github.com/stealthrocket/coroutine/compiler/testdata.YieldingExpressionDesugaring") _types.RegisterFunc[func(_fn0 int) (_ int)]("github.com/stealthrocket/coroutine/compiler/testdata.a") _types.RegisterFunc[func(_fn0 int) (_ int)]("github.com/stealthrocket/coroutine/compiler/testdata.b") - _types.RegisterFunc[func[T any](n T) func()]("github.com/stealthrocket/coroutine/compiler/testdata.buildClosure[go.shape.int]") + _types.RegisterFunc[func(n T) func()]("github.com/stealthrocket/coroutine/compiler/testdata.buildClosure[go.shape.int]") _types.RegisterClosure[func(), struct { F uintptr - X0 T + X0 int }]("github.com/stealthrocket/coroutine/compiler/testdata.buildClosure[go.shape.int].func1") - _types.RegisterFunc[func[T any](n T) func()]("github.com/stealthrocket/coroutine/compiler/testdata.buildClosure[go.shape.interface{}]") + _types.RegisterFunc[func(n T) func()]("github.com/stealthrocket/coroutine/compiler/testdata.buildClosure[go.shape.interface{}]") _types.RegisterClosure[func(), struct { F uintptr X0 T