Skip to content

Commit

Permalink
remove some use of errBadConnNoWrite
Browse files Browse the repository at this point in the history
  • Loading branch information
methane committed Nov 21, 2024
1 parent 41a5fa2 commit d366da3
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 35 deletions.
5 changes: 5 additions & 0 deletions buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ func newBuffer(nc net.Conn) buffer {
}
}

// busy retruns true if the buffer contains some read data.
func (b *buffer) busy() bool {
return b.length > 0
}

// flip replaces the active buffer with the background buffer
// this is a delayed flip that simply increases the buffer counter;
// the actual flip will be performed the next time we call `buffer.fill`
Expand Down
10 changes: 7 additions & 3 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,14 @@ func (mc *mysqlConn) Close() (err error) {
if !mc.closed.Load() {
err = mc.writeCommandPacket(comQuit)
}
mc.close()
return
}

// close closes the network connection and cleare results without sending COM_QUIT.
func (mc *mysqlConn) close() {
mc.cleanup()
mc.clearResult()
return
}

// Closes the network connection and unsets internal variables. Do not call this
Expand Down Expand Up @@ -637,7 +641,7 @@ func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) {
// ResetSession implements driver.SessionResetter.
// (From Go 1.10)
func (mc *mysqlConn) ResetSession(ctx context.Context) error {
if mc.closed.Load() {
if mc.closed.Load() || mc.buf.busy() {
return driver.ErrBadConn
}

Expand Down Expand Up @@ -671,7 +675,7 @@ func (mc *mysqlConn) ResetSession(ctx context.Context) error {
// IsValid implements driver.Validator interface
// (From Go 1.15)
func (mc *mysqlConn) IsValid() bool {
return !mc.closed.Load()
return !mc.closed.Load() && !mc.buf.busy()
}

var _ driver.SessionResetter = &mysqlConn{}
Expand Down
47 changes: 15 additions & 32 deletions packets.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
// read packet header
data, err := mc.buf.readNext(4)
if err != nil {
mc.close()
if cerr := mc.canceled.Value(); cerr != nil {
return nil, cerr
}
mc.log(err)
mc.Close()
return nil, ErrInvalidConn
}

Expand All @@ -45,7 +45,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {

// check packet sync [8 bit]
if data[3] != mc.sequence {
mc.Close()
mc.close()
if data[3] > mc.sequence {
return nil, ErrPktSyncMul
}
Expand All @@ -59,7 +59,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
// there was no previous packet
if prevData == nil {
mc.log(ErrMalformPkt)
mc.Close()
mc.close()
return nil, ErrInvalidConn
}

Expand All @@ -69,12 +69,11 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
// read packet body [pktLen bytes]
data, err = mc.buf.readNext(pktLen)
if err != nil {
mc.close()
if cerr := mc.canceled.Value(); cerr != nil {
return nil, cerr
}
mc.log(err)
mc.Close()
return nil, ErrInvalidConn
return nil, err
}

// return data if this was the last packet
Expand Down Expand Up @@ -125,10 +124,10 @@ func (mc *mysqlConn) writePacket(data []byte) error {

n, err := mc.netConn.Write(data[:4+size])
if err != nil {
mc.cleanup()
if cerr := mc.canceled.Value(); cerr != nil {
return cerr
}
mc.cleanup()
if n == 0 && pktLen == len(data)-4 {
// only for the first loop iteration when nothing was written yet
mc.log(err)
Expand Down Expand Up @@ -162,11 +161,6 @@ func (mc *mysqlConn) writePacket(data []byte) error {
func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err error) {
data, err = mc.readPacket()
if err != nil {
// for init we can rewrite this to ErrBadConn for sql.Driver to retry, since
// in connection initialization we don't risk retrying non-idempotent actions.
if err == ErrInvalidConn {
return nil, "", driver.ErrBadConn
}
return
}

Expand Down Expand Up @@ -312,9 +306,8 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
// Calculate packet length and get buffer with that size
data, err := mc.buf.takeBuffer(pktLen + 4)
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
mc.log(err)
return errBadConnNoWrite
mc.cleanup()
return err
}

// ClientFlags [32 bit]
Expand Down Expand Up @@ -404,9 +397,8 @@ func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error {
pktLen := 4 + len(authData)
data, err := mc.buf.takeBuffer(pktLen)
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
mc.log(err)
return errBadConnNoWrite
mc.cleanup()
return err
}

// Add the auth data [EOF]
Expand All @@ -424,9 +416,7 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error {

data, err := mc.buf.takeSmallBuffer(4 + 1)
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
mc.log(err)
return errBadConnNoWrite
return err
}

// Add command byte
Expand All @@ -443,9 +433,7 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {
pktLen := 1 + len(arg)
data, err := mc.buf.takeBuffer(pktLen + 4)
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
mc.log(err)
return errBadConnNoWrite
return err
}

// Add command byte
Expand All @@ -464,9 +452,7 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {

data, err := mc.buf.takeSmallBuffer(4 + 1 + 4)
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
mc.log(err)
return errBadConnNoWrite
return err
}

// Add command byte
Expand Down Expand Up @@ -1007,9 +993,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
// In this case the len(data) == cap(data) which is used to optimise the flow below.
}
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
mc.log(err)
return errBadConnNoWrite
return err
}

// command [1 byte]
Expand Down Expand Up @@ -1207,8 +1191,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
if valuesCap != cap(paramValues) {
data = append(data[:pos], paramValues...)
if err = mc.buf.store(data); err != nil {
mc.log(err)
return errBadConnNoWrite
return err
}
}

Expand Down

0 comments on commit d366da3

Please sign in to comment.