Skip to content

Commit

Permalink
Handle license validation failures with a middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrielseibel1 committed Jul 31, 2023
1 parent 2bef48d commit a7266c7
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 33 deletions.
7 changes: 7 additions & 0 deletions controllers/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ import (
"github.com/gravitl/netmaker/servercfg"
)

// HttpMiddlewares - middleware functions for REST interactions
var HttpMiddlewares []mux.MiddlewareFunc

// HttpHandlers - handler functions for REST interactions
var HttpHandlers = []interface{}{
nodeHandlers,
Expand Down Expand Up @@ -42,6 +45,10 @@ func HandleRESTRequests(wg *sync.WaitGroup, ctx context.Context) {
originsOk := handlers.AllowedOrigins(strings.Split(servercfg.GetAllowedOrigin(), ","))
methodsOk := handlers.AllowedMethods([]string{http.MethodGet, http.MethodPut, http.MethodPost, http.MethodDelete})

for _, middleware := range HttpMiddlewares {
r.Use(middleware)
}

for _, handler := range HttpHandlers {
handler.(func(*mux.Router))(r)
}
Expand Down
13 changes: 9 additions & 4 deletions controllers/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,20 @@ func getUsage(w http.ResponseWriter, r *http.Request) {
// 200: serverConfigResponse
func getStatus(w http.ResponseWriter, r *http.Request) {
type status struct {
DB bool `json:"db_connected"`
Broker bool `json:"broker_connected"`
UnlicensedEE bool `json:"unlicensed_ee"`
DB bool `json:"db_connected"`
Broker bool `json:"broker_connected"`
LicenseError string `json:"license_error"`
}

licenseErr := ""
if servercfg.ErrLicenseValidation != nil {
licenseErr = servercfg.ErrLicenseValidation.Error()
}

currentServerStatus := status{
DB: database.IsConnected(),
Broker: mq.IsConnected(),
UnlicensedEE: servercfg.Is_EE && servercfg.IsUnlicensed,
LicenseError: licenseErr,
}

w.Header().Set("Content-Type", "application/json")
Expand Down
18 changes: 18 additions & 0 deletions ee/ee_controllers/middleware.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package ee_controllers

import (
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/servercfg"
"net/http"
"strings"
)

func OnlyServerAPIWhenUnlicensedMiddleware(handler http.Handler) http.Handler {
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
if servercfg.ErrLicenseValidation != nil && !strings.HasPrefix(request.URL.Path, "/api/server") {
logic.ReturnErrorResponse(writer, request, logic.FormatError(servercfg.ErrLicenseValidation, "unauthorized"))
return
}
handler.ServeHTTP(writer, request)
})
}
15 changes: 11 additions & 4 deletions ee/initialize.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,21 @@ import (
controller "github.com/gravitl/netmaker/controllers"
"github.com/gravitl/netmaker/ee/ee_controllers"
eelogic "github.com/gravitl/netmaker/ee/logic"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/servercfg"
"golang.org/x/exp/slog"
)

// InitEE - Initialize EE Logic
func InitEE() {
setIsEnterprise()
servercfg.Is_EE = true
models.SetLogo(retrieveEELogo())
controller.HttpMiddlewares = append(
controller.HttpMiddlewares,
ee_controllers.OnlyServerAPIWhenUnlicensedMiddleware,
)
controller.HttpHandlers = append(
controller.HttpHandlers,
ee_controllers.MetricHandlers,
Expand All @@ -27,8 +31,11 @@ func InitEE() {
)
logic.EnterpriseCheckFuncs = append(logic.EnterpriseCheckFuncs, func() {
// == License Handling ==
ValidateLicense()
logger.Log(0, "proceeding with Paid Tier license")
if err := ValidateLicense(); err != nil {
slog.Error(err.Error())
return
}
slog.Info("proceeding with Paid Tier license")
logic.SetFreeTierForTelemetry(false)
// == End License Handling ==
AddLicenseHooks()
Expand All @@ -48,7 +55,7 @@ func resetFailover() {
for _, net := range nets {
err = eelogic.ResetFailover(net.NetID)
if err != nil {
logger.Log(0, "failed to reset failover on network", net.NetID, ":", err.Error())
slog.Error("failed to reset failover", "network", net.NetID, "error", err.Error())
}
}
}
Expand Down
51 changes: 33 additions & 18 deletions ee/license.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"golang.org/x/exp/slog"
"io"
"net/http"
"os"
"time"

"github.com/gravitl/netmaker/database"
Expand Down Expand Up @@ -44,29 +43,40 @@ func AddLicenseHooks() {
}
}

// ValidateLicense - the initial license check for netmaker server
// ValidateLicense - the initial and periodic license check for netmaker server
// checks if a license is valid + limits are not exceeded
// if license is free_tier and limits exceeds, then server should terminate
// if license is not valid, server should terminate
func ValidateLicense() error {
// if license is free_tier and limits exceeds, then function should error
// if license is not valid, function should error
func ValidateLicense() (err error) {
defer func() {
if err != nil {
err = fmt.Errorf("%w: %s", errValidation, err.Error())
servercfg.ErrLicenseValidation = err
}
}()

licenseKeyValue := servercfg.GetLicenseKey()
netmakerTenantID := servercfg.GetNetmakerTenantID()
slog.Info("proceeding with Netmaker license validation...")
if len(licenseKeyValue) == 0 {
failValidation(errors.New("empty license-key (LICENSE_KEY environment variable)"))
err = errors.New("empty license-key (LICENSE_KEY environment variable)")
return err
}
if len(netmakerTenantID) == 0 {
failValidation(errors.New("empty tenant-id (NETMAKER_TENANT_ID environment variable)"))
err = errors.New("empty tenant-id (NETMAKER_TENANT_ID environment variable)")
return err
}

apiPublicKey, err := getLicensePublicKey(licenseKeyValue)
if err != nil {
failValidation(fmt.Errorf("failed to get license public key: %w", err))
err = fmt.Errorf("failed to get license public key: %w", err)
return err
}

tempPubKey, tempPrivKey, err := FetchApiServerKeys()
if err != nil {
failValidation(fmt.Errorf("failed to fetch api server keys: %w", err))
err = fmt.Errorf("failed to fetch api server keys: %w", err)
return err
}

licenseSecret := LicenseSecret{
Expand All @@ -76,35 +86,42 @@ func ValidateLicense() error {

secretData, err := json.Marshal(&licenseSecret)
if err != nil {
failValidation(fmt.Errorf("failed to marshal license secret: %w", err))
err = fmt.Errorf("failed to marshal license secret: %w", err)
return err
}

encryptedData, err := ncutils.BoxEncrypt(secretData, apiPublicKey, tempPrivKey)
if err != nil {
failValidation(fmt.Errorf("failed to encrypt license secret data: %w", err))
err = fmt.Errorf("failed to encrypt license secret data: %w", err)
return err
}

validationResponse, err := validateLicenseKey(encryptedData, tempPubKey)
if err != nil {
failValidation(fmt.Errorf("failed to validate license key: %w", err))
err = fmt.Errorf("failed to validate license key: %w", err)
return err
}
if len(validationResponse) == 0 {
failValidation(errors.New("empty validation response"))
err = errors.New("empty validation response")
return err
}

var licenseResponse ValidatedLicense
if err = json.Unmarshal(validationResponse, &licenseResponse); err != nil {
failValidation(fmt.Errorf("failed to unmarshal validation response: %w", err))
err = fmt.Errorf("failed to unmarshal validation response: %w", err)
return err
}

respData, err := ncutils.BoxDecrypt(base64decode(licenseResponse.EncryptedLicense), apiPublicKey, tempPrivKey)
if err != nil {
failValidation(fmt.Errorf("failed to decrypt license: %w", err))
err = fmt.Errorf("failed to decrypt license: %w", err)
return err
}

license := LicenseKey{}
if err = json.Unmarshal(respData, &license); err != nil {
failValidation(fmt.Errorf("failed to unmarshal license key: %w", err))
err = fmt.Errorf("failed to unmarshal license key: %w", err)
return err
}

slog.Info("License validation succeeded!")
Expand Down Expand Up @@ -159,8 +176,6 @@ func FetchApiServerKeys() (pub *[32]byte, priv *[32]byte, err error) {
}

func failValidation(err error) {
slog.Error(errValidation.Error(), "error", err)
os.Exit(0)
}

func getLicensePublicKey(licensePubKeyEncoded string) (*[32]byte, error) {
Expand Down
12 changes: 8 additions & 4 deletions logic/timer.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@ package logic
import (
"context"
"fmt"
"github.com/gravitl/netmaker/logger"
"golang.org/x/exp/slog"
"sync"
"time"

"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/models"
)

Expand Down Expand Up @@ -52,7 +53,7 @@ func StartHookManager(ctx context.Context, wg *sync.WaitGroup) {
for {
select {
case <-ctx.Done():
logger.Log(0, "## Stopping Hook Manager")
slog.Error("## Stopping Hook Manager")
return
case newhook := <-HookManagerCh:
wg.Add(1)
Expand All @@ -70,7 +71,9 @@ func addHookWithInterval(ctx context.Context, wg *sync.WaitGroup, hook func() er
case <-ctx.Done():
return
case <-ticker.C:
hook()
if err := hook(); err != nil {
slog.Error(err.Error())
}
}
}

Expand All @@ -85,6 +88,7 @@ var timeHooks = []interface{}{
}

func loggerDump() error {
// TODO use slog?
logger.DumpFile(fmt.Sprintf("data/netmaker.log.%s", time.Now().Format(logger.TimeFormatDay)))
return nil
}
Expand All @@ -93,7 +97,7 @@ func loggerDump() error {
func runHooks() {
for _, hook := range timeHooks {
if err := hook.(func() error)(); err != nil {
logger.Log(1, "error occurred when running timer function:", err.Error())
slog.Error("error occurred when running timer function", "error", err.Error())
}
}
}
6 changes: 3 additions & 3 deletions servercfg/serverconf.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ import (
const EmqxBrokerType = "emqx"

var (
Version = "dev"
Is_EE = false
IsUnlicensed = false
Version = "dev"
Is_EE = false
ErrLicenseValidation error
)

// SetHost - sets the host ip
Expand Down

0 comments on commit a7266c7

Please sign in to comment.