From 2c79cd4879bb2c6f791ca10bb4453263ef39e144 Mon Sep 17 00:00:00 2001 From: Joshua Wang Date: Wed, 17 Jul 2024 16:20:06 -0700 Subject: [PATCH] file: do not attempt to record an artifact if it was not opened by the process --- attestation/file/file.go | 14 +++++++++----- attestation/file/file_test.go | 6 +++--- attestation/material/material.go | 2 +- attestation/product/product.go | 20 +++++++++----------- 4 files changed, 22 insertions(+), 20 deletions(-) diff --git a/attestation/file/file.go b/attestation/file/file.go index 36a37e4e..14065d6f 100644 --- a/attestation/file/file.go +++ b/attestation/file/file.go @@ -26,7 +26,7 @@ import ( // recordArtifacts will walk basePath and record the digests of each file with each of the functions in hashes. // If file already exists in baseArtifacts and the two artifacts are equal the artifact will not be in the // returned map of artifacts. -func RecordArtifacts(basePath string, baseArtifacts map[string]cryptoutil.DigestSet, hashes []cryptoutil.DigestValue, visitedSymlinks map[string]struct{}) (map[string]cryptoutil.DigestSet, error) { +func RecordArtifacts(basePath string, baseArtifacts map[string]cryptoutil.DigestSet, hashes []cryptoutil.DigestValue, visitedSymlinks map[string]struct{}, processWasTraced bool, openedFiles map[string]bool) (map[string]cryptoutil.DigestSet, error) { artifacts := make(map[string]cryptoutil.DigestSet) err := filepath.Walk(basePath, func(path string, info fs.FileInfo, err error) error { if err != nil { @@ -57,7 +57,7 @@ func RecordArtifacts(basePath string, baseArtifacts map[string]cryptoutil.Digest } visitedSymlinks[linkedPath] = struct{}{} - symlinkedArtifacts, err := RecordArtifacts(linkedPath, baseArtifacts, hashes, visitedSymlinks) + symlinkedArtifacts, err := RecordArtifacts(linkedPath, baseArtifacts, hashes, visitedSymlinks, processWasTraced, openedFiles) if err != nil { return err } @@ -65,7 +65,7 @@ func RecordArtifacts(basePath string, baseArtifacts map[string]cryptoutil.Digest for artifactPath, artifact := range symlinkedArtifacts { // all artifacts in the symlink should be recorded relative to our basepath joinedPath := filepath.Join(relPath, artifactPath) - if shouldRecord(joinedPath, artifact, baseArtifacts) { + if shouldRecord(joinedPath, artifact, baseArtifacts, processWasTraced, openedFiles) { artifacts[filepath.Join(relPath, artifactPath)] = artifact } } @@ -78,7 +78,7 @@ func RecordArtifacts(basePath string, baseArtifacts map[string]cryptoutil.Digest return err } - if shouldRecord(relPath, artifact, baseArtifacts) { + if shouldRecord(relPath, artifact, baseArtifacts, processWasTraced, openedFiles) { artifacts[relPath] = artifact } @@ -89,9 +89,13 @@ func RecordArtifacts(basePath string, baseArtifacts map[string]cryptoutil.Digest } // shouldRecord determines whether artifact should be recorded. +// if the process was traced and the artifact was not one of the opened files, return false // if the artifact is already in baseArtifacts, check if it's changed // if it is not equal to the existing artifact, return true, otherwise return false -func shouldRecord(path string, artifact cryptoutil.DigestSet, baseArtifacts map[string]cryptoutil.DigestSet) bool { +func shouldRecord(path string, artifact cryptoutil.DigestSet, baseArtifacts map[string]cryptoutil.DigestSet, processWasTraced bool, openedFiles map[string]bool) bool { + if _, ok := openedFiles[path]; !ok && processWasTraced { + return false + } if previous, ok := baseArtifacts[path]; ok && artifact.Equal(previous) { return false } diff --git a/attestation/file/file_test.go b/attestation/file/file_test.go index 73344bff..436e4a4b 100644 --- a/attestation/file/file_test.go +++ b/attestation/file/file_test.go @@ -38,13 +38,13 @@ func TestBrokenSymlink(t *testing.T) { symTestDir := filepath.Join(dir, "symTestDir") require.NoError(t, os.Symlink(testDir, symTestDir)) - _, err := RecordArtifacts(dir, map[string]cryptoutil.DigestSet{}, []cryptoutil.DigestValue{{Hash: crypto.SHA256}}, map[string]struct{}{}) + _, err := RecordArtifacts(dir, map[string]cryptoutil.DigestSet{}, []cryptoutil.DigestValue{{Hash: crypto.SHA256}}, map[string]struct{}{}, false, map[string]bool) require.NoError(t, err) // remove the symlinks and make sure we don't get an error back require.NoError(t, os.RemoveAll(testDir)) require.NoError(t, os.RemoveAll(testFile)) - _, err = RecordArtifacts(dir, map[string]cryptoutil.DigestSet{}, []cryptoutil.DigestValue{{Hash: crypto.SHA256}}, map[string]struct{}{}) + _, err = RecordArtifacts(dir, map[string]cryptoutil.DigestSet{}, []cryptoutil.DigestValue{{Hash: crypto.SHA256}}, map[string]struct{}{}, false, map[string]bool) require.NoError(t, err) } @@ -58,6 +58,6 @@ func TestSymlinkCycle(t *testing.T) { require.NoError(t, os.Symlink(dir, symTestDir)) // if a symlink cycle weren't properly handled this would be an infinite loop - _, err := RecordArtifacts(dir, map[string]cryptoutil.DigestSet{}, []cryptoutil.DigestValue{{Hash: crypto.SHA256}}, map[string]struct{}{}) + _, err := RecordArtifacts(dir, map[string]cryptoutil.DigestSet{}, []cryptoutil.DigestValue{{Hash: crypto.SHA256}}, map[string]struct{}{}, false, map[string]bool) require.NoError(t, err) } diff --git a/attestation/material/material.go b/attestation/material/material.go index 74f047c0..6b99a4e3 100644 --- a/attestation/material/material.go +++ b/attestation/material/material.go @@ -90,7 +90,7 @@ func (a *Attestor) Schema() *jsonschema.Schema { } func (a *Attestor) Attest(ctx *attestation.AttestationContext) error { - materials, err := file.RecordArtifacts(ctx.WorkingDir(), nil, ctx.Hashes(), map[string]struct{}{}) + materials, err := file.RecordArtifacts(ctx.WorkingDir(), nil, ctx.Hashes(), map[string]struct{}{}, false, map[string]bool{}) if err != nil { return err } diff --git a/attestation/product/product.go b/attestation/product/product.go index 4e370c0f..d70987e7 100644 --- a/attestation/product/product.go +++ b/attestation/product/product.go @@ -182,30 +182,28 @@ func (a *Attestor) Attest(ctx *attestation.AttestationContext) error { a.compiledExcludeGlob = compiledExcludeGlob a.baseArtifacts = ctx.Materials() - products, err := file.RecordArtifacts(ctx.WorkingDir(), a.baseArtifacts, ctx.Hashes(), map[string]struct{}{}) - if err != nil { - return err - } + + processWasTraced := false + openedFileSet := map[string]bool{} for _, completedAttestor := range ctx.CompletedAttestors() { attestor := completedAttestor.Attestor if commandRunAttestor, ok := attestor.(*commandrun.CommandRun); ok && commandRunAttestor.EnableTracing { - openedFileSet := map[string]bool{} + processWasTraced = true for _, process := range commandRunAttestor.Processes { for file := range process.OpenedFiles { openedFileSet[file] = true; } } - - for file := range products { - if _, ok := openedFileSet[file]; !ok { - delete(products, file) - } - } } } + products, err := file.RecordArtifacts(ctx.WorkingDir(), a.baseArtifacts, ctx.Hashes(), map[string]struct{}{}, processWasTraced, openedFileSet) + if err != nil { + return err + } + a.products = fromDigestMap(ctx.WorkingDir(), products) return nil }