Skip to content

Commit

Permalink
Unpack generic instance info
Browse files Browse the repository at this point in the history
  • Loading branch information
chriso committed Dec 13, 2023
1 parent f9bc0ca commit 84565b3
Show file tree
Hide file tree
Showing 2 changed files with 229 additions and 49 deletions.
218 changes: 212 additions & 6 deletions compiler/function.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@ import (
"maps"
"slices"
"strconv"
"strings"

"golang.org/x/tools/go/ast/astutil"
"golang.org/x/tools/go/packages"
"golang.org/x/tools/go/ssa"
)

type functype struct {
Expand Down Expand Up @@ -57,7 +59,7 @@ type funcvar struct {
typ ast.Expr
}

func collectFunctypes(p *packages.Package, name string, fn ast.Node, scope *funcscope, colors map[ast.Node]*types.Signature, functypes map[string]functype) {
func collectFunctypes(p *packages.Package, name string, fn ast.Node, scope *funcscope, colors map[ast.Node]*types.Signature, functypes map[string]functype, g *genericInstance) {
type function struct {
node ast.Node
scope *funcscope
Expand Down Expand Up @@ -191,7 +193,7 @@ func collectFunctypes(p *packages.Package, name string, fn ast.Node, scope *func

for i, anonFunc := range anonFuncs[index:] {
anonFuncName := anonFuncLinkName(name, index+i+1)
collectFunctypes(p, anonFuncName, anonFunc.node, anonFunc.scope, colors, functypes)
collectFunctypes(p, anonFuncName, anonFunc.node, anonFunc.scope, colors, functypes, g)
}
}
}
Expand All @@ -217,13 +219,23 @@ func (c *compiler) generateFunctypes(p *packages.Package, f *ast.File, colors ma
obj := p.TypesInfo.ObjectOf(d.Name).(*types.Func)
fn := c.prog.FuncValue(obj)
if fn.TypeParams() != nil {
// TODO: support generics. Generate type func/closure type information for each instance from: instances := c.generics[fn]
log.Printf("warning: cannot register runtime type information for generic function %s", fn)
continue
instances := c.generics[fn]
if len(instances) == 0 {
// This can occur when a generic function is never instantiated/used,
// or when it's instantiated in a package not known to the compiler.
log.Printf("warning: cannot register runtime type information for generic function %s", fn)
continue
}
for _, instance := range instances {
g := newGenericInstance(fn, instance)
scope := &funcscope{vars: map[string]*funcvar{}}
name := g.gcshapePath()
collectFunctypes(p, name, instance.Syntax(), scope, colors, functypes, g)
}
} else {
scope := &funcscope{vars: map[string]*funcvar{}}
name := functionPath(p, d)
collectFunctypes(p, name, d, scope, colors, functypes)
collectFunctypes(p, name, d, scope, colors, functypes, nil)
}
}
}
Expand Down Expand Up @@ -312,3 +324,197 @@ func functionBodyOf(fn ast.Node) *ast.BlockStmt {
panic("node is neither *ast.FuncDecl or *ast.FuncLit")
}
}

type genericInstance struct {
origin *ssa.Function
instance *ssa.Function

recvPtr bool
recvType *types.Named
recvArgs map[types.Type]types.Type

typeParams map[types.Type]*types.TypeParam

types map[types.Type]types.Type
}

func newGenericInstance(origin, instance *ssa.Function) *genericInstance {
g := &genericInstance{origin: origin, instance: instance}

g.typeParams = map[types.Type]*types.TypeParam{}
typeParams := g.origin.Signature.TypeParams()
for i := 0; i < typeParams.Len(); i++ {
p := typeParams.At(i)
pt := p.Obj().Type()
g.typeParams[pt] = p
}

if recv := g.instance.Signature.Recv(); recv != nil {
switch t := recv.Type().(type) {
case *types.Pointer:
g.recvPtr = true
switch pt := t.Elem().(type) {
case *types.Named:
g.recvType = pt
default:
panic(fmt.Sprintf("not implemented: %T", t))
}

case *types.Named:
g.recvType = t
default:
panic(fmt.Sprintf("not implemented: %T", t))
}
}

g.types = map[types.Type]types.Type{}
if g.recvType != nil {
g.scanRecvTypeArgs(func(p *types.TypeParam, _ int, arg types.Type) {
g.types[p.Obj().Type()] = arg
})
}
g.scanTypeArgs(func(p *types.TypeParam, _ int, arg types.Type) {
g.types[p.Obj().Type()] = arg
})

return g
}

func (g *genericInstance) typeOfParam(t types.Type) (types.Type, bool) {
v, ok := g.types[t]
return v, ok
}

func (g *genericInstance) scanRecvTypeArgs(fn func(*types.TypeParam, int, types.Type)) {
typeParams := g.recvType.TypeParams()
typeArgs := g.recvType.TypeArgs()
for i := 0; i < typeArgs.Len(); i++ {
arg := typeArgs.At(i)
param := typeParams.At(i)

fn(param, i, arg)
}
}

func (g *genericInstance) scanTypeArgs(fn func(*types.TypeParam, int, types.Type)) {
knownTypes := map[*types.TypeParam]map[types.Type]struct{}{}
g.scanParams(func(p *types.TypeParam, i int, arg *types.Var) {
t, ok := knownTypes[p]
if !ok {
t = map[types.Type]struct{}{}
knownTypes[p] = t
}
t[arg.Type()] = struct{}{}
})
g.scanResults(func(p *types.TypeParam, i int, res *types.Var) {
t, ok := knownTypes[p]
if !ok {
t = map[types.Type]struct{}{}
knownTypes[p] = t
}
t[res.Type()] = struct{}{}
})

typeParams := g.origin.Signature.TypeParams()
for i := 0; i < typeParams.Len(); i++ {
p := typeParams.At(i)
types := knownTypes[p]

switch len(types) {
case 0:
panic(fmt.Sprintf("not implemented: no usage of type param %s in function %s instance %s", p, g.origin, g.instance))
case 1:
for knownType := range types {
fn(p, i, knownType)
}
default:
panic(fmt.Sprintf("not implemented: more than one type registered for type param %s in function %s instance %s", p, g.origin, g.instance))
}
}
}

func (g *genericInstance) scanParams(fn func(*types.TypeParam, int, *types.Var)) {
originParams := g.origin.Signature.Params()
params := g.instance.Signature.Params()
for i := 0; i < params.Len(); i++ {
param := params.At(i)
op := originParams.At(i)
if tp, ok := g.typeParams[op.Type()]; ok {
fn(tp, i, param)
}
}
}

func (g *genericInstance) scanResults(fn func(*types.TypeParam, int, *types.Var)) {
originResults := g.origin.Signature.Results()
results := g.instance.Signature.Results()
for i := 0; i < results.Len(); i++ {
result := results.At(i)
op := originResults.At(i)
if tp, ok := g.typeParams[op.Type()]; ok {
fn(tp, i, result)
}
}
}

func (g *genericInstance) gcshapePath() string {
var path strings.Builder

path.WriteString(g.origin.Pkg.Pkg.Path())

if g.recvType != nil {
path.WriteByte('.')
if g.recvPtr {
path.WriteString("(*")
}
path.WriteString(g.recvType.Obj().Name())

if g.recvType.TypeParams() != nil {
path.WriteByte('[')
g.scanRecvTypeArgs(func(_ *types.TypeParam, i int, arg types.Type) {
if i > 0 {
path.WriteString(",")
}
writeGoShape(&path, arg)
})
path.WriteByte(']')
}

if g.recvPtr {
path.WriteByte(')')
}
}

path.WriteByte('.')
path.WriteString(g.instance.Object().(*types.Func).Name())

if g.origin.Signature.TypeParams() != nil {
path.WriteByte('[')
g.scanTypeArgs(func(_ *types.TypeParam, i int, arg types.Type) {
if i > 0 {
path.WriteString(",")
}
writeGoShape(&path, arg)
})
path.WriteByte(']')
}

return path.String()
}

func writeGoShape(b *strings.Builder, tt types.Type) {
b.WriteString("go.shape.")

switch t := tt.Underlying().(type) {
case *types.Basic:
b.WriteString(t.Name())
case *types.Interface:
if t.Empty() {
b.WriteString("interface{}")
} else {
panic(fmt.Sprintf("not implemented: %#v (%T)", tt, t))
}
default:
panic(fmt.Sprintf("not implemented: %#v (%T)", tt, t))
}
}
Loading

0 comments on commit 84565b3

Please sign in to comment.