diff --git a/attestation/context.go b/attestation/context.go index bef8e7bc..07ba75d1 100644 --- a/attestation/context.go +++ b/attestation/context.go @@ -21,6 +21,7 @@ import ( "os" "time" + "github.com/gobwas/glob" "github.com/in-toto/go-witness/cryptoutil" "github.com/in-toto/go-witness/log" ) @@ -77,6 +78,20 @@ func WithWorkingDir(workingDir string) AttestationContextOption { } } +func WithDirHashGlob(dirHashGlob []string) AttestationContextOption { + return func(ctx *AttestationContext) { + if len(dirHashGlob) > 0 { + ctx.dirHashGlob = dirHashGlob + + ctx.dirHashGlobCompiled = make([]glob.Glob, len(ctx.dirHashGlob)) + for i, dirHashGlobItem := range dirHashGlob { + dirHashGlobItemCompiled, _ := glob.Compile(dirHashGlobItem) + ctx.dirHashGlobCompiled[i] = dirHashGlobItemCompiled + } + } + } +} + type CompletedAttestor struct { Attestor Attestor StartTime time.Time @@ -85,13 +100,15 @@ type CompletedAttestor struct { } type AttestationContext struct { - ctx context.Context - attestors []Attestor - workingDir string - hashes []cryptoutil.DigestValue - completedAttestors []CompletedAttestor - products map[string]Product - materials map[string]cryptoutil.DigestSet + ctx context.Context + attestors []Attestor + workingDir string + dirHashGlob []string + dirHashGlobCompiled []glob.Glob + hashes []cryptoutil.DigestValue + completedAttestors []CompletedAttestor + products map[string]Product + materials map[string]cryptoutil.DigestSet } type Product struct { @@ -173,6 +190,10 @@ func (ctx *AttestationContext) runAttestor(attestor Attestor) { } } +func (ctx *AttestationContext) DirHashGlob() []glob.Glob { + return ctx.dirHashGlobCompiled +} + func (ctx *AttestationContext) CompletedAttestors() []CompletedAttestor { out := make([]CompletedAttestor, len(ctx.completedAttestors)) copy(out, ctx.completedAttestors) diff --git a/attestation/file/file.go b/attestation/file/file.go index 36a37e4e..2646dad9 100644 --- a/attestation/file/file.go +++ b/attestation/file/file.go @@ -19,6 +19,7 @@ import ( "os" "path/filepath" + "github.com/gobwas/glob" "github.com/in-toto/go-witness/cryptoutil" "github.com/in-toto/go-witness/log" ) @@ -26,22 +27,40 @@ import ( // recordArtifacts will walk basePath and record the digests of each file with each of the functions in hashes. // If file already exists in baseArtifacts and the two artifacts are equal the artifact will not be in the // returned map of artifacts. -func RecordArtifacts(basePath string, baseArtifacts map[string]cryptoutil.DigestSet, hashes []cryptoutil.DigestValue, visitedSymlinks map[string]struct{}) (map[string]cryptoutil.DigestSet, error) { +func RecordArtifacts(basePath string, baseArtifacts map[string]cryptoutil.DigestSet, hashes []cryptoutil.DigestValue, visitedSymlinks map[string]struct{}, dirHashGlob []glob.Glob) (map[string]cryptoutil.DigestSet, error) { artifacts := make(map[string]cryptoutil.DigestSet) err := filepath.Walk(basePath, func(path string, info fs.FileInfo, err error) error { if err != nil { return err } - if info.IsDir() { - return nil - } - relPath, err := filepath.Rel(basePath, path) if err != nil { return err } + if info.IsDir() { + dirHashMatch := false + for _, globItem := range dirHashGlob { + if !dirHashMatch && globItem.Match(relPath) { + dirHashMatch = true + } + } + + if dirHashMatch { + dir, _ := cryptoutil.CalculateDigestSetFromDir(path, hashes) + + if err != nil { + return err + } + + artifacts[relPath + string(os.PathSeparator)] = dir + return filepath.SkipDir + } + + return nil + } + if info.Mode()&fs.ModeSymlink != 0 { // if this is a symlink, eval the true path and eval any artifacts in the symlink. we record every symlink we've visited to prevent infinite loops linkedPath, err := filepath.EvalSymlinks(path) @@ -57,7 +76,7 @@ func RecordArtifacts(basePath string, baseArtifacts map[string]cryptoutil.Digest } visitedSymlinks[linkedPath] = struct{}{} - symlinkedArtifacts, err := RecordArtifacts(linkedPath, baseArtifacts, hashes, visitedSymlinks) + symlinkedArtifacts, err := RecordArtifacts(linkedPath, baseArtifacts, hashes, visitedSymlinks, dirHashGlob) if err != nil { return err } diff --git a/attestation/material/material.go b/attestation/material/material.go index 458515a1..72886aa5 100644 --- a/attestation/material/material.go +++ b/attestation/material/material.go @@ -69,7 +69,7 @@ func New(opts ...Option) *Attestor { } func (a *Attestor) Attest(ctx *attestation.AttestationContext) error { - materials, err := file.RecordArtifacts(ctx.WorkingDir(), nil, ctx.Hashes(), map[string]struct{}{}) + materials, err := file.RecordArtifacts(ctx.WorkingDir(), nil, ctx.Hashes(), map[string]struct{}{}, ctx.DirHashGlob()) if err != nil { return err } diff --git a/attestation/product/product.go b/attestation/product/product.go index 1754d841..1373455a 100644 --- a/attestation/product/product.go +++ b/attestation/product/product.go @@ -107,11 +107,23 @@ func fromDigestMap(digestMap map[string]cryptoutil.DigestSet) map[string]attesta products := make(map[string]attestation.Product) for fileName, digestSet := range digestMap { mimeType := "unknown" + f, err := os.OpenFile(fileName, os.O_RDONLY, 0666) if err == nil { - mimeType, err = getFileContentType(f) + // This returns an *os.FileInfo type + fileInfo, err := f.Stat() if err != nil { - mimeType = "unknown" + // error handling + } + + // IsDir is short for fileInfo.Mode().IsDir() + if fileInfo.IsDir() { + mimeType = "text/directory" + } else { + mimeType, err = getFileContentType(f) + if err != nil { + mimeType = "unknown" + } } f.Close() } @@ -164,7 +176,7 @@ func (a *Attestor) Attest(ctx *attestation.AttestationContext) error { a.compiledExcludeGlob = compiledExcludeGlob a.baseArtifacts = ctx.Materials() - products, err := file.RecordArtifacts(ctx.WorkingDir(), a.baseArtifacts, ctx.Hashes(), map[string]struct{}{}) + products, err := file.RecordArtifacts(ctx.WorkingDir(), a.baseArtifacts, ctx.Hashes(), map[string]struct{}{}, ctx.DirHashGlob()) if err != nil { return err } @@ -202,7 +214,11 @@ func (a *Attestor) Subjects() map[string]cryptoutil.DigestSet { continue } - subjects[fmt.Sprintf("file:%v", productName)] = product.Digest + subjectType := "file" + if product.MimeType == "text/directory" { + subjectType = "dir" + } + subjects[fmt.Sprintf("%v:%v", subjectType, productName)] = product.Digest } return subjects diff --git a/cryptoutil/digestset.go b/cryptoutil/digestset.go index c75d57c0..a7b7a7a4 100644 --- a/cryptoutil/digestset.go +++ b/cryptoutil/digestset.go @@ -22,6 +22,8 @@ import ( "hash" "io" "os" + + "golang.org/x/mod/sumdb/dirhash" ) var ( @@ -203,6 +205,19 @@ func CalculateDigestSetFromFile(path string, hashes []DigestValue) (DigestSet, e return CalculateDigestSet(file, hashes) } +func CalculateDigestSetFromDir(dir string, hashes []DigestValue) (DigestSet, error) { + + dirHash, err := dirhash.HashDir(dir, "", DirhHashSha256) + if err != nil { + return nil, err + } + + digestSetByName := make(map[string]string) + digestSetByName["sha256"] = dirHash + + return NewDigestSet(digestSetByName) +} + func (ds DigestSet) MarshalJSON() ([]byte, error) { nameMap, err := ds.ToNameMap() if err != nil { diff --git a/cryptoutil/dirhash.go b/cryptoutil/dirhash.go new file mode 100644 index 00000000..f35e8c62 --- /dev/null +++ b/cryptoutil/dirhash.go @@ -0,0 +1,47 @@ +package cryptoutil + +import ( + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "io" + "sort" + "strings" +) + +// DirHashSha256 is the "h1:" directory hash function, using SHA-256. +// +// DirHashSha256 returns a SHA-256 hash of a summary +// prepared as if by the Unix command: +// +// sha256sum $(find . -type f | sort) | sha256sum +// +// More precisely, the hashed summary contains a single line for each file in the list, +// ordered by sort.Strings applied to the file names, where each line consists of +// the hexadecimal SHA-256 hash of the file content, +// two spaces (U+0020), the file name, and a newline (U+000A). +// +// File names with newlines (U+000A) are disallowed. +func DirhHashSha256(files []string, open func(string) (io.ReadCloser, error)) (string, error) { + h := sha256.New() + files = append([]string(nil), files...) + sort.Strings(files) + for _, file := range files { + if strings.Contains(file, "\n") { + return "", errors.New("dirhash: filenames with newlines are not supported") + } + r, err := open(file) + if err != nil { + return "", err + } + hf := sha256.New() + _, err = io.Copy(hf, r) + r.Close() + if err != nil { + return "", err + } + fmt.Fprintf(h, "%x %s\n", hf.Sum(nil), file) + } + return hex.EncodeToString(h.Sum(nil)), nil +} \ No newline at end of file