From efb60afdee168d8eab9defb06d6d0ae2965201ec Mon Sep 17 00:00:00 2001 From: Lisa Ugray Date: Mon, 1 Apr 2024 12:07:14 -0400 Subject: [PATCH] Allow overwriting symlinks When extracting regular files, the file is opened for writing with os.O_TRUNC, making it suitable to extract over an existing file. This allows updating a local copy from an updated archive by extracting over it. If a link is added to the archive though, the second time it's extracted, an error occurs since os.Link and os.Symlink will not overwrite. This removes any file currently in place before attempting to write a link. --- extract.go | 2 ++ extract_test.go | 42 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+) diff --git a/extract.go b/extract.go index 5c68e0f..1a74094 100644 --- a/extract.go +++ b/extract.go @@ -116,6 +116,7 @@ func Zip(ctx context.Context, body io.Reader, location string, rename Renamer) e type fs struct{} func (f fs) Link(oldname, newname string) error { + _ = os.Remove(newname) // Ignore error. We don't care if the file doesn't exist. return os.Link(oldname, newname) } @@ -124,6 +125,7 @@ func (f fs) MkdirAll(path string, perm os.FileMode) error { } func (f fs) Symlink(oldname, newname string) error { + _ = os.Remove(newname) // Ignore error. We don't care if the file doesn't exist. return os.Symlink(oldname, newname) } diff --git a/extract_test.go b/extract_test.go index 3ad01f2..7b088c1 100644 --- a/extract_test.go +++ b/extract_test.go @@ -220,6 +220,48 @@ func TestExtract(t *testing.T) { } } +func TestExtractIdempotency(t *testing.T) { + for _, test := range ExtractCases { + dir, _ := os.MkdirTemp("", "") + dir = filepath.Join(dir, "test") + data, err := os.ReadFile(test.Archive) + if err != nil { + t.Fatal(err) + } + + var extractFn func(context.Context, io.Reader, string, extract.Renamer) error + switch filepath.Ext(test.Archive) { + case ".bz2": + extractFn = extract.Bz2 + case ".gz": + extractFn = extract.Gz + case ".zip": + extractFn = extract.Zip + case ".mistery": + extractFn = extract.Archive + default: + t.Fatal("unknown error") + } + + buffer := bytes.NewBuffer(data) + if err = extractFn(context.Background(), buffer, dir, test.Renamer); err != nil { + t.Fatal(test.Name, ": Should not fail first extraction: "+err.Error()) + } + + buffer = bytes.NewBuffer(data) + if err = extractFn(context.Background(), buffer, dir, test.Renamer); err != nil { + t.Fatal(test.Name, ": Should not fail second extraction: "+err.Error()) + } + + testWalk(t, dir, test.Files) + + err = os.RemoveAll(dir) + if err != nil { + t.Fatal(err) + } + } +} + func BenchmarkArchive(b *testing.B) { dir, _ := os.MkdirTemp("", "") data, _ := os.ReadFile("testdata/archive.tar.bz2")