Skip to content

Commit

Permalink
fix: only load applicable rules (#1650)
Browse files Browse the repository at this point in the history
  • Loading branch information
didroe authored Jul 9, 2024
1 parent 2eb2f94 commit fa94098
Show file tree
Hide file tree
Showing 8 changed files with 55 additions and 29 deletions.
25 changes: 12 additions & 13 deletions pkg/commands/artifact/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@ import (
"github.com/hhatto/gocloc"
"github.com/rs/zerolog/log"

"golang.org/x/exp/maps"

"github.com/bearer/bearer/pkg/commands/artifact/scanid"
"github.com/bearer/bearer/pkg/commands/process/filelist"
"github.com/bearer/bearer/pkg/commands/process/filelist/files"
Expand All @@ -33,6 +31,7 @@ import (
"github.com/bearer/bearer/pkg/util/ignore"
ignoretypes "github.com/bearer/bearer/pkg/util/ignore/types"
outputhandler "github.com/bearer/bearer/pkg/util/output"
"github.com/bearer/bearer/pkg/util/set"
"github.com/bearer/bearer/pkg/version_check"

"github.com/bearer/bearer/pkg/types"
Expand Down Expand Up @@ -258,10 +257,10 @@ func Run(ctx context.Context, opts flagtypes.Options, engine engine.Engine) (err
log.Debug().Msgf("Error in line of code output %s", err)
return err
}
languageList := FormatFoundLanguages(inputgocloc.Languages)
foundLanguageIDs := GetFoundLanguageIDs(engine, inputgocloc.Languages)

// set used language list for external rules to empty if we dont use them
metaLanguageList := languageList
metaLanguageList := foundLanguageIDs
if opts.RuleOptions.DisableDefaultRules {
metaLanguageList = make([]string, 0)
}
Expand Down Expand Up @@ -290,7 +289,7 @@ func Run(ctx context.Context, opts flagtypes.Options, engine engine.Engine) (err
return fmt.Errorf("failed to initialize engine: %w", err)
}

scanSettings, err := settingsloader.FromOptions(opts, versionMeta, engine)
scanSettings, err := settingsloader.FromOptions(opts, versionMeta, engine, foundLanguageIDs)
scanSettings.Target = opts.Target
if err != nil {
return err
Expand Down Expand Up @@ -476,18 +475,18 @@ func getPlaceholderOutput(reportData *outputtypes.ReportData, report types.Repor
return stats.GetPlaceholderOutput(reportData, inputgocloc, config)
}

func FormatFoundLanguages(languages map[string]*gocloc.Language) (foundLanguages []string) {
var foundLanguagesMap = make(map[string]bool, len(languages))
func GetFoundLanguageIDs(engine engine.Engine, goclocLanguages map[string]*gocloc.Language) []string {
foundLanguages := set.New[string]()

for _, language := range languages {
if language.Name == "TypeScript" {
foundLanguagesMap["javascript"] = true
} else {
foundLanguagesMap[strings.ToLower(language.Name)] = true
for _, goclocLanguage := range goclocLanguages {
for _, language := range engine.GetLanguages() {
if slices.Contains(language.GoclocLanguages(), goclocLanguage.Name) {
foundLanguages.Add(language.ID())
}
}
}

keys := maps.Keys(foundLanguagesMap)
keys := foundLanguages.Items()
sort.Strings(keys)

return keys
Expand Down
21 changes: 15 additions & 6 deletions pkg/commands/artifact/run_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,16 @@ package artifact
import (
"testing"

engineimpl "github.com/bearer/bearer/pkg/engine/implementation"
"github.com/bearer/bearer/pkg/languages"

"github.com/hhatto/gocloc"
"github.com/stretchr/testify/assert"
)

func TestFormatLanguagesWithJavascriptAndTypescript(t *testing.T) {
func TestGetFoundLanguageIDsWithJavascriptAndTypescript(t *testing.T) {
engine := engineimpl.New(languages.Default())

dummyGoclocLanguage := gocloc.Language{}
dummyGoclocResult := gocloc.Result{
Total: &dummyGoclocLanguage,
Expand All @@ -29,11 +34,13 @@ func TestFormatLanguagesWithJavascriptAndTypescript(t *testing.T) {
assert.Equal(
t,
[]string{"javascript", "ruby"},
FormatFoundLanguages(dummyGoclocResult.Languages),
GetFoundLanguageIDs(engine, dummyGoclocResult.Languages),
)
}

func TestFormatLanguagesWithoutJavascript(t *testing.T) {
func TestGetFoundLanguageIDsWithoutJavascript(t *testing.T) {
engine := engineimpl.New(languages.Default())

dummyGoclocLanguage := gocloc.Language{}
dummyGoclocResult := gocloc.Result{
Total: &dummyGoclocLanguage,
Expand All @@ -52,11 +59,13 @@ func TestFormatLanguagesWithoutJavascript(t *testing.T) {
assert.Equal(
t,
[]string{"javascript", "ruby"},
FormatFoundLanguages(dummyGoclocResult.Languages),
GetFoundLanguageIDs(engine, dummyGoclocResult.Languages),
)
}

func TestFormatLanguagesWithJavascriptFirst(t *testing.T) {
func TestGetFoundLanguageIDsWithJavascriptFirst(t *testing.T) {
engine := engineimpl.New(languages.Default())

dummyGoclocLanguage := gocloc.Language{}
dummyGoclocResult := gocloc.Result{
Total: &dummyGoclocLanguage,
Expand All @@ -78,6 +87,6 @@ func TestFormatLanguagesWithJavascriptFirst(t *testing.T) {
assert.Equal(
t,
[]string{"javascript", "ruby"},
FormatFoundLanguages(dummyGoclocResult.Languages),
GetFoundLanguageIDs(engine, dummyGoclocResult.Languages),
)
}
2 changes: 2 additions & 0 deletions pkg/commands/process/settings/loader/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ func FromOptions(
opts flagtypes.Options,
versionMeta *version_check.VersionMeta,
engine engine.Engine,
foundLanguageIDs []string,
) (settings.Config, error) {
policies, err := policies.Load()
if err != nil {
Expand All @@ -31,6 +32,7 @@ func FromOptions(
versionMeta,
engine,
opts.ScanOptions.Force,
foundLanguageIDs,
)
if err != nil {
return settings.Config{}, err
Expand Down
14 changes: 10 additions & 4 deletions pkg/commands/process/settings/rules/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"net/http"
"os"
"path/filepath"
"slices"
"strings"
"time"

Expand Down Expand Up @@ -159,7 +160,13 @@ func readRuleDefinitionZip(ruleDefinitions map[string]settings.RuleDefinition, f
return nil
}

func loadCustomDefinitions(engine engine.Engine, definitions map[string]settings.RuleDefinition, isBuiltIn bool, dir fs.FS) error {
func loadCustomDefinitions(
engine engine.Engine,
definitions map[string]settings.RuleDefinition,
isBuiltIn bool,
dir fs.FS,
languageIDs []string,
) error {
loadedDefinitions := make(map[string]settings.RuleDefinition)
if err := fs.WalkDir(dir, ".", func(path string, dirEntry fs.DirEntry, err error) error {
if err != nil {
Expand All @@ -186,7 +193,7 @@ func loadCustomDefinitions(engine engine.Engine, definitions map[string]settings
err = yaml.Unmarshal(entry, &ruleDefinition)
if err != nil {
output.StdErrLog(validateCustomRuleSchema(entry, filename))
return fmt.Errorf("rule file was invalid")
return fmt.Errorf("rule file was invalid: %w", err)
}

if ruleDefinition.Metadata == nil {
Expand All @@ -202,8 +209,7 @@ func loadCustomDefinitions(engine engine.Engine, definitions map[string]settings

supported := isBuiltIn
for _, languageID := range ruleDefinition.Languages {
language := engine.GetLanguageById(languageID)
if language != nil {
if slices.Contains(languageIDs, languageID) {
supported = true
}
}
Expand Down
5 changes: 3 additions & 2 deletions pkg/commands/process/settings/rules/rules.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ func Load(
versionMeta *version_check.VersionMeta,
engine engine.Engine,
force bool,
foundLanguageIDs []string,
) (
result LoadRulesResult,
err error,
Expand All @@ -56,7 +57,7 @@ func Load(
return result, fmt.Errorf("error loading remote rules: %w", err)
}

if err := loadCustomDefinitions(engine, builtInDefinitions, true, builtInRulesFS); err != nil {
if err := loadCustomDefinitions(engine, builtInDefinitions, true, builtInRulesFS, nil); err != nil {
return result, fmt.Errorf("error loading built-in rules: %w", err)
}

Expand All @@ -66,7 +67,7 @@ func Load(
dir = filepath.Join(dirname, dir[2:])
}
log.Debug().Msgf("loading external rules from: %s", dir)
if err := loadCustomDefinitions(engine, definitions, false, os.DirFS(dir)); err != nil {
if err := loadCustomDefinitions(engine, definitions, false, os.DirFS(dir), foundLanguageIDs); err != nil {
return result, fmt.Errorf("external rules %w", err)
}
}
Expand Down
11 changes: 9 additions & 2 deletions pkg/languages/testhelper/testhelper.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"github.com/bearer/bearer/pkg/scanner/language"
"github.com/bearer/bearer/pkg/types"
util "github.com/bearer/bearer/pkg/util/output"
"github.com/bearer/bearer/pkg/util/set"
"github.com/bearer/bearer/pkg/version_check"
)

Expand Down Expand Up @@ -197,12 +198,18 @@ func buildConfig(t *testing.T, engine engine.Engine, ruleBytes []byte) settings.
},
}

config, err := settingsloader.FromOptions(configFlags, meta, engine)
rules := getRulesFromYaml(t, ruleBytes)
languageIDs := set.New[string]()
for _, rule := range rules {
languageIDs.AddAll(rule.Languages)
}

config, err := settingsloader.FromOptions(configFlags, meta, engine, languageIDs.Items())
if err != nil {
t.Fatalf("failed to generate default scan settings: %s", err)
}

config.Rules = getRulesFromYaml(t, ruleBytes)
config.Rules = rules

return config
}
Expand Down
3 changes: 2 additions & 1 deletion pkg/report/output/privacy/privacy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
engineimpl "github.com/bearer/bearer/pkg/engine/implementation"
flagtypes "github.com/bearer/bearer/pkg/flag/types"
"github.com/bearer/bearer/pkg/languages"
"github.com/bearer/bearer/pkg/languages/ruby"
"github.com/bearer/bearer/pkg/report/output/dataflow/types"
"github.com/bearer/bearer/pkg/report/output/privacy"
"github.com/bearer/bearer/pkg/report/output/testhelper"
Expand Down Expand Up @@ -82,7 +83,7 @@ func generateConfig(engine engine.Engine, reportOptions flagtypes.ReportOptions)
},
}

return settingsloader.FromOptions(opts, meta, engine)
return settingsloader.FromOptions(opts, meta, engine, []string{ruby.Get().ID()})
}

func dummyDataflow() *outputtypes.DataFlow {
Expand Down
3 changes: 2 additions & 1 deletion pkg/report/output/security/security_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
flagtypes "github.com/bearer/bearer/pkg/flag/types"
"github.com/bearer/bearer/pkg/git"
"github.com/bearer/bearer/pkg/languages"
"github.com/bearer/bearer/pkg/languages/ruby"
"github.com/bearer/bearer/pkg/report/basebranchfindings"
"github.com/bearer/bearer/pkg/report/schema"
globaltypes "github.com/bearer/bearer/pkg/types"
Expand Down Expand Up @@ -334,7 +335,7 @@ func generateConfig(engine engine.Engine, reportOptions flagtypes.ReportOptions)
},
}

return settingsloader.FromOptions(opts, meta, engine)
return settingsloader.FromOptions(opts, meta, engine, []string{ruby.Get().ID()})
}

func dummyDataflowData() *outputtypes.ReportData {
Expand Down

0 comments on commit fa94098

Please sign in to comment.