Skip to content

Commit

Permalink
Merge pull request #15 from CallMeGreg/callmegreg/more-validation
Browse files Browse the repository at this point in the history
Added repo name validation and settings validation
  • Loading branch information
CallMeGreg authored Sep 16, 2024
2 parents 436b195 + 671493a commit 68f845f
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 44 deletions.
20 changes: 6 additions & 14 deletions cmd/alerts.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,19 @@ func runAlerts(cmd *cobra.Command, args []string) (err error) {
// set scope & target based on the flag that was used:
scope, target, err := getScopeAndTarget()
if err != nil {
fmt.Println(err)
return
return err
}

// set the API URL based on the target:
requestPath, err := createGitHubSecretAlertsAPIPath(scope, target)
if err != nil {
fmt.Println(err)
return
return err
}

// update the URL to include query parameters based on specified flags:
parsedURL, err := url.Parse(requestPath)
if err != nil {
fmt.Println(err)
return
return err
}
values := parsedURL.Query()

Expand All @@ -49,7 +46,7 @@ func runAlerts(cmd *cobra.Command, args []string) (err error) {
}
per_page_int, err := strconv.Atoi(per_page)
if err != nil {
fmt.Println(err)
return err
}
values.Set("per_page", per_page)
// if provider was specified, filter results. Otherwise, return all results:
Expand All @@ -73,21 +70,17 @@ func runAlerts(cmd *cobra.Command, args []string) (err error) {
opts := setOptions()
client, err := api.NewRESTClient(opts)
if err != nil {
fmt.Println(err)
return err
}

for page := 1; page <= pages; page++ {
fmt.Println("Processing page: " + strconv.Itoa(page))
_, nextPage, err := callGitHubAPI(client, requestPath, &pageOfSecretAlerts, GET)
if err != nil {
fmt.Println("ERROR: Unable to get alerts for target: " + requestPath)
return err
}
for _, secretAlert := range pageOfSecretAlerts {
// add each secret alert in the response page to allSecretAlerts array
allSecretAlerts = append(allSecretAlerts, secretAlert)
}
// add each secret alert in the response page to allSecretAlerts array
allSecretAlerts = append(allSecretAlerts, pageOfSecretAlerts...)
var hasNextPage bool
if requestPath, hasNextPage = findNextPage(nextPage); !hasNextPage {
break
Expand All @@ -114,7 +107,6 @@ func runAlerts(cmd *cobra.Command, args []string) (err error) {
if len(sortedAlerts) > 0 && csvReport {
err = generateCSVReport(sortedAlerts, scope, false)
if err != nil {
fmt.Println(err)
return err
}
}
Expand Down
95 changes: 81 additions & 14 deletions cmd/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func createGitHubSecretAlertsAPIPath(scope string, target string) (apiURL string
replacer := strings.NewReplacer("{owner}", owner, "{repo}", repo)
apiURL = replacer.Replace(repositoryAlertsURL)
default:
err = fmt.Errorf("Invalid API target.")
err = fmt.Errorf("invalid API target")
}
return apiURL, err
}
Expand Down Expand Up @@ -159,13 +159,11 @@ func callGitHubAPI(client *api.RESTClient, requestPath string, parseType interfa
nextPage := response.Header.Get("Link")
responseBody, err := io.ReadAll(response.Body)
if err != nil {
fmt.Println("ERROR: Unable to read next page link")
return response.StatusCode, nextPage, err
}

err = decodeJSONResponse(responseBody, &parseType)
if err != nil {
fmt.Println("ERROR: Unable to decode JSON response")
return response.StatusCode, nextPage, err
}

Expand All @@ -176,7 +174,6 @@ func decodeJSONResponse(body []byte, parseType interface{}) error {
decoder := json.NewDecoder(bytes.NewReader(body))
err := decoder.Decode(&parseType)
if err != nil {
fmt.Println("ERROR: Unable to decode JSON response")
return err
}

Expand All @@ -200,11 +197,11 @@ func validateProvider(provider string) (err error) {
providerList = append(providerList, key)
}
for _, item := range providerList {
if strings.ToLower(item) == strings.ToLower(provider) {
if strings.EqualFold(item, provider) {
return nil
}
}
err = fmt.Errorf(Red("Invalid provider: " + provider + "\nValid providers are: " + strings.Join(providerList, ", ")))
err = fmt.Errorf("Invalid provider: " + provider + "\nValid providers are: " + strings.Join(providerList, ", "))
return err
}

Expand Down Expand Up @@ -250,6 +247,20 @@ func getScopeAndTarget() (scope string, target string, err error) {
target = organization
} else if repository != "" {
scope = "repository"
repoPattern := regexp.MustCompile(`^[^/]+/[^/]+$`)
if !repoPattern.MatchString(repository) {
err = errors.New("repository must follow the format 'owner/repository'")
return "", "", err
}
// check if secret scanning is enabled for the repository:
secretScanningEnabled, err := checkSecretScanningSetting(repository)
if err != nil {
return "", "", err
}
if !secretScanningEnabled {
err = errors.New("Secret scanning is not enabled for the repository: " + repository)
return "", "", err
}
target = repository
}
return scope, target, err
Expand Down Expand Up @@ -341,11 +352,10 @@ func generateCSVReport(alerts []Alert, scope string, validity_check bool) (err e
now := time.Now()
// Format the time as YYYYMMDD-HHMMSS
timestamp := now.Format("20060102-150405")
filename := "secretscanningreport-" + scope + "-" + timestamp + ".csv"
filename := "SecretScanningReport-" + scope + "-" + timestamp + ".csv"
// Create a CSV file
file, err := os.Create(filename)
if err != nil {
fmt.Println("ERROR: Error creating CSV file.")
return err
}
defer file.Close()
Expand Down Expand Up @@ -381,7 +391,6 @@ func generateCSVReport(alerts []Alert, scope string, validity_check bool) (err e
counter++
}
if err := writer.Error(); err != nil {
fmt.Println("ERROR: Error writing to CSV file.")
return err
}
fmt.Println(Blue("CSV report generated: " + filename))
Expand Down Expand Up @@ -420,19 +429,17 @@ func verifyAlerts(alerts []Alert) (alertsOutput []Alert, err error) {
}
client, err := api.NewHTTPClient(opts)
if err != nil {
fmt.Println("ERROR: Unable to create HTTP client.")
return alerts, err
}
// send a request to the validation endpoint:
var response *http.Response
if secret_validation_method == "POST" {
var body io.Reader
req, err := http.NewRequest("POST", alert.Validity_endpoint, body)
req, _ := http.NewRequest("POST", alert.Validity_endpoint, body)
req.Header.Set("Authorization", "Bearer "+alert.Secret)
req.Header.Set("Content-Type", secret_validation_content_type)
req.Header.Set("User-Agent", "gh-secret-scanning")
response, err = client.Do(req)
// response, err = client.Post(alert.Validity_endpoint, secret_validation_content_type, body)
if err != nil {
fmt.Println("WARNING: Unable to send " + secret_validation_method + " request to " + alert.Validity_endpoint)
continue
Expand Down Expand Up @@ -462,7 +469,7 @@ func verifyAlerts(alerts []Alert) (alertsOutput []Alert, err error) {
} else {
alert.Validity_boolean = false
}
if provider == "github" && alert.Validity_boolean == false && host != "github.com" {
if provider == "github" && !alert.Validity_boolean && host != "github.com" {
// also confirm validity with the provided GitHub Enterprise Server API:
alert = checkEnterpriseServerAPI(alert, client, secret_validation_method, secret_validation_content_type)
if alert.Validity_response_code == "200" {
Expand Down Expand Up @@ -510,7 +517,7 @@ func checkForExpectedBody(response *http.Response, expected_body_key string, exp
func checkEnterpriseServerAPI(alert Alert, client *http.Client, secret_validation_method string, secret_validation_content_type string) (alertOutput Alert) {
enterprise_server_api_endpoint := "https://" + host + "/api/v3/"
// create a new http request:
req, err := http.NewRequest("GET", enterprise_server_api_endpoint, nil)
req, _ := http.NewRequest("GET", enterprise_server_api_endpoint, nil)
req.Header.Set("Authorization", "Bearer "+alert.Secret)
req.Header.Set("Content-Type", secret_validation_content_type)
req.Header.Set("User-Agent", "gh-secret-scanning")
Expand Down Expand Up @@ -572,3 +579,63 @@ func createIssuesForValidAlerts(alerts []Alert) (err error) {
fmt.Println(Blue("Created " + strconv.Itoa(issue_count) + " issue(s)."))
return err
}

func checkSecretScanningSetting(repository string) (secretScanningEnabled bool, err error) {
// create a new client for the request:
var opts api.ClientOptions
opts.Headers = map[string]string{
"User-Agent": "gh-secret-scanning",
}
client, err := api.NewHTTPClient(opts)
if err != nil {
return false, fmt.Errorf("unable to create HTTP client")
}
// make a GET request to the repository API:
response, err := client.Get("https://api.github.com/repos/" + repository)

if err != nil {
return false, fmt.Errorf("unable to get repository information for " + repository)
}
defer response.Body.Close()

// check if the response has a 200 status code:
if response.StatusCode != http.StatusOK {
return false, fmt.Errorf("non-200 status code received: %d", response.StatusCode)
}

// parse the JSON response
var result map[string]interface{}
if err := json.NewDecoder(response.Body).Decode(&result); err != nil {
return false, fmt.Errorf("unable to parse JSON response")
}

// navigate to the required fields
securityAndAnalysis, ok := result["security_and_analysis"].(map[string]interface{})
if !ok {
return false, fmt.Errorf("security_and_analysis field not found")
}

advancedSecurity, ok := securityAndAnalysis["advanced_security"].(map[string]interface{})
if !ok {
return false, fmt.Errorf("advanced_security field not found")
}

status, ok := advancedSecurity["status"].(string)
if !ok || status != "enabled" {
return false, fmt.Errorf("Advanced Security is not enabled for " + repository + Yellow("\nEnable GitHub Advanced Security for the repository here: https://github.com/"+repository+"/settings/security_analysis"))
}

secretScanning, ok := securityAndAnalysis["secret_scanning"].(map[string]interface{})
if !ok {
return false, fmt.Errorf("secret_scanning field not found")
}

secretScanningStatus, ok := secretScanning["status"].(string)
if !ok || secretScanningStatus != "enabled" {
return false, fmt.Errorf("secret scanning is not enabled for " + repository + Yellow("\nEnable secret scanning for the repository here: https://github.com/"+repository+"/settings/security_analysis"))
} else {
secretScanningEnabled = true
}

return secretScanningEnabled, nil
}
6 changes: 5 additions & 1 deletion cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package cmd
import (
"fmt"
"log"
"strings"

"github.com/spf13/cobra"
)
Expand Down Expand Up @@ -37,13 +38,16 @@ func init() {
// disable completion subcommand:
rootCmd.CompletionOptions.DisableDefaultCmd = true

// silence usage output on error:
rootCmd.SilenceUsage = true

rootCmd.PersistentPreRunE = func(cmd *cobra.Command, args []string) (err error) {
// warn user about --show-secret flag:
if secret {
fmt.Println(Yellow("WARNING: --show-secret flag is enabled. Full secret values will be displayed in PLAIN TEXT in the output. Would you like to continue? (y/n)"))
var response string
fmt.Scanln(&response)
if response != "y" {
if strings.ToLower(response) != "y" && strings.ToLower(response) != "yes" {
log.Fatal("Exiting...")
}
}
Expand Down
25 changes: 10 additions & 15 deletions cmd/verify.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,19 @@ func runVerify(cmd *cobra.Command, args []string) (err error) {
// set scope & target based on the flag that was used:
scope, target, err := getScopeAndTarget()
if err != nil {
fmt.Println(err)
return
return err
}

// set the API URL based on the target:
requestPath, err := createGitHubSecretAlertsAPIPath(scope, target)
if err != nil {
fmt.Println(err)
return
return err
}

// update the URL to include query parameters based on specified flags:
parsedURL, err := url.Parse(requestPath)
if err != nil {
fmt.Println(err)
return
return err
}
values := parsedURL.Query()
var per_page string
Expand All @@ -54,7 +51,7 @@ func runVerify(cmd *cobra.Command, args []string) (err error) {
}
per_page_int, err := strconv.Atoi(per_page)
if err != nil {
fmt.Println(err)
return err
}
values.Set("per_page", per_page)
// if provider was specified, filter results for just that provider. Otherwise, target all supported providers:
Expand All @@ -75,21 +72,17 @@ func runVerify(cmd *cobra.Command, args []string) (err error) {
opts := setOptions()
client, err := api.NewRESTClient(opts)
if err != nil {
fmt.Println(err)
return err
}

for page := 1; page <= pages; page++ {
fmt.Println("Processing page: " + strconv.Itoa(page))
_, nextPage, err := callGitHubAPI(client, requestPath, &pageOfSecretAlerts, GET)
if err != nil {
fmt.Println("ERROR: Unable to get alerts for target: " + requestPath)
return err
}
for _, secretAlert := range pageOfSecretAlerts {
// add each secret alert in the response page to allSecretAlerts array
allSecretAlerts = append(allSecretAlerts, secretAlert)
}
// add each secret alert in the response page to allSecretAlerts array
allSecretAlerts = append(allSecretAlerts, pageOfSecretAlerts...)
var hasNextPage bool
if requestPath, hasNextPage = findNextPage(nextPage); !hasNextPage {
break
Expand All @@ -110,13 +103,14 @@ func runVerify(cmd *cobra.Command, args []string) (err error) {
// verify which secret alerts are confirmed valid:
verifiedAlerts, err := verifyAlerts(sortedAlerts)
if err != nil {
// print to console
fmt.Println("WARNING: issues encountered while sending verify requests.")
}

// pretty print with validity status
if !quiet {
prettyPrintAlerts(verifiedAlerts, true)
}

// optionally generate a csv report of the results:
if len(sortedAlerts) > 0 && csvReport {
err = generateCSVReport(sortedAlerts, scope, true)
Expand All @@ -125,6 +119,7 @@ func runVerify(cmd *cobra.Command, args []string) (err error) {
return err
}
}

// optionally create an issue for each repository that contains at least one valid secret alert:
if createIssues {
err = createIssuesForValidAlerts(verifiedAlerts)
Expand All @@ -133,5 +128,5 @@ func runVerify(cmd *cobra.Command, args []string) (err error) {
return err
}
}
return err
return
}

0 comments on commit 68f845f

Please sign in to comment.