Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes #148

Merged
merged 9 commits into from
Jun 17, 2024
Merged

Fixes #148

Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions compiler/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ func (c *compiler) compilePackage(p *packages.Package, colors functionColors) er
c.generateFunctypes(p, gen, colorsByFunc)

// Find all the required imports for this file.
gen = addImports(p, gen)
gen = addImports(p, f, gen)

outputPath := strings.TrimSuffix(p.GoFiles[i], ".go")
outputPath += "_durable.go"
Expand Down Expand Up @@ -400,7 +400,7 @@ func containsColoredFuncLit(decl *ast.FuncDecl, colorsByFunc map[ast.Node]*types
return
}

func addImports(p *packages.Package, gen *ast.File) *ast.File {
func addImports(p *packages.Package, f *ast.File, gen *ast.File) *ast.File {
imports := map[string]string{}

ast.Inspect(gen, func(n ast.Node) bool {
Expand Down Expand Up @@ -438,6 +438,15 @@ func addImports(p *packages.Package, gen *ast.File) *ast.File {
}

importspecs := make([]ast.Spec, 0, len(imports))

// Preserve underscore (side effect) imports.
for _, imp := range f.Imports {
if imp.Name != nil && imp.Name.Name == "_" {
importspecs = append(importspecs, imp)
}
}

// Add imports for all packages used in the file.
for name, path := range imports {
importspecs = append(importspecs, &ast.ImportSpec{
Name: ast.NewIdent(name),
Expand Down Expand Up @@ -526,6 +535,8 @@ func (scope *scope) compileFuncLit(p *packages.Package, fn *ast.FuncLit, color *
Body: scope.compileFuncBody(p, fn.Type, fn.Body, nil, color),
}

p.TypesInfo.Types[gen] = types.TypeAndValue{Type: p.TypesInfo.TypeOf(fn)}

if !isExpr(gen.Body) {
scope.colors[gen] = color
}
Expand Down
7 changes: 6 additions & 1 deletion compiler/function.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func collectFunctypes(p *packages.Package, name string, fn ast.Node, scope *func
typeArg = g.typeArgOf
}

signature := copyFunctionType(functionTypeOf(fn))
signature := copyFunctionType(funcTypeWithNamedResults(p, fn))
signature.TypeParams = nil

recv := copyFieldList(functionRecvOf(fn))
Expand Down Expand Up @@ -182,6 +182,11 @@ func collectFunctypes(p *packages.Package, name string, fn ast.Node, scope *func
fieldName := ast.NewIdent(fmt.Sprintf("X%d", i))
fieldType := freeVar.typ

// Convert ellipsis into slice (...X => []X).
if e, ok := fieldType.(*ast.Ellipsis); ok {
fieldType = &ast.ArrayType{Elt: e.Elt}
}

// The Go compiler uses a more advanced mechanism to determine if a
// free variable should be captured by pointer or by value: it looks
// at whether the variable is reassigned, its address taken, and if
Expand Down
22 changes: 18 additions & 4 deletions compiler/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,6 @@ func typeExpr(p *packages.Package, typ types.Type, typeArg func(*types.TypeParam
if t.Empty() {
return ast.NewIdent("any")
}
if t.NumEmbeddeds() > 0 {
panic("not implemented: interface with embeddeds")
}
methods := make([]*ast.Field, t.NumExplicitMethods())
for i := range methods {
m := t.ExplicitMethod(i)
Expand All @@ -63,8 +60,16 @@ func typeExpr(p *packages.Package, typ types.Type, typeArg func(*types.TypeParam
Type: typeExpr(p, m.Type(), typeArg),
}
}
embeddeds := make([]*ast.Field, t.NumEmbeddeds())
for i := range embeddeds {
embeddeds[i] = &ast.Field{
Type: typeExpr(p, t.EmbeddedType(i), typeArg),
}
}
return &ast.InterfaceType{
Methods: &ast.FieldList{List: methods},
Methods: &ast.FieldList{
List: append(methods, embeddeds...),
},
}
case *types.Signature:
return newFuncType(p, t, typeArg)
Expand Down Expand Up @@ -197,6 +202,15 @@ func substituteTypeArgs(p *packages.Package, expr ast.Expr, typeArg func(*types.
X: substituteTypeArgs(p, e.X, typeArg),
Index: substituteTypeArgs(p, e.Index, typeArg),
}
case *ast.IndexListExpr:
indices := make([]ast.Expr, len(e.Indices))
for i, index := range e.Indices {
indices[i] = substituteTypeArgs(p, index, typeArg)
}
return &ast.IndexListExpr{
X: substituteTypeArgs(p, e.X, typeArg),
Indices: indices,
}
case *ast.Ident:
t := p.TypesInfo.TypeOf(e)
tp, ok := t.(*types.TypeParam)
Expand Down
Loading