diff --git a/internal/commands/scan.go b/internal/commands/scan.go index 56dca2260..e39a16b13 100644 --- a/internal/commands/scan.go +++ b/internal/commands/scan.go @@ -14,6 +14,7 @@ import ( "path" "path/filepath" "reflect" + "slices" "strconv" "strings" "time" @@ -1648,7 +1649,8 @@ func runCreateScanCommand( return err } - err = applyThreshold(cmd, resultsWrapper, exportWrapper, scanResponseModel, thresholdMap) + err = applyThreshold(cmd, resultsWrapper, exportWrapper, scanResponseModel, thresholdMap, risksOverviewWrapper) + if err != nil { return err } @@ -1901,6 +1903,7 @@ func applyThreshold( exportWrapper wrappers.ExportWrapper, scanResponseModel *wrappers.ScanResponseModel, thresholdMap map[string]int, + risksOverviewWrapper wrappers.RisksOverviewWrapper, ) error { if len(thresholdMap) == 0 { return nil @@ -1912,7 +1915,8 @@ func applyThreshold( params[commonParams.SastRedundancyFlag] = "" } - summaryMap, err := getSummaryThresholdMap(resultsWrapper, exportWrapper, scanResponseModel, params) + summaryMap, err := getSummaryThresholdMap(resultsWrapper, exportWrapper, scanResponseModel, params, risksOverviewWrapper) + if err != nil { return err } @@ -1995,21 +1999,35 @@ func parseThresholdLimit(limit string) (engineName string, intLimit int, err err return engineName, intLimit, err } -func getSummaryThresholdMap(resultsWrapper wrappers.ResultsWrapper, exportWrapper wrappers.ExportWrapper, scan *wrappers.ScanResponseModel, params map[string]string) ( - map[string]int, - error, -) { +func getSummaryThresholdMap( + resultsWrapper wrappers.ResultsWrapper, + exportWrapper wrappers.ExportWrapper, + scan *wrappers.ScanResponseModel, + params map[string]string, + risksOverviewWrapper wrappers.RisksOverviewWrapper, +) (map[string]int, error) { + summaryMap := make(map[string]int) results, err := ReadResults(resultsWrapper, exportWrapper, scan, params) + if err != nil { return nil, err } - summaryMap := make(map[string]int) for _, result := range results.Results { if isExploitable(result.State) { key := strings.ToLower(fmt.Sprintf("%s-%s", strings.Replace(result.Type, commonParams.KicsType, commonParams.IacType, 1), result.Severity)) summaryMap[key]++ } } + + if slices.Contains(scan.Engines, commonParams.APISecType) { + apiSecRisks, err := getResultsForAPISecScanner(risksOverviewWrapper, scan.ID) + if err != nil { + return nil, err + } + summaryMap["api-security-high"] = apiSecRisks.Risks[1] + summaryMap["api-security-medium"] = apiSecRisks.Risks[2] + summaryMap["api-security-low"] = apiSecRisks.Risks[3] + } return summaryMap, nil } diff --git a/test/integration/scan_test.go b/test/integration/scan_test.go index 6d688e67c..bddd0e976 100644 --- a/test/integration/scan_test.go +++ b/test/integration/scan_test.go @@ -660,6 +660,20 @@ func TestScanCreateWithThreshold(t *testing.T) { assert.NilError(t, err, "") } +func TestScansAPISecThresholdShouldBlock(t *testing.T) { + createASTIntegrationTestCommand(t) + testArgs := []string{ + "scan", "create", + flag(params.ProjectName), "my-project", + flag(params.SourcesFlag), "data/sources.zip", + flag(params.BranchFlag), "dummy_branch", + flag(params.ScanInfoFormatFlag), printer.FormatJSON, + flag(params.ScanTypes), "sast, api-security", + flag(params.Threshold), "api-security-high=1", + } + _, _ = executeCommand(t, testArgs...) +} + // Create a scan with the sources // Assert the scan completes func TestScanCreateWithThresholdParseError(t *testing.T) {