diff --git a/connection.go b/connection.go index ef6fc9e4..140adb34 100644 --- a/connection.go +++ b/connection.go @@ -86,15 +86,6 @@ func (mc *mysqlConn) handleParams() (err error) { return } -// markBadConn replaces errBadConnNoWrite with driver.ErrBadConn. -// This function is used to return driver.ErrBadConn only when safe to retry. -func (mc *mysqlConn) markBadConn(err error) error { - if err == errBadConnNoWrite { - return driver.ErrBadConn - } - return err -} - func (mc *mysqlConn) Begin() (driver.Tx, error) { return mc.begin(false) } @@ -113,7 +104,7 @@ func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) { if err == nil { return &mysqlTx{mc}, err } - return nil, mc.markBadConn(err) + return nil, err } func (mc *mysqlConn) Close() (err error) { @@ -315,7 +306,7 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err copied := mc.result return &copied, err } - return nil, mc.markBadConn(err) + return nil, err } // Internal function to execute commands @@ -323,7 +314,7 @@ func (mc *mysqlConn) exec(query string) error { handleOk := mc.clearResult() // Send command if err := mc.writeCommandPacketStr(comQuery, query); err != nil { - return mc.markBadConn(err) + return err } // Read Result @@ -371,7 +362,7 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) // Send command err := mc.writeCommandPacketStr(comQuery, query) if err != nil { - return nil, mc.markBadConn(err) + return nil, err } // Read Result @@ -462,7 +453,7 @@ func (mc *mysqlConn) Ping(ctx context.Context) (err error) { handleOk := mc.clearResult() if err = mc.writeCommandPacket(comPing); err != nil { - return mc.markBadConn(err) + return err } return handleOk.readResultOK() diff --git a/connection_test.go b/connection_test.go index 6f8d2a6d..015190b7 100644 --- a/connection_test.go +++ b/connection_test.go @@ -157,23 +157,6 @@ func TestCleanCancel(t *testing.T) { } } -func TestPingMarkBadConnection(t *testing.T) { - nc := badConnection{err: errors.New("boom")} - mc := &mysqlConn{ - netConn: nc, - buf: newBuffer(nc), - maxAllowedPacket: defaultMaxAllowedPacket, - closech: make(chan struct{}), - cfg: NewConfig(), - } - - err := mc.Ping(context.Background()) - - if err != driver.ErrBadConn { - t.Errorf("expected driver.ErrBadConn, got %#v", err) - } -} - func TestPingErrInvalidConn(t *testing.T) { nc := badConnection{err: errors.New("failed to write"), n: 10} mc := &mysqlConn{ @@ -187,7 +170,7 @@ func TestPingErrInvalidConn(t *testing.T) { err := mc.Ping(context.Background()) if err != nc.err { - t.Errorf("expected %#v, got %#v", nc.err, err) + t.Errorf("expected %v, got %#v", nc.err, err) } } diff --git a/errors.go b/errors.go index 584617b1..dcd38905 100644 --- a/errors.go +++ b/errors.go @@ -29,12 +29,6 @@ var ( ErrPktSyncMul = errors.New("commands out of sync. Did you run multiple statements at once?") ErrPktTooLarge = errors.New("packet for query is too large. Try adjusting the `Config.MaxAllowedPacket`") ErrBusyBuffer = errors.New("busy buffer") - - // errBadConnNoWrite is used for connection errors where nothing was sent to the database yet. - // If this happens first in a function starting a database interaction, it should be replaced by driver.ErrBadConn - // to trigger a resend. Use mc.markBadConn(err) to do this. - // See https://github.com/go-sql-driver/mysql/pull/302 - errBadConnNoWrite = errors.New("bad connection") ) var defaultLogger = Logger(log.New(os.Stderr, "[mysql] ", log.Ldate|log.Ltime)) diff --git a/packets.go b/packets.go index eb4e0cef..0dab3837 100644 --- a/packets.go +++ b/packets.go @@ -125,17 +125,11 @@ func (mc *mysqlConn) writePacket(data []byte) error { n, err := mc.netConn.Write(data[:4+size]) if err != nil { + mc.cleanup() if cerr := mc.canceled.Value(); cerr != nil { return cerr } - mc.cleanup() - if n == 0 && pktLen == len(data)-4 { - // only for the first loop iteration when nothing was written yet - mc.log(err) - return errBadConnNoWrite - } else { - return err - } + return err } if n != 4+size { // io.Writer(b) must return a non-nil error if it cannot write len(b) bytes. @@ -305,8 +299,8 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string data, err := mc.buf.takeBuffer(pktLen + 4) if err != nil { // cannot take the buffer. Something must be wrong with the connection - mc.log(err) - return errBadConnNoWrite + mc.cleanup() + return err } // ClientFlags [32 bit] @@ -395,8 +389,8 @@ func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error { data, err := mc.buf.takeBuffer(pktLen) if err != nil { // cannot take the buffer. Something must be wrong with the connection - mc.log(err) - return errBadConnNoWrite + mc.cleanup() + return err } // Add the auth data [EOF] @@ -415,8 +409,8 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error { data, err := mc.buf.takeSmallBuffer(4 + 1) if err != nil { // cannot take the buffer. Something must be wrong with the connection - mc.log(err) - return errBadConnNoWrite + mc.cleanup() + return err } // Add command byte @@ -434,8 +428,8 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { data, err := mc.buf.takeBuffer(pktLen + 4) if err != nil { // cannot take the buffer. Something must be wrong with the connection - mc.log(err) - return errBadConnNoWrite + mc.cleanup() + return err } // Add command byte @@ -455,8 +449,8 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { data, err := mc.buf.takeSmallBuffer(4 + 1 + 4) if err != nil { // cannot take the buffer. Something must be wrong with the connection - mc.log(err) - return errBadConnNoWrite + mc.cleanup() + return err } // Add command byte @@ -998,8 +992,8 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { } if err != nil { // cannot take the buffer. Something must be wrong with the connection - mc.log(err) - return errBadConnNoWrite + mc.cleanup() + return err } // command [1 byte] @@ -1197,8 +1191,8 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { if valuesCap != cap(paramValues) { data = append(data[:pos], paramValues...) if err = mc.buf.store(data); err != nil { - mc.log(err) - return errBadConnNoWrite + mc.cleanup() + return err } } diff --git a/statement.go b/statement.go index 35b02bbe..321846d5 100644 --- a/statement.go +++ b/statement.go @@ -56,7 +56,7 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { // Send command err := stmt.writeExecutePacket(args) if err != nil { - return nil, stmt.mc.markBadConn(err) + return nil, err } mc := stmt.mc @@ -99,7 +99,7 @@ func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) { // Send command err := stmt.writeExecutePacket(args) if err != nil { - return nil, stmt.mc.markBadConn(err) + return nil, err } mc := stmt.mc