diff --git a/go/vt/topo/zk2topo/zk_conn.go b/go/vt/topo/zk2topo/zk_conn.go index a0eec8b4340..60ba00a2bf6 100644 --- a/go/vt/topo/zk2topo/zk_conn.go +++ b/go/vt/topo/zk2topo/zk_conn.go @@ -277,6 +277,8 @@ func (c *ZkConn) withRetry(ctx context.Context, action func(conn *zk.Conn) error c.conn = nil } c.mu.Unlock() + log.Infof("zk conn: got ErrConnectionClosed: closing") + conn.Close() } return } @@ -327,13 +329,9 @@ func (c *ZkConn) maybeAddAuth(ctx context.Context) { // clears out the connection record. func (c *ZkConn) handleSessionEvents(conn *zk.Conn, session <-chan zk.Event) { for event := range session { - closeRequired := false switch event.State { - case zk.StateExpired, zk.StateConnecting: - closeRequired = true - fallthrough - case zk.StateDisconnected: + case zk.StateDisconnected, zk.StateExpired, zk.StateConnecting: c.mu.Lock() if c.conn == conn { // The ZkConn still references this @@ -341,9 +339,8 @@ func (c *ZkConn) handleSessionEvents(conn *zk.Conn, session <-chan zk.Event) { c.conn = nil } c.mu.Unlock() - if closeRequired { - conn.Close() - } + log.Infof("zk conn: got %v: closing", event.State) + conn.Close() log.Infof("zk conn: session for addr %v ended: %v", c.addr, event) return } diff --git a/go/vt/topo/zk2topo/zk_conn_test.go b/go/vt/topo/zk2topo/zk_conn_test.go new file mode 100644 index 00000000000..de66f7352ce --- /dev/null +++ b/go/vt/topo/zk2topo/zk_conn_test.go @@ -0,0 +1,65 @@ +/* +Copyright 2019 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package zk2topo + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "github.com/z-division/go-zookeeper/zk" + + "vitess.io/vitess/go/testfiles" + "vitess.io/vitess/go/vt/zkctl" +) + +func TestZkConnClosedOnDisconnect(t *testing.T) { + zkd, serverAddr := zkctl.StartLocalZk(testfiles.GoVtTopoZk2topoZkID, testfiles.GoVtTopoZk2topoPort) + defer zkd.Teardown() + + conn := Connect(serverAddr) + defer conn.Close() + + _, _, err := conn.Get(context.Background(), "/") + if err != nil { + t.Fatalf("Get() failed: %v", err) + } + + if !conn.conn.State().IsConnected() { + t.Fatalf("Connection not connected: %v", conn.conn.State()) + } + + oldConn := conn.conn + + // force a disconnect + zkd.Shutdown() + zkd.Start() + + // do another get to trigger a new connection + _, _, err = conn.Get(context.Background(), "/") + if err != nil { + t.Fatalf("Get() failed: %v", err) + } + + // Check that old connection is closed + _, _, err = oldConn.Get("/") + require.ErrorContains(t, err, "zookeeper is closing") + + if oldConn.State() != zk.StateDisconnected { + t.Fatalf("Connection is not in disconnected state: %v", oldConn.State()) + } +}