diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 08bd67a5a..649a63137 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -51,6 +51,7 @@ jobs: run: | go vet ./... go test -p 1 ./... -v + go test -p 1 ./ee -v --tags ee env: DATABASE: sqlite CLIENT_MODE: "off" diff --git a/config/config.go b/config/config.go index fa47fab66..934837a38 100644 --- a/config/config.go +++ b/config/config.go @@ -86,6 +86,7 @@ type ServerConfig struct { NetworksLimit int `yaml:"network_limit"` HostsLimit int `yaml:"host_limit"` DeployedByOperator bool `yaml:"deployed_by_operator"` + Environment string `yaml:"environment"` } // SQLConfig - Generic SQL Config diff --git a/ee/license.go b/ee/license.go index dfbb36eca..5ed507120 100644 --- a/ee/license.go +++ b/ee/license.go @@ -186,7 +186,7 @@ func validateLicenseKey(encryptedData []byte, publicKey *[32]byte) ([]byte, erro return nil, err } - req, err := http.NewRequest(http.MethodPost, api_endpoint, bytes.NewReader(requestBody)) + req, err := http.NewRequest(http.MethodPost, getAccountsHost()+"/api/v1/license/validate", bytes.NewReader(requestBody)) if err != nil { return nil, err } @@ -217,6 +217,17 @@ func validateLicenseKey(encryptedData []byte, publicKey *[32]byte) ([]byte, erro return body, err } +func getAccountsHost() string { + switch servercfg.GetEnvironment() { + case "dev": + return accountsHostDevelopment + case "staging": + return accountsHostStaging + default: + return accountsHostProduction + } +} + func cacheResponse(response []byte) error { var lrc = licenseResponseCache{ Body: response, diff --git a/ee/license_test.go b/ee/license_test.go new file mode 100644 index 000000000..30984680d --- /dev/null +++ b/ee/license_test.go @@ -0,0 +1,77 @@ +//go:build ee +// +build ee + +package ee + +import ( + "github.com/gravitl/netmaker/config" + "testing" +) + +func Test_getAccountsHost(t *testing.T) { + tests := []struct { + name string + envK string + envV string + conf string + want string + }{ + { + name: "no env var and no conf", + envK: "NOT_THE_CORRECT_ENV_VAR", + envV: "dev", + want: "https://api.accounts.netmaker.io", + }, + { + name: "dev env var", + envK: "ENVIRONMENT", + envV: "dev", + want: "https://api.dev.accounts.netmaker.io", + }, + { + name: "staging env var", + envK: "ENVIRONMENT", + envV: "staging", + want: "https://api.staging.accounts.netmaker.io", + }, + { + name: "prod env var", + envK: "ENVIRONMENT", + envV: "prod", + want: "https://api.accounts.netmaker.io", + }, + { + name: "dev conf", + conf: "dev", + want: "https://api.dev.accounts.netmaker.io", + }, + { + name: "staging conf", + conf: "staging", + want: "https://api.staging.accounts.netmaker.io", + }, + { + name: "prod conf", + conf: "prod", + want: "https://api.accounts.netmaker.io", + }, + { + name: "env var vs conf precedence", + envK: "ENVIRONMENT", + envV: "prod", + conf: "staging", + want: "https://api.accounts.netmaker.io", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config.Config.Server.Environment = tt.conf + if tt.envK != "" { + t.Setenv(tt.envK, tt.envV) + } + if got := getAccountsHost(); got != tt.want { + t.Errorf("getAccountsHost() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/ee/types.go b/ee/types.go index 17b5dd331..e7f77464e 100644 --- a/ee/types.go +++ b/ee/types.go @@ -1,9 +1,20 @@ package ee -import "fmt" +import ( + "fmt" +) + +// constants for accounts api hosts +const ( + // accountsHostDevelopment is the accounts api host for development environment + accountsHostDevelopment = "https://api.dev.accounts.netmaker.io" + // accountsHostStaging is the accounts api host for staging environment + accountsHostStaging = "https://api.staging.accounts.netmaker.io" + // accountsHostProduction is the accounts api host for production environment + accountsHostProduction = "https://api.accounts.netmaker.io" +) const ( - api_endpoint = "https://api.accounts.netmaker.io/api/v1/license/validate" license_cache_key = "license_response_cache" license_validation_err_msg = "invalid license" server_id_key = "nm-server-id" diff --git a/servercfg/serverconf.go b/servercfg/serverconf.go index 18170a7eb..f32944619 100644 --- a/servercfg/serverconf.go +++ b/servercfg/serverconf.go @@ -782,6 +782,17 @@ func DeployedByOperator() bool { return config.Config.Server.DeployedByOperator } +// GetEnvironment returns the environment the server is running in (e.g. dev, staging, prod...) +func GetEnvironment() string { + if env := os.Getenv("ENVIRONMENT"); env != "" { + return env + } + if env := config.Config.Server.Environment; env != "" { + return env + } + return "" +} + // parseStunList - turn string into slice of StunServers func parseStunList(stunString string) ([]models.StunServer, error) { var err error