From 3f491a3c5300621a95d231a7f4892e8e1d6e53f9 Mon Sep 17 00:00:00 2001 From: Joshua Wang Date: Tue, 30 Jul 2024 18:37:53 -0700 Subject: [PATCH] better concurrency support on Linux (#306) * commandrun: keep track of newly created files * product: only attest for opened files when tracing is enabled * file: do not attempt to record an artifact if it was not opened by the process --------- Signed-off-by: Joshua Wang --- attestation/commandrun/commandrun.go | 4 ++++ attestation/commandrun/tracing_linux.go | 26 +++++++++++++++++++++++++ attestation/file/file.go | 14 ++++++++----- attestation/file/file_test.go | 6 +++--- attestation/material/material.go | 2 +- attestation/product/product.go | 20 ++++++++++++++++++- 6 files changed, 62 insertions(+), 10 deletions(-) diff --git a/attestation/commandrun/commandrun.go b/attestation/commandrun/commandrun.go index 5fc2007e..431d9ca7 100644 --- a/attestation/commandrun/commandrun.go +++ b/attestation/commandrun/commandrun.go @@ -160,6 +160,10 @@ func (rc *CommandRun) RunType() attestation.RunType { return RunType } +func (rc *CommandRun) TracingEnabled() bool { + return rc.enableTracing +} + func (r *CommandRun) runCmd(ctx *attestation.AttestationContext) error { c := exec.Command(r.Cmd[0], r.Cmd[1:]...) c.Dir = ctx.WorkingDir() diff --git a/attestation/commandrun/tracing_linux.go b/attestation/commandrun/tracing_linux.go index cd473c47..90eb9c24 100644 --- a/attestation/commandrun/tracing_linux.go +++ b/attestation/commandrun/tracing_linux.go @@ -74,6 +74,8 @@ func (r *CommandRun) trace(c *exec.Cmd, actx *attestation.AttestationContext) ([ } func (p *ptraceContext) runTrace() error { + defer p.retryOpenedFiles() + runtime.LockOSThread() defer runtime.UnlockOSThread() status := unix.WaitStatus(0) @@ -121,6 +123,26 @@ func (p *ptraceContext) runTrace() error { } } +func (p *ptraceContext) retryOpenedFiles() { + // after tracing, look through opened files to try to resolve any newly created files + procInfo := p.getProcInfo(p.parentPid) + + for file, digestSet := range procInfo.OpenedFiles { + if digestSet != nil { + continue + } + + newDigest, err := cryptoutil.CalculateDigestSetFromFile(file, p.hash) + + if err != nil { + delete(procInfo.OpenedFiles, file) + continue + } + + procInfo.OpenedFiles[file] = newDigest + } +} + func (p *ptraceContext) nextSyscall(pid int) error { regs := unix.PtraceRegs{} if err := unix.PtraceGetRegs(pid, ®s); err != nil { @@ -213,6 +235,10 @@ func (p *ptraceContext) handleSyscall(pid int, regs unix.PtraceRegs) error { procInfo := p.getProcInfo(pid) digestSet, err := cryptoutil.CalculateDigestSetFromFile(file, p.hash) if err != nil { + if _, isPathErr := err.(*os.PathError); isPathErr { + procInfo.OpenedFiles[file] = nil + } + return err } diff --git a/attestation/file/file.go b/attestation/file/file.go index 36a37e4e..14065d6f 100644 --- a/attestation/file/file.go +++ b/attestation/file/file.go @@ -26,7 +26,7 @@ import ( // recordArtifacts will walk basePath and record the digests of each file with each of the functions in hashes. // If file already exists in baseArtifacts and the two artifacts are equal the artifact will not be in the // returned map of artifacts. -func RecordArtifacts(basePath string, baseArtifacts map[string]cryptoutil.DigestSet, hashes []cryptoutil.DigestValue, visitedSymlinks map[string]struct{}) (map[string]cryptoutil.DigestSet, error) { +func RecordArtifacts(basePath string, baseArtifacts map[string]cryptoutil.DigestSet, hashes []cryptoutil.DigestValue, visitedSymlinks map[string]struct{}, processWasTraced bool, openedFiles map[string]bool) (map[string]cryptoutil.DigestSet, error) { artifacts := make(map[string]cryptoutil.DigestSet) err := filepath.Walk(basePath, func(path string, info fs.FileInfo, err error) error { if err != nil { @@ -57,7 +57,7 @@ func RecordArtifacts(basePath string, baseArtifacts map[string]cryptoutil.Digest } visitedSymlinks[linkedPath] = struct{}{} - symlinkedArtifacts, err := RecordArtifacts(linkedPath, baseArtifacts, hashes, visitedSymlinks) + symlinkedArtifacts, err := RecordArtifacts(linkedPath, baseArtifacts, hashes, visitedSymlinks, processWasTraced, openedFiles) if err != nil { return err } @@ -65,7 +65,7 @@ func RecordArtifacts(basePath string, baseArtifacts map[string]cryptoutil.Digest for artifactPath, artifact := range symlinkedArtifacts { // all artifacts in the symlink should be recorded relative to our basepath joinedPath := filepath.Join(relPath, artifactPath) - if shouldRecord(joinedPath, artifact, baseArtifacts) { + if shouldRecord(joinedPath, artifact, baseArtifacts, processWasTraced, openedFiles) { artifacts[filepath.Join(relPath, artifactPath)] = artifact } } @@ -78,7 +78,7 @@ func RecordArtifacts(basePath string, baseArtifacts map[string]cryptoutil.Digest return err } - if shouldRecord(relPath, artifact, baseArtifacts) { + if shouldRecord(relPath, artifact, baseArtifacts, processWasTraced, openedFiles) { artifacts[relPath] = artifact } @@ -89,9 +89,13 @@ func RecordArtifacts(basePath string, baseArtifacts map[string]cryptoutil.Digest } // shouldRecord determines whether artifact should be recorded. +// if the process was traced and the artifact was not one of the opened files, return false // if the artifact is already in baseArtifacts, check if it's changed // if it is not equal to the existing artifact, return true, otherwise return false -func shouldRecord(path string, artifact cryptoutil.DigestSet, baseArtifacts map[string]cryptoutil.DigestSet) bool { +func shouldRecord(path string, artifact cryptoutil.DigestSet, baseArtifacts map[string]cryptoutil.DigestSet, processWasTraced bool, openedFiles map[string]bool) bool { + if _, ok := openedFiles[path]; !ok && processWasTraced { + return false + } if previous, ok := baseArtifacts[path]; ok && artifact.Equal(previous) { return false } diff --git a/attestation/file/file_test.go b/attestation/file/file_test.go index 73344bff..5379a487 100644 --- a/attestation/file/file_test.go +++ b/attestation/file/file_test.go @@ -38,13 +38,13 @@ func TestBrokenSymlink(t *testing.T) { symTestDir := filepath.Join(dir, "symTestDir") require.NoError(t, os.Symlink(testDir, symTestDir)) - _, err := RecordArtifacts(dir, map[string]cryptoutil.DigestSet{}, []cryptoutil.DigestValue{{Hash: crypto.SHA256}}, map[string]struct{}{}) + _, err := RecordArtifacts(dir, map[string]cryptoutil.DigestSet{}, []cryptoutil.DigestValue{{Hash: crypto.SHA256}}, map[string]struct{}{}, false, map[string]bool{}) require.NoError(t, err) // remove the symlinks and make sure we don't get an error back require.NoError(t, os.RemoveAll(testDir)) require.NoError(t, os.RemoveAll(testFile)) - _, err = RecordArtifacts(dir, map[string]cryptoutil.DigestSet{}, []cryptoutil.DigestValue{{Hash: crypto.SHA256}}, map[string]struct{}{}) + _, err = RecordArtifacts(dir, map[string]cryptoutil.DigestSet{}, []cryptoutil.DigestValue{{Hash: crypto.SHA256}}, map[string]struct{}{}, false, map[string]bool{}) require.NoError(t, err) } @@ -58,6 +58,6 @@ func TestSymlinkCycle(t *testing.T) { require.NoError(t, os.Symlink(dir, symTestDir)) // if a symlink cycle weren't properly handled this would be an infinite loop - _, err := RecordArtifacts(dir, map[string]cryptoutil.DigestSet{}, []cryptoutil.DigestValue{{Hash: crypto.SHA256}}, map[string]struct{}{}) + _, err := RecordArtifacts(dir, map[string]cryptoutil.DigestSet{}, []cryptoutil.DigestValue{{Hash: crypto.SHA256}}, map[string]struct{}{}, false, map[string]bool{}) require.NoError(t, err) } diff --git a/attestation/material/material.go b/attestation/material/material.go index 74f047c0..6b99a4e3 100644 --- a/attestation/material/material.go +++ b/attestation/material/material.go @@ -90,7 +90,7 @@ func (a *Attestor) Schema() *jsonschema.Schema { } func (a *Attestor) Attest(ctx *attestation.AttestationContext) error { - materials, err := file.RecordArtifacts(ctx.WorkingDir(), nil, ctx.Hashes(), map[string]struct{}{}) + materials, err := file.RecordArtifacts(ctx.WorkingDir(), nil, ctx.Hashes(), map[string]struct{}{}, false, map[string]bool{}) if err != nil { return err } diff --git a/attestation/product/product.go b/attestation/product/product.go index 9e834f32..8c9d6c34 100644 --- a/attestation/product/product.go +++ b/attestation/product/product.go @@ -23,6 +23,7 @@ import ( "github.com/gabriel-vasile/mimetype" "github.com/gobwas/glob" "github.com/in-toto/go-witness/attestation" + "github.com/in-toto/go-witness/attestation/commandrun" "github.com/in-toto/go-witness/attestation/file" "github.com/in-toto/go-witness/cryptoutil" "github.com/in-toto/go-witness/registry" @@ -181,7 +182,24 @@ func (a *Attestor) Attest(ctx *attestation.AttestationContext) error { a.compiledExcludeGlob = compiledExcludeGlob a.baseArtifacts = ctx.Materials() - products, err := file.RecordArtifacts(ctx.WorkingDir(), a.baseArtifacts, ctx.Hashes(), map[string]struct{}{}) + + processWasTraced := false + openedFileSet := map[string]bool{} + + for _, completedAttestor := range ctx.CompletedAttestors() { + attestor := completedAttestor.Attestor + if commandRunAttestor, ok := attestor.(*commandrun.CommandRun); ok && commandRunAttestor.TracingEnabled() { + processWasTraced = true + + for _, process := range commandRunAttestor.Processes { + for fname := range process.OpenedFiles { + openedFileSet[fname] = true + } + } + } + } + + products, err := file.RecordArtifacts(ctx.WorkingDir(), a.baseArtifacts, ctx.Hashes(), map[string]struct{}{}, processWasTraced, openedFileSet) if err != nil { return err }