Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

better concurrency support on Linux #306

Merged
merged 6 commits into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions attestation/commandrun/commandrun.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,10 @@ func (rc *CommandRun) RunType() attestation.RunType {
return RunType
}

func (rc *CommandRun) TracingEnabled() bool {
return rc.enableTracing
}

func (r *CommandRun) runCmd(ctx *attestation.AttestationContext) error {
c := exec.Command(r.Cmd[0], r.Cmd[1:]...)
c.Dir = ctx.WorkingDir()
Expand Down
26 changes: 26 additions & 0 deletions attestation/commandrun/tracing_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ func (r *CommandRun) trace(c *exec.Cmd, actx *attestation.AttestationContext) ([
}

func (p *ptraceContext) runTrace() error {
defer p.retryOpenedFiles()

runtime.LockOSThread()
defer runtime.UnlockOSThread()
status := unix.WaitStatus(0)
Expand Down Expand Up @@ -121,6 +123,26 @@ func (p *ptraceContext) runTrace() error {
}
}

func (p *ptraceContext) retryOpenedFiles() {
// after tracing, look through opened files to try to resolve any newly created files
procInfo := p.getProcInfo(p.parentPid)

for file, digestSet := range procInfo.OpenedFiles {
if digestSet != nil {
continue
}

newDigest, err := cryptoutil.CalculateDigestSetFromFile(file, p.hash)

if err != nil {
delete(procInfo.OpenedFiles, file)
continue
}

procInfo.OpenedFiles[file] = newDigest
}
}

func (p *ptraceContext) nextSyscall(pid int) error {
regs := unix.PtraceRegs{}
if err := unix.PtraceGetRegs(pid, &regs); err != nil {
Expand Down Expand Up @@ -213,6 +235,10 @@ func (p *ptraceContext) handleSyscall(pid int, regs unix.PtraceRegs) error {
procInfo := p.getProcInfo(pid)
digestSet, err := cryptoutil.CalculateDigestSetFromFile(file, p.hash)
if err != nil {
if _, isPathErr := err.(*os.PathError); isPathErr {
procInfo.OpenedFiles[file] = nil
}

return err
}

Expand Down
14 changes: 9 additions & 5 deletions attestation/file/file.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import (
// recordArtifacts will walk basePath and record the digests of each file with each of the functions in hashes.
// If file already exists in baseArtifacts and the two artifacts are equal the artifact will not be in the
// returned map of artifacts.
func RecordArtifacts(basePath string, baseArtifacts map[string]cryptoutil.DigestSet, hashes []cryptoutil.DigestValue, visitedSymlinks map[string]struct{}) (map[string]cryptoutil.DigestSet, error) {
func RecordArtifacts(basePath string, baseArtifacts map[string]cryptoutil.DigestSet, hashes []cryptoutil.DigestValue, visitedSymlinks map[string]struct{}, processWasTraced bool, openedFiles map[string]bool) (map[string]cryptoutil.DigestSet, error) {
artifacts := make(map[string]cryptoutil.DigestSet)
err := filepath.Walk(basePath, func(path string, info fs.FileInfo, err error) error {
if err != nil {
Expand Down Expand Up @@ -57,15 +57,15 @@ func RecordArtifacts(basePath string, baseArtifacts map[string]cryptoutil.Digest
}

visitedSymlinks[linkedPath] = struct{}{}
symlinkedArtifacts, err := RecordArtifacts(linkedPath, baseArtifacts, hashes, visitedSymlinks)
symlinkedArtifacts, err := RecordArtifacts(linkedPath, baseArtifacts, hashes, visitedSymlinks, processWasTraced, openedFiles)
if err != nil {
return err
}

for artifactPath, artifact := range symlinkedArtifacts {
// all artifacts in the symlink should be recorded relative to our basepath
joinedPath := filepath.Join(relPath, artifactPath)
if shouldRecord(joinedPath, artifact, baseArtifacts) {
if shouldRecord(joinedPath, artifact, baseArtifacts, processWasTraced, openedFiles) {
artifacts[filepath.Join(relPath, artifactPath)] = artifact
}
}
Expand All @@ -78,7 +78,7 @@ func RecordArtifacts(basePath string, baseArtifacts map[string]cryptoutil.Digest
return err
}

if shouldRecord(relPath, artifact, baseArtifacts) {
if shouldRecord(relPath, artifact, baseArtifacts, processWasTraced, openedFiles) {
artifacts[relPath] = artifact
}

Expand All @@ -89,9 +89,13 @@ func RecordArtifacts(basePath string, baseArtifacts map[string]cryptoutil.Digest
}

// shouldRecord determines whether artifact should be recorded.
// if the process was traced and the artifact was not one of the opened files, return false
// if the artifact is already in baseArtifacts, check if it's changed
// if it is not equal to the existing artifact, return true, otherwise return false
func shouldRecord(path string, artifact cryptoutil.DigestSet, baseArtifacts map[string]cryptoutil.DigestSet) bool {
func shouldRecord(path string, artifact cryptoutil.DigestSet, baseArtifacts map[string]cryptoutil.DigestSet, processWasTraced bool, openedFiles map[string]bool) bool {
if _, ok := openedFiles[path]; !ok && processWasTraced {
return false
}
if previous, ok := baseArtifacts[path]; ok && artifact.Equal(previous) {
return false
}
Expand Down
6 changes: 3 additions & 3 deletions attestation/file/file_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,13 @@ func TestBrokenSymlink(t *testing.T) {
symTestDir := filepath.Join(dir, "symTestDir")
require.NoError(t, os.Symlink(testDir, symTestDir))

_, err := RecordArtifacts(dir, map[string]cryptoutil.DigestSet{}, []cryptoutil.DigestValue{{Hash: crypto.SHA256}}, map[string]struct{}{})
_, err := RecordArtifacts(dir, map[string]cryptoutil.DigestSet{}, []cryptoutil.DigestValue{{Hash: crypto.SHA256}}, map[string]struct{}{}, false, map[string]bool{})
require.NoError(t, err)

// remove the symlinks and make sure we don't get an error back
require.NoError(t, os.RemoveAll(testDir))
require.NoError(t, os.RemoveAll(testFile))
_, err = RecordArtifacts(dir, map[string]cryptoutil.DigestSet{}, []cryptoutil.DigestValue{{Hash: crypto.SHA256}}, map[string]struct{}{})
_, err = RecordArtifacts(dir, map[string]cryptoutil.DigestSet{}, []cryptoutil.DigestValue{{Hash: crypto.SHA256}}, map[string]struct{}{}, false, map[string]bool{})
require.NoError(t, err)
}

Expand All @@ -58,6 +58,6 @@ func TestSymlinkCycle(t *testing.T) {
require.NoError(t, os.Symlink(dir, symTestDir))

// if a symlink cycle weren't properly handled this would be an infinite loop
_, err := RecordArtifacts(dir, map[string]cryptoutil.DigestSet{}, []cryptoutil.DigestValue{{Hash: crypto.SHA256}}, map[string]struct{}{})
_, err := RecordArtifacts(dir, map[string]cryptoutil.DigestSet{}, []cryptoutil.DigestValue{{Hash: crypto.SHA256}}, map[string]struct{}{}, false, map[string]bool{})
require.NoError(t, err)
}
2 changes: 1 addition & 1 deletion attestation/material/material.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func (a *Attestor) Schema() *jsonschema.Schema {
}

func (a *Attestor) Attest(ctx *attestation.AttestationContext) error {
materials, err := file.RecordArtifacts(ctx.WorkingDir(), nil, ctx.Hashes(), map[string]struct{}{})
materials, err := file.RecordArtifacts(ctx.WorkingDir(), nil, ctx.Hashes(), map[string]struct{}{}, false, map[string]bool{})
if err != nil {
return err
}
Expand Down
20 changes: 19 additions & 1 deletion attestation/product/product.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/gabriel-vasile/mimetype"
"github.com/gobwas/glob"
"github.com/in-toto/go-witness/attestation"
"github.com/in-toto/go-witness/attestation/commandrun"
"github.com/in-toto/go-witness/attestation/file"
"github.com/in-toto/go-witness/cryptoutil"
"github.com/in-toto/go-witness/registry"
Expand Down Expand Up @@ -181,7 +182,24 @@ func (a *Attestor) Attest(ctx *attestation.AttestationContext) error {
a.compiledExcludeGlob = compiledExcludeGlob

a.baseArtifacts = ctx.Materials()
products, err := file.RecordArtifacts(ctx.WorkingDir(), a.baseArtifacts, ctx.Hashes(), map[string]struct{}{})

processWasTraced := false
openedFileSet := map[string]bool{}

for _, completedAttestor := range ctx.CompletedAttestors() {
attestor := completedAttestor.Attestor
if commandRunAttestor, ok := attestor.(*commandrun.CommandRun); ok && commandRunAttestor.TracingEnabled() {
processWasTraced = true

for _, process := range commandRunAttestor.Processes {
for fname := range process.OpenedFiles {
openedFileSet[fname] = true
}
}
}
}

products, err := file.RecordArtifacts(ctx.WorkingDir(), a.baseArtifacts, ctx.Hashes(), map[string]struct{}{}, processWasTraced, openedFileSet)
if err != nil {
return err
}
Expand Down
Loading