From e39975fdcf1affce087be15f4b6f2154f702d7bf Mon Sep 17 00:00:00 2001 From: Jeroen Vervaeke Date: Wed, 9 Oct 2024 12:55:06 +0100 Subject: [PATCH] Removed 404 check, added bearer token refresh logic when hitting 401 --- internal/cli/clusters/watch.go | 46 +++++++++++++++++++---------- internal/cli/clusters/watch_test.go | 3 +- internal/decryption/log_record.go | 2 +- 3 files changed, 33 insertions(+), 18 deletions(-) diff --git a/internal/cli/clusters/watch.go b/internal/cli/clusters/watch.go index b72aee4767..5dba9ca2f8 100644 --- a/internal/cli/clusters/watch.go +++ b/internal/cli/clusters/watch.go @@ -16,7 +16,9 @@ package clusters import ( "context" + "errors" "fmt" + "net/http" "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/cli" "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/cli/require" @@ -31,6 +33,7 @@ import ( type WatchOpts struct { cli.GlobalOpts cli.WatchOpts + cli.RefresherOpts name string store store.ClusterDescriber } @@ -45,24 +48,34 @@ func (opts *WatchOpts) initStore(ctx context.Context) func() error { } } -func isRetryable(err error) bool { - atlasErr, ok := atlasClustersPinned.AsError(err) - return ok && atlasErr.GetErrorCode() == "CLUSTER_NOT_FOUND" -} +func (opts *WatchOpts) watcher(ctx context.Context) func() (any, bool, error) { + return func() (any, bool, error) { + result, err := opts.store.AtlasCluster(opts.ConfigProjectID(), opts.name) + if err != nil { + var atlasClustersPinnedErr *atlasClustersPinned.GenericOpenAPIError -func (opts *WatchOpts) watcher() (any, bool, error) { - result, err := opts.store.AtlasCluster(opts.ConfigProjectID(), opts.name) - if err != nil { - return nil, false, err - } - if result.GetStateName() == "UPDATING" { - opts.IsRetryableErr = isRetryable + if errors.As(err, &atlasClustersPinnedErr) { + if *atlasClustersPinnedErr.Model().Error == http.StatusUnauthorized { + // Refresh the access token + // Note: this only updates the config, so we have to re-initialize the store + if err := opts.RefreshAccessToken(ctx); err != nil { + return nil, false, err + } + + // Re-initialize store, refreshAccessToken only refreshes the config + return nil, false, opts.initStore(ctx)() + } + } + } + if err != nil { + return nil, false, err + } + return nil, result.GetStateName() == "IDLE", nil } - return nil, result.GetStateName() == "IDLE", nil } -func (opts *WatchOpts) Run() error { - if _, err := opts.Watch(opts.watcher); err != nil { +func (opts *WatchOpts) Run(ctx context.Context) error { + if _, err := opts.Watch(opts.watcher(ctx)); err != nil { return err } @@ -93,11 +106,12 @@ You can interrupt the command's polling at any time with CTRL-C. opts.ValidateProjectID, opts.initStore(cmd.Context()), opts.InitOutput(cmd.OutOrStdout(), watchTemplate), + opts.InitFlow(config.Default()), ) }, - RunE: func(_ *cobra.Command, args []string) error { + RunE: func(cmd *cobra.Command, args []string) error { opts.name = args[0] - return opts.Run() + return opts.Run(cmd.Context()) }, } diff --git a/internal/cli/clusters/watch_test.go b/internal/cli/clusters/watch_test.go index f7eb8b1945..b7facad53b 100644 --- a/internal/cli/clusters/watch_test.go +++ b/internal/cli/clusters/watch_test.go @@ -17,6 +17,7 @@ package clusters import ( + "context" "testing" "github.com/golang/mock/gomock" @@ -44,7 +45,7 @@ func TestWatch_Run(t *testing.T) { Return(expected, nil). Times(1) - if err := opts.Run(); err != nil { + if err := opts.Run(context.Background()); err != nil { t.Fatalf("Run() unexpected error: %v", err) } } diff --git a/internal/decryption/log_record.go b/internal/decryption/log_record.go index a3e6262d2a..40a8979566 100644 --- a/internal/decryption/log_record.go +++ b/internal/decryption/log_record.go @@ -106,7 +106,7 @@ func (logLine *AuditLogLine) logAdditionalAuthData() []byte { const AADByteSize = 8 additionalAuthData := make([]byte, AADByteSize) - binary.LittleEndian.PutUint64(additionalAuthData, uint64(logLine.TS.UnixMilli())) //nolint:gosec + binary.LittleEndian.PutUint64(additionalAuthData, uint64(logLine.TS.UnixMilli())) return additionalAuthData }