diff --git a/go/vt/topo/zk2topo/zk_conn_test.go b/go/vt/topo/zk2topo/zk_conn_test.go index 0912294fd52..de66f7352ce 100644 --- a/go/vt/topo/zk2topo/zk_conn_test.go +++ b/go/vt/topo/zk2topo/zk_conn_test.go @@ -20,6 +20,7 @@ import ( "context" "testing" + "github.com/stretchr/testify/require" "github.com/z-division/go-zookeeper/zk" "vitess.io/vitess/go/testfiles" @@ -31,6 +32,8 @@ func TestZkConnClosedOnDisconnect(t *testing.T) { defer zkd.Teardown() conn := Connect(serverAddr) + defer conn.Close() + _, _, err := conn.Get(context.Background(), "/") if err != nil { t.Fatalf("Get() failed: %v", err) @@ -42,7 +45,7 @@ func TestZkConnClosedOnDisconnect(t *testing.T) { oldConn := conn.conn - // simulate a disconnect + // force a disconnect zkd.Shutdown() zkd.Start() @@ -54,11 +57,9 @@ func TestZkConnClosedOnDisconnect(t *testing.T) { // Check that old connection is closed _, _, err = oldConn.Get("/") - if err == nil { - t.Fatalf("Get() should have failed: %v", err) - } + require.ErrorContains(t, err, "zookeeper is closing") if oldConn.State() != zk.StateDisconnected { - t.Fatalf("Connection not closed: %v", oldConn.State()) + t.Fatalf("Connection is not in disconnected state: %v", oldConn.State()) } }