diff --git a/globals/globalVars.go b/globals/globalVars.go index f43d175..ad8d529 100644 --- a/globals/globalVars.go +++ b/globals/globalVars.go @@ -3,16 +3,16 @@ package globals import discovery "github.com/gkarthiks/k8s-discovery" var ( - VaultIPList map[string]struct{} - K8s *discovery.K8s - Namespace string - HttpTimeout string + VaultIPList map[string]string + K8s *discovery.K8s + Namespace string + HttpTimeout string LabelSelector string ) const ( - HealthCheckPath = ":8200/v1/sys/seal-status" - ProxyPath = ":8200" - DefaultTimeOut = "1" - DefaultBalancerPort = 8000 + HealthCheckPath = ":8200/v1/sys/seal-status" + ProxyPath = ":8200" + DefaultTimeOut = "1" + DefaultBalancerPort = 8000 ) diff --git a/helper/util.go b/helper/util.go index df87bf9..31e619b 100644 --- a/helper/util.go +++ b/helper/util.go @@ -18,7 +18,7 @@ const ( ) // GetVaultIPsFromLabelSelectors will extract the IP Addresses for the pods that matches the labelSelectors -func GetVaultIPsFromLabelSelectors() { +func GetVaultIPsFromLabelSelectors(vaultPool *types.VaultPool) { if len(globals.LabelSelector) > 0 { log.Infof("Discovering the Vault pods based on the label selector '%v'.", globals.LabelSelector) strings.Split(globals.LabelSelector, ",") @@ -29,7 +29,7 @@ func GetVaultIPsFromLabelSelectors() { if err != nil { log.Fatalf("err while retrieving the pods: %v", err) } else { - fetchIpAddress(pods) + populateIpAddresses(pods, vaultPool) } log.Infof("Finalized pods discovery process with label selector. Obtained the IP Address %v", reflect.ValueOf(globals.VaultIPList).MapKeys()) } @@ -59,14 +59,25 @@ func HealthCheck(vaultPool *types.VaultPool) { } // extracts the pods IP from the selected pods -func fetchIpAddress(podsList *v1.PodList) { +func populateIpAddresses(podsList *v1.PodList, vaultPool *types.VaultPool) { + currentPodNames := make(map[string]struct{}) for _, pod := range podsList.Items { + currentPodNames[pod.Name] = struct{}{} if pod.Status.Phase == v1.PodRunning { - if _, ok := globals.VaultIPList[pod.Status.PodIP]; ok { - log.Infof("%v already added", pod.Status.PodIP) - } else { - globals.VaultIPList[pod.Status.PodIP] = struct{}{} - } + // adding the currently discovered pod ips + globals.VaultIPList[pod.Name] = pod.Status.PodIP + //if _, ok := globals.VaultIPList[pod.Status.PodIP]; ok { + // log.Infof("%v already added", pod.Status.PodIP) + //} else { + // globals.VaultIPList[pod.Status.PodIP] = struct{}{} + //} + } + } + for historyPodName, ipAddress := range globals.VaultIPList { + if _, ok := currentPodNames[historyPodName]; !ok { + // removing the obsolete pod and its details + delete(globals.VaultIPList, historyPodName) + vaultPool.RetireBackend(ipAddress) } } } diff --git a/main.go b/main.go index 272ed4f..7984dbc 100644 --- a/main.go +++ b/main.go @@ -21,11 +21,11 @@ import ( func init() { log.SetFormatter(&log.JSONFormatter{ FieldMap: log.FieldMap{ - "FieldKeyTime": "@timestamp", - "version": "@BuildVersion", - }, - CallerPrettyfier: nil, - PrettyPrint: false, + "FieldKeyTime": "@timestamp", + "version": "@BuildVersion", + }, + CallerPrettyfier: nil, + PrettyPrint: false, }) log.Infof("Vault Balancer running version: `%v`", BuildVersion) @@ -38,7 +38,7 @@ func init() { if !avail { log.Fatalf("No label selector has been provided. Please provide the label selector in `VAULT_LABEL_SELECTOR` key.") } else { - globals.VaultIPList = make(map[string]struct{}) + globals.VaultIPList = make(map[string]string) globals.LabelSelector = label } @@ -47,7 +47,7 @@ func init() { log.Warnf("Balancer port is not specified. Please provide the balancer port in `BALANCER_PORT` key. Now the default will be used. BALANCER_PORT: %v", globals.DefaultBalancerPort) balancerPort = globals.DefaultBalancerPort } else { - balancerPort,_ = strconv.Atoi(balancerPortStr) + balancerPort, _ = strconv.Atoi(balancerPortStr) } globals.HttpTimeout, avail = os.LookupEnv("HTTP_TIMEOUT") @@ -60,7 +60,7 @@ func init() { var ( BuildVersion = "dev" balancerPort int - vaultPool types.VaultPool + vaultPool types.VaultPool ) func main() { @@ -68,8 +68,8 @@ func main() { // start the balancer http service server := http.Server{ - Addr: fmt.Sprintf(":%d", balancerPort), - Handler: http.HandlerFunc(loadBalance), + Addr: fmt.Sprintf(":%d", balancerPort), + Handler: http.HandlerFunc(loadBalance), } // log.Infof("Vault Balancer started and running at :%d", balancerPort) @@ -86,7 +86,7 @@ func startRoutine() { for { select { case <-t.C: - helper.GetVaultIPsFromLabelSelectors() + helper.GetVaultIPsFromLabelSelectors(&vaultPool) setUpProxies(&vaultPool) helper.HealthCheck(&vaultPool) } @@ -113,13 +113,13 @@ func loadBalance(w http.ResponseWriter, r *http.Request) { // setUpProxies will create the reverse proxies for the identified IPs func setUpProxies(vaultPool *types.VaultPool) { log.Infof("Setting up the reverse proxy for %v", reflect.ValueOf(globals.VaultIPList).MapKeys()) - for individualIP, _ := range globals.VaultIPList { + for _, individualIP := range globals.VaultIPList { sanitizedIP := strings.TrimSpace(individualIP) - vaultUrl, err := url.Parse("http://"+sanitizedIP + globals.ProxyPath) + vaultUrl, err := url.Parse("http://" + sanitizedIP + globals.ProxyPath) if err != nil { log.Errorf("error occurred while converting string to URL for proxy path. error: %v", err) } - healthUrl, _ := url.Parse("http://"+sanitizedIP + globals.HealthCheckPath) + healthUrl, _ := url.Parse("http://" + sanitizedIP + globals.HealthCheckPath) proxy := httputil.NewSingleHostReverseProxy(vaultUrl) proxy.ErrorHandler = func(writer http.ResponseWriter, request *http.Request, e error) { diff --git a/types/vault_pool_type.go b/types/vault_pool_type.go index 1d8d34d..c8bda56 100644 --- a/types/vault_pool_type.go +++ b/types/vault_pool_type.go @@ -22,6 +22,15 @@ func (vp *VaultPool) AddBackend(vaultBackend *VaultBackend) { vp.vaultBackends = append(vp.vaultBackends, vaultBackend) } +// AddBackend to the existing vault pool +func (vp *VaultPool) RetireBackend(obsoleteIP string) { + for index, currBackend := range vp.vaultBackends { + if currBackend.IP == obsoleteIP { + vp.vaultBackends = append(vp.vaultBackends[:index], vp.vaultBackends[index+1:]...) + } + } +} + // NextIndex atomically increase the counter and return an index func (vp *VaultPool) NextIndex() int { return int(atomic.AddUint64(&vp.current, uint64(1)) % uint64(len(vp.vaultBackends)))