diff --git a/README.md b/README.md index b36f422..2cd4f14 100644 --- a/README.md +++ b/README.md @@ -72,7 +72,14 @@ constant, but this code has yet to be written. The second sort of false positive is based on a limitation in the sort of analysis SafeSQL performs: there are many safe SQL statements which are not feasible (or not possible) to represent as compile-time constants. More advanced -static analysis techniques (such as taint analysis) or user-provided safety -annotations would be able to reduce the number of false positives, but this is -expected to be a significant undertaking. +static analysis techniques (such as taint analysis). + +In order to ignore false positives, add the following comment to the line before +or the same line as the statement: +``` +//nolint:safesql +``` + +Even if a statement is ignored it will still be logged, but will not cause +safesql to exit with a status code of 1 if all found statements are ignored. diff --git a/safesql.go b/safesql.go index adf8bb8..8a13193 100644 --- a/safesql.go +++ b/safesql.go @@ -4,11 +4,14 @@ package main import ( + "bufio" "flag" "fmt" "go/build" + "go/token" "go/types" "os" + "sort" "path/filepath" "strings" @@ -20,6 +23,8 @@ import ( "golang.org/x/tools/go/ssa/ssautil" ) +const IgnoreComment = "nolint:safesql" + type sqlPackage struct { packageName string paramNames []string @@ -133,16 +138,37 @@ func main() { fmt.Printf("Found %d potentially unsafe SQL statements:\n", len(bad)) } + potentialBadStatements := []token.Position{} for _, ci := range bad { - pos := p.Fset.Position(ci.Pos()) - fmt.Printf("- %s\n", pos) + potentialBadStatements = append(potentialBadStatements, p.Fset.Position(ci.Pos())) + } + + issues, err := CheckIssues(potentialBadStatements) + if err != nil { + fmt.Printf("error when checking for ignore comments: %v\n", err) + os.Exit(2) } + if verbose { fmt.Println("Please ensure that all SQL queries you use are compile-time constants.") fmt.Println("You should always use parameterized queries or prepared statements") fmt.Println("instead of building queries from strings.") } - os.Exit(1) + + hasNonIgnoredUnsafeStatement := false + + for _, issue := range issues { + if issue.ignored { + fmt.Printf("- %s is potentially unsafe but ignored by comment\n", issue.statement) + } else { + fmt.Printf("- %s\n", issue.statement) + hasNonIgnoredUnsafeStatement = true + } + } + + if hasNonIgnoredUnsafeStatement { + os.Exit(1) + } } // QueryMethod represents a method on a type which has a string parameter named @@ -154,6 +180,70 @@ type QueryMethod struct { Param int } +type Issue struct { + statement token.Position + ignored bool +} + +// CheckIssues checks lines to see if the line before or the current line has an ignore comment and marks those +// statements that have the ignore comment on the current line or the line before +func CheckIssues(lines []token.Position) ([]Issue, error) { + files := make(map[string][]token.Position) + + for _, line := range lines { + files[line.Filename] = append(files[line.Filename], line) + } + + issues := []Issue{} + + for file, linesInFile := range files { + // ensure we have the lines in ascending order + sort.Slice(linesInFile, func(i, j int) bool { return linesInFile[i].Line < linesInFile[j].Line }) + + f, err := os.Open(file) + if err != nil { + return nil, err + } + defer f.Close() + s := bufio.NewScanner(f) + + currentLine := 0 + for _, line := range linesInFile { + // check the line before the problematic statement first + potentialCommentLine := line.Line - 1 + + // if there are 2 statements back to back that are ignored then we don't want to check the previous so skip + // ie + // db.Query(query, args) //IsSqlSafe + // db.Query(query2, args2) + if currentLine != potentialCommentLine { + for ; currentLine < potentialCommentLine; currentLine++ { + if !s.Scan() { + return nil, s.Err() + } + } + if HasIgnoreComment(s.Text()) { + issues = append(issues, Issue{statement: line, ignored: true}) + continue + } + } + + // check the line of the statement + if !s.Scan() { + return nil, s.Err() + } + isIgnored := HasIgnoreComment(s.Text()) + issues = append(issues, Issue{statement: line, ignored: isIgnored}) + } + } + + return issues, nil +} + +func HasIgnoreComment(line string) bool { + return strings.HasSuffix(strings.TrimSpace(line), IgnoreComment) +} + // FindQueryMethods locates all methods in the given package (assumed to be // package database/sql) with a string parameter named "query". func FindQueryMethods(sqlPackages sqlPackage, sql *types.Package, ssa *ssa.Program) []*QueryMethod {