From 05aa01407d8162e7db0dfde510f14e170e9e4386 Mon Sep 17 00:00:00 2001 From: Ivan Dagelic Date: Fri, 1 Nov 2024 10:00:46 +0100 Subject: [PATCH] fix: persist default provider deletion (#1288) Signed-off-by: Ivan Dagelic --- pkg/api/controllers/provider/install.go | 4 +- pkg/provider/manager/error.go | 22 +++++++ pkg/provider/manager/installer.go | 7 +- pkg/provider/manager/manager.go | 87 +++++++++++++++++-------- pkg/server/providers.go | 47 ++++++++++--- 5 files changed, 122 insertions(+), 45 deletions(-) create mode 100644 pkg/provider/manager/error.go diff --git a/pkg/api/controllers/provider/install.go b/pkg/api/controllers/provider/install.go index e9b0f7c366..4cfc8a5d68 100644 --- a/pkg/api/controllers/provider/install.go +++ b/pkg/api/controllers/provider/install.go @@ -40,13 +40,13 @@ func InstallProvider(ctx *gin.Context) { } } - downloadPath, err := server.ProviderManager.DownloadProvider(ctx.Request.Context(), req.DownloadUrls, req.Name, true) + downloadPath, err := server.ProviderManager.DownloadProvider(ctx.Request.Context(), req.DownloadUrls, req.Name) if err != nil { ctx.AbortWithError(http.StatusInternalServerError, fmt.Errorf("failed to download provider: %w", err)) return } - err = server.ProviderManager.RegisterProvider(downloadPath) + err = server.ProviderManager.RegisterProvider(downloadPath, true) if err != nil { ctx.AbortWithError(http.StatusInternalServerError, fmt.Errorf("failed to register provider: %w", err)) return diff --git a/pkg/provider/manager/error.go b/pkg/provider/manager/error.go new file mode 100644 index 0000000000..e8faa1f393 --- /dev/null +++ b/pkg/provider/manager/error.go @@ -0,0 +1,22 @@ +// Copyright 2024 Daytona Platforms Inc. +// SPDX-License-Identifier: Apache-2.0 + +package manager + +import "fmt" + +func IsProviderAlreadyDownloaded(err error, name string) bool { + return err.Error() == providerAlreadyDownloadedError(name).Error() +} + +func IsNoPluginFound(err error, dir string) bool { + return err.Error() == noPluginFoundError(dir).Error() +} + +func providerAlreadyDownloadedError(name string) error { + return fmt.Errorf("provider %s already installed", name) +} + +func noPluginFoundError(dir string) error { + return fmt.Errorf("no plugin found in %s", dir) +} diff --git a/pkg/provider/manager/installer.go b/pkg/provider/manager/installer.go index cfe48d3a19..239928ee74 100644 --- a/pkg/provider/manager/installer.go +++ b/pkg/provider/manager/installer.go @@ -41,17 +41,14 @@ func (m *ProviderManager) GetProvidersManifest() (*ProvidersManifest, error) { return &manifest, nil } -func (m *ProviderManager) DownloadProvider(ctx context.Context, downloadUrls map[os.OperatingSystem]string, providerName string, throwIfPresent bool) (string, error) { +func (m *ProviderManager) DownloadProvider(ctx context.Context, downloadUrls map[os.OperatingSystem]string, providerName string) (string, error) { downloadPath := filepath.Join(m.baseDir, providerName, providerName) if runtime.GOOS == "windows" { downloadPath += ".exe" } if _, err := goos.Stat(downloadPath); err == nil { - if throwIfPresent { - return "", fmt.Errorf("provider %s already downloaded", providerName) - } - return "", nil + return "", providerAlreadyDownloadedError(providerName) } log.Info("Downloading " + providerName) diff --git a/pkg/provider/manager/manager.go b/pkg/provider/manager/manager.go index caec4a0c89..6f554f3af0 100644 --- a/pkg/provider/manager/manager.go +++ b/pkg/provider/manager/manager.go @@ -6,6 +6,7 @@ package manager import ( "context" "errors" + "fmt" "os" "os/exec" "path/filepath" @@ -22,6 +23,8 @@ import ( log "github.com/sirupsen/logrus" ) +const INITIAL_SETUP_LOCK_FILE_NAME = "initial-setup.lock" + type pluginRef struct { client *plugin.Client path string @@ -35,11 +38,11 @@ var ProviderHandshakeConfig = plugin.HandshakeConfig{ } type IProviderManager interface { - DownloadProvider(ctx context.Context, downloadUrls map[os_util.OperatingSystem]string, providerName string, throwIfPresent bool) (string, error) + DownloadProvider(ctx context.Context, downloadUrls map[os_util.OperatingSystem]string, providerName string) (string, error) GetProvider(name string) (*Provider, error) GetProviders() map[string]Provider GetProvidersManifest() (*ProvidersManifest, error) - RegisterProvider(pluginPath string) error + RegisterProvider(pluginPath string, manualInstall bool) error TerminateProviderProcesses(providersBasePath string) error UninstallProvider(name string) error Purge() error @@ -129,7 +132,7 @@ func (m *ProviderManager) GetProviders() map[string]Provider { return providers } -func (m *ProviderManager) RegisterProvider(pluginPath string) error { +func (m *ProviderManager) RegisterProvider(pluginPath string, manualInstall bool) error { pluginRef, err := m.initializeProvider(pluginPath) if err != nil { return err @@ -137,36 +140,40 @@ func (m *ProviderManager) RegisterProvider(pluginPath string) error { m.pluginRefs[pluginRef.name] = pluginRef - p, err := m.dispenseProvider(pluginRef.client, pluginRef.name) - if err != nil { - return err - } - - existingTargets, err := m.providerTargetService.Map() - if err != nil { - return errors.New("failed to get targets: " + err.Error()) - } - - presetTargets, err := (*p).GetPresetTargets() - if err != nil { - return errors.New("failed to get preset targets: " + err.Error()) - } + lockFilePath := filepath.Join(pluginRef.path, INITIAL_SETUP_LOCK_FILE_NAME) + _, err = os.Stat(lockFilePath) + if os.IsNotExist(err) || manualInstall { + p, err := m.GetProvider(pluginRef.name) + if err != nil { + return fmt.Errorf("failed to get provider: %w", err) + } - log.Info("Setting preset targets") - for _, target := range *presetTargets { - if _, ok := existingTargets[target.Name]; ok { - log.Infof("Target %s already exists. Skipping...", target.Name) - continue + existingTargets, err := m.providerTargetService.Map() + if err != nil { + return errors.New("failed to get targets: " + err.Error()) } - err := m.providerTargetService.Save(&target) + presetTargets, err := (*p).GetPresetTargets() if err != nil { - log.Errorf("Failed to set target %s: %s", target.Name, err) - } else { - log.Infof("Target %s set", target.Name) + return errors.New("failed to get preset targets: " + err.Error()) + } + + log.Infof("Setting preset targets for %s", pluginRef.name) + for _, target := range *presetTargets { + if _, ok := existingTargets[target.Name]; ok { + log.Infof("Target %s already exists. Skipping...", target.Name) + continue + } + + err := m.providerTargetService.Save(&target) + if err != nil { + log.Errorf("Failed to set target %s: %s", target.Name, err) + } else { + log.Infof("Target %s set", target.Name) + } } + log.Infof("Preset targets set for %s", pluginRef.name) } - log.Info("Preset targets set") log.Infof("Provider %s initialized", pluginRef.name) @@ -180,11 +187,35 @@ func (m *ProviderManager) UninstallProvider(name string) error { } pluginRef.client.Kill() - err := os.RemoveAll(pluginRef.path) + lockFileExisted := false + lockFilePath := filepath.Join(pluginRef.path, INITIAL_SETUP_LOCK_FILE_NAME) + _, err := os.Stat(lockFilePath) + if err == nil { + lockFileExisted = true + } + + err = os.RemoveAll(pluginRef.path) if err != nil { return errors.New("failed to remove provider: " + err.Error()) } + if lockFileExisted { + // After clearing up the contents, remake the directory and add a lock file that + // will be used to ensure that the provider is not reinstalled automatically + err = os.MkdirAll(pluginRef.path, os.ModePerm) + if err != nil { + return err + } + + lockFilePath := filepath.Join(pluginRef.path, INITIAL_SETUP_LOCK_FILE_NAME) + + file, err := os.Create(lockFilePath) + if err != nil { + return err + } + defer file.Close() + } + delete(m.pluginRefs, name) return nil diff --git a/pkg/server/providers.go b/pkg/server/providers.go index 031c8c9d0f..5718238d32 100644 --- a/pkg/server/providers.go +++ b/pkg/server/providers.go @@ -9,6 +9,7 @@ import ( "os" "path/filepath" + "github.com/daytonaio/daytona/pkg/provider/manager" log "github.com/sirupsen/logrus" ) @@ -22,9 +23,19 @@ func (s *Server) downloadDefaultProviders() error { log.Info("Downloading default providers") for providerName, provider := range defaultProviders { - _, err = s.ProviderManager.DownloadProvider(context.Background(), provider.DownloadUrls, providerName, false) + lockFilePath := filepath.Join(s.config.ProvidersDir, providerName, manager.INITIAL_SETUP_LOCK_FILE_NAME) + + _, err := os.Stat(lockFilePath) + if err == nil { + continue + } + + _, err = s.ProviderManager.DownloadProvider(context.Background(), provider.DownloadUrls, providerName) if err != nil { - log.Error(err) + if !manager.IsProviderAlreadyDownloaded(err, providerName) { + log.Error(err) + } + continue } } @@ -41,7 +52,7 @@ func (s *Server) registerProviders() error { return err } - files, err := os.ReadDir(s.config.ProvidersDir) + directoryEntries, err := os.ReadDir(s.config.ProvidersDir) if err != nil { if os.IsNotExist(err) { log.Info("No providers found") @@ -50,22 +61,38 @@ func (s *Server) registerProviders() error { return err } - for _, file := range files { - if file.IsDir() { - pluginPath, err := s.getPluginPath(filepath.Join(s.config.ProvidersDir, file.Name())) + for _, entry := range directoryEntries { + if entry.IsDir() { + providerDir := filepath.Join(s.config.ProvidersDir, entry.Name()) + + pluginPath, err := s.getPluginPath(providerDir) if err != nil { - log.Error(err) + if !manager.IsNoPluginFound(err, providerDir) { + log.Error(err) + } continue } - err = s.ProviderManager.RegisterProvider(pluginPath) + err = s.ProviderManager.RegisterProvider(pluginPath, false) if err != nil { log.Error(err) continue } + // Lock the initial setup + lockFilePath := filepath.Join(s.config.ProvidersDir, entry.Name(), manager.INITIAL_SETUP_LOCK_FILE_NAME) + + _, err = os.Stat(lockFilePath) + if err != nil { + file, err := os.Create(lockFilePath) + if err != nil { + return err + } + defer file.Close() + } + // Check for updates - provider, err := s.ProviderManager.GetProvider(file.Name()) + provider, err := s.ProviderManager.GetProvider(entry.Name()) if err != nil { log.Error(err) continue @@ -95,7 +122,7 @@ func (s *Server) getPluginPath(dir string) (string, error) { } for _, file := range files { - if !file.IsDir() { + if !file.IsDir() && file.Name() != manager.INITIAL_SETUP_LOCK_FILE_NAME { return filepath.Join(dir, file.Name()), nil } }