Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[release-1.31] Fix issue with loadbalancer failover to default server #11324

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pkg/agent/loadbalancer/loadbalancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,8 @@ func (lb *LoadBalancer) dialContext(ctx context.Context, network, _ string) (net
if !allChecksFailed {
defer server.closeAll()
}
} else {
logrus.Debugf("Dial health check failed for %s", targetServer)
}

newServer, err := lb.nextServer(targetServer)
Expand Down
186 changes: 157 additions & 29 deletions pkg/agent/loadbalancer/loadbalancer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"testing"
"time"

"github.com/k3s-io/k3s/pkg/cli/cmds"
"github.com/sirupsen/logrus"
)

Expand All @@ -24,7 +23,7 @@ type testServer struct {
prefix string
}

func createServer(prefix string) (*testServer, error) {
func createServer(ctx context.Context, prefix string) (*testServer, error) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return nil, err
Expand All @@ -34,6 +33,10 @@ func createServer(prefix string) (*testServer, error) {
listener: listener,
}
go s.serve()
go func() {
<-ctx.Done()
s.close()
}()
return s, nil
}

Expand All @@ -49,6 +52,7 @@ func (s *testServer) serve() {
}

func (s *testServer) close() {
logrus.Printf("testServer %s closing", s.prefix)
s.listener.Close()
for _, conn := range s.conns {
conn.Close()
Expand All @@ -65,6 +69,10 @@ func (s *testServer) echo(conn net.Conn) {
}
}

func (s *testServer) address() string {
return s.listener.Addr().String()
}

func ping(conn net.Conn) (string, error) {
fmt.Fprintf(conn, "ping\n")
result, err := bufio.NewReader(conn).ReadString('\n')
Expand All @@ -74,25 +82,31 @@ func ping(conn net.Conn) (string, error) {
return strings.TrimSpace(result), nil
}

// Test_UnitFailOver creates a LB using a default server (ie fixed registration endpoint)
// and then adds a new server (a node). The node server is then closed, and it is confirmed
// that new connections use the default server.
func Test_UnitFailOver(t *testing.T) {
tmpDir := t.TempDir()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

ogServe, err := createServer("og")
defaultServer, err := createServer(ctx, "default")
if err != nil {
t.Fatalf("createServer(og) failed: %v", err)
t.Fatalf("createServer(default) failed: %v", err)
}

lbServe, err := createServer("lb")
node1Server, err := createServer(ctx, "node1")
if err != nil {
t.Fatalf("createServer(lb) failed: %v", err)
t.Fatalf("createServer(node1) failed: %v", err)
}

cfg := cmds.Agent{
ServerURL: fmt.Sprintf("http://%s/", ogServe.listener.Addr().String()),
DataDir: tmpDir,
node2Server, err := createServer(ctx, "node2")
if err != nil {
t.Fatalf("createServer(node2) failed: %v", err)
}

lb, err := New(context.TODO(), cfg.DataDir, SupervisorServiceName, cfg.ServerURL, RandomPort, false)
// start the loadbalancer with the default server as the only server
lb, err := New(ctx, tmpDir, SupervisorServiceName, "http://"+defaultServer.address(), RandomPort, false)
if err != nil {
t.Fatalf("New() failed: %v", err)
}
Expand All @@ -103,50 +117,123 @@ func Test_UnitFailOver(t *testing.T) {
}
localAddress := parsedURL.Host

lb.Update([]string{lbServe.listener.Addr().String()})
// add the node as a new server address.
lb.Update([]string{node1Server.address()})

// make sure connections go to the node
conn1, err := net.Dial("tcp", localAddress)
if err != nil {
t.Fatalf("net.Dial failed: %v", err)
}
result1, err := ping(conn1)
if err != nil {
if result, err := ping(conn1); err != nil {
t.Fatalf("ping(conn1) failed: %v", err)
} else if result != "node1:ping" {
t.Fatalf("Unexpected ping(conn1) result: %v", result)
}
if result1 != "lb:ping" {
t.Fatalf("Unexpected ping result: %v", result1)
}

lbServe.close()
t.Log("conn1 tested OK")

// set failing health check for node 1
lb.SetHealthCheck(node1Server.address(), func() bool { return false })

// Server connections are checked every second, now that node 1 is failed
// the connections to it should be closed.
time.Sleep(2 * time.Second)

_, err = ping(conn1)
if err == nil {
if _, err := ping(conn1); err == nil {
t.Fatal("Unexpected successful ping on closed connection conn1")
}

t.Log("conn1 closed on failure OK")

// make sure connection still goes to the first node - it is failing health checks but so
// is the default endpoint, so it should be tried first with health checks disabled,
// before failing back to the default.
conn2, err := net.Dial("tcp", localAddress)
if err != nil {
t.Fatalf("net.Dial failed: %v", err)

}
result2, err := ping(conn2)
if err != nil {
if result, err := ping(conn2); err != nil {
t.Fatalf("ping(conn2) failed: %v", err)
} else if result != "node1:ping" {
t.Fatalf("Unexpected ping(conn2) result: %v", result)
}
if result2 != "og:ping" {
t.Fatalf("Unexpected ping result: %v", result2)

t.Log("conn2 tested OK")

// make sure the health checks don't close the connection we just made -
// connections should only be closed when it transitions from health to unhealthy.
time.Sleep(2 * time.Second)

if result, err := ping(conn2); err != nil {
t.Fatalf("ping(conn2) failed: %v", err)
} else if result != "node1:ping" {
t.Fatalf("Unexpected ping(conn2) result: %v", result)
}

t.Log("conn2 tested OK again")

// shut down the first node server to force failover to the default
node1Server.close()

// make sure new connections go to the default, and existing connections are closed
conn3, err := net.Dial("tcp", localAddress)
if err != nil {
t.Fatalf("net.Dial failed: %v", err)

}
if result, err := ping(conn3); err != nil {
t.Fatalf("ping(conn3) failed: %v", err)
} else if result != "default:ping" {
t.Fatalf("Unexpected ping(conn3) result: %v", result)
}

t.Log("conn3 tested OK")

if _, err := ping(conn2); err == nil {
t.Fatal("Unexpected successful ping on closed connection conn2")
}

t.Log("conn2 closed on failure OK")

// add the second node as a new server address.
lb.Update([]string{node2Server.address()})

// make sure connection now goes to the second node,
// and connections to the default are closed.
conn4, err := net.Dial("tcp", localAddress)
if err != nil {
t.Fatalf("net.Dial failed: %v", err)

}
if result, err := ping(conn4); err != nil {
t.Fatalf("ping(conn4) failed: %v", err)
} else if result != "node2:ping" {
t.Fatalf("Unexpected ping(conn4) result: %v", result)
}

t.Log("conn4 tested OK")

// Server connections are checked every second, now that we have a healthy
// server, connections to the default server should be closed
time.Sleep(2 * time.Second)

if _, err := ping(conn3); err == nil {
t.Fatal("Unexpected successful ping on connection conn3")
}

t.Log("conn3 closed on failure OK")
}

// Test_UnitFailFast confirms that connnections to invalid addresses fail quickly
func Test_UnitFailFast(t *testing.T) {
tmpDir := t.TempDir()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

cfg := cmds.Agent{
ServerURL: "http://127.0.0.1:0/",
DataDir: tmpDir,
}

lb, err := New(context.TODO(), cfg.DataDir, SupervisorServiceName, cfg.ServerURL, RandomPort, false)
serverURL := "http://127.0.0.1:0/"
lb, err := New(ctx, tmpDir, SupervisorServiceName, serverURL, RandomPort, false)
if err != nil {
t.Fatalf("New() failed: %v", err)
}
Expand All @@ -172,3 +259,44 @@ func Test_UnitFailFast(t *testing.T) {
t.Fatal("Test timed out")
}
}

// Test_UnitFailUnreachable confirms that connnections to unreachable addresses do fail
// within the expected duration
func Test_UnitFailUnreachable(t *testing.T) {
if testing.Short() {
t.Skip("skipping slow test in short mode.")
}
tmpDir := t.TempDir()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

serverAddr := "192.0.2.1:6443"
lb, err := New(ctx, tmpDir, SupervisorServiceName, "http://"+serverAddr, RandomPort, false)
if err != nil {
t.Fatalf("New() failed: %v", err)
}

// Set failing health check to reduce retries
lb.SetHealthCheck(serverAddr, func() bool { return false })

conn, err := net.Dial("tcp", lb.localAddress)
if err != nil {
t.Fatalf("net.Dial failed: %v", err)
}

done := make(chan error)
go func() {
_, err = ping(conn)
done <- err
}()
timeout := time.After(11 * time.Second)

select {
case err := <-done:
if err == nil {
t.Fatal("Unexpected successful ping from unreachable address")
}
case <-timeout:
t.Fatal("Test timed out")
}
}
Loading
Loading