diff --git a/pkg/agent/loadbalancer/config.go b/pkg/agent/loadbalancer/config.go index 9a2de3214fbb..b7d8f63f9d10 100644 --- a/pkg/agent/loadbalancer/config.go +++ b/pkg/agent/loadbalancer/config.go @@ -15,8 +15,8 @@ type lbConfig struct { func (lb *LoadBalancer) writeConfig() error { config := &lbConfig{ - ServerURL: lb.serverURL, - ServerAddresses: lb.serverAddresses, + ServerURL: lb.scheme + "://" + lb.servers.getDefaultAddress(), + ServerAddresses: lb.servers.getAddresses(), } configOut, err := json.MarshalIndent(config, "", " ") if err != nil { @@ -26,20 +26,17 @@ func (lb *LoadBalancer) writeConfig() error { } func (lb *LoadBalancer) updateConfig() error { - writeConfig := true if configBytes, err := os.ReadFile(lb.configFile); err == nil { config := &lbConfig{} if err := json.Unmarshal(configBytes, config); err == nil { - if config.ServerURL == lb.serverURL { - writeConfig = false - lb.setServers(config.ServerAddresses) + // if the default server from the config matches our current default, + // load the rest of the addresses as well. + if config.ServerURL == lb.scheme+"://"+lb.servers.getDefaultAddress() { + lb.Update(config.ServerAddresses) + return nil } } } - if writeConfig { - if err := lb.writeConfig(); err != nil { - return err - } - } - return nil + // config didn't exist or used a different default server, write the current config to disk. + return lb.writeConfig() } diff --git a/pkg/agent/loadbalancer/loadbalancer.go b/pkg/agent/loadbalancer/loadbalancer.go index db9fa6f16f72..2f6d33fbf4c2 100644 --- a/pkg/agent/loadbalancer/loadbalancer.go +++ b/pkg/agent/loadbalancer/loadbalancer.go @@ -2,55 +2,29 @@ package loadbalancer import ( "context" - "errors" "fmt" "net" + "net/url" "os" "path/filepath" - "sync" - "time" + "strings" "github.com/inetaf/tcpproxy" "github.com/k3s-io/k3s/pkg/version" "github.com/sirupsen/logrus" ) -// server tracks the connections to a server, so that they can be closed when the server is removed. -type server struct { - // This mutex protects access to the connections map. All direct access to the map should be protected by it. - mutex sync.Mutex - address string - healthCheck func() bool - connections map[net.Conn]struct{} -} - -// serverConn wraps a net.Conn so that it can be removed from the server's connection map when closed. -type serverConn struct { - server *server - net.Conn -} - // LoadBalancer holds data for a local listener which forwards connections to a // pool of remote servers. It is not a proper load-balancer in that it does not // actually balance connections, but instead fails over to a new server only // when a connection attempt to the currently selected server fails. type LoadBalancer struct { - // This mutex protects access to servers map and randomServers list. - // All direct access to the servers map/list should be protected by it. - mutex sync.RWMutex - proxy *tcpproxy.Proxy - - serviceName string - configFile string - localAddress string - localServerURL string - defaultServerAddress string - serverURL string - serverAddresses []string - randomServers []string - servers map[string]*server - currentServerAddress string - nextServerIndex int + serviceName string + configFile string + scheme string + localAddress string + servers serverList + proxy *tcpproxy.Proxy } const RandomPort = 0 @@ -63,7 +37,7 @@ var ( // New contstructs a new LoadBalancer instance. The default server URL, and // currently active servers, are stored in a file within the dataDir. -func New(ctx context.Context, dataDir, serviceName, serverURL string, lbServerPort int, isIPv6 bool) (_lb *LoadBalancer, _err error) { +func New(ctx context.Context, dataDir, serviceName, defaultServerURL string, lbServerPort int, isIPv6 bool) (_lb *LoadBalancer, _err error) { config := net.ListenConfig{Control: reusePort} var localAddress string if isIPv6 { @@ -84,30 +58,35 @@ func New(ctx context.Context, dataDir, serviceName, serverURL string, lbServerPo return nil, err } - // if lbServerPort was 0, the port was assigned by the OS when bound - see what we ended up with. - localAddress = listener.Addr().String() - - defaultServerAddress, localServerURL, err := parseURL(serverURL, localAddress) + serverURL, err := url.Parse(defaultServerURL) if err != nil { return nil, err } - if serverURL == localServerURL { - logrus.Debugf("Initial server URL for load balancer %s points at local server URL - starting with empty default server address", serviceName) - defaultServerAddress = "" + // Set explicit port from scheme + if serverURL.Port() == "" { + if strings.ToLower(serverURL.Scheme) == "http" { + serverURL.Host += ":80" + } + if strings.ToLower(serverURL.Scheme) == "https" { + serverURL.Host += ":443" + } } lb := &LoadBalancer{ - serviceName: serviceName, - configFile: filepath.Join(dataDir, "etc", serviceName+".json"), - localAddress: localAddress, - localServerURL: localServerURL, - defaultServerAddress: defaultServerAddress, - servers: make(map[string]*server), - serverURL: serverURL, + serviceName: serviceName, + configFile: filepath.Join(dataDir, "etc", serviceName+".json"), + scheme: serverURL.Scheme, + localAddress: listener.Addr().String(), } - lb.setServers([]string{lb.defaultServerAddress}) + // if starting pointing at ourselves, don't set a default server address, + // which will cause all dials to fail until servers are added. + if serverURL.Host == lb.localAddress { + logrus.Debugf("Initial server URL for load balancer %s points at local server URL - starting with empty default server address", serviceName) + } else { + lb.servers.setDefaultAddress(lb.serviceName, serverURL.Host) + } lb.proxy = &tcpproxy.Proxy{ ListenFunc: func(string, string) (net.Listener, error) { @@ -116,7 +95,7 @@ func New(ctx context.Context, dataDir, serviceName, serverURL string, lbServerPo } lb.proxy.AddRoute(serviceName, &tcpproxy.DialProxy{ Addr: serviceName, - DialContext: lb.dialContext, + DialContext: lb.servers.dialContext, OnDialError: onDialError, }) @@ -126,92 +105,50 @@ func New(ctx context.Context, dataDir, serviceName, serverURL string, lbServerPo if err := lb.proxy.Start(); err != nil { return nil, err } - logrus.Infof("Running load balancer %s %s -> %v [default: %s]", serviceName, lb.localAddress, lb.serverAddresses, lb.defaultServerAddress) + logrus.Infof("Running load balancer %s %s -> %v [default: %s]", serviceName, lb.localAddress, lb.servers.getAddresses(), lb.servers.getDefaultAddress()) - go lb.runHealthChecks(ctx) + go lb.servers.runHealthChecks(ctx, lb.serviceName) return lb, nil } +// Update updates the list of server addresses to contain only the listed servers. func (lb *LoadBalancer) Update(serverAddresses []string) { - if lb == nil { - return - } - if !lb.setServers(serverAddresses) { + if !lb.servers.setAddresses(lb.serviceName, serverAddresses) { return } - logrus.Infof("Updated load balancer %s server addresses -> %v [default: %s]", lb.serviceName, lb.serverAddresses, lb.defaultServerAddress) + + logrus.Infof("Updated load balancer %s server addresses -> %v [default: %s]", lb.serviceName, lb.servers.getAddresses(), lb.servers.getDefaultAddress()) if err := lb.writeConfig(); err != nil { logrus.Warnf("Error updating load balancer %s config: %s", lb.serviceName, err) } } -func (lb *LoadBalancer) LoadBalancerServerURL() string { - if lb == nil { - return "" +// SetDefault sets the selected address as the default / fallback address +func (lb *LoadBalancer) SetDefault(serverAddress string) { + lb.servers.setDefaultAddress(lb.serviceName, serverAddress) + + if err := lb.writeConfig(); err != nil { + logrus.Warnf("Error updating load balancer %s config: %s", lb.serviceName, err) } - return lb.localServerURL } -func (lb *LoadBalancer) ServerAddresses() []string { - if lb == nil { - return nil +// SetHealthCheck adds a health-check callback to an address, replacing the default no-op function. +func (lb *LoadBalancer) SetHealthCheck(address string, healthCheck HealthCheckFunc) { + if err := lb.servers.setHealthCheck(address, healthCheck); err != nil { + logrus.Errorf("Failed to set health check for load balancer %s: %v", lb.serviceName, err) + } else { + logrus.Debugf("Set health check for load balancer %s: %s", lb.serviceName, address) } - return lb.serverAddresses } -func (lb *LoadBalancer) dialContext(ctx context.Context, network, _ string) (net.Conn, error) { - lb.mutex.RLock() - defer lb.mutex.RUnlock() - - var allChecksFailed bool - startIndex := lb.nextServerIndex - for { - targetServer := lb.currentServerAddress - - server := lb.servers[targetServer] - if server == nil || targetServer == "" { - logrus.Debugf("Nil server for load balancer %s: %s", lb.serviceName, targetServer) - } else if allChecksFailed || server.healthCheck() { - dialTime := time.Now() - conn, err := server.dialContext(ctx, network, targetServer) - if err == nil { - return conn, nil - } - logrus.Debugf("Dial error from load balancer %s after %s: %s", lb.serviceName, time.Now().Sub(dialTime), err) - // Don't close connections to the failed server if we're retrying with health checks ignored. - // We don't want to disrupt active connections if it is unlikely they will have anywhere to go. - if !allChecksFailed { - defer server.closeAll() - } - } else { - logrus.Debugf("Dial health check failed for %s", targetServer) - } - - newServer, err := lb.nextServer(targetServer) - if err != nil { - return nil, err - } - if targetServer != newServer { - logrus.Debugf("Failed over to new server for load balancer %s: %s -> %s", lb.serviceName, targetServer, newServer) - } - if ctx.Err() != nil { - return nil, ctx.Err() - } +func (lb *LoadBalancer) LocalURL() string { + return lb.scheme + "://" + lb.localAddress +} - maxIndex := len(lb.randomServers) - if startIndex > maxIndex { - startIndex = maxIndex - } - if lb.nextServerIndex == startIndex { - if allChecksFailed { - return nil, errors.New("all servers failed") - } - logrus.Debugf("Health checks for all servers in load balancer %s have failed: retrying with health checks ignored", lb.serviceName) - allChecksFailed = true - } - } +func (lb *LoadBalancer) ServerAddresses() []string { + return lb.servers.getAddresses() } func onDialError(src net.Conn, dstDialErr error) { @@ -220,10 +157,9 @@ func onDialError(src net.Conn, dstDialErr error) { } // ResetLoadBalancer will delete the local state file for the load balancer on disk -func ResetLoadBalancer(dataDir, serviceName string) error { +func ResetLoadBalancer(dataDir, serviceName string) { stateFile := filepath.Join(dataDir, "etc", serviceName+".json") - if err := os.Remove(stateFile); err != nil { + if err := os.Remove(stateFile); err != nil && !os.IsNotExist(err) { logrus.Warn(err) } - return nil } diff --git a/pkg/agent/loadbalancer/loadbalancer_test.go b/pkg/agent/loadbalancer/loadbalancer_test.go index cbfdf982c690..260e180397dd 100644 --- a/pkg/agent/loadbalancer/loadbalancer_test.go +++ b/pkg/agent/loadbalancer/loadbalancer_test.go @@ -6,6 +6,7 @@ import ( "fmt" "net" "net/url" + "slices" "strings" "testing" "time" @@ -111,15 +112,19 @@ func Test_UnitFailOver(t *testing.T) { t.Fatalf("New() failed: %v", err) } - parsedURL, err := url.Parse(lb.LoadBalancerServerURL()) + parsedURL, err := url.Parse(lb.LocalURL()) if err != nil { t.Fatalf("url.Parse failed: %v", err) } localAddress := parsedURL.Host + t.Logf("Adding node1 server: %v", lb.servers.getServers()) + // add the node as a new server address. lb.Update([]string{node1Server.address()}) + t.Logf("Added node1 server: %v", lb.servers.getServers()) + // make sure connections go to the node conn1, err := net.Dial("tcp", localAddress) if err != nil { @@ -134,7 +139,7 @@ func Test_UnitFailOver(t *testing.T) { t.Log("conn1 tested OK") // set failing health check for node 1 - lb.SetHealthCheck(node1Server.address(), func() bool { return false }) + lb.SetHealthCheck(node1Server.address(), func() HealthCheckResult { return HealthCheckResultFailed }) // Server connections are checked every second, now that node 1 is failed // the connections to it should be closed. @@ -146,9 +151,7 @@ func Test_UnitFailOver(t *testing.T) { t.Log("conn1 closed on failure OK") - // make sure connection still goes to the first node - it is failing health checks but so - // is the default endpoint, so it should be tried first with health checks disabled, - // before failing back to the default. + // connections shoould go to the default now that node 1 is failed conn2, err := net.Dial("tcp", localAddress) if err != nil { t.Fatalf("net.Dial failed: %v", err) @@ -156,7 +159,7 @@ func Test_UnitFailOver(t *testing.T) { } if result, err := ping(conn2); err != nil { t.Fatalf("ping(conn2) failed: %v", err) - } else if result != "node1:ping" { + } else if result != "default:ping" { t.Fatalf("Unexpected ping(conn2) result: %v", result) } @@ -168,7 +171,7 @@ func Test_UnitFailOver(t *testing.T) { if result, err := ping(conn2); err != nil { t.Fatalf("ping(conn2) failed: %v", err) - } else if result != "node1:ping" { + } else if result != "default:ping" { t.Fatalf("Unexpected ping(conn2) result: %v", result) } @@ -191,15 +194,13 @@ func Test_UnitFailOver(t *testing.T) { t.Log("conn3 tested OK") - if _, err := ping(conn2); err == nil { - t.Fatal("Unexpected successful ping on closed connection conn2") - } - - t.Log("conn2 closed on failure OK") + t.Logf("Adding node2 server: %v", lb.servers.getServers()) // add the second node as a new server address. lb.Update([]string{node2Server.address()}) + t.Logf("Added node2 server: %v", lb.servers.getServers()) + // make sure connection now goes to the second node, // and connections to the default are closed. conn4, err := net.Dial("tcp", localAddress) @@ -219,11 +220,63 @@ func Test_UnitFailOver(t *testing.T) { // server, connections to the default server should be closed time.Sleep(2 * time.Second) + if _, err := ping(conn2); err == nil { + t.Fatal("Unexpected successful ping on closed connection conn2") + } + + t.Log("conn2 closed on failure OK") + if _, err := ping(conn3); err == nil { t.Fatal("Unexpected successful ping on connection conn3") } t.Log("conn3 closed on failure OK") + + t.Logf("Adding default server: %v", lb.servers.getServers()) + + // add the default as a full server + lb.Update([]string{node2Server.address(), defaultServer.address()}) + + // confirm that both servers are listed in the address list + serverAddresses := lb.ServerAddresses() + if len(serverAddresses) != 2 { + t.Fatalf("Unexpected server address count") + } + + if !slices.Contains(serverAddresses, node2Server.address()) { + t.Fatalf("node2 server not in server address list") + } + + if !slices.Contains(serverAddresses, defaultServer.address()) { + t.Fatalf("default server not in server address list") + } + + // confirm that the default is still listed as default + if lb.servers.getDefaultAddress() != defaultServer.address() { + t.Fatalf("default server is not default") + } + + t.Logf("Default server added OK: %v", lb.servers.getServers()) + + // remove the default as a server + lb.Update([]string{node2Server.address()}) + + // confirm that it is not listed as a server + serverAddresses = lb.ServerAddresses() + if len(serverAddresses) != 1 { + t.Fatalf("Unexpected server address count") + } + + if slices.Contains(serverAddresses, defaultServer.address()) { + t.Fatalf("default server in server address list") + } + + // but is still listed as the default + if lb.servers.getDefaultAddress() != defaultServer.address() { + t.Fatalf("default server is not default") + } + + t.Logf("Default removed added OK: %v", lb.servers.getServers()) } // Test_UnitFailFast confirms that connnections to invalid addresses fail quickly @@ -277,7 +330,7 @@ func Test_UnitFailUnreachable(t *testing.T) { } // Set failing health check to reduce retries - lb.SetHealthCheck(serverAddr, func() bool { return false }) + lb.SetHealthCheck(serverAddr, func() HealthCheckResult { return HealthCheckResultFailed }) conn, err := net.Dial("tcp", lb.localAddress) if err != nil { diff --git a/pkg/agent/loadbalancer/servers.go b/pkg/agent/loadbalancer/servers.go index 675bee5c5c86..919d85de7638 100644 --- a/pkg/agent/loadbalancer/servers.go +++ b/pkg/agent/loadbalancer/servers.go @@ -1,118 +1,415 @@ package loadbalancer import ( + "cmp" "context" - "math/rand" + "errors" + "fmt" "net" "slices" + "sync" "time" - "github.com/pkg/errors" - "github.com/sirupsen/logrus" "k8s.io/apimachinery/pkg/util/sets" "k8s.io/apimachinery/pkg/util/wait" ) -func (lb *LoadBalancer) setServers(serverAddresses []string) bool { - serverAddresses, hasDefaultServer := sortServers(serverAddresses, lb.defaultServerAddress) - if len(serverAddresses) == 0 { - return false - } +type HealthCheckFunc func() HealthCheckResult + +// HealthCheckResult indicates the status of a server health check poll. +// For health-checks that poll in the background, Unknown should be returned +// if a poll has not occurred since the last check. +type HealthCheckResult int - lb.mutex.Lock() - defer lb.mutex.Unlock() +const ( + HealthCheckResultUnknown HealthCheckResult = iota + HealthCheckResultFailed + HealthCheckResultOK +) + +// serverList tracks potential backend servers for use by a loadbalancer. +type serverList struct { + // This mutex protects access to the server list. All direct access to the list should be protected by it. + mutex sync.Mutex + servers []*server +} - newAddresses := sets.NewString(serverAddresses...) - curAddresses := sets.NewString(lb.serverAddresses...) +// setServers updates the server list to contain only the selected addresses. +func (sl *serverList) setAddresses(serviceName string, addresses []string) bool { + newAddresses := sets.New(addresses...) + curAddresses := sets.New(sl.getAddresses()...) if newAddresses.Equal(curAddresses) { return false } - for addedServer := range newAddresses.Difference(curAddresses) { - logrus.Infof("Adding server to load balancer %s: %s", lb.serviceName, addedServer) - lb.servers[addedServer] = &server{ - address: addedServer, - connections: make(map[net.Conn]struct{}), - healthCheck: func() bool { return true }, + sl.mutex.Lock() + defer sl.mutex.Unlock() + + var closeAllFuncs []func() + var defaultServer *server + if i := slices.IndexFunc(sl.servers, func(s *server) bool { return s.isDefault }); i != -1 { + defaultServer = sl.servers[i] + } + + // add new servers + for addedAddress := range newAddresses.Difference(curAddresses) { + if defaultServer != nil && defaultServer.address == addedAddress { + logrus.Infof("Server %s->%s from add to load balancer %s", defaultServer, statePreferred, serviceName) + defaultServer.state = statePreferred + defaultServer.lastTransition = time.Now() + } else { + s := newServer(addedAddress, false) + logrus.Infof("Adding server to load balancer %s: %s", serviceName, s.address) + sl.servers = append(sl.servers, s) + } + } + + // remove old servers + for removedAddress := range curAddresses.Difference(newAddresses) { + if defaultServer != nil && defaultServer.address == removedAddress { + // demote the default server down to standby, instead of deleting it + defaultServer.state = stateStandby + closeAllFuncs = append(closeAllFuncs, defaultServer.closeAll) + } else { + sl.servers = slices.DeleteFunc(sl.servers, func(s *server) bool { + if s.address == removedAddress { + logrus.Infof("Removing server from load balancer %s: %s", serviceName, s.address) + // set state to invalid to prevent server from making additional connections + s.state = stateInvalid + closeAllFuncs = append(closeAllFuncs, s.closeAll) + return true + } + return false + }) + } + } + + slices.SortFunc(sl.servers, compareServers) + + // Close all connections to servers that were removed + for _, closeAll := range closeAllFuncs { + closeAll() + } + + return true +} + +// getAddresses returns the addresses of all servers. +// If the default server is in standby state, indicating it is only present +// because it is the default, it is not returned in this list. +func (sl *serverList) getAddresses() []string { + sl.mutex.Lock() + defer sl.mutex.Unlock() + + addresses := make([]string, 0, len(sl.servers)) + for _, s := range sl.servers { + if s.isDefault && s.state == stateStandby { + continue + } + addresses = append(addresses, s.address) + } + return addresses +} + +// setDefault sets the server with the provided address as the default server. +// The default flag is cleared on all other servers, and if the server was previously +// only kept in the list because it was the default, it is removed. +func (sl *serverList) setDefaultAddress(serviceName, address string) { + sl.mutex.Lock() + defer sl.mutex.Unlock() + + // deal with existing default first + if i := slices.IndexFunc(sl.servers, func(s *server) bool { return s.isDefault }); i != -1 { + s := sl.servers[i] + s.isDefault = false + if s.state == stateStandby { + s.state = stateInvalid + defer s.closeAll() + sl.servers = slices.Delete(sl.servers, i, i) } } - for removedServer := range curAddresses.Difference(newAddresses) { - server := lb.servers[removedServer] - if server != nil { - logrus.Infof("Removing server from load balancer %s: %s", lb.serviceName, removedServer) - // Defer closing connections until after the new server list has been put into place. - // Closing open connections ensures that anything stuck retrying on a stale server is forced - // over to a valid endpoint. - defer server.closeAll() - // Don't delete the default server from the server map, in case we need to fall back to it. - if removedServer != lb.defaultServerAddress { - delete(lb.servers, removedServer) + // update or create server with selected address + if i := slices.IndexFunc(sl.servers, func(s *server) bool { return s.address == address }); i != -1 { + sl.servers[i].isDefault = false + } else { + sl.servers = append(sl.servers, newServer(address, true)) + } + + logrus.Infof("Updated load balancer %s default server: %s", serviceName, address) + slices.SortFunc(sl.servers, compareServers) +} + +// getDefault returns the address of the default server. +func (sl *serverList) getDefaultAddress() string { + if s := sl.getDefaultServer(); s != nil { + return s.address + } + return "" +} + +// getDefault returns the default server. +func (sl *serverList) getDefaultServer() *server { + sl.mutex.Lock() + defer sl.mutex.Unlock() + + if i := slices.IndexFunc(sl.servers, func(s *server) bool { return s.isDefault }); i != -1 { + return sl.servers[i] + } + return nil +} + +// getServers returns a copy of the servers list that can be safely iterated over without holding a lock +func (sl *serverList) getServers() []*server { + sl.mutex.Lock() + defer sl.mutex.Unlock() + + return slices.Clone(sl.servers) +} + +// getServer returns the first server with the specified address +func (sl *serverList) getServer(address string) *server { + sl.mutex.Lock() + defer sl.mutex.Unlock() + + if i := slices.IndexFunc(sl.servers, func(s *server) bool { return s.address == address }); i != -1 { + return sl.servers[i] + } + return nil +} + +// setHealthCheck updates the health check function for a server, replacing the +// current function. +func (sl *serverList) setHealthCheck(address string, healthCheck HealthCheckFunc) error { + if s := sl.getServer(address); s != nil { + s.healthCheck = healthCheck + return nil + } + return fmt.Errorf("no server found for %s", address) +} + +// recordSuccess records a successful check of a server, either via health-check or dial. +// The server's state is adjusted accordingly. +func (sl *serverList) recordSuccess(srv *server, r reason) { + var new_state state + switch srv.state { + case stateFailed: + // dialed or health checked OK once, improve to recovering + new_state = stateRecovering + case stateRecovering: + if r == reasonHealthCheck { + // was recovering due to successful dial or first health check, can now improve + if len(srv.connections) > 0 { + // server accepted connections while recovering, attempt to go straight to active + new_state = stateActive + } else { + // no connections, just make it preferred + new_state = statePreferred + } + } + case stateHealthy: + if r == reasonDial { + // improve from healthy to active by being dialed + new_state = stateActive + } + case statePreferred: + if r == reasonDial { + // improve from healthy to active by being dialed + new_state = stateActive + } else { + if time.Now().Sub(srv.lastTransition) > time.Minute { + // has been preferred for a while without being dialed, demote to healthy + new_state = stateHealthy } } } - lb.serverAddresses = serverAddresses - lb.randomServers = append([]string{}, lb.serverAddresses...) - rand.Shuffle(len(lb.randomServers), func(i, j int) { - lb.randomServers[i], lb.randomServers[j] = lb.randomServers[j], lb.randomServers[i] - }) - // If the current server list does not contain the default server address, - // we want to include it in the random server list so that it can be tried if necessary. - // However, it should be treated as always failing health checks so that it is only - // used if all other endpoints are unavailable. - if !hasDefaultServer { - lb.randomServers = append(lb.randomServers, lb.defaultServerAddress) - if defaultServer, ok := lb.servers[lb.defaultServerAddress]; ok { - defaultServer.healthCheck = func() bool { return false } - lb.servers[lb.defaultServerAddress] = defaultServer + // no-op if state did not change + if new_state == stateInvalid { + return + } + + // handle active transition and sort the server list while holding the lock + sl.mutex.Lock() + defer sl.mutex.Unlock() + + // handle states of other servers when attempting to make this one active + if new_state == stateActive { + for _, s := range sl.servers { + if srv.address == s.address { + continue + } + switch s.state { + case stateFailed, stateStandby, stateRecovering, stateHealthy: + // close connections to other non-active servers whenever we have a new active server + defer s.closeAll() + case stateActive: + if len(s.connections) > len(srv.connections) { + // if there is a currently active server that has more connections that we do than we do, + // close our connections and go to preferred instead + new_state = statePreferred + defer srv.closeAll() + } else { + // otherwise, close its connections and demote it to preferred + s.state = statePreferred + defer s.closeAll() + } + } } } - lb.currentServerAddress = lb.randomServers[0] - lb.nextServerIndex = 1 - return true + // ensure some other routine didn't already make the transition + if srv.state == new_state { + return + } + + logrus.Infof("Server %s->%s from successful %s", srv, new_state, r) + srv.state = new_state + srv.lastTransition = time.Now() + + slices.SortFunc(sl.servers, compareServers) } -// nextServer attempts to get the next server in the loadbalancer server list. -// If another goroutine has already updated the current server address to point at -// a different address than just failed, nothing is changed. Otherwise, a new server address -// is stored to the currentServerAddress field, and returned for use. -// This function must always be called by a goroutine that holds a read lock on the loadbalancer mutex. -func (lb *LoadBalancer) nextServer(failedServer string) (string, error) { - // note: these fields are not protected by the mutex, so we clamp the index value and update - // the index/current address using local variables, to avoid time-of-check vs time-of-use - // race conditions caused by goroutine A incrementing it in between the time goroutine B - // validates its value, and uses it as a list index. - currentServerAddress := lb.currentServerAddress - nextServerIndex := lb.nextServerIndex +// recordSuccess records a failed check of a server, either via health-check or dial. +// The server's state is adjusted accordingly. +func (sl *serverList) recordFailure(srv *server, r reason) { + var new_state state + switch srv.state { + case stateRecovering: + if r == reasonHealthCheck { + // only demote from recovering if a dial fails, health checks may + // continue to fail despite it beig dialable. just leave it in + // recovering and don't close any connections. + new_state = stateFailed + } + case stateHealthy, statePreferred, stateActive: + // should not have any connections when in any state other than active or + // recovering, but close them all anyway to force failover. + defer srv.closeAll() + new_state = stateFailed + } - if len(lb.randomServers) == 0 { - return "", errors.New("No servers in load balancer proxy list") + // no-op if state did not change + if new_state == stateInvalid { + return } - if len(lb.randomServers) == 1 { - return currentServerAddress, nil + + // sort the server list while holding the lock + sl.mutex.Lock() + defer sl.mutex.Unlock() + + // ensure some other routine didn't already make the transition + if srv.state == new_state { + return } - if failedServer != currentServerAddress { - return currentServerAddress, nil + + logrus.Infof("Server %s->%s from failed %s", srv, new_state, r) + srv.state = new_state + srv.lastTransition = time.Now() + + slices.SortFunc(sl.servers, compareServers) +} + +// state is possible server health states, in increasing order of preference. +// The server list is kept sorted in descending order by this state value. +type state int + +const ( + stateInvalid state = iota + stateFailed // failed a health check or dial + stateStandby // reserved for use by default server if not in server list + stateRecovering // successfully health checked once, or dialed when failed + stateHealthy // normal state + statePreferred // recently transitioned from recovering; should be preferred as others may go down for maintenance + stateActive // currently active server +) + +func (s state) String() string { + switch s { + case stateInvalid: + return "INVALID" + case stateFailed: + return "FAILED" + case stateStandby: + return "STANDBY" + case stateRecovering: + return "RECOVERING" + case stateHealthy: + return "HEALTHY" + case statePreferred: + return "PREFERRED" + case stateActive: + return "ACTIVE" + default: + return "UNKNOWN" } - if nextServerIndex >= len(lb.randomServers) { - nextServerIndex = 0 +} + +// reason specifies the reason for a successful or failed health report +type reason int + +const ( + reasonDial reason = iota + reasonHealthCheck +) + +func (r reason) String() string { + switch r { + case reasonDial: + return "dial" + case reasonHealthCheck: + return "health check" + default: + return "unknown reason" } +} - currentServerAddress = lb.randomServers[nextServerIndex] - nextServerIndex++ +// server tracks the connections to a server, so that they can be closed when the server is removed. +type server struct { + // This mutex protects access to the connections map. All direct access to the map should be protected by it. + mutex sync.Mutex + address string + isDefault bool + state state + lastTransition time.Time + healthCheck HealthCheckFunc + connections map[net.Conn]struct{} +} - lb.currentServerAddress = currentServerAddress - lb.nextServerIndex = nextServerIndex +// newServer creates a new server, with a default health check +// and default/state fields appropriate for whether or not +// the server is a full server, or just a fallback default. +func newServer(address string, isDefault bool) *server { + state := statePreferred + if isDefault { + state = stateStandby + } + return &server{ + address: address, + isDefault: isDefault, + state: state, + lastTransition: time.Now(), + healthCheck: func() HealthCheckResult { return HealthCheckResultOK }, + connections: make(map[net.Conn]struct{}), + } +} - return currentServerAddress, nil +func (s *server) String() string { + format := "%s@%s" + if s.isDefault { + format += "*" + } + return fmt.Sprintf(format, s.address, s.state) } -// dialContext dials a new connection using the environment's proxy settings, and adds its wrapped connection to the map -func (s *server) dialContext(ctx context.Context, network, address string) (net.Conn, error) { - conn, err := defaultDialer.Dial(network, address) +// dialContext dials a new connection to the server using the environment's proxy settings, and adds its wrapped connection to the map +func (s *server) dialContext(ctx context.Context, network string) (net.Conn, error) { + if s.state == stateInvalid { + return nil, fmt.Errorf("server %s is stopping", s.address) + } + + conn, err := defaultDialer.Dial(network, s.address) if err != nil { return nil, err } @@ -132,7 +429,7 @@ func (s *server) closeAll() { defer s.mutex.Unlock() if l := len(s.connections); l > 0 { - logrus.Infof("Closing %d connections to load balancer server %s", len(s.connections), s.address) + logrus.Infof("Closing %d connections to load balancer server %s", len(s.connections), s) for conn := range s.connections { // Close the connection in a goroutine so that we don't hold the lock while doing so. go conn.Close() @@ -140,6 +437,12 @@ func (s *server) closeAll() { } } +// serverConn wraps a net.Conn so that it can be removed from the server's connection map when closed. +type serverConn struct { + server *server + net.Conn +} + // Close removes the connection entry from the server's connection map, and // closes the wrapped connection. func (sc *serverConn) Close() error { @@ -150,73 +453,43 @@ func (sc *serverConn) Close() error { return sc.Conn.Close() } -// SetDefault sets the selected address as the default / fallback address -func (lb *LoadBalancer) SetDefault(serverAddress string) { - lb.mutex.Lock() - defer lb.mutex.Unlock() - - hasDefaultServer := slices.Contains(lb.serverAddresses, lb.defaultServerAddress) - // if the old default server is not currently in use, remove it from the server map - if server := lb.servers[lb.defaultServerAddress]; server != nil && !hasDefaultServer { - defer server.closeAll() - delete(lb.servers, lb.defaultServerAddress) - } - // if the new default server doesn't have an entry in the map, add one - but - // with a failing health check so that it is only used as a last resort. - if _, ok := lb.servers[serverAddress]; !ok { - lb.servers[serverAddress] = &server{ - address: serverAddress, - healthCheck: func() bool { return false }, - connections: make(map[net.Conn]struct{}), +// runHealthChecks periodically health-checks all servers. +func (sl *serverList) runHealthChecks(ctx context.Context, serviceName string) { + wait.Until(func() { + for _, s := range sl.getServers() { + switch s.healthCheck() { + case HealthCheckResultOK: + sl.recordSuccess(s, reasonHealthCheck) + case HealthCheckResultFailed: + sl.recordFailure(s, reasonHealthCheck) + } } - } - - lb.defaultServerAddress = serverAddress - logrus.Infof("Updated load balancer %s default server address -> %s", lb.serviceName, serverAddress) + }, time.Second, ctx.Done()) + logrus.Debugf("Stopped health checking for load balancer %s", serviceName) } -// SetHealthCheck adds a health-check callback to an address, replacing the default no-op function. -func (lb *LoadBalancer) SetHealthCheck(address string, healthCheck func() bool) { - lb.mutex.Lock() - defer lb.mutex.Unlock() - - if server := lb.servers[address]; server != nil { - logrus.Debugf("Added health check for load balancer %s: %s", lb.serviceName, address) - server.healthCheck = healthCheck - } else { - logrus.Errorf("Failed to add health check for load balancer %s: no server found for %s", lb.serviceName, address) +// dialContext attemps to dial a connection to a server from the server list. +// Success or failure is recorded to ensure that server state is updated appropriately. +func (sl *serverList) dialContext(ctx context.Context, network, _ string) (net.Conn, error) { + for _, s := range sl.getServers() { + dialTime := time.Now() + conn, err := s.dialContext(ctx, network) + if err == nil { + sl.recordSuccess(s, reasonDial) + return conn, nil + } + logrus.Debugf("Dial error from server %s after %s: %s", s, time.Now().Sub(dialTime), err) + sl.recordFailure(s, reasonDial) } + return nil, errors.New("all servers failed") } -// runHealthChecks periodically health-checks all servers. Any servers that fail the health-check will have their -// connections closed, to force clients to switch over to a healthy server. -func (lb *LoadBalancer) runHealthChecks(ctx context.Context) { - previousStatus := map[string]bool{} - wait.Until(func() { - lb.mutex.RLock() - defer lb.mutex.RUnlock() - var healthyServerExists bool - for address, server := range lb.servers { - status := server.healthCheck() - healthyServerExists = healthyServerExists || status - if status == false && previousStatus[address] == true { - // Only close connections when the server transitions from healthy to unhealthy; - // we don't want to re-close all the connections every time as we might be ignoring - // health checks due to all servers being marked unhealthy. - defer server.closeAll() - } - previousStatus[address] = status - } - - // If there is at least one healthy server, and the default server is not in the server list, - // close all the connections to the default server so that clients reconnect and switch over - // to a preferred server. - hasDefaultServer := slices.Contains(lb.serverAddresses, lb.defaultServerAddress) - if healthyServerExists && !hasDefaultServer { - if server, ok := lb.servers[lb.defaultServerAddress]; ok { - defer server.closeAll() - } - } - }, time.Second, ctx.Done()) - logrus.Debugf("Stopped health checking for load balancer %s", lb.serviceName) +// compareServers is a comparison function that can be used to sort the server list +// so that servers with a more preferred state, or higher number of connections, are ordered first. +func compareServers(a, b *server) int { + c := cmp.Compare(b.state, a.state) + if c == 0 { + return cmp.Compare(len(b.connections), len(a.connections)) + } + return c } diff --git a/pkg/agent/proxy/apiproxy.go b/pkg/agent/proxy/apiproxy.go index e711623e467e..56d86a031366 100644 --- a/pkg/agent/proxy/apiproxy.go +++ b/pkg/agent/proxy/apiproxy.go @@ -22,7 +22,7 @@ type Proxy interface { SupervisorAddresses() []string APIServerURL() string IsAPIServerLBEnabled() bool - SetHealthCheck(address string, healthCheck func() bool) + SetHealthCheck(address string, healthCheck loadbalancer.HealthCheckFunc) } // NewSupervisorProxy sets up a new proxy for retrieving supervisor and apiserver addresses. If @@ -52,7 +52,7 @@ func NewSupervisorProxy(ctx context.Context, lbEnabled bool, dataDir, supervisor return nil, err } p.supervisorLB = lb - p.supervisorURL = lb.LoadBalancerServerURL() + p.supervisorURL = lb.LocalURL() p.apiServerURL = p.supervisorURL } @@ -102,7 +102,7 @@ func (p *proxy) Update(addresses []string) { p.supervisorAddresses = supervisorAddresses } -func (p *proxy) SetHealthCheck(address string, healthCheck func() bool) { +func (p *proxy) SetHealthCheck(address string, healthCheck loadbalancer.HealthCheckFunc) { if p.supervisorLB != nil { p.supervisorLB.SetHealthCheck(address, healthCheck) } @@ -155,7 +155,7 @@ func (p *proxy) SetAPIServerPort(port int, isIPv6 bool) error { return err } p.apiServerLB = lb - p.apiServerURL = lb.LoadBalancerServerURL() + p.apiServerURL = lb.LocalURL() } else { p.apiServerURL = u.String() } diff --git a/pkg/agent/tunnel/tunnel.go b/pkg/agent/tunnel/tunnel.go index a5df415c7343..82d555cb10d1 100644 --- a/pkg/agent/tunnel/tunnel.go +++ b/pkg/agent/tunnel/tunnel.go @@ -14,6 +14,7 @@ import ( "github.com/gorilla/websocket" agentconfig "github.com/k3s-io/k3s/pkg/agent/config" + "github.com/k3s-io/k3s/pkg/agent/loadbalancer" "github.com/k3s-io/k3s/pkg/agent/proxy" daemonconfig "github.com/k3s-io/k3s/pkg/daemons/config" "github.com/k3s-io/k3s/pkg/util" @@ -310,7 +311,7 @@ func (a *agentTunnel) watchEndpoints(ctx context.Context, apiServerReady <-chan if _, ok := disconnect[address]; !ok { conn := a.connect(ctx, wg, address, tlsConfig) disconnect[address] = conn.cancel - proxy.SetHealthCheck(address, conn.connected) + proxy.SetHealthCheck(address, conn.healthCheck) } } @@ -384,7 +385,7 @@ func (a *agentTunnel) watchEndpoints(ctx context.Context, apiServerReady <-chan if _, ok := disconnect[address]; !ok { conn := a.connect(ctx, nil, address, tlsConfig) disconnect[address] = conn.cancel - proxy.SetHealthCheck(address, conn.connected) + proxy.SetHealthCheck(address, conn.healthCheck) } } @@ -427,8 +428,8 @@ func (a *agentTunnel) authorized(ctx context.Context, proto, address string) boo } type agentConnection struct { - cancel context.CancelFunc - connected func() bool + cancel context.CancelFunc + healthCheck loadbalancer.HealthCheckFunc } // connect initiates a connection to the remotedialer server. Incoming dial requests from @@ -484,8 +485,13 @@ func (a *agentTunnel) connect(rootCtx context.Context, waitGroup *sync.WaitGroup }() return agentConnection{ - cancel: cancel, - connected: func() bool { return connected }, + cancel: cancel, + healthCheck: func() loadbalancer.HealthCheckResult { + if connected { + return loadbalancer.HealthCheckResultOK + } + return loadbalancer.HealthCheckResultFailed + }, } } diff --git a/pkg/etcd/etcdproxy.go b/pkg/etcd/etcdproxy.go index 57a2e48c80c1..a842fa625155 100644 --- a/pkg/etcd/etcdproxy.go +++ b/pkg/etcd/etcdproxy.go @@ -52,7 +52,7 @@ func NewETCDProxy(ctx context.Context, supervisorPort int, dataDir, etcdURL stri return nil, err } e.etcdLB = lb - e.etcdLBURL = lb.LoadBalancerServerURL() + e.etcdLBURL = lb.LocalURL() e.fallbackETCDAddress = u.Host e.etcdPort = u.Port() @@ -112,10 +112,8 @@ func (e *etcdproxy) ETCDServerURL() string { // start a polling routine that makes periodic requests to the etcd node's supervisor port. // If the request fails, the node is marked unhealthy. -func (e etcdproxy) createHealthCheck(ctx context.Context, address string) func() bool { - // Assume that the connection to the server will succeed, to avoid failing health checks while attempting to connect. - // If we cannot connect, connected will be set to false when the initial connection attempt fails. - connected := true +func (e etcdproxy) createHealthCheck(ctx context.Context, address string) loadbalancer.HealthCheckFunc { + var status loadbalancer.HealthCheckResult host, _, _ := net.SplitHostPort(address) url := fmt.Sprintf("https://%s/ping", net.JoinHostPort(host, strconv.Itoa(e.supervisorPort))) @@ -131,13 +129,17 @@ func (e etcdproxy) createHealthCheck(ctx context.Context, address string) func() } if err != nil || statusCode != http.StatusOK { logrus.Debugf("Health check %s failed: %v (StatusCode: %d)", address, err, statusCode) - connected = false + status = loadbalancer.HealthCheckResultFailed } else { - connected = true + status = loadbalancer.HealthCheckResultOK } }, 5*time.Second, 1.0, true) - return func() bool { - return connected + return func() loadbalancer.HealthCheckResult { + // Reset the status to unknown on reading, until next time it is checked. + // This avoids having a health check result alter the server state between active checks. + s := status + status = loadbalancer.HealthCheckResultUnknown + return s } }