Skip to content

Commit

Permalink
Allow for specifying file permissions when untarring bundle (#61)
Browse files Browse the repository at this point in the history
  • Loading branch information
RebeccaMahany authored Apr 10, 2024
1 parent 3d14b08 commit a88c068
Show file tree
Hide file tree
Showing 2 changed files with 202 additions and 14 deletions.
82 changes: 68 additions & 14 deletions fsutil/filesystem.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@ package fsutil
import (
"archive/tar"
"compress/gzip"
"fmt"
"io"
"io/fs"
"os"
"path/filepath"
"strings"

"github.com/kolide/kit/env"
"github.com/pkg/errors"
)

const (
Expand Down Expand Up @@ -93,13 +94,13 @@ func CopyFile(src, dest string) error {
func UntarBundle(destination string, source string) error {
f, err := os.Open(source)
if err != nil {
return errors.Wrap(err, "open download source")
return fmt.Errorf("opening source: %w", err)
}
defer f.Close()

gzr, err := gzip.NewReader(f)
if err != nil {
return errors.Wrapf(err, "create gzip reader from %s", source)
return fmt.Errorf("creating gzip reader from %s: %w", source, err)
}
defer gzr.Close()

Expand All @@ -110,40 +111,93 @@ func UntarBundle(destination string, source string) error {
break
}
if err != nil {
return errors.Wrap(err, "reading tar file")
return fmt.Errorf("reading tar file: %w", err)
}

if err := sanitizeExtractPath(filepath.Dir(destination), header.Name); err != nil {
return errors.Wrap(err, "checking filename")
return fmt.Errorf("checking filename: %w", err)
}

path := filepath.Join(filepath.Dir(destination), header.Name)
destPath := filepath.Join(filepath.Dir(destination), header.Name)
info := header.FileInfo()
if info.IsDir() {
if err = os.MkdirAll(path, info.Mode()); err != nil {
return errors.Wrapf(err, "creating directory for tar file: %s", path)
if err = os.MkdirAll(destPath, info.Mode()); err != nil {
return fmt.Errorf("creating directory %s for tar file: %w", destPath, err)
}
continue
}

file, err := os.OpenFile(path, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, info.Mode())
if err := writeBundleFile(destPath, info.Mode(), tr); err != nil {
return fmt.Errorf("writing file: %w", err)
}
}
return nil
}

// UntarBundleWithRequiredFilePermission performs the same operation as UntarBundle,
// but enforces `requiredFilePerm` for all files in the bundle.
func UntarBundleWithRequiredFilePermission(destination string, source string, requiredFilePerm fs.FileMode) error {
f, err := os.Open(source)
if err != nil {
return fmt.Errorf("opening source: %w", err)
}
defer f.Close()

gzr, err := gzip.NewReader(f)
if err != nil {
return fmt.Errorf("creating gzip reader from %s: %w", source, err)
}
defer gzr.Close()

tr := tar.NewReader(gzr)
for {
header, err := tr.Next()
if err == io.EOF {
break
}
if err != nil {
return errors.Wrapf(err, "open file %s", path)
return fmt.Errorf("reading tar file: %w", err)
}

if err := sanitizeExtractPath(filepath.Dir(destination), header.Name); err != nil {
return fmt.Errorf("checking filename: %w", err)
}
defer file.Close()
if _, err := io.Copy(file, tr); err != nil {
return errors.Wrapf(err, "copy tar %s to destination %s", header.FileInfo().Name(), path)

destPath := filepath.Join(filepath.Dir(destination), header.Name)
info := header.FileInfo()
if info.IsDir() {
if err = os.MkdirAll(destPath, info.Mode()); err != nil {
return fmt.Errorf("creating directory %s for tar file: %w", destPath, err)
}
continue
}

if err := writeBundleFile(destPath, requiredFilePerm, tr); err != nil {
return fmt.Errorf("writing file: %w", err)
}
}
return nil
}

func writeBundleFile(destPath string, perm fs.FileMode, srcReader io.Reader) error {
file, err := os.OpenFile(destPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, perm)
if err != nil {
return fmt.Errorf("opening %s: %w", destPath, err)
}
defer file.Close()
if _, err := io.Copy(file, srcReader); err != nil {
return fmt.Errorf("copying to %s: %w", destPath, err)
}

return nil
}

// sanitizeExtractPath checks that the supplied extraction path is nor
// vulnerable to zip slip attacks. See https://snyk.io/research/zip-slip-vulnerability
func sanitizeExtractPath(filePath string, destination string) error {
destpath := filepath.Join(destination, filePath)
if !strings.HasPrefix(destpath, filepath.Clean(destination)+string(os.PathSeparator)) {
return errors.Errorf("%s: illegal file path", filePath)
return fmt.Errorf("%s: illegal file path", filePath)
}
return nil
}
134 changes: 134 additions & 0 deletions fsutil/filesystem_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,145 @@
package fsutil

import (
"archive/tar"
"compress/gzip"
"fmt"
"io"
"io/fs"
"os"
"path/filepath"
"strings"
"testing"

"github.com/stretchr/testify/require"
)

func TestUntarBundle(t *testing.T) {
t.Parallel()

// Create tarball contents
originalDir := t.TempDir()
topLevelFile := filepath.Join(originalDir, "testfile.txt")
var topLevelFileMode fs.FileMode = 0655
require.NoError(t, os.WriteFile(topLevelFile, []byte("test1"), topLevelFileMode))
internalDir := filepath.Join(originalDir, "some", "path", "to")
var nestedFileMode fs.FileMode = 0755
require.NoError(t, os.MkdirAll(internalDir, nestedFileMode))
nestedFile := filepath.Join(internalDir, "anotherfile.txt")
require.NoError(t, os.WriteFile(nestedFile, []byte("test2"), nestedFileMode))

// Create test tarball
tarballDir := t.TempDir()
tarballFile := filepath.Join(tarballDir, "test.gz")
createTar(t, tarballFile, originalDir)

// Confirm we can untar the tarball successfully
newDir := t.TempDir()
require.NoError(t, UntarBundle(filepath.Join(newDir, "anything"), tarballFile))

// Confirm the tarball has the contents we expect
newTopLevelFile := filepath.Join(newDir, filepath.Base(topLevelFile))
require.FileExists(t, newTopLevelFile)
newNestedFile := filepath.Join(newDir, "some", "path", "to", filepath.Base(nestedFile))
require.FileExists(t, newNestedFile)

// Confirm each file retained its original permissions
topLevelFileInfo, err := os.Stat(newTopLevelFile)
require.NoError(t, err)
require.Equal(t, topLevelFileMode, topLevelFileInfo.Mode())
nestedFileInfo, err := os.Stat(newNestedFile)
require.NoError(t, err)
require.Equal(t, nestedFileMode, nestedFileInfo.Mode())
}

func TestUntarBundleWithRequiredFilePermission(t *testing.T) {
t.Parallel()

// Create tarball contents
originalDir := t.TempDir()
topLevelFile := filepath.Join(originalDir, "testfile.txt")
require.NoError(t, os.WriteFile(topLevelFile, []byte("test1"), 0655))
internalDir := filepath.Join(originalDir, "some", "path", "to")
require.NoError(t, os.MkdirAll(internalDir, 0744))
nestedFile := filepath.Join(internalDir, "anotherfile.txt")
require.NoError(t, os.WriteFile(nestedFile, []byte("test2"), 0744))

// Create test tarball
tarballDir := t.TempDir()
tarballFile := filepath.Join(tarballDir, "test.gz")
createTar(t, tarballFile, originalDir)

// Confirm we can untar the tarball successfully
newDir := t.TempDir()
var requiredFileMode fs.FileMode = 0755
require.NoError(t, UntarBundleWithRequiredFilePermission(filepath.Join(newDir, "anything"), tarballFile, requiredFileMode))

// Confirm the tarball has the contents we expect
newTopLevelFile := filepath.Join(newDir, filepath.Base(topLevelFile))
require.FileExists(t, newTopLevelFile)
newNestedFile := filepath.Join(newDir, "some", "path", "to", filepath.Base(nestedFile))
require.FileExists(t, newNestedFile)

// Require that both files have the required permission 0755
topLevelFileInfo, err := os.Stat(newTopLevelFile)
require.NoError(t, err)
require.Equal(t, requiredFileMode, topLevelFileInfo.Mode())
nestedFileInfo, err := os.Stat(newNestedFile)
require.NoError(t, err)
require.Equal(t, requiredFileMode, nestedFileInfo.Mode())
}

// createTar is a helper to create a test tar
func createTar(t *testing.T, createLocation string, sourceDir string) {
tarballFile, err := os.Create(createLocation)
require.NoError(t, err)
defer tarballFile.Close()

gzw := gzip.NewWriter(tarballFile)
defer gzw.Close()

tw := tar.NewWriter(gzw)
defer tw.Close()

require.NoError(t, filepath.Walk(sourceDir, func(path string, info fs.FileInfo, err error) error {
if err != nil {
return err
}

srcInfo, err := os.Lstat(path)
if os.IsNotExist(err) {
return fmt.Errorf("error adding %s to tarball: %w", path, err)
}

hdr, err := tar.FileInfoHeader(srcInfo, path)
if err != nil {
return fmt.Errorf("error creating tar header: %w", err)
}
hdr.Name = strings.TrimPrefix(path, sourceDir+"/")

if err := tw.WriteHeader(hdr); err != nil {
return fmt.Errorf("error writing tar header: %w", err)
}

if !srcInfo.Mode().IsRegular() {
// Don't open/copy over directories
return nil
}

srcFile, err := os.Open(path)
if err != nil {
return fmt.Errorf("error opening file to add to tarball: %w", err)
}
defer srcFile.Close()

if _, err := io.Copy(tw, srcFile); err != nil {
return fmt.Errorf("error copying file %s to tarball: %w", path, err)
}

return nil
}))
}

func TestSanitizeExtractPath(t *testing.T) {
t.Parallel()

Expand Down

0 comments on commit a88c068

Please sign in to comment.