diff --git a/connection.go b/connection.go index 2b19c927..e181b4e1 100644 --- a/connection.go +++ b/connection.go @@ -111,15 +111,6 @@ func (mc *mysqlConn) handleParams() (err error) { return } -// markBadConn replaces errBadConnNoWrite with driver.ErrBadConn. -// This function is used to return driver.ErrBadConn only when safe to retry. -func (mc *mysqlConn) markBadConn(err error) error { - if err == errBadConnNoWrite { - return driver.ErrBadConn - } - return err -} - func (mc *mysqlConn) Begin() (driver.Tx, error) { return mc.begin(false) } @@ -138,7 +129,7 @@ func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) { if err == nil { return &mysqlTx{mc}, err } - return nil, mc.markBadConn(err) + return nil, err } func (mc *mysqlConn) Close() (err error) { @@ -340,7 +331,7 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err copied := mc.result return &copied, err } - return nil, mc.markBadConn(err) + return nil, err } // Internal function to execute commands @@ -348,7 +339,7 @@ func (mc *mysqlConn) exec(query string) error { handleOk := mc.clearResult() // Send command if err := mc.writeCommandPacketStr(comQuery, query); err != nil { - return mc.markBadConn(err) + return err } // Read Result @@ -378,10 +369,10 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) { handleOk := mc.clearResult() - if mc.closed.Load() { return nil, driver.ErrBadConn } + if len(args) != 0 { if !mc.cfg.InterpolateParams { return nil, driver.ErrSkip @@ -393,10 +384,11 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) } query = prepared } + // Send command err := mc.writeCommandPacketStr(comQuery, query) if err != nil { - return nil, mc.markBadConn(err) + return nil, err } // Read Result @@ -487,7 +479,7 @@ func (mc *mysqlConn) Ping(ctx context.Context) (err error) { handleOk := mc.clearResult() if err = mc.writeCommandPacket(comPing); err != nil { - return mc.markBadConn(err) + return err } return handleOk.readResultOK() diff --git a/connection_test.go b/connection_test.go index c59cb617..c866cee3 100644 --- a/connection_test.go +++ b/connection_test.go @@ -163,12 +163,13 @@ func TestPingMarkBadConnection(t *testing.T) { netConn: nc, buf: newBuffer(nc), maxAllowedPacket: defaultMaxAllowedPacket, + closech: make(chan struct{}), } err := mc.Ping(context.Background()) - if err != driver.ErrBadConn { - t.Errorf("expected driver.ErrBadConn, got %#v", err) + if !errors.Is(err, nc.err) { + t.Errorf("expected %v, got %#v", nc.err, err) } } @@ -184,8 +185,8 @@ func TestPingErrInvalidConn(t *testing.T) { err := mc.Ping(context.Background()) - if err != ErrInvalidConn { - t.Errorf("expected ErrInvalidConn, got %#v", err) + if !errors.Is(err, nc.err) { + t.Errorf("expected %v, got %#v", nc.err, err) } } diff --git a/errors.go b/errors.go index 584617b1..dcd38905 100644 --- a/errors.go +++ b/errors.go @@ -29,12 +29,6 @@ var ( ErrPktSyncMul = errors.New("commands out of sync. Did you run multiple statements at once?") ErrPktTooLarge = errors.New("packet for query is too large. Try adjusting the `Config.MaxAllowedPacket`") ErrBusyBuffer = errors.New("busy buffer") - - // errBadConnNoWrite is used for connection errors where nothing was sent to the database yet. - // If this happens first in a function starting a database interaction, it should be replaced by driver.ErrBadConn - // to trigger a resend. Use mc.markBadConn(err) to do this. - // See https://github.com/go-sql-driver/mysql/pull/302 - errBadConnNoWrite = errors.New("bad connection") ) var defaultLogger = Logger(log.New(os.Stderr, "[mysql] ", log.Ldate|log.Ltime)) diff --git a/packets.go b/packets.go index b90b14c5..ec62bbe3 100644 --- a/packets.go +++ b/packets.go @@ -117,39 +117,33 @@ func (mc *mysqlConn) writePacket(data []byte) error { // Write packet if mc.writeTimeout > 0 { if err := mc.netConn.SetWriteDeadline(time.Now().Add(mc.writeTimeout)); err != nil { - mc.cleanup() mc.log(err) + mc.cleanup() return err } } n, err := mc.netConn.Write(data[:4+size]) - if err == nil && n == 4+size { - mc.sequence++ - if size != maxPacketSize { - return nil - } - pktLen -= size - data = data[size:] - continue - } - - // Handle error - if err == nil { // n != len(data) + if err != nil { mc.cleanup() - mc.log(ErrMalformPkt) - } else { if cerr := mc.canceled.Value(); cerr != nil { return cerr } - if n == 0 && pktLen == len(data)-4 { - // only for the first loop iteration when nothing was written yet - return errBadConnNoWrite - } + return err + } + if n != 4+size { + // io.Writer(b) must return a non-nil error if it cannot write len(b) bytes. + // The io.ErrShortWrite error is used to indicate that this rule has not been followed. mc.cleanup() - mc.log(err) + return io.ErrShortWrite } - return ErrInvalidConn + + mc.sequence++ + if size != maxPacketSize { + return nil + } + pktLen -= size + data = data[size:] } } @@ -305,8 +299,8 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string data, err := mc.buf.takeBuffer(pktLen + 4) if err != nil { // cannot take the buffer. Something must be wrong with the connection - mc.log(err) - return errBadConnNoWrite + mc.cleanup() + return err } // ClientFlags [32 bit] @@ -394,8 +388,8 @@ func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error { data, err := mc.buf.takeSmallBuffer(pktLen) if err != nil { // cannot take the buffer. Something must be wrong with the connection - mc.log(err) - return errBadConnNoWrite + mc.cleanup() + return err } // Add the auth data [EOF] @@ -414,8 +408,8 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error { data, err := mc.buf.takeSmallBuffer(4 + 1) if err != nil { // cannot take the buffer. Something must be wrong with the connection - mc.log(err) - return errBadConnNoWrite + mc.cleanup() + return err } // Add command byte @@ -433,8 +427,8 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { data, err := mc.buf.takeBuffer(pktLen + 4) if err != nil { // cannot take the buffer. Something must be wrong with the connection - mc.log(err) - return errBadConnNoWrite + mc.cleanup() + return err } // Add command byte @@ -454,8 +448,8 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { data, err := mc.buf.takeSmallBuffer(4 + 1 + 4) if err != nil { // cannot take the buffer. Something must be wrong with the connection - mc.log(err) - return errBadConnNoWrite + mc.cleanup() + return err } // Add command byte @@ -997,8 +991,8 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { } if err != nil { // cannot take the buffer. Something must be wrong with the connection - mc.log(err) - return errBadConnNoWrite + mc.cleanup() + return err } // command [1 byte] @@ -1196,8 +1190,8 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { if valuesCap != cap(paramValues) { data = append(data[:pos], paramValues...) if err = mc.buf.store(data); err != nil { - mc.log(err) - return errBadConnNoWrite + mc.cleanup() + return err } } diff --git a/statement.go b/statement.go index 35b02bbe..321846d5 100644 --- a/statement.go +++ b/statement.go @@ -56,7 +56,7 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { // Send command err := stmt.writeExecutePacket(args) if err != nil { - return nil, stmt.mc.markBadConn(err) + return nil, err } mc := stmt.mc @@ -99,7 +99,7 @@ func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) { // Send command err := stmt.writeExecutePacket(args) if err != nil { - return nil, stmt.mc.markBadConn(err) + return nil, err } mc := stmt.mc