diff --git a/package15.go b/package15.go new file mode 100644 index 0000000..4c05efc --- /dev/null +++ b/package15.go @@ -0,0 +1,7 @@ +// +build !go1.6 + +package main + +import "os" + +var useVendor = os.Getenv("GO15VENDOREXPERIMENT") == "1" diff --git a/package16.go b/package16.go new file mode 100644 index 0000000..409faaf --- /dev/null +++ b/package16.go @@ -0,0 +1,7 @@ +// +build go1.6 + +package main + +import "os" + +var useVendor = os.Getenv("GO15VENDOREXPERIMENT") == "0" || os.Getenv("GO15VENDOREXPERIMENT") == "" diff --git a/safesql.go b/safesql.go index 3c4664a..3071879 100644 --- a/safesql.go +++ b/safesql.go @@ -6,8 +6,11 @@ package main import ( "flag" "fmt" + "go/build" "go/types" "os" + "path/filepath" + "strings" "golang.org/x/tools/go/callgraph" "golang.org/x/tools/go/loader" @@ -32,7 +35,9 @@ func main() { os.Exit(2) } - c := loader.Config{} + c := loader.Config{ + FindPackage: FindPackage, + } c.Import("database/sql") for _, pkg := range pkgs { c.Import(pkg) @@ -198,3 +203,46 @@ func FindNonConstCalls(cg *callgraph.Graph, qms []*QueryMethod) []ssa.CallInstru return bad } + +// Deal with GO15VENDOREXPERIMENT +func FindPackage(ctxt *build.Context, path, dir string, mode build.ImportMode) (*build.Package, error) { + if !useVendor { + return ctxt.Import(path, dir, mode) + } + + // First, walk up the filesystem from dir looking for vendor directories + var vendorDir string + for tmp := dir; vendorDir == "" && tmp != "/"; tmp = filepath.Dir(tmp) { + dname := filepath.Join(tmp, "vendor", filepath.FromSlash(path)) + fd, err := os.Open(dname) + if err != nil { + continue + } + // Directories are only valid if they contain at least one file + // with suffix ".go" (this also ensures that the file descriptor + // we have is in fact a directory) + names, err := fd.Readdirnames(-1) + if err != nil { + continue + } + for _, name := range names { + if strings.HasSuffix(name, ".go") { + vendorDir = filepath.ToSlash(dname) + break + } + } + } + + if vendorDir != "" { + pkg, err := ctxt.ImportDir(vendorDir, mode) + if err != nil { + return nil, err + } + // Go tries to derive a valid import path for the package, but + // it's wrong (it includes "/vendor/"). Overwrite it here. + pkg.ImportPath = path + return pkg, nil + } + + return ctxt.Import(path, dir, mode) +}