diff --git a/pkg/downloader/downloader.go b/pkg/downloader/downloader.go index c64495d781e..26117401cb7 100644 --- a/pkg/downloader/downloader.go +++ b/pkg/downloader/downloader.go @@ -12,14 +12,15 @@ import ( "os/exec" "path" "path/filepath" - "strconv" "strings" + "sync/atomic" "time" "github.com/cheggaaa/pb/v3" "github.com/containerd/continuity/fs" "github.com/lima-vm/lima/pkg/httpclientutil" "github.com/lima-vm/lima/pkg/localpathutil" + "github.com/lima-vm/lima/pkg/lockutil" "github.com/lima-vm/lima/pkg/progressbar" "github.com/opencontainers/go-digest" "github.com/sirupsen/logrus" @@ -267,21 +268,21 @@ func Download(ctx context.Context, local, remote string, opts ...Opt) (*Result, return nil, err } shadURL := filepath.Join(shad, "url") - if err := atomicWrite(shadURL, []byte(remote), 0o644); err != nil { + if err := writeFirst(shadURL, []byte(remote), 0o644); err != nil { return nil, err } if err := downloadHTTP(ctx, shadData, shadTime, shadType, remote, o.description, o.expectedDigest); err != nil { return nil, err } - // no need to pass the digest to copyLocal(), as we already verified the digest - if err := copyLocal(ctx, localPath, shadData, ext, o.decompress, "", ""); err != nil { - return nil, err - } if shadDigest != "" && o.expectedDigest != "" { - if err := atomicWrite(shadDigest, []byte(o.expectedDigest.String()), 0o644); err != nil { + if err := writeFirst(shadDigest, []byte(o.expectedDigest.String()), 0o644); err != nil { return nil, err } } + // no need to pass the digest to copyLocal(), as we already verified the digest + if err := copyLocal(ctx, localPath, shadData, ext, o.decompress, "", ""); err != nil { + return nil, err + } res := &Result{ Status: StatusDownloaded, CachePath: shadData, @@ -605,13 +606,13 @@ func downloadHTTP(ctx context.Context, localPath, lastModified, contentType, url } if lastModified != "" { lm := resp.Header.Get("Last-Modified") - if err := atomicWrite(lastModified, []byte(lm), 0o644); err != nil { + if err := writeFirst(lastModified, []byte(lm), 0o644); err != nil { return err } } if contentType != "" { ct := resp.Header.Get("Content-Type") - if err := atomicWrite(contentType, []byte(ct), 0o644); err != nil { + if err := writeFirst(contentType, []byte(ct), 0o644); err != nil { return err } } @@ -672,43 +673,43 @@ func downloadHTTP(ctx context.Context, localPath, lastModified, contentType, url return err } - return os.Rename(localPathTmp, localPath) + // If localPath was created by a parallel download keep it. Replacing it + // while another process is copying it to the destination may fail the + // clonefile syscall. We use a lock to ensure that only one process updates + // data, and when we return data file exists. + + return lockutil.WithDirLock(filepath.Dir(localPath), func() error { + if _, err := os.Stat(localPath); err == nil { + return nil + } else if !errors.Is(err, os.ErrNotExist) { + return err + } + return os.Rename(localPathTmp, localPath) + }) } +var tempfileCount atomic.Uint64 + // To allow parallel download we use a per-process unique suffix for tempoary // files. Renaming the temporary file to the final file is safe without // synchronization on posix. +// To make it easy to test we also include a counter ensuring that each +// temporary file is unique in the same process. // https://github.com/lima-vm/lima/issues/2722 func perProcessTempfile(path string) string { - return path + ".tmp." + strconv.FormatInt(int64(os.Getpid()), 10) + return fmt.Sprintf("%s.tmp.%d.%d", path, os.Getpid(), tempfileCount.Add(1)) } -// atomicWrite writes data to path, creating a new file or replacing existing -// one. Multiple processess can write to the same path safely. Safe on posix and -// likely safe on windows when using NTFS. -func atomicWrite(path string, data []byte, perm os.FileMode) error { - tmpPath := perProcessTempfile(path) - tmp, err := os.OpenFile(tmpPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, perm) - if err != nil { - return err - } - defer func() { - if err != nil { - tmp.Close() - os.RemoveAll(tmpPath) +// writeFirst writes data to path unless path already exists. +func writeFirst(path string, data []byte, perm os.FileMode) error { + return lockutil.WithDirLock(filepath.Dir(path), func() error { + if _, err := os.Stat(path); err == nil { + return nil + } else if !errors.Is(err, os.ErrNotExist) { + return err } - }() - if _, err = tmp.Write(data); err != nil { - return err - } - if err = tmp.Sync(); err != nil { - return err - } - if err = tmp.Close(); err != nil { - return err - } - err = os.Rename(tmpPath, path) - return err + return os.WriteFile(path, data, perm) + }) } // CacheEntries returns a map of cache entries. diff --git a/pkg/downloader/downloader_test.go b/pkg/downloader/downloader_test.go index 136de343670..00c42c6f55d 100644 --- a/pkg/downloader/downloader_test.go +++ b/pkg/downloader/downloader_test.go @@ -8,6 +8,7 @@ import ( "os/exec" "path/filepath" "runtime" + "slices" "strings" "testing" "time" @@ -21,6 +22,20 @@ func TestMain(m *testing.M) { os.Exit(m.Run()) } +type downloadResult struct { + r *Result + err error +} + +// We expect only few parallel downloads. Testing with larger number to find +// races quicker. 20 parallel downloads take about 120 milliseocnds on M1 Pro. +const parallelDownloads = 20 + +// When downloading in parallel usually all downloads completed with +// StatusDownload, but some may be delayed and find the data file when they +// start. Can be reproduced locally using 100 parallel downloads. +var parallelStatus = []Status{StatusDownloaded, StatusUsedCache} + func TestDownloadRemote(t *testing.T) { ts := httptest.NewServer(http.FileServer(http.Dir("testdata"))) t.Cleanup(ts.Close) @@ -57,38 +72,90 @@ func TestDownloadRemote(t *testing.T) { }) }) t.Run("with cache", func(t *testing.T) { - cacheDir := filepath.Join(t.TempDir(), "cache") - localPath := filepath.Join(t.TempDir(), t.Name()) - r, err := Download(context.Background(), localPath, dummyRemoteFileURL, WithExpectedDigest(dummyRemoteFileDigest), WithCacheDir(cacheDir)) - assert.NilError(t, err) - assert.Equal(t, StatusDownloaded, r.Status) + t.Run("serial", func(t *testing.T) { + cacheDir := filepath.Join(t.TempDir(), "cache") + localPath := filepath.Join(t.TempDir(), t.Name()) + r, err := Download(context.Background(), localPath, dummyRemoteFileURL, + WithExpectedDigest(dummyRemoteFileDigest), WithCacheDir(cacheDir)) + assert.NilError(t, err) + assert.Equal(t, StatusDownloaded, r.Status) - r, err = Download(context.Background(), localPath, dummyRemoteFileURL, WithExpectedDigest(dummyRemoteFileDigest), WithCacheDir(cacheDir)) - assert.NilError(t, err) - assert.Equal(t, StatusSkipped, r.Status) + r, err = Download(context.Background(), localPath, dummyRemoteFileURL, + WithExpectedDigest(dummyRemoteFileDigest), WithCacheDir(cacheDir)) + assert.NilError(t, err) + assert.Equal(t, StatusSkipped, r.Status) - localPath2 := localPath + "-2" - r, err = Download(context.Background(), localPath2, dummyRemoteFileURL, WithExpectedDigest(dummyRemoteFileDigest), WithCacheDir(cacheDir)) - assert.NilError(t, err) - assert.Equal(t, StatusUsedCache, r.Status) + localPath2 := localPath + "-2" + r, err = Download(context.Background(), localPath2, dummyRemoteFileURL, + WithExpectedDigest(dummyRemoteFileDigest), WithCacheDir(cacheDir)) + assert.NilError(t, err) + assert.Equal(t, StatusUsedCache, r.Status) + }) + t.Run("parallel", func(t *testing.T) { + cacheDir := filepath.Join(t.TempDir(), "cache") + results := make(chan downloadResult, parallelDownloads) + for i := 0; i < parallelDownloads; i++ { + go func() { + // Parallel download is supported only for different instances with unique localPath. + localPath := filepath.Join(t.TempDir(), t.Name()) + r, err := Download(context.Background(), localPath, dummyRemoteFileURL, + WithExpectedDigest(dummyRemoteFileDigest), WithCacheDir(cacheDir)) + results <- downloadResult{r, err} + }() + } + // We must process all results before cleanup. + for i := 0; i < parallelDownloads; i++ { + result := <-results + if result.err != nil { + t.Errorf("Download failed: %s", result.err) + } else if !slices.Contains(parallelStatus, result.r.Status) { + t.Errorf("Expected download status %s, got %s", parallelStatus, result.r.Status) + } + } + }) }) t.Run("caching-only mode", func(t *testing.T) { - _, err := Download(context.Background(), "", dummyRemoteFileURL, WithExpectedDigest(dummyRemoteFileDigest)) - assert.ErrorContains(t, err, "cache directory to be specified") + t.Run("serial", func(t *testing.T) { + _, err := Download(context.Background(), "", dummyRemoteFileURL, WithExpectedDigest(dummyRemoteFileDigest)) + assert.ErrorContains(t, err, "cache directory to be specified") - cacheDir := filepath.Join(t.TempDir(), "cache") - r, err := Download(context.Background(), "", dummyRemoteFileURL, WithExpectedDigest(dummyRemoteFileDigest), WithCacheDir(cacheDir)) - assert.NilError(t, err) - assert.Equal(t, StatusDownloaded, r.Status) + cacheDir := filepath.Join(t.TempDir(), "cache") + r, err := Download(context.Background(), "", dummyRemoteFileURL, WithExpectedDigest(dummyRemoteFileDigest), + WithCacheDir(cacheDir)) + assert.NilError(t, err) + assert.Equal(t, StatusDownloaded, r.Status) - r, err = Download(context.Background(), "", dummyRemoteFileURL, WithExpectedDigest(dummyRemoteFileDigest), WithCacheDir(cacheDir)) - assert.NilError(t, err) - assert.Equal(t, StatusUsedCache, r.Status) + r, err = Download(context.Background(), "", dummyRemoteFileURL, WithExpectedDigest(dummyRemoteFileDigest), + WithCacheDir(cacheDir)) + assert.NilError(t, err) + assert.Equal(t, StatusUsedCache, r.Status) - localPath := filepath.Join(t.TempDir(), t.Name()) - r, err = Download(context.Background(), localPath, dummyRemoteFileURL, WithExpectedDigest(dummyRemoteFileDigest), WithCacheDir(cacheDir)) - assert.NilError(t, err) - assert.Equal(t, StatusUsedCache, r.Status) + localPath := filepath.Join(t.TempDir(), t.Name()) + r, err = Download(context.Background(), localPath, dummyRemoteFileURL, + WithExpectedDigest(dummyRemoteFileDigest), WithCacheDir(cacheDir)) + assert.NilError(t, err) + assert.Equal(t, StatusUsedCache, r.Status) + }) + t.Run("parallel", func(t *testing.T) { + cacheDir := filepath.Join(t.TempDir(), "cache") + results := make(chan downloadResult, parallelDownloads) + for i := 0; i < parallelDownloads; i++ { + go func() { + r, err := Download(context.Background(), "", dummyRemoteFileURL, + WithExpectedDigest(dummyRemoteFileDigest), WithCacheDir(cacheDir)) + results <- downloadResult{r, err} + }() + } + // We must process all results before cleanup. + for i := 0; i < parallelDownloads; i++ { + result := <-results + if result.err != nil { + t.Errorf("Download failed: %s", result.err) + } else if !slices.Contains(parallelStatus, result.r.Status) { + t.Errorf("Expected download status %s, got %s", parallelStatus, result.r.Status) + } + } + }) }) t.Run("cached", func(t *testing.T) { _, err := Cached(dummyRemoteFileURL, WithExpectedDigest(dummyRemoteFileDigest))