Skip to content

Commit

Permalink
gopls/internal/golang: fix bad slice append in function extraction
Browse files Browse the repository at this point in the history
With multiple return statements, a slice append would overwrite the
return values of earlier returns. Fix by using slices.Concat.

For golang/go#66289

Change-Id: Ib23bcb9ff297aa1ce9511c7ae54e692b14facca7
Reviewed-on: https://go-review.googlesource.com/c/tools/+/627537
Reviewed-by: Hongxiang Jiang <[email protected]>
LUCI-TryBot-Result: Go LUCI <[email protected]>
  • Loading branch information
findleyr committed Nov 14, 2024
1 parent 5b5d57c commit 47a5f7d
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 5 deletions.
12 changes: 8 additions & 4 deletions gopls/internal/golang/extract.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"go/parser"
"go/token"
"go/types"
"slices"
"sort"
"strings"
"text/scanner"
Expand Down Expand Up @@ -1214,10 +1215,14 @@ func generateReturnInfo(enclosing *ast.FuncType, pkg *types.Package, path []ast.
}
var name string
name, idx = generateAvailableIdentifier(pos, path, pkg, info, "returnValue", idx)
z := analysisinternal.ZeroValue(file, pkg, typ)
if z == nil {
return nil, nil, fmt.Errorf("can't generate zero value for %T", typ)
}
retVars = append(retVars, &returnVariable{
name: ast.NewIdent(name),
decl: &ast.Field{Type: expr},
zeroVal: analysisinternal.ZeroValue(file, pkg, typ),
zeroVal: z,
})
}
}
Expand Down Expand Up @@ -1250,8 +1255,7 @@ func adjustReturnStatements(returnTypes []*ast.Field, seenVars map[types.Object]
break
}
if val == nil {
return fmt.Errorf(
"could not find matching AST expression for %T", returnType.Type)
return fmt.Errorf("could not find matching AST expression for %T", returnType.Type)
}
zeroVals = append(zeroVals, val)
}
Expand All @@ -1266,7 +1270,7 @@ func adjustReturnStatements(returnTypes []*ast.Field, seenVars map[types.Object]
return false
}
if n, ok := n.(*ast.ReturnStmt); ok {
n.Results = append(zeroVals, n.Results...)
n.Results = slices.Concat(zeroVals, n.Results)
return false
}
return true
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@

-- a.go --
package a

import (
"fmt"
"encoding/json"
)

func F() error {
a, err := json.Marshal(0) //@codeaction("a", end, "refactor.extract.function", out)
if err != nil {
return fmt.Errorf("1: %w", err)
}
b, err := json.Marshal(0)
if err != nil {
return fmt.Errorf("2: %w", err)
} //@loc(end, "}")
fmt.Println(a, b)
return nil
}

-- @out/a.go --
package a

import (
"fmt"
"encoding/json"
)

func F() error {
//@codeaction("a", end, "refactor.extract.function", out)
a, b, shouldReturn, returnValue := newFunction()
if shouldReturn {
return returnValue
} //@loc(end, "}")
fmt.Println(a, b)
return nil
}

func newFunction() ([]byte, []byte, bool, error) {
a, err := json.Marshal(0)
if err != nil {
return nil, nil, true, fmt.Errorf("1: %w", err)
}
b, err := json.Marshal(0)
if err != nil {
return nil, nil, true, fmt.Errorf("2: %w", err)
}
return a, b, false, nil
}

4 changes: 3 additions & 1 deletion internal/analysisinternal/analysis.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,11 @@ func ZeroValue(f *ast.File, pkg *types.Package, typ types.Type) ast.Expr {
case under.Info()&types.IsNumeric != 0:
return &ast.BasicLit{Kind: token.INT, Value: "0"}
case under.Info()&types.IsBoolean != 0:
return &ast.Ident{Name: "false"}
return ast.NewIdent("false")
case under.Info()&types.IsString != 0:
return &ast.BasicLit{Kind: token.STRING, Value: `""`}
case under == types.Typ[types.Invalid]:
return nil
default:
panic(fmt.Sprintf("unknown basic type %v", under))
}
Expand Down

0 comments on commit 47a5f7d

Please sign in to comment.