diff --git a/cli/cmd/generate/catalog.go b/cli/cmd/generate/catalog.go index 943ae810..1d77cda6 100644 --- a/cli/cmd/generate/catalog.go +++ b/cli/cmd/generate/catalog.go @@ -5,11 +5,11 @@ import ( "encoding/xml" "fmt" "io/ioutil" - "os" "github.com/sirupsen/logrus" "github.com/docker/oscalkit/generator" + "github.com/docker/oscalkit/pkg/oscal_source" "github.com/urfave/cli" ) @@ -45,23 +45,13 @@ var Catalog = cli.Command{ return nil }, Action: func(c *cli.Context) error { - - profilePath, err := generator.GetAbsolutePath(profilePath) - if err != nil { - return cli.NewExitError(fmt.Sprintf("cannot get absolute path, err: %v", err), 1) - } - - _, err = os.Stat(profilePath) - if err != nil { - return cli.NewExitError(fmt.Sprintf("cannot fetch file, err %v", err), 1) - } - f, err := os.Open(profilePath) + os, err := oscal_source.Open(profilePath) if err != nil { return cli.NewExitError(err, 1) } - defer f.Close() + defer os.Close() - profile, err := generator.ReadProfile(f) + profile, err := generator.ReadProfile(os.OSCAL()) if err != nil { return cli.NewExitError(err, 1) } diff --git a/cli/cmd/generate/code.go b/cli/cmd/generate/code.go index a9e0cb09..1bdb6826 100644 --- a/cli/cmd/generate/code.go +++ b/cli/cmd/generate/code.go @@ -8,6 +8,7 @@ import ( "regexp" "github.com/docker/oscalkit/generator" + "github.com/docker/oscalkit/pkg/oscal_source" "github.com/docker/oscalkit/templates" "github.com/docker/oscalkit/types/oscal/catalog" "github.com/sirupsen/logrus" @@ -53,22 +54,13 @@ var Code = cli.Command{ return cli.NewExitError(err, 1) } - profilePath, err := generator.GetAbsolutePath(profilePath) - if err != nil { - return cli.NewExitError(fmt.Sprintf("cannot get absolute path, err: %v", err), 1) - } - - _, err = os.Stat(profilePath) - if err != nil { - return cli.NewExitError(fmt.Sprintf("cannot fetch file, err %v", err), 1) - } - f, err := os.Open(profilePath) + osource, err := oscal_source.Open(profilePath) if err != nil { return cli.NewExitError(err, 1) } - defer f.Close() + defer osource.Close() - profile, err := generator.ReadProfile(f) + profile, err := generator.ReadProfile(osource.OSCAL()) if err != nil { return cli.NewExitError(err, 1) } diff --git a/cli/cmd/info.go b/cli/cmd/info.go index 88909f96..a008bafa 100644 --- a/cli/cmd/info.go +++ b/cli/cmd/info.go @@ -2,10 +2,8 @@ package cmd import ( "fmt" - "os" - "github.com/docker/oscalkit/generator" - "github.com/docker/oscalkit/types/oscal" + "github.com/docker/oscalkit/pkg/oscal_source" "github.com/docker/oscalkit/types/oscal/catalog" "github.com/urfave/cli" ) @@ -17,25 +15,13 @@ var Info = cli.Command{ ArgsUsage: "[file]", Action: func(c *cli.Context) error { for _, filePath := range c.Args() { - profilePath, err := generator.GetAbsolutePath(filePath) + os, err := oscal_source.Open(filePath) if err != nil { - return cli.NewExitError(fmt.Sprintf("cannot get absolute path, err: %v", err), 1) + return cli.NewExitError(fmt.Sprintf("Could not open oscal file: %v", err), 1) } + defer os.Close() - _, err = os.Stat(profilePath) - if err != nil { - return cli.NewExitError(fmt.Sprintf("cannot fetch file, err %v", err), 1) - } - f, err := os.Open(profilePath) - if err != nil { - return cli.NewExitError(err, 1) - } - defer f.Close() - - o, err := oscal.New(f) - if err != nil { - return cli.NewExitError(err, 1) - } + o := os.OSCAL() if o.Profile != nil { fmt.Println("OSCAL Profile (represents subset of controls from OSCAL catalog(s))") fmt.Println("ID:\t", o.Profile.Id) diff --git a/generator/reader.go b/generator/reader.go index c539b816..5f7a5fb9 100644 --- a/generator/reader.go +++ b/generator/reader.go @@ -37,13 +37,8 @@ func ReadCatalog(r io.Reader) (*catalog.Catalog, error) { } -// ReadProfile reads profile from byte array -func ReadProfile(r io.Reader) (*profile.Profile, error) { - - o, err := oscal.New(r) - if err != nil { - return nil, fmt.Errorf("cannot read oscal profile from file. err: %v,", err) - } +// ReadProfile reads profile from OSCAL +func ReadProfile(o *oscal.OSCAL) (*profile.Profile, error) { if o.Profile == nil { return nil, fmt.Errorf("unable to marshall profile") } diff --git a/pkg/oscal_source/oscal_source.go b/pkg/oscal_source/oscal_source.go new file mode 100644 index 00000000..e7a45ae8 --- /dev/null +++ b/pkg/oscal_source/oscal_source.go @@ -0,0 +1,49 @@ +package oscal_source + +import ( + "fmt" + "github.com/docker/oscalkit/types/oscal" + "os" + "path/filepath" +) + +// OSCALSource is intermediary that handles IO and low-level common operations consistently for oscalkit +type OSCALSource struct { + UserPath string + file *os.File + oscal *oscal.OSCAL +} + +// Open creates new OSCALSource and load it up +func Open(path string) (*OSCALSource, error) { + result := OSCALSource{UserPath: path} + return &result, result.open() +} + +func (s *OSCALSource) open() error { + var err error + path := s.UserPath + if !filepath.IsAbs(path) { + if path, err = filepath.Abs(path); err != nil { + return fmt.Errorf("Cannot get absolute path: %v", err) + } + } + if _, err = os.Stat(path); err != nil { + return fmt.Errorf("Cannot stat %s, %v", path, err) + } + if s.file, err = os.Open(path); err != nil { + return fmt.Errorf("Cannot open file %s: %v", path, err) + } + if s.oscal, err = oscal.New(s.file); err != nil { + return fmt.Errorf("Cannot parse file: %v", err) + } + return nil +} + +func (s *OSCALSource) OSCAL() *oscal.OSCAL { + return s.oscal +} + +// Close the OSCALSource +func (s *OSCALSource) Close() { +}