Skip to content

Commit

Permalink
remove errBadConnNoWrite and markBadConn
Browse files Browse the repository at this point in the history
  • Loading branch information
methane committed Nov 13, 2024
1 parent f62f523 commit 70d8617
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 48 deletions.
19 changes: 5 additions & 14 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,6 @@ func (mc *mysqlConn) handleParams() (err error) {
return
}

// markBadConn replaces errBadConnNoWrite with driver.ErrBadConn.
// This function is used to return driver.ErrBadConn only when safe to retry.
func (mc *mysqlConn) markBadConn(err error) error {
if err == errBadConnNoWrite {
return driver.ErrBadConn
}
return err
}

func (mc *mysqlConn) Begin() (driver.Tx, error) {
return mc.begin(false)
}
Expand All @@ -113,7 +104,7 @@ func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) {
if err == nil {
return &mysqlTx{mc}, err
}
return nil, mc.markBadConn(err)
return nil, err
}

func (mc *mysqlConn) Close() (err error) {
Expand Down Expand Up @@ -315,15 +306,15 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err
copied := mc.result
return &copied, err
}
return nil, mc.markBadConn(err)
return nil, err
}

// Internal function to execute commands
func (mc *mysqlConn) exec(query string) error {
handleOk := mc.clearResult()
// Send command
if err := mc.writeCommandPacketStr(comQuery, query); err != nil {
return mc.markBadConn(err)
return err
}

// Read Result
Expand Down Expand Up @@ -371,7 +362,7 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error)
// Send command
err := mc.writeCommandPacketStr(comQuery, query)
if err != nil {
return nil, mc.markBadConn(err)
return nil, err
}

// Read Result
Expand Down Expand Up @@ -462,7 +453,7 @@ func (mc *mysqlConn) Ping(ctx context.Context) (err error) {

handleOk := mc.clearResult()
if err = mc.writeCommandPacket(comPing); err != nil {
return mc.markBadConn(err)
return err
}

return handleOk.readResultOK()
Expand Down
8 changes: 4 additions & 4 deletions connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,8 @@ func TestPingMarkBadConnection(t *testing.T) {

err := mc.Ping(context.Background())

if err != driver.ErrBadConn {
t.Errorf("expected driver.ErrBadConn, got %#v", err)
if !errors.Is(err, nc.err) {
t.Errorf("expected %v, got %#v", nc.err, err)
}
}

Expand All @@ -186,8 +186,8 @@ func TestPingErrInvalidConn(t *testing.T) {

err := mc.Ping(context.Background())

if err != nc.err {
t.Errorf("expected %#v, got %#v", nc.err, err)
if !errors.Is(err, nc.err) {
t.Errorf("expected %v, got %#v", nc.err, err)
}
}

Expand Down
6 changes: 0 additions & 6 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,6 @@ var (
ErrPktSyncMul = errors.New("commands out of sync. Did you run multiple statements at once?")
ErrPktTooLarge = errors.New("packet for query is too large. Try adjusting the `Config.MaxAllowedPacket`")
ErrBusyBuffer = errors.New("busy buffer")

// errBadConnNoWrite is used for connection errors where nothing was sent to the database yet.
// If this happens first in a function starting a database interaction, it should be replaced by driver.ErrBadConn
// to trigger a resend. Use mc.markBadConn(err) to do this.
// See https://github.com/go-sql-driver/mysql/pull/302
errBadConnNoWrite = errors.New("bad connection")
)

var defaultLogger = Logger(log.New(os.Stderr, "[mysql] ", log.Ldate|log.Ltime))
Expand Down
38 changes: 16 additions & 22 deletions packets.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,17 +125,11 @@ func (mc *mysqlConn) writePacket(data []byte) error {

n, err := mc.netConn.Write(data[:4+size])
if err != nil {
mc.cleanup()
if cerr := mc.canceled.Value(); cerr != nil {
return cerr
}
mc.cleanup()
if n == 0 && pktLen == len(data)-4 {
// only for the first loop iteration when nothing was written yet
mc.log(err)
return errBadConnNoWrite
} else {
return err
}
return err
}
if n != 4+size {
// io.Writer(b) must return a non-nil error if it cannot write len(b) bytes.
Expand Down Expand Up @@ -305,8 +299,8 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
data, err := mc.buf.takeBuffer(pktLen + 4)
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
mc.log(err)
return errBadConnNoWrite
mc.cleanup()
return err
}

// ClientFlags [32 bit]
Expand Down Expand Up @@ -395,8 +389,8 @@ func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error {
data, err := mc.buf.takeBuffer(pktLen)
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
mc.log(err)
return errBadConnNoWrite
mc.cleanup()
return err
}

// Add the auth data [EOF]
Expand All @@ -415,8 +409,8 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error {
data, err := mc.buf.takeSmallBuffer(4 + 1)
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
mc.log(err)
return errBadConnNoWrite
mc.cleanup()
return err
}

// Add command byte
Expand All @@ -434,8 +428,8 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {
data, err := mc.buf.takeBuffer(pktLen + 4)
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
mc.log(err)
return errBadConnNoWrite
mc.cleanup()
return err
}

// Add command byte
Expand All @@ -455,8 +449,8 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {
data, err := mc.buf.takeSmallBuffer(4 + 1 + 4)
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
mc.log(err)
return errBadConnNoWrite
mc.cleanup()
return err
}

// Add command byte
Expand Down Expand Up @@ -998,8 +992,8 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
}
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
mc.log(err)
return errBadConnNoWrite
mc.cleanup()
return err
}

// command [1 byte]
Expand Down Expand Up @@ -1197,8 +1191,8 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
if valuesCap != cap(paramValues) {
data = append(data[:pos], paramValues...)
if err = mc.buf.store(data); err != nil {
mc.log(err)
return errBadConnNoWrite
mc.cleanup()
return err
}
}

Expand Down
4 changes: 2 additions & 2 deletions statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
// Send command
err := stmt.writeExecutePacket(args)
if err != nil {
return nil, stmt.mc.markBadConn(err)
return nil, err
}

mc := stmt.mc
Expand Down Expand Up @@ -99,7 +99,7 @@ func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) {
// Send command
err := stmt.writeExecutePacket(args)
if err != nil {
return nil, stmt.mc.markBadConn(err)
return nil, err
}

mc := stmt.mc
Expand Down

0 comments on commit 70d8617

Please sign in to comment.