diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..723ef36 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +.idea \ No newline at end of file diff --git a/README.md b/README.md index b36f422..2ef1191 100644 --- a/README.md +++ b/README.md @@ -72,7 +72,21 @@ constant, but this code has yet to be written. The second sort of false positive is based on a limitation in the sort of analysis SafeSQL performs: there are many safe SQL statements which are not feasible (or not possible) to represent as compile-time constants. More advanced -static analysis techniques (such as taint analysis) or user-provided safety -annotations would be able to reduce the number of false positives, but this is -expected to be a significant undertaking. +static analysis techniques (such as taint analysis). +In order to ignore false positives, add the following comment to the line before +or the same line as the statement: +``` +//nolint:safesql +``` + +Even if a statement is ignored it will still be logged, but will not cause +safesql to exit with a status code of 1 if all found statements are ignored. + +Adding tests +--------------- +To add a test create a new director in `testdata` and add a go program in the +folder you created, for an example look at `testdata/multiple_files`. + +After adding a new directory and go program, add an entry to the tests map in +`safesql_test.go`, which will run the tests against the program added. \ No newline at end of file diff --git a/safesql.go b/safesql.go index adf8bb8..9b2f543 100644 --- a/safesql.go +++ b/safesql.go @@ -7,8 +7,11 @@ import ( "flag" "fmt" "go/build" + "go/token" "go/types" + "io/ioutil" "os" + "sort" "path/filepath" "strings" @@ -20,6 +23,8 @@ import ( "golang.org/x/tools/go/ssa/ssautil" ) +const IgnoreComment = "//nolint:safesql" + type sqlPackage struct { packageName string paramNames []string @@ -133,16 +138,37 @@ func main() { fmt.Printf("Found %d potentially unsafe SQL statements:\n", len(bad)) } + potentialBadStatements := []token.Position{} for _, ci := range bad { - pos := p.Fset.Position(ci.Pos()) - fmt.Printf("- %s\n", pos) + potentialBadStatements = append(potentialBadStatements, p.Fset.Position(ci.Pos())) + } + + issues, err := CheckIssues(potentialBadStatements) + if err != nil { + fmt.Printf("error when checking for ignore comments: %v\n", err) + os.Exit(2) } + if verbose { fmt.Println("Please ensure that all SQL queries you use are compile-time constants.") fmt.Println("You should always use parameterized queries or prepared statements") fmt.Println("instead of building queries from strings.") } - os.Exit(1) + + hasNonIgnoredUnsafeStatement := false + + for _, issue := range issues { + if issue.ignored { + fmt.Printf("- %s is potentially unsafe but ignored by comment\n", issue.statement) + } else { + fmt.Printf("- %s\n", issue.statement) + hasNonIgnoredUnsafeStatement = true + } + } + + if hasNonIgnoredUnsafeStatement { + os.Exit(1) + } } // QueryMethod represents a method on a type which has a string parameter named @@ -154,6 +180,59 @@ type QueryMethod struct { Param int } +type Issue struct { + statement token.Position + ignored bool +} + +// CheckIssues checks lines to see if the line before or the current line has an ignore comment and marks those +// statements that have the ignore comment on the current line or the line before +func CheckIssues(lines []token.Position) ([]Issue, error) { + files := make(map[string][]token.Position) + + for _, line := range lines { + files[line.Filename] = append(files[line.Filename], line) + } + + issues := []Issue{} + + for file, linesInFile := range files { + // ensure we have the lines in ascending order + sort.Slice(linesInFile, func(i, j int) bool { return linesInFile[i].Line < linesInFile[j].Line }) + + data, err := ioutil.ReadFile(file) + if err != nil { + return nil, err + } + fileLines := strings.Split(string(data), "\n") + + for _, line := range linesInFile { + // check the line before the problematic statement first + potentialCommentLine := line.Line - 2 + + // check only if the previous line is strictly a line that begins with + // the ignore comment + if 0 <= potentialCommentLine && BeginsWithComment(fileLines[potentialCommentLine]) { + issues = append(issues, Issue{statement: line, ignored: true}) + continue + } + + isIgnored := HasIgnoreComment(fileLines[line.Line-1]) + issues = append(issues, Issue{statement: line, ignored: isIgnored}) + } + } + + return issues, nil +} + +func BeginsWithComment(line string) bool { + return strings.HasPrefix(strings.TrimSpace(line), IgnoreComment) +} + +func HasIgnoreComment(line string) bool { + return strings.HasSuffix(strings.TrimSpace(line), IgnoreComment) +} + // FindQueryMethods locates all methods in the given package (assumed to be // package database/sql) with a string parameter named "query". func FindQueryMethods(sqlPackages sqlPackage, sql *types.Package, ssa *ssa.Program) []*QueryMethod { diff --git a/safesql_test.go b/safesql_test.go new file mode 100644 index 0000000..a2f53c8 --- /dev/null +++ b/safesql_test.go @@ -0,0 +1,95 @@ +package main + +import ( + "go/token" + "path" + "reflect" + "sort" + "testing" +) + +const testDir = "./testdata" + +// TestCheckIssues attempts to see if issues are ignored or not and annotates the issues if they are ignored +func TestCheckIssues(t *testing.T) { + tests := map[string]struct{ + tokens []token.Position + expected []Issue + }{ + "all_ignored": { + tokens: []token.Position{ + token.Position{Filename:"main.go", Line: 23, Column: 5 }, + token.Position{Filename:"main.go", Line: 29, Column: 5 }, + }, + expected: []Issue{ + Issue{statement: token.Position{Filename:"main.go", Line: 23, Column: 5 }, ignored: true}, + Issue{statement: token.Position{Filename:"main.go", Line: 29, Column: 5 }, ignored: true}, + }, + }, + "ignored_back_to_back": { + tokens: []token.Position{ + token.Position{Filename:"main.go", Line: 22, Column: 5 }, + token.Position{Filename:"main.go", Line: 23, Column: 5 }, + }, + expected: []Issue{ + Issue{statement: token.Position{Filename:"main.go", Line: 22, Column: 5 }, ignored: true}, + Issue{statement: token.Position{Filename:"main.go", Line: 23, Column: 5 }, ignored: false}, + }, + }, + "single_ignored": { + tokens: []token.Position{ + token.Position{Filename:"main.go", Line: 23, Column: 5 }, + token.Position{Filename:"main.go", Line: 29, Column: 5 }, + }, + expected: []Issue{ + Issue{statement: token.Position{Filename:"main.go", Line: 23, Column: 5 }, ignored: true}, + Issue{statement: token.Position{Filename:"main.go", Line: 29, Column: 5 }, ignored: false}, + }, + }, + "multiple_files": { + tokens: []token.Position{ + token.Position{Filename:"main.go", Line: 23, Column: 5 }, + token.Position{Filename:"main.go", Line: 24, Column: 5 }, + token.Position{Filename:"helpers.go", Line: 16, Column: 5 }, + token.Position{Filename:"helpers.go", Line: 17, Column: 5 }, + }, + expected: []Issue{ + Issue{statement: token.Position{Filename:"main.go", Line: 23, Column: 5 }, ignored: true}, + Issue{statement: token.Position{Filename:"main.go", Line: 24, Column: 5 }, ignored: true}, + Issue{statement: token.Position{Filename:"helpers.go", Line: 16, Column: 5 }, ignored: true}, + Issue{statement: token.Position{Filename:"helpers.go", Line: 17, Column: 5 }, ignored: true}, + }, + }, + } + + for name, expectations := range tests { + t.Run(name, func(t *testing.T) { + for idx, pos := range expectations.tokens { + expectations.tokens[idx].Filename = path.Join(testDir, name, pos.Filename) + } + for idx, issue := range expectations.expected { + expectations.expected[idx].statement.Filename = path.Join(testDir, name, issue.statement.Filename) + } + + issues, err := CheckIssues(expectations.tokens) + if err != nil { + t.Fatal(err) + } + + if len(issues) != len(expectations.expected) { + t.Fatal("The expected number of issues was not found") + } + + // sort to ensure we have the same issue order + sort.Slice(expectations.expected, func(i, j int) bool {return expectations.expected[i].statement.Filename < expectations.expected[j].statement.Filename }) + sort.Slice(issues, func(i, j int) bool {return issues[i].statement.Filename < issues[j].statement.Filename }) + + for idx, expected := range expectations.expected { + actual := issues[idx] + if !reflect.DeepEqual(actual, expected) { + t.Errorf("The actual value of %v did not match the expected %v", actual, expected) + } + } + }) + } +} \ No newline at end of file diff --git a/testdata/all_ignored/main.go b/testdata/all_ignored/main.go new file mode 100644 index 0000000..6fbb7c0 --- /dev/null +++ b/testdata/all_ignored/main.go @@ -0,0 +1,35 @@ +package main + +import ( + "database/sql" + "fmt" +) + +func main() { + fmt.Println(query("'test' OR 1=1")) +} + +const GetAllQuery = "SELECT COUNT(*) FROM t WHERE arg=%s" + +// All issues are ignored in this test +func query(arg string) error { + db, err := sql.Open("postgres", "postgresql://test:test@test") + if err != nil { + return err + } + + query := fmt.Sprintf(GetAllQuery, arg) + //nolint:safesql + row := db.QueryRow(query) + var count int + if err := row.Scan(&count); err != nil { + return err + } + + row = db.QueryRow(fmt.Sprintf(GetAllQuery, "Catch me please?")) //nolint:safesql + if err := row.Scan(&count); err != nil { + return err + } + + return nil +} diff --git a/testdata/ignored_back_to_back/main.go b/testdata/ignored_back_to_back/main.go new file mode 100644 index 0000000..a49b98f --- /dev/null +++ b/testdata/ignored_back_to_back/main.go @@ -0,0 +1,27 @@ +package main + +import ( + "database/sql" + "fmt" +) + +func main() { + fmt.Println(query("'test' OR 1=1")) +} + +const GetAllQuery = "SELECT COUNT(*) FROM t WHERE arg=%s" + +// For this test we expect the second QueryRow to be an issue even though the line before has a comment +func query(arg string) error { + db, err := sql.Open("postgres", "postgresql://test:test@test") + if err != nil { + return err + } + + query := fmt.Sprintf(GetAllQuery, arg) + _ := db.QueryRow(query) //nolint:safesql + _ := db.QueryRow(fmt.Sprintf(GetAllQuery, "Catch me please?")) + + + return nil +} diff --git a/testdata/multiple_files/helpers.go b/testdata/multiple_files/helpers.go new file mode 100644 index 0000000..551b546 --- /dev/null +++ b/testdata/multiple_files/helpers.go @@ -0,0 +1,21 @@ +package main + +import ( + "database/sql" + "fmt" +) + +// For this test we expect the second QueryRow to be an issue even though the line before has a comment +func query(arg string) error { + db, err := sql.Open("postgres", "postgresql://test:test@test") + if err != nil { + return err + } + + query := fmt.Sprintf(GetAllQuery, arg) + _ := db.QueryRow(query) //nolint:safesql + _ := db.QueryRow(fmt.Sprintf(GetAllQuery, "Catch me please?")) //nolint:safesql + + + return nil +} \ No newline at end of file diff --git a/testdata/multiple_files/main.go b/testdata/multiple_files/main.go new file mode 100644 index 0000000..fc0f178 --- /dev/null +++ b/testdata/multiple_files/main.go @@ -0,0 +1,28 @@ +package main + +import ( + "database/sql" + "fmt" +) + +func main() { + fmt.Println(query("'test' OR 1=1")) + fmt.Println(query2("'test' OR 1=1")) +} + +const GetAllQuery = "SELECT COUNT(*) FROM t WHERE arg=%s" + +// For this test we expect the second QueryRow to be an issue even though the line before has a comment +func query2(arg string) error { + db, err := sql.Open("postgres", "postgresql://test:test@test") + if err != nil { + return err + } + + query := fmt.Sprintf(GetAllQuery, arg) + _ := db.QueryRow(query) //nolint:safesql + _ := db.QueryRow(fmt.Sprintf(GetAllQuery, "Catch me please?")) //nolint:safesql + + + return nil +} \ No newline at end of file diff --git a/testdata/single_ignored/main.go b/testdata/single_ignored/main.go new file mode 100644 index 0000000..3281eb6 --- /dev/null +++ b/testdata/single_ignored/main.go @@ -0,0 +1,35 @@ +package main + +import ( + "database/sql" + "fmt" +) + +func main() { + fmt.Println(query("'test' OR 1=1")) +} + +const GetAllQuery = "SELECT COUNT(*) FROM t WHERE arg=%s" + +// For this test we expect the second QueryRow to have an SQL injection issue +func query(arg string) error { + db, err := sql.Open("postgres", "postgresql://test:test@test") + if err != nil { + return err + } + + query := fmt.Sprintf(GetAllQuery, arg) + //nolint:safesql + row := db.QueryRow(query) + var count int + if err := row.Scan(&count); err != nil { + return err + } + + row = db.QueryRow(fmt.Sprintf(GetAllQuery, "Catch me please?")) + if err := row.Scan(&count); err != nil { + return err + } + + return nil +}