From 7d18f3e6d04bff06be235159e6286cde6efb704f Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Thu, 21 Nov 2024 16:35:06 +0900 Subject: [PATCH] remove some use of errBadConnNoWrite --- buffer.go | 5 +++++ connection.go | 10 +++++++--- packets.go | 50 ++++++++++++++++--------------------------------- packets_test.go | 13 ++++++------- 4 files changed, 34 insertions(+), 44 deletions(-) diff --git a/buffer.go b/buffer.go index 0774c5c8c..2dc898724 100644 --- a/buffer.go +++ b/buffer.go @@ -43,6 +43,11 @@ func newBuffer(nc net.Conn) buffer { } } +// busy retruns true if the buffer contains some read data. +func (b *buffer) busy() bool { + return b.length > 0 +} + // flip replaces the active buffer with the background buffer // this is a delayed flip that simply increases the buffer counter; // the actual flip will be performed the next time we call `buffer.fill` diff --git a/connection.go b/connection.go index ef6fc9e40..95f05c92f 100644 --- a/connection.go +++ b/connection.go @@ -121,10 +121,14 @@ func (mc *mysqlConn) Close() (err error) { if !mc.closed.Load() { err = mc.writeCommandPacket(comQuit) } + mc.close() + return +} +// close closes the network connection and cleare results without sending COM_QUIT. +func (mc *mysqlConn) close() { mc.cleanup() mc.clearResult() - return } // Closes the network connection and unsets internal variables. Do not call this @@ -637,7 +641,7 @@ func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) { // ResetSession implements driver.SessionResetter. // (From Go 1.10) func (mc *mysqlConn) ResetSession(ctx context.Context) error { - if mc.closed.Load() { + if mc.closed.Load() || mc.buf.busy() { return driver.ErrBadConn } @@ -671,7 +675,7 @@ func (mc *mysqlConn) ResetSession(ctx context.Context) error { // IsValid implements driver.Validator interface // (From Go 1.15) func (mc *mysqlConn) IsValid() bool { - return !mc.closed.Load() + return !mc.closed.Load() && !mc.buf.busy() } var _ driver.SessionResetter = &mysqlConn{} diff --git a/packets.go b/packets.go index a2e7ef95c..56e5c2e04 100644 --- a/packets.go +++ b/packets.go @@ -32,12 +32,11 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { // read packet header data, err := mc.buf.readNext(4) if err != nil { + mc.close() if cerr := mc.canceled.Value(); cerr != nil { return nil, cerr } - mc.log(err) - mc.Close() - return nil, ErrInvalidConn + return nil, err } // packet length [24 bit] @@ -45,7 +44,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { // check packet sync [8 bit] if data[3] != mc.sequence { - mc.Close() + mc.close() if data[3] > mc.sequence { return nil, ErrPktSyncMul } @@ -59,7 +58,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { // there was no previous packet if prevData == nil { mc.log(ErrMalformPkt) - mc.Close() + mc.close() return nil, ErrInvalidConn } @@ -69,12 +68,11 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { // read packet body [pktLen bytes] data, err = mc.buf.readNext(pktLen) if err != nil { + mc.close() if cerr := mc.canceled.Value(); cerr != nil { return nil, cerr } - mc.log(err) - mc.Close() - return nil, ErrInvalidConn + return nil, err } // return data if this was the last packet @@ -125,10 +123,10 @@ func (mc *mysqlConn) writePacket(data []byte) error { n, err := mc.netConn.Write(data[:4+size]) if err != nil { + mc.cleanup() if cerr := mc.canceled.Value(); cerr != nil { return cerr } - mc.cleanup() if n == 0 && pktLen == len(data)-4 { // only for the first loop iteration when nothing was written yet mc.log(err) @@ -162,11 +160,6 @@ func (mc *mysqlConn) writePacket(data []byte) error { func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err error) { data, err = mc.readPacket() if err != nil { - // for init we can rewrite this to ErrBadConn for sql.Driver to retry, since - // in connection initialization we don't risk retrying non-idempotent actions. - if err == ErrInvalidConn { - return nil, "", driver.ErrBadConn - } return } @@ -312,9 +305,8 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string // Calculate packet length and get buffer with that size data, err := mc.buf.takeBuffer(pktLen + 4) if err != nil { - // cannot take the buffer. Something must be wrong with the connection - mc.log(err) - return errBadConnNoWrite + mc.cleanup() + return err } // ClientFlags [32 bit] @@ -404,9 +396,8 @@ func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error { pktLen := 4 + len(authData) data, err := mc.buf.takeBuffer(pktLen) if err != nil { - // cannot take the buffer. Something must be wrong with the connection - mc.log(err) - return errBadConnNoWrite + mc.cleanup() + return err } // Add the auth data [EOF] @@ -424,9 +415,7 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error { data, err := mc.buf.takeSmallBuffer(4 + 1) if err != nil { - // cannot take the buffer. Something must be wrong with the connection - mc.log(err) - return errBadConnNoWrite + return err } // Add command byte @@ -443,9 +432,7 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { pktLen := 1 + len(arg) data, err := mc.buf.takeBuffer(pktLen + 4) if err != nil { - // cannot take the buffer. Something must be wrong with the connection - mc.log(err) - return errBadConnNoWrite + return err } // Add command byte @@ -464,9 +451,7 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { data, err := mc.buf.takeSmallBuffer(4 + 1 + 4) if err != nil { - // cannot take the buffer. Something must be wrong with the connection - mc.log(err) - return errBadConnNoWrite + return err } // Add command byte @@ -1007,9 +992,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { // In this case the len(data) == cap(data) which is used to optimise the flow below. } if err != nil { - // cannot take the buffer. Something must be wrong with the connection - mc.log(err) - return errBadConnNoWrite + return err } // command [1 byte] @@ -1207,8 +1190,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { if valuesCap != cap(paramValues) { data = append(data[:pos], paramValues...) if err = mc.buf.store(data); err != nil { - mc.log(err) - return errBadConnNoWrite + return err } } diff --git a/packets_test.go b/packets_test.go index fa4683eab..ec1c01b8d 100644 --- a/packets_test.go +++ b/packets_test.go @@ -17,7 +17,6 @@ import ( ) var ( - errConnClosed = errors.New("connection is closed") errConnTooManyReads = errors.New("too many reads") errConnTooManyWrites = errors.New("too many writes") ) @@ -39,7 +38,7 @@ type mockConn struct { func (m *mockConn) Read(b []byte) (n int, err error) { if m.closed { - return 0, errConnClosed + return 0, net.ErrClosed } m.reads++ @@ -54,7 +53,7 @@ func (m *mockConn) Read(b []byte) (n int, err error) { } func (m *mockConn) Write(b []byte) (n int, err error) { if m.closed { - return 0, errConnClosed + return 0, net.ErrClosed } m.writes++ @@ -290,8 +289,8 @@ func TestReadPacketFail(t *testing.T) { // fail to read header conn.closed = true _, err = mc.readPacket() - if err != ErrInvalidConn { - t.Errorf("expected ErrInvalidConn, got %v", err) + if err != net.ErrClosed { + t.Errorf("expected ErrClosed, got %v", err) } // reset @@ -303,8 +302,8 @@ func TestReadPacketFail(t *testing.T) { // fail to read body conn.maxReads = 1 _, err = mc.readPacket() - if err != ErrInvalidConn { - t.Errorf("expected ErrInvalidConn, got %v", err) + if err != errConnTooManyReads { + t.Errorf("expected errConnTooManyReads, got %#v", err) } }