From a7266c76fa0b60340c282d4691682eb4312afef0 Mon Sep 17 00:00:00 2001 From: gabrielseibel1 Date: Mon, 31 Jul 2023 16:21:39 -0300 Subject: [PATCH] Handle license validation failures with a middleware --- controllers/controller.go | 7 +++++ controllers/server.go | 13 ++++++--- ee/ee_controllers/middleware.go | 18 ++++++++++++ ee/initialize.go | 15 +++++++--- ee/license.go | 51 +++++++++++++++++++++------------ logic/timer.go | 12 +++++--- servercfg/serverconf.go | 6 ++-- 7 files changed, 89 insertions(+), 33 deletions(-) create mode 100644 ee/ee_controllers/middleware.go diff --git a/controllers/controller.go b/controllers/controller.go index 7abef20b3..8e54d38b1 100644 --- a/controllers/controller.go +++ b/controllers/controller.go @@ -14,6 +14,9 @@ import ( "github.com/gravitl/netmaker/servercfg" ) +// HttpMiddlewares - middleware functions for REST interactions +var HttpMiddlewares []mux.MiddlewareFunc + // HttpHandlers - handler functions for REST interactions var HttpHandlers = []interface{}{ nodeHandlers, @@ -42,6 +45,10 @@ func HandleRESTRequests(wg *sync.WaitGroup, ctx context.Context) { originsOk := handlers.AllowedOrigins(strings.Split(servercfg.GetAllowedOrigin(), ",")) methodsOk := handlers.AllowedMethods([]string{http.MethodGet, http.MethodPut, http.MethodPost, http.MethodDelete}) + for _, middleware := range HttpMiddlewares { + r.Use(middleware) + } + for _, handler := range HttpHandlers { handler.(func(*mux.Router))(r) } diff --git a/controllers/server.go b/controllers/server.go index ba4a18f67..79c8e21f0 100644 --- a/controllers/server.go +++ b/controllers/server.go @@ -69,15 +69,20 @@ func getUsage(w http.ResponseWriter, r *http.Request) { // 200: serverConfigResponse func getStatus(w http.ResponseWriter, r *http.Request) { type status struct { - DB bool `json:"db_connected"` - Broker bool `json:"broker_connected"` - UnlicensedEE bool `json:"unlicensed_ee"` + DB bool `json:"db_connected"` + Broker bool `json:"broker_connected"` + LicenseError string `json:"license_error"` + } + + licenseErr := "" + if servercfg.ErrLicenseValidation != nil { + licenseErr = servercfg.ErrLicenseValidation.Error() } currentServerStatus := status{ DB: database.IsConnected(), Broker: mq.IsConnected(), - UnlicensedEE: servercfg.Is_EE && servercfg.IsUnlicensed, + LicenseError: licenseErr, } w.Header().Set("Content-Type", "application/json") diff --git a/ee/ee_controllers/middleware.go b/ee/ee_controllers/middleware.go new file mode 100644 index 000000000..35b79ec4c --- /dev/null +++ b/ee/ee_controllers/middleware.go @@ -0,0 +1,18 @@ +package ee_controllers + +import ( + "github.com/gravitl/netmaker/logic" + "github.com/gravitl/netmaker/servercfg" + "net/http" + "strings" +) + +func OnlyServerAPIWhenUnlicensedMiddleware(handler http.Handler) http.Handler { + return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + if servercfg.ErrLicenseValidation != nil && !strings.HasPrefix(request.URL.Path, "/api/server") { + logic.ReturnErrorResponse(writer, request, logic.FormatError(servercfg.ErrLicenseValidation, "unauthorized")) + return + } + handler.ServeHTTP(writer, request) + }) +} diff --git a/ee/initialize.go b/ee/initialize.go index 455ee59df..e779147bf 100644 --- a/ee/initialize.go +++ b/ee/initialize.go @@ -7,10 +7,10 @@ import ( controller "github.com/gravitl/netmaker/controllers" "github.com/gravitl/netmaker/ee/ee_controllers" eelogic "github.com/gravitl/netmaker/ee/logic" - "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/servercfg" + "golang.org/x/exp/slog" ) // InitEE - Initialize EE Logic @@ -18,6 +18,10 @@ func InitEE() { setIsEnterprise() servercfg.Is_EE = true models.SetLogo(retrieveEELogo()) + controller.HttpMiddlewares = append( + controller.HttpMiddlewares, + ee_controllers.OnlyServerAPIWhenUnlicensedMiddleware, + ) controller.HttpHandlers = append( controller.HttpHandlers, ee_controllers.MetricHandlers, @@ -27,8 +31,11 @@ func InitEE() { ) logic.EnterpriseCheckFuncs = append(logic.EnterpriseCheckFuncs, func() { // == License Handling == - ValidateLicense() - logger.Log(0, "proceeding with Paid Tier license") + if err := ValidateLicense(); err != nil { + slog.Error(err.Error()) + return + } + slog.Info("proceeding with Paid Tier license") logic.SetFreeTierForTelemetry(false) // == End License Handling == AddLicenseHooks() @@ -48,7 +55,7 @@ func resetFailover() { for _, net := range nets { err = eelogic.ResetFailover(net.NetID) if err != nil { - logger.Log(0, "failed to reset failover on network", net.NetID, ":", err.Error()) + slog.Error("failed to reset failover", "network", net.NetID, "error", err.Error()) } } } diff --git a/ee/license.go b/ee/license.go index dfbb36eca..b10d5b1b4 100644 --- a/ee/license.go +++ b/ee/license.go @@ -12,7 +12,6 @@ import ( "golang.org/x/exp/slog" "io" "net/http" - "os" "time" "github.com/gravitl/netmaker/database" @@ -44,29 +43,40 @@ func AddLicenseHooks() { } } -// ValidateLicense - the initial license check for netmaker server +// ValidateLicense - the initial and periodic license check for netmaker server // checks if a license is valid + limits are not exceeded -// if license is free_tier and limits exceeds, then server should terminate -// if license is not valid, server should terminate -func ValidateLicense() error { +// if license is free_tier and limits exceeds, then function should error +// if license is not valid, function should error +func ValidateLicense() (err error) { + defer func() { + if err != nil { + err = fmt.Errorf("%w: %s", errValidation, err.Error()) + servercfg.ErrLicenseValidation = err + } + }() + licenseKeyValue := servercfg.GetLicenseKey() netmakerTenantID := servercfg.GetNetmakerTenantID() slog.Info("proceeding with Netmaker license validation...") if len(licenseKeyValue) == 0 { - failValidation(errors.New("empty license-key (LICENSE_KEY environment variable)")) + err = errors.New("empty license-key (LICENSE_KEY environment variable)") + return err } if len(netmakerTenantID) == 0 { - failValidation(errors.New("empty tenant-id (NETMAKER_TENANT_ID environment variable)")) + err = errors.New("empty tenant-id (NETMAKER_TENANT_ID environment variable)") + return err } apiPublicKey, err := getLicensePublicKey(licenseKeyValue) if err != nil { - failValidation(fmt.Errorf("failed to get license public key: %w", err)) + err = fmt.Errorf("failed to get license public key: %w", err) + return err } tempPubKey, tempPrivKey, err := FetchApiServerKeys() if err != nil { - failValidation(fmt.Errorf("failed to fetch api server keys: %w", err)) + err = fmt.Errorf("failed to fetch api server keys: %w", err) + return err } licenseSecret := LicenseSecret{ @@ -76,35 +86,42 @@ func ValidateLicense() error { secretData, err := json.Marshal(&licenseSecret) if err != nil { - failValidation(fmt.Errorf("failed to marshal license secret: %w", err)) + err = fmt.Errorf("failed to marshal license secret: %w", err) + return err } encryptedData, err := ncutils.BoxEncrypt(secretData, apiPublicKey, tempPrivKey) if err != nil { - failValidation(fmt.Errorf("failed to encrypt license secret data: %w", err)) + err = fmt.Errorf("failed to encrypt license secret data: %w", err) + return err } validationResponse, err := validateLicenseKey(encryptedData, tempPubKey) if err != nil { - failValidation(fmt.Errorf("failed to validate license key: %w", err)) + err = fmt.Errorf("failed to validate license key: %w", err) + return err } if len(validationResponse) == 0 { - failValidation(errors.New("empty validation response")) + err = errors.New("empty validation response") + return err } var licenseResponse ValidatedLicense if err = json.Unmarshal(validationResponse, &licenseResponse); err != nil { - failValidation(fmt.Errorf("failed to unmarshal validation response: %w", err)) + err = fmt.Errorf("failed to unmarshal validation response: %w", err) + return err } respData, err := ncutils.BoxDecrypt(base64decode(licenseResponse.EncryptedLicense), apiPublicKey, tempPrivKey) if err != nil { - failValidation(fmt.Errorf("failed to decrypt license: %w", err)) + err = fmt.Errorf("failed to decrypt license: %w", err) + return err } license := LicenseKey{} if err = json.Unmarshal(respData, &license); err != nil { - failValidation(fmt.Errorf("failed to unmarshal license key: %w", err)) + err = fmt.Errorf("failed to unmarshal license key: %w", err) + return err } slog.Info("License validation succeeded!") @@ -159,8 +176,6 @@ func FetchApiServerKeys() (pub *[32]byte, priv *[32]byte, err error) { } func failValidation(err error) { - slog.Error(errValidation.Error(), "error", err) - os.Exit(0) } func getLicensePublicKey(licensePubKeyEncoded string) (*[32]byte, error) { diff --git a/logic/timer.go b/logic/timer.go index 89070f2fe..7a21a3785 100644 --- a/logic/timer.go +++ b/logic/timer.go @@ -3,10 +3,11 @@ package logic import ( "context" "fmt" + "github.com/gravitl/netmaker/logger" + "golang.org/x/exp/slog" "sync" "time" - "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/models" ) @@ -52,7 +53,7 @@ func StartHookManager(ctx context.Context, wg *sync.WaitGroup) { for { select { case <-ctx.Done(): - logger.Log(0, "## Stopping Hook Manager") + slog.Error("## Stopping Hook Manager") return case newhook := <-HookManagerCh: wg.Add(1) @@ -70,7 +71,9 @@ func addHookWithInterval(ctx context.Context, wg *sync.WaitGroup, hook func() er case <-ctx.Done(): return case <-ticker.C: - hook() + if err := hook(); err != nil { + slog.Error(err.Error()) + } } } @@ -85,6 +88,7 @@ var timeHooks = []interface{}{ } func loggerDump() error { + // TODO use slog? logger.DumpFile(fmt.Sprintf("data/netmaker.log.%s", time.Now().Format(logger.TimeFormatDay))) return nil } @@ -93,7 +97,7 @@ func loggerDump() error { func runHooks() { for _, hook := range timeHooks { if err := hook.(func() error)(); err != nil { - logger.Log(1, "error occurred when running timer function:", err.Error()) + slog.Error("error occurred when running timer function", "error", err.Error()) } } } diff --git a/servercfg/serverconf.go b/servercfg/serverconf.go index 101ba94f5..f4e70030e 100644 --- a/servercfg/serverconf.go +++ b/servercfg/serverconf.go @@ -18,9 +18,9 @@ import ( const EmqxBrokerType = "emqx" var ( - Version = "dev" - Is_EE = false - IsUnlicensed = false + Version = "dev" + Is_EE = false + ErrLicenseValidation error ) // SetHost - sets the host ip