diff --git a/pkg/commands/artifact/run.go b/pkg/commands/artifact/run.go index 90a4d9656..bcaf1e043 100644 --- a/pkg/commands/artifact/run.go +++ b/pkg/commands/artifact/run.go @@ -13,8 +13,6 @@ import ( "github.com/hhatto/gocloc" "github.com/rs/zerolog/log" - "golang.org/x/exp/maps" - "github.com/bearer/bearer/pkg/commands/artifact/scanid" "github.com/bearer/bearer/pkg/commands/process/filelist" "github.com/bearer/bearer/pkg/commands/process/filelist/files" @@ -33,6 +31,7 @@ import ( "github.com/bearer/bearer/pkg/util/ignore" ignoretypes "github.com/bearer/bearer/pkg/util/ignore/types" outputhandler "github.com/bearer/bearer/pkg/util/output" + "github.com/bearer/bearer/pkg/util/set" "github.com/bearer/bearer/pkg/version_check" "github.com/bearer/bearer/pkg/types" @@ -258,10 +257,10 @@ func Run(ctx context.Context, opts flagtypes.Options, engine engine.Engine) (err log.Debug().Msgf("Error in line of code output %s", err) return err } - languageList := FormatFoundLanguages(inputgocloc.Languages) + foundLanguageIDs := GetFoundLanguageIDs(engine, inputgocloc.Languages) // set used language list for external rules to empty if we dont use them - metaLanguageList := languageList + metaLanguageList := foundLanguageIDs if opts.RuleOptions.DisableDefaultRules { metaLanguageList = make([]string, 0) } @@ -290,7 +289,7 @@ func Run(ctx context.Context, opts flagtypes.Options, engine engine.Engine) (err return fmt.Errorf("failed to initialize engine: %w", err) } - scanSettings, err := settingsloader.FromOptions(opts, versionMeta, engine) + scanSettings, err := settingsloader.FromOptions(opts, versionMeta, engine, foundLanguageIDs) scanSettings.Target = opts.Target if err != nil { return err @@ -476,18 +475,18 @@ func getPlaceholderOutput(reportData *outputtypes.ReportData, report types.Repor return stats.GetPlaceholderOutput(reportData, inputgocloc, config) } -func FormatFoundLanguages(languages map[string]*gocloc.Language) (foundLanguages []string) { - var foundLanguagesMap = make(map[string]bool, len(languages)) +func GetFoundLanguageIDs(engine engine.Engine, goclocLanguages map[string]*gocloc.Language) []string { + foundLanguages := set.New[string]() - for _, language := range languages { - if language.Name == "TypeScript" { - foundLanguagesMap["javascript"] = true - } else { - foundLanguagesMap[strings.ToLower(language.Name)] = true + for _, goclocLanguage := range goclocLanguages { + for _, language := range engine.GetLanguages() { + if slices.Contains(language.GoclocLanguages(), goclocLanguage.Name) { + foundLanguages.Add(language.ID()) + } } } - keys := maps.Keys(foundLanguagesMap) + keys := foundLanguages.Items() sort.Strings(keys) return keys diff --git a/pkg/commands/artifact/run_test.go b/pkg/commands/artifact/run_test.go index 4e9f343de..aec699b24 100644 --- a/pkg/commands/artifact/run_test.go +++ b/pkg/commands/artifact/run_test.go @@ -3,11 +3,16 @@ package artifact import ( "testing" + engineimpl "github.com/bearer/bearer/pkg/engine/implementation" + "github.com/bearer/bearer/pkg/languages" + "github.com/hhatto/gocloc" "github.com/stretchr/testify/assert" ) -func TestFormatLanguagesWithJavascriptAndTypescript(t *testing.T) { +func TestGetFoundLanguageIDsWithJavascriptAndTypescript(t *testing.T) { + engine := engineimpl.New(languages.Default()) + dummyGoclocLanguage := gocloc.Language{} dummyGoclocResult := gocloc.Result{ Total: &dummyGoclocLanguage, @@ -29,11 +34,13 @@ func TestFormatLanguagesWithJavascriptAndTypescript(t *testing.T) { assert.Equal( t, []string{"javascript", "ruby"}, - FormatFoundLanguages(dummyGoclocResult.Languages), + GetFoundLanguageIDs(engine, dummyGoclocResult.Languages), ) } -func TestFormatLanguagesWithoutJavascript(t *testing.T) { +func TestGetFoundLanguageIDsWithoutJavascript(t *testing.T) { + engine := engineimpl.New(languages.Default()) + dummyGoclocLanguage := gocloc.Language{} dummyGoclocResult := gocloc.Result{ Total: &dummyGoclocLanguage, @@ -52,11 +59,13 @@ func TestFormatLanguagesWithoutJavascript(t *testing.T) { assert.Equal( t, []string{"javascript", "ruby"}, - FormatFoundLanguages(dummyGoclocResult.Languages), + GetFoundLanguageIDs(engine, dummyGoclocResult.Languages), ) } -func TestFormatLanguagesWithJavascriptFirst(t *testing.T) { +func TestGetFoundLanguageIDsWithJavascriptFirst(t *testing.T) { + engine := engineimpl.New(languages.Default()) + dummyGoclocLanguage := gocloc.Language{} dummyGoclocResult := gocloc.Result{ Total: &dummyGoclocLanguage, @@ -78,6 +87,6 @@ func TestFormatLanguagesWithJavascriptFirst(t *testing.T) { assert.Equal( t, []string{"javascript", "ruby"}, - FormatFoundLanguages(dummyGoclocResult.Languages), + GetFoundLanguageIDs(engine, dummyGoclocResult.Languages), ) } diff --git a/pkg/commands/process/settings/loader/loader.go b/pkg/commands/process/settings/loader/loader.go index f7a27ba0e..6d23a316f 100644 --- a/pkg/commands/process/settings/loader/loader.go +++ b/pkg/commands/process/settings/loader/loader.go @@ -19,6 +19,7 @@ func FromOptions( opts flagtypes.Options, versionMeta *version_check.VersionMeta, engine engine.Engine, + foundLanguageIDs []string, ) (settings.Config, error) { policies, err := policies.Load() if err != nil { @@ -31,6 +32,7 @@ func FromOptions( versionMeta, engine, opts.ScanOptions.Force, + foundLanguageIDs, ) if err != nil { return settings.Config{}, err diff --git a/pkg/commands/process/settings/rules/loader.go b/pkg/commands/process/settings/rules/loader.go index 8aeb31192..26b1f1a15 100644 --- a/pkg/commands/process/settings/rules/loader.go +++ b/pkg/commands/process/settings/rules/loader.go @@ -11,6 +11,7 @@ import ( "net/http" "os" "path/filepath" + "slices" "strings" "time" @@ -159,7 +160,13 @@ func readRuleDefinitionZip(ruleDefinitions map[string]settings.RuleDefinition, f return nil } -func loadCustomDefinitions(engine engine.Engine, definitions map[string]settings.RuleDefinition, isBuiltIn bool, dir fs.FS) error { +func loadCustomDefinitions( + engine engine.Engine, + definitions map[string]settings.RuleDefinition, + isBuiltIn bool, + dir fs.FS, + languageIDs []string, +) error { loadedDefinitions := make(map[string]settings.RuleDefinition) if err := fs.WalkDir(dir, ".", func(path string, dirEntry fs.DirEntry, err error) error { if err != nil { @@ -186,7 +193,7 @@ func loadCustomDefinitions(engine engine.Engine, definitions map[string]settings err = yaml.Unmarshal(entry, &ruleDefinition) if err != nil { output.StdErrLog(validateCustomRuleSchema(entry, filename)) - return fmt.Errorf("rule file was invalid") + return fmt.Errorf("rule file was invalid: %w", err) } if ruleDefinition.Metadata == nil { @@ -202,8 +209,7 @@ func loadCustomDefinitions(engine engine.Engine, definitions map[string]settings supported := isBuiltIn for _, languageID := range ruleDefinition.Languages { - language := engine.GetLanguageById(languageID) - if language != nil { + if slices.Contains(languageIDs, languageID) { supported = true } } diff --git a/pkg/commands/process/settings/rules/rules.go b/pkg/commands/process/settings/rules/rules.go index 48b5e6b41..914ca4343 100644 --- a/pkg/commands/process/settings/rules/rules.go +++ b/pkg/commands/process/settings/rules/rules.go @@ -39,6 +39,7 @@ func Load( versionMeta *version_check.VersionMeta, engine engine.Engine, force bool, + foundLanguageIDs []string, ) ( result LoadRulesResult, err error, @@ -56,7 +57,7 @@ func Load( return result, fmt.Errorf("error loading remote rules: %w", err) } - if err := loadCustomDefinitions(engine, builtInDefinitions, true, builtInRulesFS); err != nil { + if err := loadCustomDefinitions(engine, builtInDefinitions, true, builtInRulesFS, nil); err != nil { return result, fmt.Errorf("error loading built-in rules: %w", err) } @@ -66,7 +67,7 @@ func Load( dir = filepath.Join(dirname, dir[2:]) } log.Debug().Msgf("loading external rules from: %s", dir) - if err := loadCustomDefinitions(engine, definitions, false, os.DirFS(dir)); err != nil { + if err := loadCustomDefinitions(engine, definitions, false, os.DirFS(dir), foundLanguageIDs); err != nil { return result, fmt.Errorf("external rules %w", err) } } diff --git a/pkg/languages/testhelper/testhelper.go b/pkg/languages/testhelper/testhelper.go index ee67e6eda..5c3a30472 100644 --- a/pkg/languages/testhelper/testhelper.go +++ b/pkg/languages/testhelper/testhelper.go @@ -32,6 +32,7 @@ import ( "github.com/bearer/bearer/pkg/scanner/language" "github.com/bearer/bearer/pkg/types" util "github.com/bearer/bearer/pkg/util/output" + "github.com/bearer/bearer/pkg/util/set" "github.com/bearer/bearer/pkg/version_check" ) @@ -197,12 +198,18 @@ func buildConfig(t *testing.T, engine engine.Engine, ruleBytes []byte) settings. }, } - config, err := settingsloader.FromOptions(configFlags, meta, engine) + rules := getRulesFromYaml(t, ruleBytes) + languageIDs := set.New[string]() + for _, rule := range rules { + languageIDs.AddAll(rule.Languages) + } + + config, err := settingsloader.FromOptions(configFlags, meta, engine, languageIDs.Items()) if err != nil { t.Fatalf("failed to generate default scan settings: %s", err) } - config.Rules = getRulesFromYaml(t, ruleBytes) + config.Rules = rules return config } diff --git a/pkg/report/output/privacy/privacy_test.go b/pkg/report/output/privacy/privacy_test.go index 88f515512..ab3245623 100644 --- a/pkg/report/output/privacy/privacy_test.go +++ b/pkg/report/output/privacy/privacy_test.go @@ -11,6 +11,7 @@ import ( engineimpl "github.com/bearer/bearer/pkg/engine/implementation" flagtypes "github.com/bearer/bearer/pkg/flag/types" "github.com/bearer/bearer/pkg/languages" + "github.com/bearer/bearer/pkg/languages/ruby" "github.com/bearer/bearer/pkg/report/output/dataflow/types" "github.com/bearer/bearer/pkg/report/output/privacy" "github.com/bearer/bearer/pkg/report/output/testhelper" @@ -82,7 +83,7 @@ func generateConfig(engine engine.Engine, reportOptions flagtypes.ReportOptions) }, } - return settingsloader.FromOptions(opts, meta, engine) + return settingsloader.FromOptions(opts, meta, engine, []string{ruby.Get().ID()}) } func dummyDataflow() *outputtypes.DataFlow { diff --git a/pkg/report/output/security/security_test.go b/pkg/report/output/security/security_test.go index 1125e6fb7..967fd0d4e 100644 --- a/pkg/report/output/security/security_test.go +++ b/pkg/report/output/security/security_test.go @@ -15,6 +15,7 @@ import ( flagtypes "github.com/bearer/bearer/pkg/flag/types" "github.com/bearer/bearer/pkg/git" "github.com/bearer/bearer/pkg/languages" + "github.com/bearer/bearer/pkg/languages/ruby" "github.com/bearer/bearer/pkg/report/basebranchfindings" "github.com/bearer/bearer/pkg/report/schema" globaltypes "github.com/bearer/bearer/pkg/types" @@ -334,7 +335,7 @@ func generateConfig(engine engine.Engine, reportOptions flagtypes.ReportOptions) }, } - return settingsloader.FromOptions(opts, meta, engine) + return settingsloader.FromOptions(opts, meta, engine, []string{ruby.Get().ID()}) } func dummyDataflowData() *outputtypes.ReportData {