diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..723ef36 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +.idea \ No newline at end of file diff --git a/README.md b/README.md index 2cd4f14..2ef1191 100644 --- a/README.md +++ b/README.md @@ -83,3 +83,10 @@ or the same line as the statement: Even if a statement is ignored it will still be logged, but will not cause safesql to exit with a status code of 1 if all found statements are ignored. +Adding tests +--------------- +To add a test create a new director in `testdata` and add a go program in the +folder you created, for an example look at `testdata/multiple_files`. + +After adding a new directory and go program, add an entry to the tests map in +`safesql_test.go`, which will run the tests against the program added. \ No newline at end of file diff --git a/safesql_test.go b/safesql_test.go new file mode 100644 index 0000000..a2f53c8 --- /dev/null +++ b/safesql_test.go @@ -0,0 +1,95 @@ +package main + +import ( + "go/token" + "path" + "reflect" + "sort" + "testing" +) + +const testDir = "./testdata" + +// TestCheckIssues attempts to see if issues are ignored or not and annotates the issues if they are ignored +func TestCheckIssues(t *testing.T) { + tests := map[string]struct{ + tokens []token.Position + expected []Issue + }{ + "all_ignored": { + tokens: []token.Position{ + token.Position{Filename:"main.go", Line: 23, Column: 5 }, + token.Position{Filename:"main.go", Line: 29, Column: 5 }, + }, + expected: []Issue{ + Issue{statement: token.Position{Filename:"main.go", Line: 23, Column: 5 }, ignored: true}, + Issue{statement: token.Position{Filename:"main.go", Line: 29, Column: 5 }, ignored: true}, + }, + }, + "ignored_back_to_back": { + tokens: []token.Position{ + token.Position{Filename:"main.go", Line: 22, Column: 5 }, + token.Position{Filename:"main.go", Line: 23, Column: 5 }, + }, + expected: []Issue{ + Issue{statement: token.Position{Filename:"main.go", Line: 22, Column: 5 }, ignored: true}, + Issue{statement: token.Position{Filename:"main.go", Line: 23, Column: 5 }, ignored: false}, + }, + }, + "single_ignored": { + tokens: []token.Position{ + token.Position{Filename:"main.go", Line: 23, Column: 5 }, + token.Position{Filename:"main.go", Line: 29, Column: 5 }, + }, + expected: []Issue{ + Issue{statement: token.Position{Filename:"main.go", Line: 23, Column: 5 }, ignored: true}, + Issue{statement: token.Position{Filename:"main.go", Line: 29, Column: 5 }, ignored: false}, + }, + }, + "multiple_files": { + tokens: []token.Position{ + token.Position{Filename:"main.go", Line: 23, Column: 5 }, + token.Position{Filename:"main.go", Line: 24, Column: 5 }, + token.Position{Filename:"helpers.go", Line: 16, Column: 5 }, + token.Position{Filename:"helpers.go", Line: 17, Column: 5 }, + }, + expected: []Issue{ + Issue{statement: token.Position{Filename:"main.go", Line: 23, Column: 5 }, ignored: true}, + Issue{statement: token.Position{Filename:"main.go", Line: 24, Column: 5 }, ignored: true}, + Issue{statement: token.Position{Filename:"helpers.go", Line: 16, Column: 5 }, ignored: true}, + Issue{statement: token.Position{Filename:"helpers.go", Line: 17, Column: 5 }, ignored: true}, + }, + }, + } + + for name, expectations := range tests { + t.Run(name, func(t *testing.T) { + for idx, pos := range expectations.tokens { + expectations.tokens[idx].Filename = path.Join(testDir, name, pos.Filename) + } + for idx, issue := range expectations.expected { + expectations.expected[idx].statement.Filename = path.Join(testDir, name, issue.statement.Filename) + } + + issues, err := CheckIssues(expectations.tokens) + if err != nil { + t.Fatal(err) + } + + if len(issues) != len(expectations.expected) { + t.Fatal("The expected number of issues was not found") + } + + // sort to ensure we have the same issue order + sort.Slice(expectations.expected, func(i, j int) bool {return expectations.expected[i].statement.Filename < expectations.expected[j].statement.Filename }) + sort.Slice(issues, func(i, j int) bool {return issues[i].statement.Filename < issues[j].statement.Filename }) + + for idx, expected := range expectations.expected { + actual := issues[idx] + if !reflect.DeepEqual(actual, expected) { + t.Errorf("The actual value of %v did not match the expected %v", actual, expected) + } + } + }) + } +} \ No newline at end of file diff --git a/testdata/all_ignored/main.go b/testdata/all_ignored/main.go new file mode 100644 index 0000000..6fbb7c0 --- /dev/null +++ b/testdata/all_ignored/main.go @@ -0,0 +1,35 @@ +package main + +import ( + "database/sql" + "fmt" +) + +func main() { + fmt.Println(query("'test' OR 1=1")) +} + +const GetAllQuery = "SELECT COUNT(*) FROM t WHERE arg=%s" + +// All issues are ignored in this test +func query(arg string) error { + db, err := sql.Open("postgres", "postgresql://test:test@test") + if err != nil { + return err + } + + query := fmt.Sprintf(GetAllQuery, arg) + //nolint:safesql + row := db.QueryRow(query) + var count int + if err := row.Scan(&count); err != nil { + return err + } + + row = db.QueryRow(fmt.Sprintf(GetAllQuery, "Catch me please?")) //nolint:safesql + if err := row.Scan(&count); err != nil { + return err + } + + return nil +} diff --git a/testdata/ignored_back_to_back/main.go b/testdata/ignored_back_to_back/main.go new file mode 100644 index 0000000..a49b98f --- /dev/null +++ b/testdata/ignored_back_to_back/main.go @@ -0,0 +1,27 @@ +package main + +import ( + "database/sql" + "fmt" +) + +func main() { + fmt.Println(query("'test' OR 1=1")) +} + +const GetAllQuery = "SELECT COUNT(*) FROM t WHERE arg=%s" + +// For this test we expect the second QueryRow to be an issue even though the line before has a comment +func query(arg string) error { + db, err := sql.Open("postgres", "postgresql://test:test@test") + if err != nil { + return err + } + + query := fmt.Sprintf(GetAllQuery, arg) + _ := db.QueryRow(query) //nolint:safesql + _ := db.QueryRow(fmt.Sprintf(GetAllQuery, "Catch me please?")) + + + return nil +} diff --git a/testdata/multiple_files/helpers.go b/testdata/multiple_files/helpers.go new file mode 100644 index 0000000..551b546 --- /dev/null +++ b/testdata/multiple_files/helpers.go @@ -0,0 +1,21 @@ +package main + +import ( + "database/sql" + "fmt" +) + +// For this test we expect the second QueryRow to be an issue even though the line before has a comment +func query(arg string) error { + db, err := sql.Open("postgres", "postgresql://test:test@test") + if err != nil { + return err + } + + query := fmt.Sprintf(GetAllQuery, arg) + _ := db.QueryRow(query) //nolint:safesql + _ := db.QueryRow(fmt.Sprintf(GetAllQuery, "Catch me please?")) //nolint:safesql + + + return nil +} \ No newline at end of file diff --git a/testdata/multiple_files/main.go b/testdata/multiple_files/main.go new file mode 100644 index 0000000..fc0f178 --- /dev/null +++ b/testdata/multiple_files/main.go @@ -0,0 +1,28 @@ +package main + +import ( + "database/sql" + "fmt" +) + +func main() { + fmt.Println(query("'test' OR 1=1")) + fmt.Println(query2("'test' OR 1=1")) +} + +const GetAllQuery = "SELECT COUNT(*) FROM t WHERE arg=%s" + +// For this test we expect the second QueryRow to be an issue even though the line before has a comment +func query2(arg string) error { + db, err := sql.Open("postgres", "postgresql://test:test@test") + if err != nil { + return err + } + + query := fmt.Sprintf(GetAllQuery, arg) + _ := db.QueryRow(query) //nolint:safesql + _ := db.QueryRow(fmt.Sprintf(GetAllQuery, "Catch me please?")) //nolint:safesql + + + return nil +} \ No newline at end of file diff --git a/testdata/single_ignored/main.go b/testdata/single_ignored/main.go new file mode 100644 index 0000000..3281eb6 --- /dev/null +++ b/testdata/single_ignored/main.go @@ -0,0 +1,35 @@ +package main + +import ( + "database/sql" + "fmt" +) + +func main() { + fmt.Println(query("'test' OR 1=1")) +} + +const GetAllQuery = "SELECT COUNT(*) FROM t WHERE arg=%s" + +// For this test we expect the second QueryRow to have an SQL injection issue +func query(arg string) error { + db, err := sql.Open("postgres", "postgresql://test:test@test") + if err != nil { + return err + } + + query := fmt.Sprintf(GetAllQuery, arg) + //nolint:safesql + row := db.QueryRow(query) + var count int + if err := row.Scan(&count); err != nil { + return err + } + + row = db.QueryRow(fmt.Sprintf(GetAllQuery, "Catch me please?")) + if err := row.Scan(&count); err != nil { + return err + } + + return nil +}