Skip to content

Commit

Permalink
Merge pull request #2728 from nirs/downloader-parallel-test
Browse files Browse the repository at this point in the history
Test parallel downloads
  • Loading branch information
AkihiroSuda authored Oct 15, 2024
2 parents 6768a56 + 709f513 commit 6f7569f
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 61 deletions.
73 changes: 37 additions & 36 deletions pkg/downloader/downloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@ import (
"os/exec"
"path"
"path/filepath"
"strconv"
"strings"
"sync/atomic"
"time"

"github.com/cheggaaa/pb/v3"
"github.com/containerd/continuity/fs"
"github.com/lima-vm/lima/pkg/httpclientutil"
"github.com/lima-vm/lima/pkg/localpathutil"
"github.com/lima-vm/lima/pkg/lockutil"
"github.com/lima-vm/lima/pkg/progressbar"
"github.com/opencontainers/go-digest"
"github.com/sirupsen/logrus"
Expand Down Expand Up @@ -267,21 +268,21 @@ func Download(ctx context.Context, local, remote string, opts ...Opt) (*Result,
return nil, err
}
shadURL := filepath.Join(shad, "url")
if err := atomicWrite(shadURL, []byte(remote), 0o644); err != nil {
if err := writeFirst(shadURL, []byte(remote), 0o644); err != nil {
return nil, err
}
if err := downloadHTTP(ctx, shadData, shadTime, shadType, remote, o.description, o.expectedDigest); err != nil {
return nil, err
}
// no need to pass the digest to copyLocal(), as we already verified the digest
if err := copyLocal(ctx, localPath, shadData, ext, o.decompress, "", ""); err != nil {
return nil, err
}
if shadDigest != "" && o.expectedDigest != "" {
if err := atomicWrite(shadDigest, []byte(o.expectedDigest.String()), 0o644); err != nil {
if err := writeFirst(shadDigest, []byte(o.expectedDigest.String()), 0o644); err != nil {
return nil, err
}
}
// no need to pass the digest to copyLocal(), as we already verified the digest
if err := copyLocal(ctx, localPath, shadData, ext, o.decompress, "", ""); err != nil {
return nil, err
}
res := &Result{
Status: StatusDownloaded,
CachePath: shadData,
Expand Down Expand Up @@ -605,13 +606,13 @@ func downloadHTTP(ctx context.Context, localPath, lastModified, contentType, url
}
if lastModified != "" {
lm := resp.Header.Get("Last-Modified")
if err := atomicWrite(lastModified, []byte(lm), 0o644); err != nil {
if err := writeFirst(lastModified, []byte(lm), 0o644); err != nil {
return err
}
}
if contentType != "" {
ct := resp.Header.Get("Content-Type")
if err := atomicWrite(contentType, []byte(ct), 0o644); err != nil {
if err := writeFirst(contentType, []byte(ct), 0o644); err != nil {
return err
}
}
Expand Down Expand Up @@ -672,43 +673,43 @@ func downloadHTTP(ctx context.Context, localPath, lastModified, contentType, url
return err
}

return os.Rename(localPathTmp, localPath)
// If localPath was created by a parallel download keep it. Replacing it
// while another process is copying it to the destination may fail the
// clonefile syscall. We use a lock to ensure that only one process updates
// data, and when we return data file exists.

return lockutil.WithDirLock(filepath.Dir(localPath), func() error {
if _, err := os.Stat(localPath); err == nil {
return nil
} else if !errors.Is(err, os.ErrNotExist) {
return err
}
return os.Rename(localPathTmp, localPath)
})
}

var tempfileCount atomic.Uint64

// To allow parallel download we use a per-process unique suffix for tempoary
// files. Renaming the temporary file to the final file is safe without
// synchronization on posix.
// To make it easy to test we also include a counter ensuring that each
// temporary file is unique in the same process.
// https://github.com/lima-vm/lima/issues/2722
func perProcessTempfile(path string) string {
return path + ".tmp." + strconv.FormatInt(int64(os.Getpid()), 10)
return fmt.Sprintf("%s.tmp.%d.%d", path, os.Getpid(), tempfileCount.Add(1))
}

// atomicWrite writes data to path, creating a new file or replacing existing
// one. Multiple processess can write to the same path safely. Safe on posix and
// likely safe on windows when using NTFS.
func atomicWrite(path string, data []byte, perm os.FileMode) error {
tmpPath := perProcessTempfile(path)
tmp, err := os.OpenFile(tmpPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, perm)
if err != nil {
return err
}
defer func() {
if err != nil {
tmp.Close()
os.RemoveAll(tmpPath)
// writeFirst writes data to path unless path already exists.
func writeFirst(path string, data []byte, perm os.FileMode) error {
return lockutil.WithDirLock(filepath.Dir(path), func() error {
if _, err := os.Stat(path); err == nil {
return nil
} else if !errors.Is(err, os.ErrNotExist) {
return err
}
}()
if _, err = tmp.Write(data); err != nil {
return err
}
if err = tmp.Sync(); err != nil {
return err
}
if err = tmp.Close(); err != nil {
return err
}
err = os.Rename(tmpPath, path)
return err
return os.WriteFile(path, data, perm)
})
}

// CacheEntries returns a map of cache entries.
Expand Down
117 changes: 92 additions & 25 deletions pkg/downloader/downloader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"os/exec"
"path/filepath"
"runtime"
"slices"
"strings"
"testing"
"time"
Expand All @@ -21,6 +22,20 @@ func TestMain(m *testing.M) {
os.Exit(m.Run())
}

type downloadResult struct {
r *Result
err error
}

// We expect only few parallel downloads. Testing with larger number to find
// races quicker. 20 parallel downloads take about 120 milliseocnds on M1 Pro.
const parallelDownloads = 20

// When downloading in parallel usually all downloads completed with
// StatusDownload, but some may be delayed and find the data file when they
// start. Can be reproduced locally using 100 parallel downloads.
var parallelStatus = []Status{StatusDownloaded, StatusUsedCache}

func TestDownloadRemote(t *testing.T) {
ts := httptest.NewServer(http.FileServer(http.Dir("testdata")))
t.Cleanup(ts.Close)
Expand Down Expand Up @@ -57,38 +72,90 @@ func TestDownloadRemote(t *testing.T) {
})
})
t.Run("with cache", func(t *testing.T) {
cacheDir := filepath.Join(t.TempDir(), "cache")
localPath := filepath.Join(t.TempDir(), t.Name())
r, err := Download(context.Background(), localPath, dummyRemoteFileURL, WithExpectedDigest(dummyRemoteFileDigest), WithCacheDir(cacheDir))
assert.NilError(t, err)
assert.Equal(t, StatusDownloaded, r.Status)
t.Run("serial", func(t *testing.T) {
cacheDir := filepath.Join(t.TempDir(), "cache")
localPath := filepath.Join(t.TempDir(), t.Name())
r, err := Download(context.Background(), localPath, dummyRemoteFileURL,
WithExpectedDigest(dummyRemoteFileDigest), WithCacheDir(cacheDir))
assert.NilError(t, err)
assert.Equal(t, StatusDownloaded, r.Status)

r, err = Download(context.Background(), localPath, dummyRemoteFileURL, WithExpectedDigest(dummyRemoteFileDigest), WithCacheDir(cacheDir))
assert.NilError(t, err)
assert.Equal(t, StatusSkipped, r.Status)
r, err = Download(context.Background(), localPath, dummyRemoteFileURL,
WithExpectedDigest(dummyRemoteFileDigest), WithCacheDir(cacheDir))
assert.NilError(t, err)
assert.Equal(t, StatusSkipped, r.Status)

localPath2 := localPath + "-2"
r, err = Download(context.Background(), localPath2, dummyRemoteFileURL, WithExpectedDigest(dummyRemoteFileDigest), WithCacheDir(cacheDir))
assert.NilError(t, err)
assert.Equal(t, StatusUsedCache, r.Status)
localPath2 := localPath + "-2"
r, err = Download(context.Background(), localPath2, dummyRemoteFileURL,
WithExpectedDigest(dummyRemoteFileDigest), WithCacheDir(cacheDir))
assert.NilError(t, err)
assert.Equal(t, StatusUsedCache, r.Status)
})
t.Run("parallel", func(t *testing.T) {
cacheDir := filepath.Join(t.TempDir(), "cache")
results := make(chan downloadResult, parallelDownloads)
for i := 0; i < parallelDownloads; i++ {
go func() {
// Parallel download is supported only for different instances with unique localPath.
localPath := filepath.Join(t.TempDir(), t.Name())
r, err := Download(context.Background(), localPath, dummyRemoteFileURL,
WithExpectedDigest(dummyRemoteFileDigest), WithCacheDir(cacheDir))
results <- downloadResult{r, err}
}()
}
// We must process all results before cleanup.
for i := 0; i < parallelDownloads; i++ {
result := <-results
if result.err != nil {
t.Errorf("Download failed: %s", result.err)
} else if !slices.Contains(parallelStatus, result.r.Status) {
t.Errorf("Expected download status %s, got %s", parallelStatus, result.r.Status)
}
}
})
})
t.Run("caching-only mode", func(t *testing.T) {
_, err := Download(context.Background(), "", dummyRemoteFileURL, WithExpectedDigest(dummyRemoteFileDigest))
assert.ErrorContains(t, err, "cache directory to be specified")
t.Run("serial", func(t *testing.T) {
_, err := Download(context.Background(), "", dummyRemoteFileURL, WithExpectedDigest(dummyRemoteFileDigest))
assert.ErrorContains(t, err, "cache directory to be specified")

cacheDir := filepath.Join(t.TempDir(), "cache")
r, err := Download(context.Background(), "", dummyRemoteFileURL, WithExpectedDigest(dummyRemoteFileDigest), WithCacheDir(cacheDir))
assert.NilError(t, err)
assert.Equal(t, StatusDownloaded, r.Status)
cacheDir := filepath.Join(t.TempDir(), "cache")
r, err := Download(context.Background(), "", dummyRemoteFileURL, WithExpectedDigest(dummyRemoteFileDigest),
WithCacheDir(cacheDir))
assert.NilError(t, err)
assert.Equal(t, StatusDownloaded, r.Status)

r, err = Download(context.Background(), "", dummyRemoteFileURL, WithExpectedDigest(dummyRemoteFileDigest), WithCacheDir(cacheDir))
assert.NilError(t, err)
assert.Equal(t, StatusUsedCache, r.Status)
r, err = Download(context.Background(), "", dummyRemoteFileURL, WithExpectedDigest(dummyRemoteFileDigest),
WithCacheDir(cacheDir))
assert.NilError(t, err)
assert.Equal(t, StatusUsedCache, r.Status)

localPath := filepath.Join(t.TempDir(), t.Name())
r, err = Download(context.Background(), localPath, dummyRemoteFileURL, WithExpectedDigest(dummyRemoteFileDigest), WithCacheDir(cacheDir))
assert.NilError(t, err)
assert.Equal(t, StatusUsedCache, r.Status)
localPath := filepath.Join(t.TempDir(), t.Name())
r, err = Download(context.Background(), localPath, dummyRemoteFileURL,
WithExpectedDigest(dummyRemoteFileDigest), WithCacheDir(cacheDir))
assert.NilError(t, err)
assert.Equal(t, StatusUsedCache, r.Status)
})
t.Run("parallel", func(t *testing.T) {
cacheDir := filepath.Join(t.TempDir(), "cache")
results := make(chan downloadResult, parallelDownloads)
for i := 0; i < parallelDownloads; i++ {
go func() {
r, err := Download(context.Background(), "", dummyRemoteFileURL,
WithExpectedDigest(dummyRemoteFileDigest), WithCacheDir(cacheDir))
results <- downloadResult{r, err}
}()
}
// We must process all results before cleanup.
for i := 0; i < parallelDownloads; i++ {
result := <-results
if result.err != nil {
t.Errorf("Download failed: %s", result.err)
} else if !slices.Contains(parallelStatus, result.r.Status) {
t.Errorf("Expected download status %s, got %s", parallelStatus, result.r.Status)
}
}
})
})
t.Run("cached", func(t *testing.T) {
_, err := Cached(dummyRemoteFileURL, WithExpectedDigest(dummyRemoteFileDigest))
Expand Down

0 comments on commit 6f7569f

Please sign in to comment.