From a443dbff0ba07eb26311b31bb8220d0de0b0ea49 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Putra?= Date: Fri, 23 Sep 2022 11:51:32 +0200 Subject: [PATCH] transport: WrapConn, close connection when initialization fails WrapConn didn't close connReader and connWriter loops, leaving this responsibility to its callers, potentially leading to goroutine leaks, this is mitigated by closing those loops and never returning a non-nil Conn pointer when WrapConn fails. Fixes #289 --- transport/cluster.go | 4 +- transport/conn.go | 6 +-- transport/conn_integration_test.go | 63 ++++++++++++++++++++++++++++++ transport/pool.go | 6 --- 4 files changed, 66 insertions(+), 13 deletions(-) diff --git a/transport/cluster.go b/transport/cluster.go index 15f94530..f7a724b5 100644 --- a/transport/cluster.go +++ b/transport/cluster.go @@ -167,14 +167,12 @@ func (c *Cluster) NewControl(ctx context.Context) (*Conn, error) { if err := conn.RegisterEventHandler(ctx, c.handleEvent, c.handledEvents...); err == nil { return conn, nil } else { + conn.Close() errs = append(errs, fmt.Sprintf("%s failed to register for events: %s", conn, err)) } } else { errs = append(errs, fmt.Sprintf("%s failed to connect: %s", addr, err)) } - if conn != nil { - conn.Close() - } } return nil, fmt.Errorf("couldn't open control connection to any known host:\n%s", strings.Join(errs, "\n")) diff --git a/transport/conn.go b/transport/conn.go index 290fefb1..4ed962df 100644 --- a/transport/conn.go +++ b/transport/conn.go @@ -385,9 +385,6 @@ func OpenShardConn(ctx context.Context, addr string, si ShardInfo, cfg ConnConfi conn, err := OpenLocalPortConn(ctx, addr, it(), cfg) if err != nil { cfg.Logger.Infof("%s dial error: %s (try %d/%d)", addr, err, i, maxTries) - if conn != nil { - conn.Close() - } continue } return conn, nil @@ -506,7 +503,8 @@ func WrapConn(ctx context.Context, conn net.Conn, cfg ConnConfig) (*Conn, error) go c.r.loop(ctx) if err := c.init(ctx); err != nil { - return c, err + c.Close() + return nil, err } return c, nil diff --git a/transport/conn_integration_test.go b/transport/conn_integration_test.go index a161ca61..875c740a 100644 --- a/transport/conn_integration_test.go +++ b/transport/conn_integration_test.go @@ -6,6 +6,7 @@ import ( "context" "fmt" "math/rand" + "net" "os/signal" "strconv" "sync" @@ -278,3 +279,65 @@ func testCompression(ctx context.Context, t *testing.T, c frame.Compression, toS } } } + +func TestConnectedToNonCqlServer(t *testing.T) { + ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGABRT, syscall.SIGTERM) + defer cancel() + + t.Logf("%+v", testingConnConfig) + testCases := []struct { + name string + response []byte + }{ + { + name: "non-cql response", + response: []byte("0"), + }, + { + name: "non supported cql response", + response: func() []byte { + var buf frame.Buffer + frame := frame.Header{ + Version: frame.CQLv4, + OpCode: frame.OpReady, + } + + frame.WriteTo(&buf) + return buf.Bytes() + }(), + }, + } + + for i := 0; i < len(testCases); i++ { + tc := testCases[i] + t.Run(tc.name, func(t *testing.T) { + server, err := net.Listen("tcp", "127.0.0.1:") + if err != nil { + t.Fatal(err) + } + defer server.Close() + go func() { + conn, err := server.Accept() + if err != nil { + t.Log(err) + t.Fail() + return + } + go func(conn net.Conn) { + defer conn.Close() + conn.Write(tc.response) + }(conn) + }() + + addr := server.Addr().String() + conn, err := OpenConn(ctx, addr, nil, testingConnConfig) + if err == nil { + t.Fatal("connecting to non-cql server should fail") + } + t.Log(err) + if conn != nil { + t.Fatal("connecting to non-cql server should return a nil-conn") + } + }) + } +} diff --git a/transport/pool.go b/transport/pool.go index b15ec62a..adadb984 100644 --- a/transport/pool.go +++ b/transport/pool.go @@ -141,9 +141,6 @@ func (r *PoolRefiller) init(ctx context.Context, host string) error { conn, err := OpenConn(ctx, host, nil, r.cfg) span.stop() if err != nil { - if conn != nil { - conn.Close() - } return err } @@ -245,9 +242,6 @@ func (r *PoolRefiller) fill(ctx context.Context) { if r.pool.connObs != nil { r.pool.connObs.OnConnect(ConnectEvent{ConnEvent: ConnEvent{Addr: r.addr, Shard: si.Shard}, span: span, Err: err}) } - if conn != nil { - conn.Close() - } continue } if r.pool.connObs != nil {